优化修正文档检测中的部分语法
This commit is contained in:
4
object_detection/core/external/setup.py
vendored
4
object_detection/core/external/setup.py
vendored
@@ -8,12 +8,12 @@ extensions = [
|
|||||||
Extension(
|
Extension(
|
||||||
"bbox",
|
"bbox",
|
||||||
["bbox.pyx"],
|
["bbox.pyx"],
|
||||||
extra_compile_args=["-Wno-cpp", "-Wno-unused-function"]
|
extra_compile_args=[]
|
||||||
),
|
),
|
||||||
Extension(
|
Extension(
|
||||||
"nms",
|
"nms",
|
||||||
["nms.pyx"],
|
["nms.pyx"],
|
||||||
extra_compile_args=["-Wno-cpp", "-Wno-unused-function"]
|
extra_compile_args=[]
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@@ -119,7 +119,7 @@ class NetworkFactory(object):
|
|||||||
def load_pretrained_params(self, pretrained_model):
|
def load_pretrained_params(self, pretrained_model):
|
||||||
print("loading from {}".format(pretrained_model))
|
print("loading from {}".format(pretrained_model))
|
||||||
with open(pretrained_model, "rb") as f:
|
with open(pretrained_model, "rb") as f:
|
||||||
params = torch.load(f)
|
params = torch.load(f, weights_only=False)
|
||||||
self.model.load_state_dict(params)
|
self.model.load_state_dict(params)
|
||||||
|
|
||||||
def load_params(self, iteration):
|
def load_params(self, iteration):
|
||||||
|
|||||||
@@ -233,7 +233,7 @@ def prepare_images(db, image, locs, flipped=True):
|
|||||||
input_size = db.configs["input_size"]
|
input_size = db.configs["input_size"]
|
||||||
num_patches = locs.shape[0]
|
num_patches = locs.shape[0]
|
||||||
|
|
||||||
images = torch.cuda.FloatTensor(num_patches, 3, input_size[0], input_size[1]).fill_(0)
|
images = torch.zeros((num_patches, 3, input_size[0], input_size[1]), dtype=torch.float32, device='cuda')
|
||||||
offsets = np.zeros((num_patches, 2), dtype=np.float32)
|
offsets = np.zeros((num_patches, 2), dtype=np.float32)
|
||||||
for ind, (y, x, scale) in enumerate(locs[:, :3]):
|
for ind, (y, x, scale) in enumerate(locs[:, :3]):
|
||||||
crop_height = int(input_size[0] / scale)
|
crop_height = int(input_size[0] / scale)
|
||||||
@@ -319,10 +319,9 @@ def cornernet_saccade_inference(db, nnet, image, decode_func=batch_decode):
|
|||||||
|
|
||||||
num_iterations = len(att_thresholds)
|
num_iterations = len(att_thresholds)
|
||||||
|
|
||||||
im_mean = torch.cuda.FloatTensor(db.mean).reshape(1, 3, 1, 1)
|
im_mean = torch.tensor(db.mean, dtype=torch.float32, device='cuda').reshape(1, 3, 1, 1)
|
||||||
im_std = torch.cuda.FloatTensor(db.std).reshape(1, 3, 1, 1)
|
im_std = torch.tensor(db.std, dtype=torch.float32, device='cuda').reshape(1, 3, 1, 1)
|
||||||
|
|
||||||
detections = []
|
|
||||||
height, width = image.shape[0:2]
|
height, width = image.shape[0:2]
|
||||||
|
|
||||||
image = image / 255.
|
image = image / 255.
|
||||||
|
|||||||
Reference in New Issue
Block a user