优化修正文档检测中的部分语法
This commit is contained in:
@@ -233,7 +233,7 @@ def prepare_images(db, image, locs, flipped=True):
|
||||
input_size = db.configs["input_size"]
|
||||
num_patches = locs.shape[0]
|
||||
|
||||
images = torch.cuda.FloatTensor(num_patches, 3, input_size[0], input_size[1]).fill_(0)
|
||||
images = torch.zeros((num_patches, 3, input_size[0], input_size[1]), dtype=torch.float32, device='cuda')
|
||||
offsets = np.zeros((num_patches, 2), dtype=np.float32)
|
||||
for ind, (y, x, scale) in enumerate(locs[:, :3]):
|
||||
crop_height = int(input_size[0] / scale)
|
||||
@@ -319,10 +319,9 @@ def cornernet_saccade_inference(db, nnet, image, decode_func=batch_decode):
|
||||
|
||||
num_iterations = len(att_thresholds)
|
||||
|
||||
im_mean = torch.cuda.FloatTensor(db.mean).reshape(1, 3, 1, 1)
|
||||
im_std = torch.cuda.FloatTensor(db.std).reshape(1, 3, 1, 1)
|
||||
im_mean = torch.tensor(db.mean, dtype=torch.float32, device='cuda').reshape(1, 3, 1, 1)
|
||||
im_std = torch.tensor(db.std, dtype=torch.float32, device='cuda').reshape(1, 3, 1, 1)
|
||||
|
||||
detections = []
|
||||
height, width = image.shape[0:2]
|
||||
|
||||
image = image / 255.
|
||||
|
||||
Reference in New Issue
Block a user