From 896d2aaf9b78980cc20036884e31918df48e3ce2 Mon Sep 17 00:00:00 2001 From: liuyebo <1515783401@qq.com> Date: Tue, 20 Aug 2024 14:27:51 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96=E4=BF=AE=E6=AD=A3=E6=96=87?= =?UTF-8?q?=E6=A1=A3=E6=A3=80=E6=B5=8B=E4=B8=AD=E7=9A=84=E9=83=A8=E5=88=86?= =?UTF-8?q?=E8=AF=AD=E6=B3=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- object_detection/core/external/setup.py | 4 ++-- object_detection/core/nnet/py_factory.py | 2 +- object_detection/core/test/cornernet_saccade.py | 7 +++---- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/object_detection/core/external/setup.py b/object_detection/core/external/setup.py index cf55a25..f4af595 100644 --- a/object_detection/core/external/setup.py +++ b/object_detection/core/external/setup.py @@ -8,12 +8,12 @@ extensions = [ Extension( "bbox", ["bbox.pyx"], - extra_compile_args=["-Wno-cpp", "-Wno-unused-function"] + extra_compile_args=[] ), Extension( "nms", ["nms.pyx"], - extra_compile_args=["-Wno-cpp", "-Wno-unused-function"] + extra_compile_args=[] ) ] diff --git a/object_detection/core/nnet/py_factory.py b/object_detection/core/nnet/py_factory.py index ead8da0..70c6e03 100644 --- a/object_detection/core/nnet/py_factory.py +++ b/object_detection/core/nnet/py_factory.py @@ -119,7 +119,7 @@ class NetworkFactory(object): def load_pretrained_params(self, pretrained_model): print("loading from {}".format(pretrained_model)) with open(pretrained_model, "rb") as f: - params = torch.load(f) + params = torch.load(f, weights_only=False) self.model.load_state_dict(params) def load_params(self, iteration): diff --git a/object_detection/core/test/cornernet_saccade.py b/object_detection/core/test/cornernet_saccade.py index 75adc2d..75b978d 100644 --- a/object_detection/core/test/cornernet_saccade.py +++ b/object_detection/core/test/cornernet_saccade.py @@ -233,7 +233,7 @@ def prepare_images(db, image, locs, flipped=True): input_size = db.configs["input_size"] num_patches = locs.shape[0] - images = torch.cuda.FloatTensor(num_patches, 3, input_size[0], input_size[1]).fill_(0) + images = torch.zeros((num_patches, 3, input_size[0], input_size[1]), dtype=torch.float32, device='cuda') offsets = np.zeros((num_patches, 2), dtype=np.float32) for ind, (y, x, scale) in enumerate(locs[:, :3]): crop_height = int(input_size[0] / scale) @@ -319,10 +319,9 @@ def cornernet_saccade_inference(db, nnet, image, decode_func=batch_decode): num_iterations = len(att_thresholds) - im_mean = torch.cuda.FloatTensor(db.mean).reshape(1, 3, 1, 1) - im_std = torch.cuda.FloatTensor(db.std).reshape(1, 3, 1, 1) + im_mean = torch.tensor(db.mean, dtype=torch.float32, device='cuda').reshape(1, 3, 1, 1) + im_std = torch.tensor(db.std, dtype=torch.float32, device='cuda').reshape(1, 3, 1, 1) - detections = [] height, width = image.shape[0:2] image = image / 255.