优化修正文档检测中的部分语法

This commit is contained in:
2024-08-20 14:27:51 +08:00
parent 299b762cad
commit 896d2aaf9b
3 changed files with 6 additions and 7 deletions

View File

@@ -8,12 +8,12 @@ extensions = [
Extension( Extension(
"bbox", "bbox",
["bbox.pyx"], ["bbox.pyx"],
extra_compile_args=["-Wno-cpp", "-Wno-unused-function"] extra_compile_args=[]
), ),
Extension( Extension(
"nms", "nms",
["nms.pyx"], ["nms.pyx"],
extra_compile_args=["-Wno-cpp", "-Wno-unused-function"] extra_compile_args=[]
) )
] ]

View File

@@ -119,7 +119,7 @@ class NetworkFactory(object):
def load_pretrained_params(self, pretrained_model): def load_pretrained_params(self, pretrained_model):
print("loading from {}".format(pretrained_model)) print("loading from {}".format(pretrained_model))
with open(pretrained_model, "rb") as f: with open(pretrained_model, "rb") as f:
params = torch.load(f) params = torch.load(f, weights_only=False)
self.model.load_state_dict(params) self.model.load_state_dict(params)
def load_params(self, iteration): def load_params(self, iteration):

View File

@@ -233,7 +233,7 @@ def prepare_images(db, image, locs, flipped=True):
input_size = db.configs["input_size"] input_size = db.configs["input_size"]
num_patches = locs.shape[0] num_patches = locs.shape[0]
images = torch.cuda.FloatTensor(num_patches, 3, input_size[0], input_size[1]).fill_(0) images = torch.zeros((num_patches, 3, input_size[0], input_size[1]), dtype=torch.float32, device='cuda')
offsets = np.zeros((num_patches, 2), dtype=np.float32) offsets = np.zeros((num_patches, 2), dtype=np.float32)
for ind, (y, x, scale) in enumerate(locs[:, :3]): for ind, (y, x, scale) in enumerate(locs[:, :3]):
crop_height = int(input_size[0] / scale) crop_height = int(input_size[0] / scale)
@@ -319,10 +319,9 @@ def cornernet_saccade_inference(db, nnet, image, decode_func=batch_decode):
num_iterations = len(att_thresholds) num_iterations = len(att_thresholds)
im_mean = torch.cuda.FloatTensor(db.mean).reshape(1, 3, 1, 1) im_mean = torch.tensor(db.mean, dtype=torch.float32, device='cuda').reshape(1, 3, 1, 1)
im_std = torch.cuda.FloatTensor(db.std).reshape(1, 3, 1, 1) im_std = torch.tensor(db.std, dtype=torch.float32, device='cuda').reshape(1, 3, 1, 1)
detections = []
height, width = image.shape[0:2] height, width = image.shape[0:2]
image = image / 255. image = image / 255.