移动paddle_detection
This commit is contained in:
@@ -0,0 +1,13 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
@@ -0,0 +1,343 @@
|
||||
import numpy as np
|
||||
import cv2
|
||||
import os
|
||||
import sys
|
||||
import glob
|
||||
from ppdet.utils.logger import setup_logger
|
||||
import copy
|
||||
logger = setup_logger('ppdet_cam')
|
||||
|
||||
import paddle
|
||||
from ppdet.engine import Trainer
|
||||
|
||||
|
||||
def get_test_images(infer_dir, infer_img):
|
||||
"""
|
||||
Get image path list in TEST mode
|
||||
"""
|
||||
assert infer_img is not None or infer_dir is not None, \
|
||||
"--infer_img or --infer_dir should be set"
|
||||
assert infer_img is None or os.path.isfile(infer_img), \
|
||||
"{} is not a file".format(infer_img)
|
||||
assert infer_dir is None or os.path.isdir(infer_dir), \
|
||||
"{} is not a directory".format(infer_dir)
|
||||
|
||||
# infer_img has a higher priority
|
||||
if infer_img and os.path.isfile(infer_img):
|
||||
return [infer_img]
|
||||
|
||||
images = set()
|
||||
infer_dir = os.path.abspath(infer_dir)
|
||||
assert os.path.isdir(infer_dir), \
|
||||
"infer_dir {} is not a directory".format(infer_dir)
|
||||
exts = ['jpg', 'jpeg', 'png', 'bmp']
|
||||
exts += [ext.upper() for ext in exts]
|
||||
for ext in exts:
|
||||
images.update(glob.glob('{}/*.{}'.format(infer_dir, ext)))
|
||||
images = list(images)
|
||||
|
||||
assert len(images) > 0, "no image found in {}".format(infer_dir)
|
||||
logger.info("Found {} inference images in total.".format(len(images)))
|
||||
|
||||
return images
|
||||
|
||||
|
||||
def compute_ious(boxes1, boxes2):
|
||||
"""[Compute pairwise IOU matrix for given two sets of boxes]
|
||||
|
||||
Args:
|
||||
boxes1 ([numpy ndarray with shape N,4]): [representing bounding boxes with format (xmin,ymin,xmax,ymax)]
|
||||
boxes2 ([numpy ndarray with shape M,4]): [representing bounding boxes with format (xmin,ymin,xmax,ymax)]
|
||||
Returns:
|
||||
pairwise IOU maxtrix with shape (N,M),where the value at ith row jth column hold the iou between ith
|
||||
box and jth box from box1 and box2 respectively.
|
||||
"""
|
||||
lu = np.maximum(
|
||||
boxes1[:, None, :2], boxes2[:, :2]
|
||||
) # lu with shape N,M,2 ; boxes1[:,None,:2] with shape (N,1,2) boxes2 with shape(M,2)
|
||||
rd = np.minimum(boxes1[:, None, 2:], boxes2[:, 2:]) # rd same to lu
|
||||
intersection_wh = np.maximum(0.0, rd - lu)
|
||||
intersection_area = intersection_wh[:, :,
|
||||
0] * intersection_wh[:, :,
|
||||
1] # with shape (N,M)
|
||||
boxes1_wh = np.maximum(0.0, boxes1[:, 2:] - boxes1[:, :2])
|
||||
boxes1_area = boxes1_wh[:, 0] * boxes1_wh[:, 1] # with shape (N,)
|
||||
boxes2_wh = np.maximum(0.0, boxes2[:, 2:] - boxes2[:, :2])
|
||||
boxes2_area = boxes2_wh[:, 0] * boxes2_wh[:, 1] # with shape (M,)
|
||||
union_area = np.maximum(
|
||||
boxes1_area[:, None] + boxes2_area - intersection_area,
|
||||
1e-8) # with shape (N,M)
|
||||
ious = np.clip(intersection_area / union_area, 0.0, 1.0)
|
||||
return ious
|
||||
|
||||
|
||||
def grad_cam(feat, grad):
|
||||
"""
|
||||
|
||||
Args:
|
||||
feat: CxHxW
|
||||
grad: CxHxW
|
||||
|
||||
Returns:
|
||||
cam: HxW
|
||||
"""
|
||||
exp = (feat * grad.mean((1, 2), keepdims=True)).mean(axis=0)
|
||||
exp = np.maximum(-exp, 0)
|
||||
return exp
|
||||
|
||||
|
||||
def resize_cam(explanation, resize_shape) -> np.ndarray:
|
||||
"""
|
||||
|
||||
Args:
|
||||
explanation: (width, height)
|
||||
resize_shape: (width, height)
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
assert len(explanation.shape) == 2, f"{explanation.shape}. " \
|
||||
f"Currently support 2D explanation results for visualization. " \
|
||||
"Reduce higher dimensions to 2D for visualization."
|
||||
|
||||
explanation = (explanation - explanation.min()) / (
|
||||
explanation.max() - explanation.min())
|
||||
|
||||
explanation = cv2.resize(explanation, resize_shape)
|
||||
explanation = np.uint8(255 * explanation)
|
||||
explanation = cv2.applyColorMap(explanation, cv2.COLORMAP_JET)
|
||||
explanation = cv2.cvtColor(explanation, cv2.COLOR_BGR2RGB)
|
||||
|
||||
return explanation
|
||||
|
||||
|
||||
class BBoxCAM:
|
||||
def __init__(self, FLAGS, cfg):
|
||||
self.FLAGS = FLAGS
|
||||
self.cfg = cfg
|
||||
# build model
|
||||
self.trainer = self.build_trainer(cfg)
|
||||
# num_class
|
||||
self.num_class = cfg.num_classes
|
||||
# set hook for extraction of featuremaps and grads
|
||||
self.set_hook(cfg)
|
||||
self.nms_idx_need_divid_numclass_arch = ['FasterRCNN', 'MaskRCNN', 'CascadeRCNN']
|
||||
"""
|
||||
In these networks, the bbox array shape before nms contain num_class,
|
||||
the nms_keep_idx of the bbox need to divide the num_class;
|
||||
"""
|
||||
|
||||
# cam image output_dir
|
||||
try:
|
||||
os.makedirs(FLAGS.cam_out)
|
||||
except:
|
||||
print('Path already exists.')
|
||||
pass
|
||||
|
||||
def build_trainer(self, cfg):
|
||||
# build trainer
|
||||
trainer = Trainer(cfg, mode='test')
|
||||
# load weights
|
||||
trainer.load_weights(cfg.weights)
|
||||
|
||||
# set for get extra_data before nms
|
||||
trainer.model.use_extra_data=True
|
||||
# set for record the bbox index before nms
|
||||
if cfg.architecture in ['FasterRCNN', 'MaskRCNN']:
|
||||
trainer.model.bbox_post_process.nms.return_index = True
|
||||
elif cfg.architecture in ['YOLOv3', 'PPYOLOE', 'PPYOLOEWithAuxHead']:
|
||||
if trainer.model.post_process is not None:
|
||||
# anchor based YOLOs: YOLOv3,PP-YOLO
|
||||
trainer.model.post_process.nms.return_index = True
|
||||
else:
|
||||
# anchor free YOLOs: PP-YOLOE, PP-YOLOE+
|
||||
trainer.model.yolo_head.nms.return_index = True
|
||||
elif cfg.architecture=='BlazeFace' or cfg.architecture=='SSD':
|
||||
trainer.model.post_process.nms.return_index = True
|
||||
elif cfg.architecture=='RetinaNet':
|
||||
trainer.model.head.nms.return_index = True
|
||||
else:
|
||||
print(
|
||||
cfg.architecture+' is not supported for cam temporarily!'
|
||||
)
|
||||
sys.exit()
|
||||
# Todo: Unify the head/post_process name in each model
|
||||
|
||||
return trainer
|
||||
|
||||
def set_hook(self, cfg):
|
||||
# set hook for extraction of featuremaps and grads
|
||||
self.target_feats = {}
|
||||
self.target_layer_name = cfg.target_feature_layer_name
|
||||
# such as trainer.model.backbone, trainer.model.bbox_head.roi_extractor
|
||||
|
||||
def hook(layer, input, output):
|
||||
self.target_feats[layer._layer_name_for_hook] = output
|
||||
|
||||
try:
|
||||
exec('self.trainer.'+self.target_layer_name+'._layer_name_for_hook = self.target_layer_name')
|
||||
# self.trainer.target_layer_name._layer_name_for_hook = self.target_layer_name
|
||||
exec('self.trainer.'+self.target_layer_name+'.register_forward_post_hook(hook)')
|
||||
# self.trainer.target_layer_name.register_forward_post_hook(hook)
|
||||
except:
|
||||
print("Error! "
|
||||
"The target_layer_name--"+self.target_layer_name+" is not in model! "
|
||||
"Please check the spelling and "
|
||||
"the network's architecture!")
|
||||
sys.exit()
|
||||
|
||||
def get_bboxes(self):
|
||||
# get inference images
|
||||
images = get_test_images(self.FLAGS.infer_dir, self.FLAGS.infer_img)
|
||||
|
||||
# inference
|
||||
result = self.trainer.predict(
|
||||
images,
|
||||
draw_threshold=self.FLAGS.draw_threshold,
|
||||
output_dir=self.FLAGS.output_dir,
|
||||
save_results=self.FLAGS.save_results,
|
||||
visualize=False)[0]
|
||||
return result
|
||||
|
||||
def get_bboxes_cams(self):
|
||||
# Get the bboxes prediction(after nms result) of the input
|
||||
inference_result = self.get_bboxes()
|
||||
|
||||
# read input image
|
||||
# Todo: Support folder multi-images process
|
||||
from PIL import Image
|
||||
img = np.array(Image.open(self.cfg.infer_img))
|
||||
|
||||
# data for calaulate bbox grad_cam
|
||||
extra_data = inference_result['extra_data']
|
||||
"""
|
||||
Example of Faster_RCNN based architecture:
|
||||
extra_data: {'scores': tensor with shape [num_of_bboxes_before_nms, num_classes], for example: [1000, 80]
|
||||
'nms_keep_idx': tensor with shape [num_of_bboxes_after_nms, 1], for example: [300, 1]
|
||||
}
|
||||
Example of YOLOv3 based architecture:
|
||||
extra_data: {'scores': tensor with shape [1, num_classes, num_of_yolo_bboxes_before_nms], #for example: [1, 80, 8400]
|
||||
'nms_keep_idx': tensor with shape [num_of_yolo_bboxes_after_nms, 1], # for example: [300, 1]
|
||||
}
|
||||
"""
|
||||
|
||||
# array index of the predicted bbox before nms
|
||||
if self.cfg.architecture in self.nms_idx_need_divid_numclass_arch:
|
||||
# some network's bbox array shape before nms may be like [num_of_bboxes_before_nms, num_classes, 4],
|
||||
# we need to divide num_classes to get the before_nms_index;
|
||||
# currently, only include the rcnn architectures (fasterrcnn, maskrcnn, cascadercnn);
|
||||
before_nms_indexes = extra_data['nms_keep_idx'].cpu().numpy(
|
||||
) // self.num_class # num_class
|
||||
else :
|
||||
before_nms_indexes = extra_data['nms_keep_idx'].cpu().numpy()
|
||||
|
||||
# Calculate and visualize the heatmap of per predict bbox
|
||||
for index, target_bbox in enumerate(inference_result['bbox']):
|
||||
# target_bbox: [cls, score, x1, y1, x2, y2]
|
||||
# filter bboxes with low predicted scores
|
||||
if target_bbox[1] < self.FLAGS.draw_threshold:
|
||||
continue
|
||||
|
||||
target_bbox_before_nms = int(before_nms_indexes[index])
|
||||
|
||||
if len(extra_data['scores'].shape)==2:
|
||||
score_out = extra_data['scores'][target_bbox_before_nms]
|
||||
else:
|
||||
score_out = extra_data['scores'][0, :, target_bbox_before_nms]
|
||||
"""
|
||||
There are two kinds array shape of bbox score output :
|
||||
1) [num_of_bboxes_before_nms, num_classes], for example: [1000, 80]
|
||||
2) [num_of_image, num_classes, num_of_yolo_bboxes_before_nms], for example: [1, 80, 1000]
|
||||
"""
|
||||
|
||||
|
||||
# construct one_hot label and do backward to get the gradients
|
||||
predicted_label = paddle.argmax(score_out)
|
||||
label_onehot = paddle.nn.functional.one_hot(
|
||||
predicted_label, num_classes=len(score_out))
|
||||
label_onehot = label_onehot.squeeze()
|
||||
target = paddle.sum(score_out * label_onehot)
|
||||
target.backward(retain_graph=True)
|
||||
|
||||
|
||||
if 'backbone' in self.target_layer_name or \
|
||||
'neck' in self.target_layer_name: # backbone/neck level feature
|
||||
if isinstance(self.target_feats[self.target_layer_name], list):
|
||||
# when the featuremap contains of multiple scales,
|
||||
# take the featuremap of the last scale
|
||||
# Todo: fuse the cam result from multisclae featuremaps
|
||||
if self.target_feats[self.target_layer_name][
|
||||
-1].shape[-1]==1:
|
||||
"""
|
||||
if the last level featuremap is 1x1 size,
|
||||
we take the second last one
|
||||
"""
|
||||
cam_grad = self.target_feats[self.target_layer_name][
|
||||
-2].grad.squeeze().cpu().numpy()
|
||||
cam_feat = self.target_feats[self.target_layer_name][
|
||||
-2].squeeze().cpu().numpy()
|
||||
else:
|
||||
cam_grad = self.target_feats[self.target_layer_name][
|
||||
-1].grad.squeeze().cpu().numpy()
|
||||
cam_feat = self.target_feats[self.target_layer_name][
|
||||
-1].squeeze().cpu().numpy()
|
||||
else:
|
||||
cam_grad = self.target_feats[
|
||||
self.target_layer_name].grad.squeeze().cpu().numpy()
|
||||
cam_feat = self.target_feats[
|
||||
self.target_layer_name].squeeze().cpu().numpy()
|
||||
else: # roi level feature
|
||||
cam_grad = self.target_feats[
|
||||
self.target_layer_name].grad.squeeze().cpu().numpy()[target_bbox_before_nms]
|
||||
cam_feat = self.target_feats[
|
||||
self.target_layer_name].squeeze().cpu().numpy()[target_bbox_before_nms]
|
||||
|
||||
# grad_cam:
|
||||
exp = grad_cam(cam_feat, cam_grad)
|
||||
|
||||
if 'backbone' in self.target_layer_name or \
|
||||
'neck' in self.target_layer_name:
|
||||
"""
|
||||
when use backbone/neck featuremap,
|
||||
we first do the cam on whole image,
|
||||
and then set the area outside the predic bbox to 0
|
||||
"""
|
||||
# reshape the cam image to the input image size
|
||||
resized_exp = resize_cam(exp, (img.shape[1], img.shape[0]))
|
||||
mask = np.zeros((img.shape[0], img.shape[1], 3))
|
||||
mask[int(target_bbox[3]):int(target_bbox[5]), int(target_bbox[2]):
|
||||
int(target_bbox[4]), :] = 1
|
||||
resized_exp = resized_exp * mask
|
||||
# add the bbox cam back to the input image
|
||||
overlay_vis = np.uint8(resized_exp * 0.4 + img * 0.6)
|
||||
elif 'roi' in self.target_layer_name:
|
||||
# get the bbox part of the image
|
||||
bbox_img = copy.deepcopy(img[int(target_bbox[3]):int(target_bbox[5]),
|
||||
int(target_bbox[2]):int(target_bbox[4]), :])
|
||||
# reshape the cam image to the bbox size
|
||||
resized_exp = resize_cam(exp, (bbox_img.shape[1], bbox_img.shape[0]))
|
||||
# add the bbox cam back to the bbox image
|
||||
bbox_overlay_vis = np.uint8(resized_exp * 0.4 + bbox_img * 0.6)
|
||||
# put the bbox_cam image to the original image
|
||||
overlay_vis = copy.deepcopy(img)
|
||||
overlay_vis[int(target_bbox[3]):int(target_bbox[5]),
|
||||
int(target_bbox[2]):int(target_bbox[4]), :] = bbox_overlay_vis
|
||||
else:
|
||||
print(
|
||||
'Only supported cam for backbone/neck feature and roi feature, the others are not supported temporarily!'
|
||||
)
|
||||
sys.exit()
|
||||
|
||||
# put the bbox rectangle on image
|
||||
cv2.rectangle(
|
||||
overlay_vis, (int(target_bbox[2]), int(target_bbox[3])),
|
||||
(int(target_bbox[4]), int(target_bbox[5])), (0, 0, 255), 2)
|
||||
|
||||
# save visualization result
|
||||
cam_image = Image.fromarray(overlay_vis)
|
||||
cam_image.save(self.FLAGS.cam_out + '/' + str(index) + '.jpg')
|
||||
|
||||
# clear gradients after each bbox grad_cam
|
||||
target.clear_gradient()
|
||||
for n, v in self.trainer.model.named_sublayers():
|
||||
v.clear_gradients()
|
||||
156
services/paddle_services/paddle_detection/ppdet/utils/check.py
Normal file
156
services/paddle_services/paddle_detection/ppdet/utils/check.py
Normal file
@@ -0,0 +1,156 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import sys
|
||||
|
||||
import paddle
|
||||
import six
|
||||
import paddle.version as paddle_version
|
||||
|
||||
from .logger import setup_logger
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
__all__ = [
|
||||
'check_gpu', 'check_npu', 'check_xpu', 'check_mlu', 'check_version',
|
||||
'check_config'
|
||||
]
|
||||
|
||||
|
||||
def check_mlu(use_mlu):
|
||||
"""
|
||||
Log error and exit when set use_mlu=true in paddlepaddle
|
||||
cpu/gpu/xpu/npu version.
|
||||
"""
|
||||
err = "Config use_mlu cannot be set as true while you are " \
|
||||
"using paddlepaddle cpu/gpu/xpu/npu version ! \nPlease try: \n" \
|
||||
"\t1. Install paddlepaddle-mlu to run model on MLU \n" \
|
||||
"\t2. Set use_mlu as false in config file to run " \
|
||||
"model on CPU/GPU/XPU/NPU"
|
||||
|
||||
try:
|
||||
if use_mlu and not paddle.is_compiled_with_mlu():
|
||||
logger.error(err)
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
|
||||
def check_npu(use_npu):
|
||||
"""
|
||||
Log error and exit when set use_npu=true in paddlepaddle
|
||||
version without paddle-custom-npu installed.
|
||||
"""
|
||||
err = "Config use_npu cannot be set as true while you are " \
|
||||
"using paddlepaddle version without paddle-custom-npu " \
|
||||
"installed! \nPlease try: \n" \
|
||||
"\t1. Install paddle-custom-npu to run model on NPU \n" \
|
||||
"\t2. Set use_npu as false in config file to run " \
|
||||
"model on other devices supported."
|
||||
|
||||
try:
|
||||
if use_npu and not 'npu' in paddle.device.get_all_custom_device_type():
|
||||
logger.error(err)
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
|
||||
def check_xpu(use_xpu):
|
||||
"""
|
||||
Log error and exit when set use_xpu=true in paddlepaddle
|
||||
cpu/gpu/npu version.
|
||||
"""
|
||||
err = "Config use_xpu cannot be set as true while you are " \
|
||||
"using paddlepaddle cpu/gpu/npu version ! \nPlease try: \n" \
|
||||
"\t1. Install paddlepaddle-xpu to run model on XPU \n" \
|
||||
"\t2. Set use_xpu as false in config file to run " \
|
||||
"model on CPU/GPU/NPU"
|
||||
|
||||
try:
|
||||
if use_xpu and not paddle.is_compiled_with_xpu():
|
||||
logger.error(err)
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
|
||||
def check_gpu(use_gpu):
|
||||
"""
|
||||
Log error and exit when set use_gpu=true in paddlepaddle
|
||||
cpu version.
|
||||
"""
|
||||
err = "Config use_gpu cannot be set as true while you are " \
|
||||
"using paddlepaddle cpu version ! \nPlease try: \n" \
|
||||
"\t1. Install paddlepaddle-gpu to run model on GPU \n" \
|
||||
"\t2. Set use_gpu as false in config file to run " \
|
||||
"model on CPU"
|
||||
|
||||
try:
|
||||
if use_gpu and not paddle.is_compiled_with_cuda():
|
||||
logger.error(err)
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
|
||||
def check_version(version='2.2'):
|
||||
"""
|
||||
Log error and exit when the installed version of paddlepaddle is
|
||||
not satisfied.
|
||||
"""
|
||||
err = "PaddlePaddle version {} or higher is required, " \
|
||||
"or a suitable develop version is satisfied as well. \n" \
|
||||
"Please make sure the version is good with your code.".format(version)
|
||||
|
||||
version_installed = [
|
||||
paddle_version.major, paddle_version.minor, paddle_version.patch,
|
||||
paddle_version.rc
|
||||
]
|
||||
|
||||
if version_installed == ['0', '0', '0', '0']:
|
||||
return
|
||||
|
||||
version_split = version.split('.')
|
||||
|
||||
length = min(len(version_installed), len(version_split))
|
||||
for i in six.moves.range(length):
|
||||
if version_installed[i] > version_split[i]:
|
||||
return
|
||||
if version_installed[i] < version_split[i]:
|
||||
raise Exception(err)
|
||||
|
||||
|
||||
def check_config(cfg):
|
||||
"""
|
||||
Check the correctness of the configuration file. Log error and exit
|
||||
when Config is not compliant.
|
||||
"""
|
||||
err = "'{}' not specified in config file. Please set it in config file."
|
||||
check_list = ['architecture', 'num_classes']
|
||||
try:
|
||||
for var in check_list:
|
||||
if not var in cfg:
|
||||
logger.error(err.format(var))
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
if 'log_iter' not in cfg:
|
||||
cfg.log_iter = 20
|
||||
|
||||
return cfg
|
||||
@@ -0,0 +1,377 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
from .download import get_weights_path
|
||||
|
||||
from .logger import setup_logger
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
|
||||
def is_url(path):
|
||||
"""
|
||||
Whether path is URL.
|
||||
Args:
|
||||
path (string): URL string or not.
|
||||
"""
|
||||
return path.startswith('http://') \
|
||||
or path.startswith('https://') \
|
||||
or path.startswith('ppdet://')
|
||||
|
||||
|
||||
def _strip_postfix(path):
|
||||
path, ext = os.path.splitext(path)
|
||||
assert ext in ['', '.pdparams', '.pdopt', '.pdmodel'], \
|
||||
"Unknown postfix {} from weights".format(ext)
|
||||
return path
|
||||
|
||||
|
||||
def load_weight(model, weight, optimizer=None, ema=None, exchange=True):
|
||||
if is_url(weight):
|
||||
weight = get_weights_path(weight)
|
||||
|
||||
path = _strip_postfix(weight)
|
||||
pdparam_path = path + '.pdparams'
|
||||
if not os.path.exists(pdparam_path):
|
||||
raise ValueError("Model pretrain path {} does not "
|
||||
"exists.".format(pdparam_path))
|
||||
|
||||
if ema is not None and os.path.exists(path + '.pdema'):
|
||||
if exchange:
|
||||
# Exchange model and ema_model to load
|
||||
logger.info('Exchange model and ema_model to load:')
|
||||
ema_state_dict = paddle.load(pdparam_path)
|
||||
logger.info('Loading ema_model weights from {}'.format(path +
|
||||
'.pdparams'))
|
||||
param_state_dict = paddle.load(path + '.pdema')
|
||||
logger.info('Loading model weights from {}'.format(path + '.pdema'))
|
||||
else:
|
||||
ema_state_dict = paddle.load(path + '.pdema')
|
||||
logger.info('Loading ema_model weights from {}'.format(path +
|
||||
'.pdema'))
|
||||
param_state_dict = paddle.load(pdparam_path)
|
||||
logger.info('Loading model weights from {}'.format(path +
|
||||
'.pdparams'))
|
||||
else:
|
||||
ema_state_dict = None
|
||||
param_state_dict = paddle.load(pdparam_path)
|
||||
|
||||
if hasattr(model, 'modelTeacher') and hasattr(model, 'modelStudent'):
|
||||
print('Loading pretrain weights for Teacher-Student framework.')
|
||||
print('Loading pretrain weights for Student model.')
|
||||
student_model_dict = model.modelStudent.state_dict()
|
||||
student_param_state_dict = match_state_dict(
|
||||
student_model_dict, param_state_dict, mode='student')
|
||||
model.modelStudent.set_dict(student_param_state_dict)
|
||||
print('Loading pretrain weights for Teacher model.')
|
||||
teacher_model_dict = model.modelTeacher.state_dict()
|
||||
|
||||
teacher_param_state_dict = match_state_dict(
|
||||
teacher_model_dict, param_state_dict, mode='teacher')
|
||||
model.modelTeacher.set_dict(teacher_param_state_dict)
|
||||
|
||||
else:
|
||||
model_dict = model.state_dict()
|
||||
model_weight = {}
|
||||
incorrect_keys = 0
|
||||
for key in model_dict.keys():
|
||||
if key in param_state_dict.keys():
|
||||
model_weight[key] = param_state_dict[key]
|
||||
else:
|
||||
logger.info('Unmatched key: {}'.format(key))
|
||||
incorrect_keys += 1
|
||||
assert incorrect_keys == 0, "Load weight {} incorrectly, \
|
||||
{} keys unmatched, please check again.".format(weight,
|
||||
incorrect_keys)
|
||||
logger.info('Finish resuming model weights: {}'.format(pdparam_path))
|
||||
model.set_dict(model_weight)
|
||||
|
||||
last_epoch = 0
|
||||
if optimizer is not None and os.path.exists(path + '.pdopt'):
|
||||
optim_state_dict = paddle.load(path + '.pdopt')
|
||||
# to solve resume bug, will it be fixed in paddle 2.0
|
||||
for key in optimizer.state_dict().keys():
|
||||
if not key in optim_state_dict.keys():
|
||||
optim_state_dict[key] = optimizer.state_dict()[key]
|
||||
if 'last_epoch' in optim_state_dict:
|
||||
last_epoch = optim_state_dict.pop('last_epoch')
|
||||
optimizer.set_state_dict(optim_state_dict)
|
||||
|
||||
if ema_state_dict is not None:
|
||||
ema.resume(ema_state_dict,
|
||||
optim_state_dict['LR_Scheduler']['last_epoch'])
|
||||
elif ema_state_dict is not None:
|
||||
ema.resume(ema_state_dict)
|
||||
return last_epoch
|
||||
|
||||
|
||||
def match_state_dict(model_state_dict, weight_state_dict, mode='default'):
|
||||
"""
|
||||
Match between the model state dict and pretrained weight state dict.
|
||||
Return the matched state dict.
|
||||
|
||||
The method supposes that all the names in pretrained weight state dict are
|
||||
subclass of the names in models`, if the prefix 'backbone.' in pretrained weight
|
||||
keys is stripped. And we could get the candidates for each model key. Then we
|
||||
select the name with the longest matched size as the final match result. For
|
||||
example, the model state dict has the name of
|
||||
'backbone.res2.res2a.branch2a.conv.weight' and the pretrained weight as
|
||||
name of 'res2.res2a.branch2a.conv.weight' and 'branch2a.conv.weight'. We
|
||||
match the 'res2.res2a.branch2a.conv.weight' to the model key.
|
||||
"""
|
||||
|
||||
model_keys = sorted(model_state_dict.keys())
|
||||
weight_keys = sorted(weight_state_dict.keys())
|
||||
|
||||
def teacher_match(a, b):
|
||||
# skip student params
|
||||
if b.startswith('modelStudent'):
|
||||
return False
|
||||
return a == b or a.endswith("." + b) or b.endswith("." + a)
|
||||
|
||||
def student_match(a, b):
|
||||
# skip teacher params
|
||||
if b.startswith('modelTeacher'):
|
||||
return False
|
||||
return a == b or a.endswith("." + b) or b.endswith("." + a)
|
||||
|
||||
def match(a, b):
|
||||
if b.startswith('backbone.res5'):
|
||||
b = b[9:]
|
||||
return a == b or a.endswith("." + b)
|
||||
|
||||
if mode == 'student':
|
||||
match_op = student_match
|
||||
elif mode == 'teacher':
|
||||
match_op = teacher_match
|
||||
else:
|
||||
match_op = match
|
||||
|
||||
match_matrix = np.zeros([len(model_keys), len(weight_keys)])
|
||||
for i, m_k in enumerate(model_keys):
|
||||
for j, w_k in enumerate(weight_keys):
|
||||
if match_op(m_k, w_k):
|
||||
match_matrix[i, j] = len(w_k)
|
||||
max_id = match_matrix.argmax(1)
|
||||
max_len = match_matrix.max(1)
|
||||
max_id[max_len == 0] = -1
|
||||
load_id = set(max_id)
|
||||
load_id.discard(-1)
|
||||
not_load_weight_name = []
|
||||
if weight_keys[0].startswith('modelStudent') or weight_keys[0].startswith(
|
||||
'modelTeacher'):
|
||||
for match_idx in range(len(max_id)):
|
||||
if max_id[match_idx] == -1:
|
||||
not_load_weight_name.append(model_keys[match_idx])
|
||||
if len(not_load_weight_name) > 0:
|
||||
logger.info('{} in model is not matched with pretrained weights, '
|
||||
'and its will be trained from scratch'.format(
|
||||
not_load_weight_name))
|
||||
|
||||
else:
|
||||
for idx in range(len(weight_keys)):
|
||||
if idx not in load_id:
|
||||
not_load_weight_name.append(weight_keys[idx])
|
||||
|
||||
if len(not_load_weight_name) > 0:
|
||||
logger.info('{} in pretrained weight is not used in the model, '
|
||||
'and its will not be loaded'.format(
|
||||
not_load_weight_name))
|
||||
matched_keys = {}
|
||||
result_state_dict = {}
|
||||
for model_id, weight_id in enumerate(max_id):
|
||||
if weight_id == -1:
|
||||
continue
|
||||
model_key = model_keys[model_id]
|
||||
weight_key = weight_keys[weight_id]
|
||||
weight_value = weight_state_dict[weight_key]
|
||||
model_value_shape = list(model_state_dict[model_key].shape)
|
||||
|
||||
if list(weight_value.shape) != model_value_shape:
|
||||
logger.info(
|
||||
'The shape {} in pretrained weight {} is unmatched with '
|
||||
'the shape {} in model {}. And the weight {} will not be '
|
||||
'loaded'.format(weight_value.shape, weight_key,
|
||||
model_value_shape, model_key, weight_key))
|
||||
continue
|
||||
|
||||
assert model_key not in result_state_dict
|
||||
result_state_dict[model_key] = weight_value
|
||||
if weight_key in matched_keys:
|
||||
raise ValueError('Ambiguity weight {} loaded, it matches at least '
|
||||
'{} and {} in the model'.format(
|
||||
weight_key, model_key, matched_keys[
|
||||
weight_key]))
|
||||
matched_keys[weight_key] = model_key
|
||||
return result_state_dict
|
||||
|
||||
|
||||
def load_pretrain_weight(model, pretrain_weight, ARSL_eval=False):
|
||||
if is_url(pretrain_weight):
|
||||
pretrain_weight = get_weights_path(pretrain_weight)
|
||||
|
||||
path = _strip_postfix(pretrain_weight)
|
||||
if not (os.path.isdir(path) or os.path.isfile(path) or
|
||||
os.path.exists(path + '.pdparams')):
|
||||
raise ValueError("Model pretrain path `{}` does not exists. "
|
||||
"If you don't want to load pretrain model, "
|
||||
"please delete `pretrain_weights` field in "
|
||||
"config file.".format(path))
|
||||
teacher_student_flag = False
|
||||
if not ARSL_eval:
|
||||
if hasattr(model, 'modelTeacher') and hasattr(model, 'modelStudent'):
|
||||
print('Loading pretrain weights for Teacher-Student framework.')
|
||||
print(
|
||||
'Assert Teacher model has the same structure with Student model.'
|
||||
)
|
||||
model_dict = model.modelStudent.state_dict()
|
||||
teacher_student_flag = True
|
||||
else:
|
||||
model_dict = model.state_dict()
|
||||
|
||||
weights_path = path + '.pdparams'
|
||||
param_state_dict = paddle.load(weights_path)
|
||||
param_state_dict = match_state_dict(model_dict, param_state_dict)
|
||||
for k, v in param_state_dict.items():
|
||||
if isinstance(v, np.ndarray):
|
||||
v = paddle.to_tensor(v)
|
||||
if model_dict[k].dtype != v.dtype:
|
||||
param_state_dict[k] = v.astype(model_dict[k].dtype)
|
||||
|
||||
if teacher_student_flag:
|
||||
model.modelStudent.set_dict(param_state_dict)
|
||||
model.modelTeacher.set_dict(param_state_dict)
|
||||
else:
|
||||
model.set_dict(param_state_dict)
|
||||
logger.info('Finish loading model weights: {}'.format(weights_path))
|
||||
|
||||
else:
|
||||
weights_path = path + '.pdparams'
|
||||
param_state_dict = paddle.load(weights_path)
|
||||
student_model_dict = model.modelStudent.state_dict()
|
||||
student_param_state_dict = match_state_dict(
|
||||
student_model_dict, param_state_dict, mode='student')
|
||||
model.modelStudent.set_dict(student_param_state_dict)
|
||||
print('Loading pretrain weights for Teacher model.')
|
||||
teacher_model_dict = model.modelTeacher.state_dict()
|
||||
|
||||
teacher_param_state_dict = match_state_dict(
|
||||
teacher_model_dict, param_state_dict, mode='teacher')
|
||||
model.modelTeacher.set_dict(teacher_param_state_dict)
|
||||
logger.info('Finish loading model weights: {}'.format(weights_path))
|
||||
|
||||
|
||||
def save_model(model,
|
||||
optimizer,
|
||||
save_dir,
|
||||
save_name,
|
||||
last_epoch,
|
||||
ema_model=None):
|
||||
"""
|
||||
save model into disk.
|
||||
|
||||
Args:
|
||||
model (dict): the model state_dict to save parameters.
|
||||
optimizer (paddle.optimizer.Optimizer): the Optimizer instance to
|
||||
save optimizer states.
|
||||
save_dir (str): the directory to be saved.
|
||||
save_name (str): the path to be saved.
|
||||
last_epoch (int): the epoch index.
|
||||
ema_model (dict|None): the ema_model state_dict to save parameters.
|
||||
"""
|
||||
if paddle.distributed.get_rank() != 0:
|
||||
return
|
||||
|
||||
save_dir = os.path.normpath(save_dir)
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir)
|
||||
|
||||
if save_name == "best_model":
|
||||
best_model_path = os.path.join(save_dir, 'best_model')
|
||||
if not os.path.exists(best_model_path):
|
||||
os.makedirs(best_model_path)
|
||||
|
||||
save_path = os.path.join(save_dir, save_name)
|
||||
# save model
|
||||
if isinstance(model, nn.Layer):
|
||||
paddle.save(model.state_dict(), save_path + ".pdparams")
|
||||
best_model = model.state_dict()
|
||||
else:
|
||||
assert isinstance(model,
|
||||
dict), 'model is not a instance of nn.layer or dict'
|
||||
if ema_model is None:
|
||||
paddle.save(model, save_path + ".pdparams")
|
||||
best_model = model
|
||||
else:
|
||||
assert isinstance(ema_model,
|
||||
dict), ("ema_model is not a instance of dict, "
|
||||
"please call model.state_dict() to get.")
|
||||
# Exchange model and ema_model to save
|
||||
paddle.save(ema_model, save_path + ".pdparams")
|
||||
paddle.save(model, save_path + ".pdema")
|
||||
best_model = ema_model
|
||||
|
||||
if save_name == 'best_model':
|
||||
best_model_path = os.path.join(best_model_path, 'model')
|
||||
paddle.save(best_model, best_model_path + ".pdparams")
|
||||
# save optimizer
|
||||
state_dict = optimizer.state_dict()
|
||||
state_dict['last_epoch'] = last_epoch
|
||||
paddle.save(state_dict, save_path + ".pdopt")
|
||||
logger.info("Save checkpoint: {}".format(save_dir))
|
||||
|
||||
|
||||
def save_semi_model(teacher_model, student_model, optimizer, save_dir,
|
||||
save_name, last_epoch, last_iter):
|
||||
"""
|
||||
save teacher and student model into disk.
|
||||
Args:
|
||||
teacher_model (dict): the teacher_model state_dict to save parameters.
|
||||
student_model (dict): the student_model state_dict to save parameters.
|
||||
optimizer (paddle.optimizer.Optimizer): the Optimizer instance to
|
||||
save optimizer states.
|
||||
save_dir (str): the directory to be saved.
|
||||
save_name (str): the path to be saved.
|
||||
last_epoch (int): the epoch index.
|
||||
last_iter (int): the iter index.
|
||||
"""
|
||||
if paddle.distributed.get_rank() != 0:
|
||||
return
|
||||
assert isinstance(teacher_model, dict), (
|
||||
"teacher_model is not a instance of dict, "
|
||||
"please call teacher_model.state_dict() to get.")
|
||||
assert isinstance(student_model, dict), (
|
||||
"student_model is not a instance of dict, "
|
||||
"please call student_model.state_dict() to get.")
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir)
|
||||
save_path = os.path.join(save_dir, save_name)
|
||||
# save model
|
||||
paddle.save(teacher_model, save_path + str(last_epoch) + "epoch_t.pdparams")
|
||||
paddle.save(student_model, save_path + str(last_epoch) + "epoch_s.pdparams")
|
||||
|
||||
# save optimizer
|
||||
state_dict = optimizer.state_dict()
|
||||
state_dict['last_epoch'] = last_epoch
|
||||
state_dict['last_iter'] = last_iter
|
||||
paddle.save(state_dict, save_path + str(last_epoch) + "epoch.pdopt")
|
||||
logger.info("Save checkpoint: {}".format(save_dir))
|
||||
158
services/paddle_services/paddle_detection/ppdet/utils/cli.py
Normal file
158
services/paddle_services/paddle_detection/ppdet/utils/cli.py
Normal file
@@ -0,0 +1,158 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from argparse import ArgumentParser, RawDescriptionHelpFormatter
|
||||
|
||||
import yaml
|
||||
import re
|
||||
from ppdet.core.workspace import get_registered_modules, dump_value
|
||||
|
||||
__all__ = ['ColorTTY', 'ArgsParser']
|
||||
|
||||
|
||||
class ColorTTY(object):
|
||||
def __init__(self):
|
||||
super(ColorTTY, self).__init__()
|
||||
self.colors = ['red', 'green', 'yellow', 'blue', 'magenta', 'cyan']
|
||||
|
||||
def __getattr__(self, attr):
|
||||
if attr in self.colors:
|
||||
color = self.colors.index(attr) + 31
|
||||
|
||||
def color_message(message):
|
||||
return "[{}m{}[0m".format(color, message)
|
||||
|
||||
setattr(self, attr, color_message)
|
||||
return color_message
|
||||
|
||||
def bold(self, message):
|
||||
return self.with_code('01', message)
|
||||
|
||||
def with_code(self, code, message):
|
||||
return "[{}m{}[0m".format(code, message)
|
||||
|
||||
|
||||
class ArgsParser(ArgumentParser):
|
||||
def __init__(self):
|
||||
super(ArgsParser, self).__init__(
|
||||
formatter_class=RawDescriptionHelpFormatter)
|
||||
self.add_argument("-c", "--config", help="configuration file to use")
|
||||
self.add_argument(
|
||||
"-o", "--opt", nargs='*', help="set configuration options")
|
||||
|
||||
def parse_args(self, argv=None):
|
||||
args = super(ArgsParser, self).parse_args(argv)
|
||||
assert args.config is not None, \
|
||||
"Please specify --config=configure_file_path."
|
||||
args.opt = self._parse_opt(args.opt)
|
||||
return args
|
||||
|
||||
def _parse_opt(self, opts):
|
||||
config = {}
|
||||
if not opts:
|
||||
return config
|
||||
for s in opts:
|
||||
s = s.strip()
|
||||
k, v = s.split('=', 1)
|
||||
if '.' not in k:
|
||||
config[k] = yaml.load(v, Loader=yaml.Loader)
|
||||
else:
|
||||
keys = k.split('.')
|
||||
if keys[0] not in config:
|
||||
config[keys[0]] = {}
|
||||
cur = config[keys[0]]
|
||||
for idx, key in enumerate(keys[1:]):
|
||||
if idx == len(keys) - 2:
|
||||
cur[key] = yaml.load(v, Loader=yaml.Loader)
|
||||
else:
|
||||
cur[key] = {}
|
||||
cur = cur[key]
|
||||
return config
|
||||
|
||||
|
||||
def merge_args(config, args, exclude_args=['config', 'opt', 'slim_config']):
|
||||
for k, v in vars(args).items():
|
||||
if k not in exclude_args:
|
||||
config[k] = v
|
||||
return config
|
||||
|
||||
|
||||
def print_total_cfg(config):
|
||||
modules = get_registered_modules()
|
||||
color_tty = ColorTTY()
|
||||
green = '___{}___'.format(color_tty.colors.index('green') + 31)
|
||||
|
||||
styled = {}
|
||||
for key in config.keys():
|
||||
if not config[key]: # empty schema
|
||||
continue
|
||||
|
||||
if key not in modules and not hasattr(config[key], '__dict__'):
|
||||
styled[key] = config[key]
|
||||
continue
|
||||
elif key in modules:
|
||||
module = modules[key]
|
||||
else:
|
||||
type_name = type(config[key]).__name__
|
||||
if type_name in modules:
|
||||
module = modules[type_name].copy()
|
||||
module.update({
|
||||
k: v
|
||||
for k, v in config[key].__dict__.items()
|
||||
if k in module.schema
|
||||
})
|
||||
key += " ({})".format(type_name)
|
||||
default = module.find_default_keys()
|
||||
missing = module.find_missing_keys()
|
||||
mismatch = module.find_mismatch_keys()
|
||||
extra = module.find_extra_keys()
|
||||
dep_missing = []
|
||||
for dep in module.inject:
|
||||
if isinstance(module[dep], str) and module[dep] != '<value>':
|
||||
if module[dep] not in modules: # not a valid module
|
||||
dep_missing.append(dep)
|
||||
else:
|
||||
dep_mod = modules[module[dep]]
|
||||
# empty dict but mandatory
|
||||
if not dep_mod and dep_mod.mandatory():
|
||||
dep_missing.append(dep)
|
||||
override = list(
|
||||
set(module.keys()) - set(default) - set(extra) - set(dep_missing))
|
||||
replacement = {}
|
||||
for name in set(override + default + extra + mismatch + missing):
|
||||
new_name = name
|
||||
if name in missing:
|
||||
value = "<missing>"
|
||||
else:
|
||||
value = module[name]
|
||||
|
||||
if name in extra:
|
||||
value = dump_value(value) + " <extraneous>"
|
||||
elif name in mismatch:
|
||||
value = dump_value(value) + " <type mismatch>"
|
||||
elif name in dep_missing:
|
||||
value = dump_value(value) + " <module config missing>"
|
||||
elif name in override and value != '<missing>':
|
||||
mark = green
|
||||
new_name = mark + name
|
||||
replacement[new_name] = value
|
||||
styled[key] = replacement
|
||||
buffer = yaml.dump(styled, default_flow_style=False, default_style='')
|
||||
buffer = (re.sub(r"<missing>", r"[31m<missing>[0m", buffer))
|
||||
buffer = (re.sub(r"<extraneous>", r"[33m<extraneous>[0m", buffer))
|
||||
buffer = (re.sub(r"<type mismatch>", r"[31m<type mismatch>[0m", buffer))
|
||||
buffer = (re.sub(r"<module config missing>",
|
||||
r"[31m<module config missing>[0m", buffer))
|
||||
buffer = re.sub(r"___(\d+)___(.*?):", r"[\1m\2[0m:", buffer)
|
||||
print(buffer)
|
||||
@@ -0,0 +1,58 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def colormap(rgb=False):
|
||||
"""
|
||||
Get colormap
|
||||
|
||||
The code of this function is copied from https://github.com/facebookresearch/Detectron/blob/main/detectron/utils/colormap.py
|
||||
"""
|
||||
color_list = np.array([
|
||||
0.000, 0.447, 0.741, 0.850, 0.325, 0.098, 0.929, 0.694, 0.125, 0.494,
|
||||
0.184, 0.556, 0.466, 0.674, 0.188, 0.301, 0.745, 0.933, 0.635, 0.078,
|
||||
0.184, 0.300, 0.300, 0.300, 0.600, 0.600, 0.600, 1.000, 0.000, 0.000,
|
||||
1.000, 0.500, 0.000, 0.749, 0.749, 0.000, 0.000, 1.000, 0.000, 0.000,
|
||||
0.000, 1.000, 0.667, 0.000, 1.000, 0.333, 0.333, 0.000, 0.333, 0.667,
|
||||
0.000, 0.333, 1.000, 0.000, 0.667, 0.333, 0.000, 0.667, 0.667, 0.000,
|
||||
0.667, 1.000, 0.000, 1.000, 0.333, 0.000, 1.000, 0.667, 0.000, 1.000,
|
||||
1.000, 0.000, 0.000, 0.333, 0.500, 0.000, 0.667, 0.500, 0.000, 1.000,
|
||||
0.500, 0.333, 0.000, 0.500, 0.333, 0.333, 0.500, 0.333, 0.667, 0.500,
|
||||
0.333, 1.000, 0.500, 0.667, 0.000, 0.500, 0.667, 0.333, 0.500, 0.667,
|
||||
0.667, 0.500, 0.667, 1.000, 0.500, 1.000, 0.000, 0.500, 1.000, 0.333,
|
||||
0.500, 1.000, 0.667, 0.500, 1.000, 1.000, 0.500, 0.000, 0.333, 1.000,
|
||||
0.000, 0.667, 1.000, 0.000, 1.000, 1.000, 0.333, 0.000, 1.000, 0.333,
|
||||
0.333, 1.000, 0.333, 0.667, 1.000, 0.333, 1.000, 1.000, 0.667, 0.000,
|
||||
1.000, 0.667, 0.333, 1.000, 0.667, 0.667, 1.000, 0.667, 1.000, 1.000,
|
||||
1.000, 0.000, 1.000, 1.000, 0.333, 1.000, 1.000, 0.667, 1.000, 0.167,
|
||||
0.000, 0.000, 0.333, 0.000, 0.000, 0.500, 0.000, 0.000, 0.667, 0.000,
|
||||
0.000, 0.833, 0.000, 0.000, 1.000, 0.000, 0.000, 0.000, 0.167, 0.000,
|
||||
0.000, 0.333, 0.000, 0.000, 0.500, 0.000, 0.000, 0.667, 0.000, 0.000,
|
||||
0.833, 0.000, 0.000, 1.000, 0.000, 0.000, 0.000, 0.167, 0.000, 0.000,
|
||||
0.333, 0.000, 0.000, 0.500, 0.000, 0.000, 0.667, 0.000, 0.000, 0.833,
|
||||
0.000, 0.000, 1.000, 0.000, 0.000, 0.000, 0.143, 0.143, 0.143, 0.286,
|
||||
0.286, 0.286, 0.429, 0.429, 0.429, 0.571, 0.571, 0.571, 0.714, 0.714,
|
||||
0.714, 0.857, 0.857, 0.857, 1.000, 1.000, 1.000
|
||||
]).astype(np.float32)
|
||||
color_list = color_list.reshape((-1, 3)) * 255
|
||||
if not rgb:
|
||||
color_list = color_list[:, ::-1]
|
||||
return color_list.astype('int32')
|
||||
@@ -0,0 +1,11 @@
|
||||
import PIL
|
||||
|
||||
def imagedraw_textsize_c(draw, text, font=None):
|
||||
if int(PIL.__version__.split('.')[0]) < 10:
|
||||
tw, th = draw.textsize(text, font=font)
|
||||
else:
|
||||
left, top, right, bottom = draw.textbbox((0, 0), text, font=font)
|
||||
tw, th = right - left, bottom - top
|
||||
|
||||
return tw, th
|
||||
|
||||
@@ -0,0 +1,560 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import os.path as osp
|
||||
import sys
|
||||
import yaml
|
||||
import time
|
||||
import shutil
|
||||
import requests
|
||||
import tqdm
|
||||
import hashlib
|
||||
import base64
|
||||
import binascii
|
||||
import tarfile
|
||||
import zipfile
|
||||
import errno
|
||||
|
||||
from paddle.utils.download import _get_unique_endpoints
|
||||
from ppdet.core.workspace import BASE_KEY
|
||||
from .logger import setup_logger
|
||||
from .voc_utils import create_list
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
__all__ = [
|
||||
'get_weights_path', 'get_dataset_path', 'get_config_path',
|
||||
'download_dataset', 'create_voc_list'
|
||||
]
|
||||
|
||||
WEIGHTS_HOME = osp.expanduser("~/.cache/paddle/weights")
|
||||
DATASET_HOME = osp.expanduser("~/.cache/paddle/dataset")
|
||||
CONFIGS_HOME = osp.expanduser("~/.cache/paddle/configs")
|
||||
|
||||
# dict of {dataset_name: (download_info, sub_dirs)}
|
||||
# download info: [(url, md5sum)]
|
||||
DATASETS = {
|
||||
'coco': ([
|
||||
(
|
||||
'http://images.cocodataset.org/zips/train2017.zip',
|
||||
'cced6f7f71b7629ddf16f17bbcfab6b2', ),
|
||||
(
|
||||
'http://images.cocodataset.org/zips/val2017.zip',
|
||||
'442b8da7639aecaf257c1dceb8ba8c80', ),
|
||||
(
|
||||
'http://images.cocodataset.org/annotations/annotations_trainval2017.zip',
|
||||
'f4bbac642086de4f52a3fdda2de5fa2c', ),
|
||||
], ["annotations", "train2017", "val2017"]),
|
||||
'voc': ([
|
||||
(
|
||||
'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar',
|
||||
'6cd6e144f989b92b3379bac3b3de84fd', ),
|
||||
(
|
||||
'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar',
|
||||
'c52e279531787c972589f7e41ab4ae64', ),
|
||||
(
|
||||
'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar',
|
||||
'b6e924de25625d8de591ea690078ad9f', ),
|
||||
(
|
||||
'https://paddledet.bj.bcebos.com/data/label_list.txt',
|
||||
'5ae5d62183cfb6f6d3ac109359d06a1b', ),
|
||||
], ["VOCdevkit/VOC2012", "VOCdevkit/VOC2007"]),
|
||||
'wider_face': ([
|
||||
(
|
||||
'https://dataset.bj.bcebos.com/wider_face/WIDER_train.zip',
|
||||
'3fedf70df600953d25982bcd13d91ba2', ),
|
||||
(
|
||||
'https://dataset.bj.bcebos.com/wider_face/WIDER_val.zip',
|
||||
'dfa7d7e790efa35df3788964cf0bbaea', ),
|
||||
(
|
||||
'https://dataset.bj.bcebos.com/wider_face/wider_face_split.zip',
|
||||
'a4a898d6193db4b9ef3260a68bad0dc7', ),
|
||||
], ["WIDER_train", "WIDER_val", "wider_face_split"]),
|
||||
'fruit': ([(
|
||||
'https://dataset.bj.bcebos.com/PaddleDetection_demo/fruit.tar',
|
||||
'baa8806617a54ccf3685fa7153388ae6', ), ],
|
||||
['Annotations', 'JPEGImages']),
|
||||
'roadsign_voc': ([(
|
||||
'https://paddlemodels.bj.bcebos.com/object_detection/roadsign_voc.tar',
|
||||
'8d629c0f880dd8b48de9aeff44bf1f3e', ), ], ['annotations', 'images']),
|
||||
'roadsign_coco': ([(
|
||||
'https://paddlemodels.bj.bcebos.com/object_detection/roadsign_coco.tar',
|
||||
'49ce5a9b5ad0d6266163cd01de4b018e', ), ], ['annotations', 'images']),
|
||||
'spine_coco': ([(
|
||||
'https://paddledet.bj.bcebos.com/data/spine.tar',
|
||||
'8a3a353c2c54a2284ad7d2780b65f6a6', ), ], ['annotations', 'images']),
|
||||
'coco_ce': ([(
|
||||
'https://paddledet.bj.bcebos.com/data/coco_ce.tar',
|
||||
'eadd1b79bc2f069f2744b1dd4e0c0329', ), ], []),
|
||||
'culane': ([('https://bj.bcebos.com/v1/paddledet/data/culane.tar', None, ), ], [])
|
||||
}
|
||||
|
||||
DOWNLOAD_DATASETS_LIST = DATASETS.keys()
|
||||
|
||||
DOWNLOAD_RETRY_LIMIT = 3
|
||||
|
||||
PPDET_WEIGHTS_DOWNLOAD_URL_PREFIX = 'https://paddledet.bj.bcebos.com/'
|
||||
|
||||
|
||||
# When running unit tests, there could be multiple processes that
|
||||
# trying to create DATA_HOME directory simultaneously, so we cannot
|
||||
# use a if condition to check for the existence of the directory;
|
||||
# instead, we use the filesystem as the synchronization mechanism by
|
||||
# catching returned errors.
|
||||
def must_mkdirs(path):
|
||||
try:
|
||||
os.makedirs(path)
|
||||
except OSError as exc:
|
||||
if exc.errno != errno.EEXIST:
|
||||
raise
|
||||
pass
|
||||
|
||||
|
||||
def parse_url(url):
|
||||
url = url.replace("ppdet://", PPDET_WEIGHTS_DOWNLOAD_URL_PREFIX)
|
||||
return url
|
||||
|
||||
|
||||
def get_weights_path(url):
|
||||
"""Get weights path from WEIGHTS_HOME, if not exists,
|
||||
download it from url.
|
||||
"""
|
||||
url = parse_url(url)
|
||||
path, _ = get_path(url, WEIGHTS_HOME)
|
||||
return path
|
||||
|
||||
|
||||
def get_config_path(url):
|
||||
"""Get weights path from CONFIGS_HOME, if not exists,
|
||||
download it from url.
|
||||
"""
|
||||
url = parse_url(url)
|
||||
path = map_path(url, CONFIGS_HOME, path_depth=2)
|
||||
if os.path.isfile(path):
|
||||
return path
|
||||
|
||||
# config file not found, try download
|
||||
# 1. clear configs directory
|
||||
if osp.isdir(CONFIGS_HOME):
|
||||
shutil.rmtree(CONFIGS_HOME)
|
||||
|
||||
# 2. get url
|
||||
try:
|
||||
from ppdet import __version__ as version
|
||||
except ImportError:
|
||||
version = None
|
||||
|
||||
cfg_url = "ppdet://configs/{}/configs.tar".format(version) \
|
||||
if version else "ppdet://configs/configs.tar"
|
||||
cfg_url = parse_url(cfg_url)
|
||||
|
||||
# 3. download and decompress
|
||||
cfg_fullname = _download_dist(cfg_url, osp.dirname(CONFIGS_HOME))
|
||||
_decompress_dist(cfg_fullname)
|
||||
|
||||
# 4. check config file existing
|
||||
if os.path.isfile(path):
|
||||
return path
|
||||
else:
|
||||
logger.error("Get config {} failed after download, please contact us on " \
|
||||
"https://github.com/PaddlePaddle/PaddleDetection/issues".format(path))
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def get_dataset_path(path, annotation, image_dir):
|
||||
"""
|
||||
If path exists, return path.
|
||||
Otherwise, get dataset path from DATASET_HOME, if not exists,
|
||||
download it.
|
||||
"""
|
||||
if _dataset_exists(path, annotation, image_dir):
|
||||
return path
|
||||
|
||||
data_name = os.path.split(path.strip().lower())[-1]
|
||||
if data_name not in DOWNLOAD_DATASETS_LIST:
|
||||
raise ValueError(
|
||||
"Dataset {} is not valid for reason above, please check again.".
|
||||
format(osp.realpath(path)))
|
||||
else:
|
||||
logger.warning(
|
||||
"Dataset {} is not valid for reason above, try searching {} or "
|
||||
"downloading dataset...".format(osp.realpath(path), DATASET_HOME))
|
||||
|
||||
for name, dataset in DATASETS.items():
|
||||
if data_name == name:
|
||||
logger.debug("Parse dataset_dir {} as dataset "
|
||||
"{}".format(path, name))
|
||||
data_dir = osp.join(DATASET_HOME, name)
|
||||
|
||||
if name == "spine_coco":
|
||||
if _dataset_exists(data_dir, annotation, image_dir):
|
||||
return data_dir
|
||||
|
||||
# For voc, only check dir VOCdevkit/VOC2012, VOCdevkit/VOC2007
|
||||
if name in ['voc', 'fruit', 'roadsign_voc']:
|
||||
exists = True
|
||||
for sub_dir in dataset[1]:
|
||||
check_dir = osp.join(data_dir, sub_dir)
|
||||
if osp.exists(check_dir):
|
||||
logger.info("Found {}".format(check_dir))
|
||||
else:
|
||||
exists = False
|
||||
if exists:
|
||||
return data_dir
|
||||
|
||||
# voc exist is checked above, voc is not exist here
|
||||
check_exist = name != 'voc' and name != 'fruit' and name != 'roadsign_voc'
|
||||
for url, md5sum in dataset[0]:
|
||||
get_path(url, data_dir, md5sum, check_exist)
|
||||
|
||||
# voc should create list after download
|
||||
if name == 'voc':
|
||||
create_voc_list(data_dir)
|
||||
return data_dir
|
||||
|
||||
raise ValueError("Dataset automaticly downloading Error.")
|
||||
|
||||
|
||||
def create_voc_list(data_dir, devkit_subdir='VOCdevkit'):
|
||||
logger.debug("Create voc file list...")
|
||||
devkit_dir = osp.join(data_dir, devkit_subdir)
|
||||
years = ['2007', '2012']
|
||||
|
||||
# NOTE: since using auto download VOC
|
||||
# dataset, VOC default label list should be used,
|
||||
# do not generate label_list.txt here. For default
|
||||
# label, see ../data/source/voc.py
|
||||
create_list(devkit_dir, years, data_dir)
|
||||
logger.debug("Create voc file list finished")
|
||||
|
||||
|
||||
def map_path(url, root_dir, path_depth=1):
|
||||
# parse path after download to decompress under root_dir
|
||||
assert path_depth > 0, "path_depth should be a positive integer"
|
||||
dirname = url
|
||||
for _ in range(path_depth):
|
||||
dirname = osp.dirname(dirname)
|
||||
fpath = osp.relpath(url, dirname)
|
||||
|
||||
zip_formats = ['.zip', '.tar', '.gz']
|
||||
for zip_format in zip_formats:
|
||||
fpath = fpath.replace(zip_format, '')
|
||||
return osp.join(root_dir, fpath)
|
||||
|
||||
|
||||
def get_path(url, root_dir, md5sum=None, check_exist=True):
|
||||
""" Download from given url to root_dir.
|
||||
if file or directory specified by url is exists under
|
||||
root_dir, return the path directly, otherwise download
|
||||
from url and decompress it, return the path.
|
||||
|
||||
url (str): download url
|
||||
root_dir (str): root dir for downloading, it should be
|
||||
WEIGHTS_HOME or DATASET_HOME
|
||||
md5sum (str): md5 sum of download package
|
||||
"""
|
||||
# parse path after download to decompress under root_dir
|
||||
fullpath = map_path(url, root_dir)
|
||||
|
||||
# For same zip file, decompressed directory name different
|
||||
# from zip file name, rename by following map
|
||||
decompress_name_map = {
|
||||
"VOCtrainval_11-May-2012": "VOCdevkit/VOC2012",
|
||||
"VOCtrainval_06-Nov-2007": "VOCdevkit/VOC2007",
|
||||
"VOCtest_06-Nov-2007": "VOCdevkit/VOC2007",
|
||||
"annotations_trainval": "annotations"
|
||||
}
|
||||
for k, v in decompress_name_map.items():
|
||||
if fullpath.find(k) >= 0:
|
||||
fullpath = osp.join(osp.split(fullpath)[0], v)
|
||||
|
||||
if osp.exists(fullpath) and check_exist:
|
||||
if not osp.isfile(fullpath) or \
|
||||
_check_exist_file_md5(fullpath, md5sum, url):
|
||||
logger.debug("Found {}".format(fullpath))
|
||||
return fullpath, True
|
||||
else:
|
||||
os.remove(fullpath)
|
||||
|
||||
fullname = _download_dist(url, root_dir, md5sum)
|
||||
|
||||
# new weights format which postfix is 'pdparams' not
|
||||
# need to decompress
|
||||
if osp.splitext(fullname)[-1] not in ['.pdparams', '.yml', '.ttf']:
|
||||
_decompress_dist(fullname)
|
||||
|
||||
return fullpath, False
|
||||
|
||||
|
||||
def download_dataset(path, dataset=None):
|
||||
if dataset not in DATASETS.keys():
|
||||
logger.error("Unknown dataset {}, it should be "
|
||||
"{}".format(dataset, DATASETS.keys()))
|
||||
return
|
||||
dataset_info = DATASETS[dataset][0]
|
||||
for info in dataset_info:
|
||||
get_path(info[0], path, info[1], False)
|
||||
logger.debug("Download dataset {} finished.".format(dataset))
|
||||
|
||||
|
||||
def _dataset_exists(path, annotation, image_dir):
|
||||
"""
|
||||
Check if user define dataset exists
|
||||
"""
|
||||
if not osp.exists(path):
|
||||
logger.warning("Config dataset_dir {} is not exits, "
|
||||
"dataset config is not valid".format(path))
|
||||
return False
|
||||
|
||||
if annotation:
|
||||
annotation_path = osp.join(path, annotation)
|
||||
if not osp.isfile(annotation_path):
|
||||
logger.warning("Config annotation {} is not a "
|
||||
"file, dataset config is not "
|
||||
"valid".format(annotation_path))
|
||||
return False
|
||||
if image_dir:
|
||||
image_path = osp.join(path, image_dir)
|
||||
if not osp.isdir(image_path):
|
||||
logger.warning("Config image_dir {} is not a "
|
||||
"directory, dataset config is not "
|
||||
"valid".format(image_path))
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _download(url, path, md5sum=None):
|
||||
"""
|
||||
Download from url, save to path.
|
||||
|
||||
url (str): download url
|
||||
path (str): download to given path
|
||||
"""
|
||||
must_mkdirs(path)
|
||||
|
||||
fname = osp.split(url)[-1]
|
||||
fullname = osp.join(path, fname)
|
||||
retry_cnt = 0
|
||||
|
||||
while not (osp.exists(fullname) and _check_exist_file_md5(fullname, md5sum,
|
||||
url)):
|
||||
if retry_cnt < DOWNLOAD_RETRY_LIMIT:
|
||||
retry_cnt += 1
|
||||
else:
|
||||
raise RuntimeError("Download from {} failed. "
|
||||
"Retry limit reached".format(url))
|
||||
|
||||
logger.info("Downloading {} from {}".format(fname, url))
|
||||
|
||||
# NOTE: windows path join may incur \, which is invalid in url
|
||||
if sys.platform == "win32":
|
||||
url = url.replace('\\', '/')
|
||||
|
||||
req = requests.get(url, stream=True)
|
||||
if req.status_code != 200:
|
||||
raise RuntimeError("Downloading from {} failed with code "
|
||||
"{}!".format(url, req.status_code))
|
||||
|
||||
# For protecting download interupted, download to
|
||||
# tmp_fullname firstly, move tmp_fullname to fullname
|
||||
# after download finished
|
||||
tmp_fullname = fullname + "_tmp"
|
||||
total_size = req.headers.get('content-length')
|
||||
with open(tmp_fullname, 'wb') as f:
|
||||
if total_size:
|
||||
for chunk in tqdm.tqdm(
|
||||
req.iter_content(chunk_size=1024),
|
||||
total=(int(total_size) + 1023) // 1024,
|
||||
unit='KB'):
|
||||
f.write(chunk)
|
||||
else:
|
||||
for chunk in req.iter_content(chunk_size=1024):
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
shutil.move(tmp_fullname, fullname)
|
||||
return fullname
|
||||
|
||||
|
||||
def _download_dist(url, path, md5sum=None):
|
||||
env = os.environ
|
||||
if 'PADDLE_TRAINERS_NUM' in env and 'PADDLE_TRAINER_ID' in env:
|
||||
# Mainly used to solve the problem of downloading data from
|
||||
# different machines in the case of multiple machines.
|
||||
# Different nodes will download data, and the same node
|
||||
# will only download data once.
|
||||
# Reference https://github.com/PaddlePaddle/PaddleClas/blob/develop/ppcls/utils/download.py#L108
|
||||
rank_id_curr_node = int(os.environ.get("PADDLE_RANK_IN_NODE", 0))
|
||||
num_trainers = int(env['PADDLE_TRAINERS_NUM'])
|
||||
if num_trainers <= 1:
|
||||
return _download(url, path, md5sum)
|
||||
else:
|
||||
fname = osp.split(url)[-1]
|
||||
fullname = osp.join(path, fname)
|
||||
lock_path = fullname + '.download.lock'
|
||||
|
||||
must_mkdirs(path)
|
||||
|
||||
if not osp.exists(fullname):
|
||||
with open(lock_path, 'w'): # touch
|
||||
os.utime(lock_path, None)
|
||||
if rank_id_curr_node == 0:
|
||||
_download(url, path, md5sum)
|
||||
os.remove(lock_path)
|
||||
else:
|
||||
while os.path.exists(lock_path):
|
||||
time.sleep(0.5)
|
||||
return fullname
|
||||
else:
|
||||
return _download(url, path, md5sum)
|
||||
|
||||
|
||||
def _check_exist_file_md5(filename, md5sum, url):
|
||||
# if md5sum is None, and file to check is weights file,
|
||||
# read md5um from url and check, else check md5sum directly
|
||||
return _md5check_from_url(filename, url) if md5sum is None \
|
||||
and filename.endswith('pdparams') \
|
||||
else _md5check(filename, md5sum)
|
||||
|
||||
|
||||
def _md5check_from_url(filename, url):
|
||||
# For weights in bcebos URLs, MD5 value is contained
|
||||
# in request header as 'content_md5'
|
||||
req = requests.get(url, stream=True)
|
||||
content_md5 = req.headers.get('content-md5')
|
||||
req.close()
|
||||
if not content_md5 or _md5check(
|
||||
filename,
|
||||
binascii.hexlify(base64.b64decode(content_md5.strip('"'))).decode(
|
||||
)):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def _md5check(fullname, md5sum=None):
|
||||
if md5sum is None:
|
||||
return True
|
||||
|
||||
logger.debug("File {} md5 checking...".format(fullname))
|
||||
md5 = hashlib.md5()
|
||||
with open(fullname, 'rb') as f:
|
||||
for chunk in iter(lambda: f.read(4096), b""):
|
||||
md5.update(chunk)
|
||||
calc_md5sum = md5.hexdigest()
|
||||
|
||||
if calc_md5sum != md5sum:
|
||||
logger.warning("File {} md5 check failed, {}(calc) != "
|
||||
"{}(base)".format(fullname, calc_md5sum, md5sum))
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _decompress(fname):
|
||||
"""
|
||||
Decompress for zip and tar file
|
||||
"""
|
||||
logger.info("Decompressing {}...".format(fname))
|
||||
|
||||
# For protecting decompressing interupted,
|
||||
# decompress to fpath_tmp directory firstly, if decompress
|
||||
# successed, move decompress files to fpath and delete
|
||||
# fpath_tmp and remove download compress file.
|
||||
fpath = osp.split(fname)[0]
|
||||
fpath_tmp = osp.join(fpath, 'tmp')
|
||||
if osp.isdir(fpath_tmp):
|
||||
shutil.rmtree(fpath_tmp)
|
||||
os.makedirs(fpath_tmp)
|
||||
|
||||
if fname.find('tar') >= 0:
|
||||
with tarfile.open(fname) as tf:
|
||||
tf.extractall(path=fpath_tmp)
|
||||
elif fname.find('zip') >= 0:
|
||||
with zipfile.ZipFile(fname) as zf:
|
||||
zf.extractall(path=fpath_tmp)
|
||||
elif fname.find('.txt') >= 0:
|
||||
return
|
||||
else:
|
||||
raise TypeError("Unsupport compress file type {}".format(fname))
|
||||
|
||||
for f in os.listdir(fpath_tmp):
|
||||
src_dir = osp.join(fpath_tmp, f)
|
||||
dst_dir = osp.join(fpath, f)
|
||||
_move_and_merge_tree(src_dir, dst_dir)
|
||||
|
||||
shutil.rmtree(fpath_tmp)
|
||||
os.remove(fname)
|
||||
|
||||
|
||||
def _decompress_dist(fname):
|
||||
env = os.environ
|
||||
if 'PADDLE_TRAINERS_NUM' in env and 'PADDLE_TRAINER_ID' in env:
|
||||
trainer_id = int(env['PADDLE_TRAINER_ID'])
|
||||
num_trainers = int(env['PADDLE_TRAINERS_NUM'])
|
||||
if num_trainers <= 1:
|
||||
_decompress(fname)
|
||||
else:
|
||||
lock_path = fname + '.decompress.lock'
|
||||
from paddle.distributed import ParallelEnv
|
||||
unique_endpoints = _get_unique_endpoints(ParallelEnv()
|
||||
.trainer_endpoints[:])
|
||||
# NOTE(dkp): _decompress_dist always performed after
|
||||
# _download_dist, in _download_dist sub-trainers is waiting
|
||||
# for download lock file release with sleeping, if decompress
|
||||
# prograss is very fast and finished with in the sleeping gap
|
||||
# time, e.g in tiny dataset such as coco_ce, spine_coco, main
|
||||
# trainer may finish decompress and release lock file, so we
|
||||
# only craete lock file in main trainer and all sub-trainer
|
||||
# wait 1s for main trainer to create lock file, for 1s is
|
||||
# twice as sleeping gap, this waiting time can keep all
|
||||
# trainer pipeline in order
|
||||
# **change this if you have more elegent methods**
|
||||
if ParallelEnv().current_endpoint in unique_endpoints:
|
||||
with open(lock_path, 'w'): # touch
|
||||
os.utime(lock_path, None)
|
||||
_decompress(fname)
|
||||
os.remove(lock_path)
|
||||
else:
|
||||
time.sleep(1)
|
||||
while os.path.exists(lock_path):
|
||||
time.sleep(0.5)
|
||||
else:
|
||||
_decompress(fname)
|
||||
|
||||
|
||||
def _move_and_merge_tree(src, dst):
|
||||
"""
|
||||
Move src directory to dst, if dst is already exists,
|
||||
merge src to dst
|
||||
"""
|
||||
if not osp.exists(dst):
|
||||
shutil.move(src, dst)
|
||||
elif osp.isfile(src):
|
||||
shutil.move(src, dst)
|
||||
else:
|
||||
for fp in os.listdir(src):
|
||||
src_fp = osp.join(src, fp)
|
||||
dst_fp = osp.join(dst, fp)
|
||||
if osp.isdir(src_fp):
|
||||
if osp.isdir(dst_fp):
|
||||
_move_and_merge_tree(src_fp, dst_fp)
|
||||
else:
|
||||
shutil.move(src_fp, dst_fp)
|
||||
elif osp.isfile(src_fp) and \
|
||||
not osp.isfile(dst_fp):
|
||||
shutil.move(src_fp, dst_fp)
|
||||
@@ -0,0 +1,179 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
|
||||
__all__ = ['fuse_conv_bn']
|
||||
|
||||
|
||||
def fuse_conv_bn(model):
|
||||
is_train = False
|
||||
if model.training:
|
||||
model.eval()
|
||||
is_train = True
|
||||
fuse_list = []
|
||||
tmp_pair = [None, None]
|
||||
for name, layer in model.named_sublayers():
|
||||
if isinstance(layer, nn.Conv2D):
|
||||
tmp_pair[0] = name
|
||||
if isinstance(layer, nn.BatchNorm2D):
|
||||
tmp_pair[1] = name
|
||||
|
||||
if tmp_pair[0] and tmp_pair[1] and len(tmp_pair) == 2:
|
||||
fuse_list.append(tmp_pair)
|
||||
tmp_pair = [None, None]
|
||||
model = fuse_layers(model, fuse_list)
|
||||
if is_train:
|
||||
model.train()
|
||||
return model
|
||||
|
||||
|
||||
def find_parent_layer_and_sub_name(model, name):
|
||||
"""
|
||||
Given the model and the name of a layer, find the parent layer and
|
||||
the sub_name of the layer.
|
||||
For example, if name is 'block_1/convbn_1/conv_1', the parent layer is
|
||||
'block_1/convbn_1' and the sub_name is `conv_1`.
|
||||
Args:
|
||||
model(paddle.nn.Layer): the model to be quantized.
|
||||
name(string): the name of a layer
|
||||
|
||||
Returns:
|
||||
parent_layer, subname
|
||||
"""
|
||||
assert isinstance(model, nn.Layer), \
|
||||
"The model must be the instance of paddle.nn.Layer."
|
||||
assert len(name) > 0, "The input (name) should not be empty."
|
||||
|
||||
last_idx = 0
|
||||
idx = 0
|
||||
parent_layer = model
|
||||
while idx < len(name):
|
||||
if name[idx] == '.':
|
||||
sub_name = name[last_idx:idx]
|
||||
if hasattr(parent_layer, sub_name):
|
||||
parent_layer = getattr(parent_layer, sub_name)
|
||||
last_idx = idx + 1
|
||||
idx += 1
|
||||
sub_name = name[last_idx:idx]
|
||||
return parent_layer, sub_name
|
||||
|
||||
|
||||
class Identity(nn.Layer):
|
||||
'''a layer to replace bn or relu layers'''
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(Identity, self).__init__()
|
||||
|
||||
def forward(self, input):
|
||||
return input
|
||||
|
||||
|
||||
def fuse_layers(model, layers_to_fuse, inplace=False):
|
||||
'''
|
||||
fuse layers in layers_to_fuse
|
||||
|
||||
Args:
|
||||
model(nn.Layer): The model to be fused.
|
||||
layers_to_fuse(list): The layers' names to be fused. For
|
||||
example,"fuse_list = [["conv1", "bn1"], ["conv2", "bn2"]]".
|
||||
A TypeError would be raised if "fuse" was set as
|
||||
True but "fuse_list" was None.
|
||||
Default: None.
|
||||
inplace(bool): Whether apply fusing to the input model.
|
||||
Default: False.
|
||||
|
||||
Return
|
||||
fused_model(paddle.nn.Layer): The fused model.
|
||||
'''
|
||||
if not inplace:
|
||||
model = copy.deepcopy(model)
|
||||
for layers_list in layers_to_fuse:
|
||||
layer_list = []
|
||||
for layer_name in layers_list:
|
||||
parent_layer, sub_name = find_parent_layer_and_sub_name(model,
|
||||
layer_name)
|
||||
layer_list.append(getattr(parent_layer, sub_name))
|
||||
new_layers = _fuse_func(layer_list)
|
||||
for i, item in enumerate(layers_list):
|
||||
parent_layer, sub_name = find_parent_layer_and_sub_name(model, item)
|
||||
setattr(parent_layer, sub_name, new_layers[i])
|
||||
return model
|
||||
|
||||
|
||||
def _fuse_func(layer_list):
|
||||
'''choose the fuser method and fuse layers'''
|
||||
types = tuple(type(m) for m in layer_list)
|
||||
fusion_method = types_to_fusion_method.get(types, None)
|
||||
new_layers = [None] * len(layer_list)
|
||||
fused_layer = fusion_method(*layer_list)
|
||||
for handle_id, pre_hook_fn in layer_list[0]._forward_pre_hooks.items():
|
||||
fused_layer.register_forward_pre_hook(pre_hook_fn)
|
||||
del layer_list[0]._forward_pre_hooks[handle_id]
|
||||
for handle_id, hook_fn in layer_list[-1]._forward_post_hooks.items():
|
||||
fused_layer.register_forward_post_hook(hook_fn)
|
||||
del layer_list[-1]._forward_post_hooks[handle_id]
|
||||
new_layers[0] = fused_layer
|
||||
for i in range(1, len(layer_list)):
|
||||
identity = Identity()
|
||||
identity.training = layer_list[0].training
|
||||
new_layers[i] = identity
|
||||
return new_layers
|
||||
|
||||
|
||||
def _fuse_conv_bn(conv, bn):
|
||||
'''fuse conv and bn for train or eval'''
|
||||
assert(conv.training == bn.training),\
|
||||
"Conv and BN both must be in the same mode (train or eval)."
|
||||
if conv.training:
|
||||
assert bn._num_features == conv._out_channels, 'Output channel of Conv2d must match num_features of BatchNorm2d'
|
||||
raise NotImplementedError
|
||||
else:
|
||||
return _fuse_conv_bn_eval(conv, bn)
|
||||
|
||||
|
||||
def _fuse_conv_bn_eval(conv, bn):
|
||||
'''fuse conv and bn for eval'''
|
||||
assert (not (conv.training or bn.training)), "Fusion only for eval!"
|
||||
fused_conv = copy.deepcopy(conv)
|
||||
|
||||
fused_weight, fused_bias = _fuse_conv_bn_weights(
|
||||
fused_conv.weight, fused_conv.bias, bn._mean, bn._variance, bn._epsilon,
|
||||
bn.weight, bn.bias)
|
||||
fused_conv.weight.set_value(fused_weight)
|
||||
if fused_conv.bias is None:
|
||||
fused_conv.bias = paddle.create_parameter(
|
||||
shape=[fused_conv._out_channels], is_bias=True, dtype=bn.bias.dtype)
|
||||
fused_conv.bias.set_value(fused_bias)
|
||||
return fused_conv
|
||||
|
||||
|
||||
def _fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b):
|
||||
'''fuse weights and bias of conv and bn'''
|
||||
if conv_b is None:
|
||||
conv_b = paddle.zeros_like(bn_rm)
|
||||
if bn_w is None:
|
||||
bn_w = paddle.ones_like(bn_rm)
|
||||
if bn_b is None:
|
||||
bn_b = paddle.zeros_like(bn_rm)
|
||||
bn_var_rsqrt = paddle.rsqrt(bn_rv + bn_eps)
|
||||
conv_w = conv_w * \
|
||||
(bn_w * bn_var_rsqrt).reshape([-1] + [1] * (len(conv_w.shape) - 1))
|
||||
conv_b = (conv_b - bn_rm) * bn_var_rsqrt * bn_w + bn_b
|
||||
return conv_w, conv_b
|
||||
|
||||
|
||||
types_to_fusion_method = {(nn.Conv2D, nn.BatchNorm2D): _fuse_conv_bn, }
|
||||
@@ -0,0 +1,70 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
import paddle.distributed as dist
|
||||
|
||||
__all__ = ['setup_logger']
|
||||
|
||||
logger_initialized = []
|
||||
|
||||
|
||||
def setup_logger(name="ppdet", output=None):
|
||||
"""
|
||||
Initialize logger and set its verbosity level to INFO.
|
||||
Args:
|
||||
output (str): a file name or a directory to save log. If None, will not save log file.
|
||||
If ends with ".txt" or ".log", assumed to be a file name.
|
||||
Otherwise, logs will be saved to `output/log.txt`.
|
||||
name (str): the root module name of this logger
|
||||
|
||||
Returns:
|
||||
logging.Logger: a logger
|
||||
"""
|
||||
logger = logging.getLogger(name)
|
||||
if name in logger_initialized:
|
||||
return logger
|
||||
|
||||
logger.setLevel(logging.INFO)
|
||||
logger.propagate = False
|
||||
|
||||
formatter = logging.Formatter(
|
||||
"[%(asctime)s] %(name)s %(levelname)s: %(message)s",
|
||||
datefmt="%m/%d %H:%M:%S")
|
||||
# stdout logging: master only
|
||||
local_rank = dist.get_rank()
|
||||
if local_rank == 0:
|
||||
ch = logging.StreamHandler(stream=sys.stdout)
|
||||
ch.setLevel(logging.DEBUG)
|
||||
ch.setFormatter(formatter)
|
||||
logger.addHandler(ch)
|
||||
|
||||
# file logging: all workers
|
||||
if output is not None:
|
||||
if output.endswith(".txt") or output.endswith(".log"):
|
||||
filename = output
|
||||
else:
|
||||
filename = os.path.join(output, "log.txt")
|
||||
if local_rank > 0:
|
||||
filename = filename + ".rank{}".format(local_rank)
|
||||
os.makedirs(os.path.dirname(filename))
|
||||
fh = logging.FileHandler(filename, mode='a')
|
||||
fh.setLevel(logging.DEBUG)
|
||||
fh.setFormatter(logging.Formatter())
|
||||
logger.addHandler(fh)
|
||||
logger_initialized.append(name)
|
||||
return logger
|
||||
@@ -0,0 +1,129 @@
|
||||
# copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
import paddle
|
||||
import paddle.profiler as profiler
|
||||
|
||||
# A global variable to record the number of calling times for profiler
|
||||
# functions. It is used to specify the tracing range of training steps.
|
||||
_profiler_step_id = 0
|
||||
|
||||
# A global variable to avoid parsing from string every time.
|
||||
_profiler_options = None
|
||||
_prof = None
|
||||
|
||||
class ProfilerOptions(object):
|
||||
'''
|
||||
Use a string to initialize a ProfilerOptions.
|
||||
The string should be in the format: "key1=value1;key2=value;key3=value3".
|
||||
For example:
|
||||
"profile_path=model.profile"
|
||||
"batch_range=[50, 60]; profile_path=model.profile"
|
||||
"batch_range=[50, 60]; tracer_option=OpDetail; profile_path=model.profile"
|
||||
|
||||
ProfilerOptions supports following key-value pair:
|
||||
batch_range - a integer list, e.g. [100, 110].
|
||||
state - a string, the optional values are 'CPU', 'GPU' or 'All'.
|
||||
sorted_key - a string, the optional values are 'calls', 'total',
|
||||
'max', 'min' or 'ave.
|
||||
tracer_option - a string, the optional values are 'Default', 'OpDetail',
|
||||
'AllOpDetail'.
|
||||
profile_path - a string, the path to save the serialized profile data,
|
||||
which can be used to generate a timeline.
|
||||
exit_on_finished - a boolean.
|
||||
'''
|
||||
|
||||
def __init__(self, options_str):
|
||||
assert isinstance(options_str, str)
|
||||
|
||||
self._options = {
|
||||
'batch_range': [10, 20],
|
||||
'state': 'All',
|
||||
'sorted_key': 'total',
|
||||
'tracer_option': 'Default',
|
||||
'profile_path': '/tmp/profile',
|
||||
'exit_on_finished': True,
|
||||
'timer_only': True
|
||||
}
|
||||
self._parse_from_string(options_str)
|
||||
|
||||
def _parse_from_string(self, options_str):
|
||||
for kv in options_str.replace(' ', '').split(';'):
|
||||
key, value = kv.split('=')
|
||||
if key == 'batch_range':
|
||||
value_list = value.replace('[', '').replace(']', '').split(',')
|
||||
value_list = list(map(int, value_list))
|
||||
if len(value_list) >= 2 and value_list[0] >= 0 and value_list[
|
||||
1] > value_list[0]:
|
||||
self._options[key] = value_list
|
||||
elif key == 'exit_on_finished':
|
||||
self._options[key] = value.lower() in ("yes", "true", "t", "1")
|
||||
elif key in [
|
||||
'state', 'sorted_key', 'tracer_option', 'profile_path'
|
||||
]:
|
||||
self._options[key] = value
|
||||
elif key == 'timer_only':
|
||||
self._options[key] = value
|
||||
|
||||
def __getitem__(self, name):
|
||||
if self._options.get(name, None) is None:
|
||||
raise ValueError(
|
||||
"ProfilerOptions does not have an option named %s." % name)
|
||||
return self._options[name]
|
||||
|
||||
|
||||
def add_profiler_step(options_str=None):
|
||||
'''
|
||||
Enable the operator-level timing using PaddlePaddle's profiler.
|
||||
The profiler uses a independent variable to count the profiler steps.
|
||||
One call of this function is treated as a profiler step.
|
||||
Args:
|
||||
profiler_options - a string to initialize the ProfilerOptions.
|
||||
Default is None, and the profiler is disabled.
|
||||
'''
|
||||
if options_str is None:
|
||||
return
|
||||
|
||||
global _prof
|
||||
global _profiler_step_id
|
||||
global _profiler_options
|
||||
|
||||
if _profiler_options is None:
|
||||
_profiler_options = ProfilerOptions(options_str)
|
||||
# profile : https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/performance_improving/profiling_model.html#chakanxingnengshujudetongjibiaodan
|
||||
# timer_only = True only the model's throughput and time overhead are displayed
|
||||
# timer_only = False calling summary can print a statistical form that presents performance data from different perspectives.
|
||||
# timer_only = False the output Timeline information can be found in the profiler_log directory
|
||||
if _prof is None:
|
||||
_timer_only = str(_profiler_options['timer_only']) == str(True)
|
||||
_prof = profiler.Profiler(
|
||||
scheduler = (_profiler_options['batch_range'][0], _profiler_options['batch_range'][1]),
|
||||
on_trace_ready = profiler.export_chrome_tracing('./profiler_log'),
|
||||
timer_only = _timer_only)
|
||||
_prof.start()
|
||||
else:
|
||||
_prof.step()
|
||||
|
||||
if _profiler_step_id == _profiler_options['batch_range'][1]:
|
||||
_prof.stop()
|
||||
_prof.summary(
|
||||
op_detail=True,
|
||||
thread_sep=False,
|
||||
time_unit='ms')
|
||||
_prof = None
|
||||
if _profiler_options['exit_on_finished']:
|
||||
sys.exit(0)
|
||||
|
||||
_profiler_step_id += 1
|
||||
Binary file not shown.
@@ -0,0 +1,94 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import collections
|
||||
import numpy as np
|
||||
|
||||
__all__ = ['SmoothedValue', 'TrainingStats']
|
||||
|
||||
|
||||
class SmoothedValue(object):
|
||||
"""Track a series of values and provide access to smoothed values over a
|
||||
window or the global series average.
|
||||
"""
|
||||
|
||||
def __init__(self, window_size=20, fmt=None):
|
||||
if fmt is None:
|
||||
fmt = "{median:.4f} ({avg:.4f})"
|
||||
self.deque = collections.deque(maxlen=window_size)
|
||||
self.fmt = fmt
|
||||
self.total = 0.
|
||||
self.count = 0
|
||||
|
||||
def update(self, value, n=1):
|
||||
self.deque.append(value)
|
||||
self.count += n
|
||||
self.total += value * n
|
||||
|
||||
@property
|
||||
def median(self):
|
||||
return np.median(self.deque)
|
||||
|
||||
@property
|
||||
def avg(self):
|
||||
return np.mean(self.deque)
|
||||
|
||||
@property
|
||||
def max(self):
|
||||
return np.max(self.deque)
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
return self.deque[-1]
|
||||
|
||||
@property
|
||||
def global_avg(self):
|
||||
return self.total / self.count
|
||||
|
||||
def __str__(self):
|
||||
return self.fmt.format(
|
||||
median=self.median, avg=self.avg, max=self.max, value=self.value)
|
||||
|
||||
|
||||
class TrainingStats(object):
|
||||
def __init__(self, window_size, delimiter=' '):
|
||||
self.meters = None
|
||||
self.window_size = window_size
|
||||
self.delimiter = delimiter
|
||||
|
||||
def update(self, stats):
|
||||
if self.meters is None:
|
||||
self.meters = {
|
||||
k: SmoothedValue(self.window_size)
|
||||
for k in stats.keys()
|
||||
}
|
||||
for k, v in self.meters.items():
|
||||
v.update(float(stats[k]))
|
||||
|
||||
def get(self, extras=None):
|
||||
stats = collections.OrderedDict()
|
||||
if extras:
|
||||
for k, v in extras.items():
|
||||
stats[k] = v
|
||||
for k, v in self.meters.items():
|
||||
stats[k] = format(v.median, '.6f')
|
||||
|
||||
return stats
|
||||
|
||||
def log(self, extras=None):
|
||||
d = self.get(extras)
|
||||
strs = []
|
||||
for k, v in d.items():
|
||||
strs.append("{}: {}".format(k, str(v)))
|
||||
return self.delimiter.join(strs)
|
||||
@@ -0,0 +1,465 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
import cv2
|
||||
import math
|
||||
|
||||
from .colormap import colormap
|
||||
from ppdet.utils.logger import setup_logger
|
||||
from ppdet.utils.compact import imagedraw_textsize_c
|
||||
from ppdet.utils.download import get_path
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
__all__ = ['visualize_results']
|
||||
|
||||
|
||||
def visualize_results(image,
|
||||
bbox_res,
|
||||
mask_res,
|
||||
segm_res,
|
||||
keypoint_res,
|
||||
pose3d_res,
|
||||
im_id,
|
||||
catid2name,
|
||||
threshold=0.5):
|
||||
"""
|
||||
Visualize bbox and mask results
|
||||
"""
|
||||
if bbox_res is not None:
|
||||
image = draw_bbox(image, im_id, catid2name, bbox_res, threshold)
|
||||
if mask_res is not None:
|
||||
image = draw_mask(image, im_id, mask_res, threshold)
|
||||
if segm_res is not None:
|
||||
image = draw_segm(image, im_id, catid2name, segm_res, threshold)
|
||||
if keypoint_res is not None:
|
||||
image = draw_pose(image, keypoint_res, threshold)
|
||||
if pose3d_res is not None:
|
||||
pose3d = np.array(pose3d_res[0]['pose3d']) * 1000
|
||||
image = draw_pose3d(image, pose3d, visual_thread=threshold)
|
||||
return image
|
||||
|
||||
|
||||
def draw_mask(image, im_id, segms, threshold, alpha=0.7):
|
||||
"""
|
||||
Draw mask on image
|
||||
"""
|
||||
mask_color_id = 0
|
||||
w_ratio = .4
|
||||
color_list = colormap(rgb=True)
|
||||
img_array = np.array(image).astype('float32')
|
||||
for dt in np.array(segms):
|
||||
if im_id != dt['image_id']:
|
||||
continue
|
||||
segm, score = dt['segmentation'], dt['score']
|
||||
if score < threshold:
|
||||
continue
|
||||
import pycocotools.mask as mask_util
|
||||
mask = mask_util.decode(segm) * 255
|
||||
color_mask = color_list[mask_color_id % len(color_list), 0:3]
|
||||
mask_color_id += 1
|
||||
for c in range(3):
|
||||
color_mask[c] = color_mask[c] * (1 - w_ratio) + w_ratio * 255
|
||||
idx = np.nonzero(mask)
|
||||
img_array[idx[0], idx[1], :] *= 1.0 - alpha
|
||||
img_array[idx[0], idx[1], :] += alpha * color_mask
|
||||
return Image.fromarray(img_array.astype('uint8'))
|
||||
|
||||
|
||||
def draw_bbox(image, im_id, catid2name, bboxes, threshold):
|
||||
"""
|
||||
Draw bbox on image
|
||||
"""
|
||||
font_url = "https://paddledet.bj.bcebos.com/simfang.ttf"
|
||||
font_path , _ = get_path(font_url, "~/.cache/paddle/")
|
||||
font_size = 18
|
||||
font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
|
||||
|
||||
draw = ImageDraw.Draw(image)
|
||||
|
||||
catid2color = {}
|
||||
color_list = colormap(rgb=True)[:40]
|
||||
for dt in np.array(bboxes):
|
||||
if im_id != dt['image_id']:
|
||||
continue
|
||||
catid, bbox, score = dt['category_id'], dt['bbox'], dt['score']
|
||||
if score < threshold:
|
||||
continue
|
||||
|
||||
if catid not in catid2color:
|
||||
idx = np.random.randint(len(color_list))
|
||||
catid2color[catid] = color_list[idx]
|
||||
color = tuple(catid2color[catid])
|
||||
|
||||
# draw bbox
|
||||
if len(bbox) == 4:
|
||||
# draw bbox
|
||||
xmin, ymin, w, h = bbox
|
||||
xmax = xmin + w
|
||||
ymax = ymin + h
|
||||
draw.line(
|
||||
[(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin),
|
||||
(xmin, ymin)],
|
||||
width=2,
|
||||
fill=color)
|
||||
elif len(bbox) == 8:
|
||||
x1, y1, x2, y2, x3, y3, x4, y4 = bbox
|
||||
draw.line(
|
||||
[(x1, y1), (x2, y2), (x3, y3), (x4, y4), (x1, y1)],
|
||||
width=2,
|
||||
fill=color)
|
||||
xmin = min(x1, x2, x3, x4)
|
||||
ymin = min(y1, y2, y3, y4)
|
||||
else:
|
||||
logger.error('the shape of bbox must be [M, 4] or [M, 8]!')
|
||||
|
||||
# draw label
|
||||
text = "{} {:.2f}".format(catid2name[catid], score)
|
||||
tw, th = imagedraw_textsize_c(draw, text, font=font)
|
||||
draw.rectangle(
|
||||
[(xmin + 1, ymin - th), (xmin + tw + 1, ymin)], fill=color)
|
||||
draw.text((xmin + 1, ymin - th), text, fill=(255, 255, 255), font=font)
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def save_result(save_path, results, catid2name, threshold):
|
||||
"""
|
||||
save result as txt
|
||||
"""
|
||||
img_id = int(results["im_id"])
|
||||
with open(save_path, 'w') as f:
|
||||
if "bbox_res" in results:
|
||||
for dt in results["bbox_res"]:
|
||||
catid, bbox, score = dt['category_id'], dt['bbox'], dt['score']
|
||||
if score < threshold:
|
||||
continue
|
||||
# each bbox result as a line
|
||||
# for rbox: classname score x1 y1 x2 y2 x3 y3 x4 y4
|
||||
# for bbox: classname score x1 y1 w h
|
||||
bbox_pred = '{} {} '.format(catid2name[catid],
|
||||
score) + ' '.join(
|
||||
[str(e) for e in bbox])
|
||||
f.write(bbox_pred + '\n')
|
||||
elif "keypoint_res" in results:
|
||||
for dt in results["keypoint_res"]:
|
||||
kpts = dt['keypoints']
|
||||
scores = dt['score']
|
||||
keypoint_pred = [img_id, scores, kpts]
|
||||
print(keypoint_pred, file=f)
|
||||
else:
|
||||
print("No valid results found, skip txt save")
|
||||
|
||||
|
||||
def draw_segm(image,
|
||||
im_id,
|
||||
catid2name,
|
||||
segms,
|
||||
threshold,
|
||||
alpha=0.7,
|
||||
draw_box=True):
|
||||
"""
|
||||
Draw segmentation on image
|
||||
"""
|
||||
mask_color_id = 0
|
||||
w_ratio = .4
|
||||
color_list = colormap(rgb=True)
|
||||
img_array = np.array(image).astype('float32')
|
||||
for dt in np.array(segms):
|
||||
if im_id != dt['image_id']:
|
||||
continue
|
||||
segm, score, catid = dt['segmentation'], dt['score'], dt['category_id']
|
||||
if score < threshold:
|
||||
continue
|
||||
import pycocotools.mask as mask_util
|
||||
mask = mask_util.decode(segm) * 255
|
||||
color_mask = color_list[mask_color_id % len(color_list), 0:3]
|
||||
mask_color_id += 1
|
||||
for c in range(3):
|
||||
color_mask[c] = color_mask[c] * (1 - w_ratio) + w_ratio * 255
|
||||
idx = np.nonzero(mask)
|
||||
img_array[idx[0], idx[1], :] *= 1.0 - alpha
|
||||
img_array[idx[0], idx[1], :] += alpha * color_mask
|
||||
|
||||
if not draw_box:
|
||||
center_y, center_x = ndimage.measurements.center_of_mass(mask)
|
||||
label_text = "{}".format(catid2name[catid])
|
||||
vis_pos = (max(int(center_x) - 10, 0), int(center_y))
|
||||
cv2.putText(img_array, label_text, vis_pos,
|
||||
cv2.FONT_HERSHEY_COMPLEX, 0.3, (255, 255, 255))
|
||||
else:
|
||||
mask = mask_util.decode(segm) * 255
|
||||
sum_x = np.sum(mask, axis=0)
|
||||
x = np.where(sum_x > 0.5)[0]
|
||||
sum_y = np.sum(mask, axis=1)
|
||||
y = np.where(sum_y > 0.5)[0]
|
||||
x0, x1, y0, y1 = x[0], x[-1], y[0], y[-1]
|
||||
cv2.rectangle(img_array, (x0, y0), (x1, y1),
|
||||
tuple(color_mask.astype('int32').tolist()), 1)
|
||||
bbox_text = '%s %.2f' % (catid2name[catid], score)
|
||||
t_size = cv2.getTextSize(bbox_text, 0, 0.3, thickness=1)[0]
|
||||
cv2.rectangle(img_array, (x0, y0), (x0 + t_size[0],
|
||||
y0 - t_size[1] - 3),
|
||||
tuple(color_mask.astype('int32').tolist()), -1)
|
||||
cv2.putText(
|
||||
img_array,
|
||||
bbox_text, (x0, y0 - 2),
|
||||
cv2.FONT_HERSHEY_SIMPLEX,
|
||||
0.3, (0, 0, 0),
|
||||
1,
|
||||
lineType=cv2.LINE_AA)
|
||||
|
||||
return Image.fromarray(img_array.astype('uint8'))
|
||||
|
||||
|
||||
def draw_pose(image,
|
||||
results,
|
||||
visual_thread=0.6,
|
||||
save_name='pose.jpg',
|
||||
save_dir='output',
|
||||
returnimg=False,
|
||||
ids=None):
|
||||
try:
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib
|
||||
plt.switch_backend('agg')
|
||||
except Exception as e:
|
||||
logger.error('Matplotlib not found, please install matplotlib.'
|
||||
'for example: `pip install matplotlib`.')
|
||||
raise e
|
||||
|
||||
skeletons = np.array([item['keypoints'] for item in results])
|
||||
kpt_nums = 17
|
||||
if len(skeletons) > 0:
|
||||
kpt_nums = int(skeletons.shape[1] / 3)
|
||||
skeletons = skeletons.reshape(-1, kpt_nums, 3)
|
||||
if kpt_nums == 17: #plot coco keypoint
|
||||
EDGES = [(0, 1), (0, 2), (1, 3), (2, 4), (3, 5), (4, 6), (5, 7), (6, 8),
|
||||
(7, 9), (8, 10), (5, 11), (6, 12), (11, 13), (12, 14),
|
||||
(13, 15), (14, 16), (11, 12)]
|
||||
else: #plot mpii keypoint
|
||||
EDGES = [(0, 1), (1, 2), (3, 4), (4, 5), (2, 6), (3, 6), (6, 7), (7, 8),
|
||||
(8, 9), (10, 11), (11, 12), (13, 14), (14, 15), (8, 12),
|
||||
(8, 13)]
|
||||
NUM_EDGES = len(EDGES)
|
||||
|
||||
colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \
|
||||
[0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \
|
||||
[170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
|
||||
cmap = matplotlib.cm.get_cmap('hsv')
|
||||
plt.figure()
|
||||
|
||||
img = np.array(image).astype('float32')
|
||||
|
||||
color_set = results['colors'] if 'colors' in results else None
|
||||
|
||||
if 'bbox' in results and ids is None:
|
||||
bboxs = results['bbox']
|
||||
for j, rect in enumerate(bboxs):
|
||||
xmin, ymin, xmax, ymax = rect
|
||||
color = colors[0] if color_set is None else colors[color_set[j] %
|
||||
len(colors)]
|
||||
cv2.rectangle(img, (xmin, ymin), (xmax, ymax), color, 1)
|
||||
|
||||
canvas = img.copy()
|
||||
for i in range(kpt_nums):
|
||||
for j in range(len(skeletons)):
|
||||
if skeletons[j][i, 2] < visual_thread:
|
||||
continue
|
||||
if ids is None:
|
||||
color = colors[i] if color_set is None else colors[color_set[j]
|
||||
%
|
||||
len(colors)]
|
||||
else:
|
||||
color = get_color(ids[j])
|
||||
|
||||
cv2.circle(
|
||||
canvas,
|
||||
tuple(skeletons[j][i, 0:2].astype('int32')),
|
||||
2,
|
||||
color,
|
||||
thickness=-1)
|
||||
|
||||
to_plot = cv2.addWeighted(img, 0.3, canvas, 0.7, 0)
|
||||
fig = matplotlib.pyplot.gcf()
|
||||
|
||||
stickwidth = 2
|
||||
|
||||
for i in range(NUM_EDGES):
|
||||
for j in range(len(skeletons)):
|
||||
edge = EDGES[i]
|
||||
if skeletons[j][edge[0], 2] < visual_thread or skeletons[j][edge[
|
||||
1], 2] < visual_thread:
|
||||
continue
|
||||
|
||||
cur_canvas = canvas.copy()
|
||||
X = [skeletons[j][edge[0], 1], skeletons[j][edge[1], 1]]
|
||||
Y = [skeletons[j][edge[0], 0], skeletons[j][edge[1], 0]]
|
||||
mX = np.mean(X)
|
||||
mY = np.mean(Y)
|
||||
length = ((X[0] - X[1])**2 + (Y[0] - Y[1])**2)**0.5
|
||||
angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
|
||||
polygon = cv2.ellipse2Poly((int(mY), int(mX)),
|
||||
(int(length / 2), stickwidth),
|
||||
int(angle), 0, 360, 1)
|
||||
if ids is None:
|
||||
color = colors[i] if color_set is None else colors[color_set[j]
|
||||
%
|
||||
len(colors)]
|
||||
else:
|
||||
color = get_color(ids[j])
|
||||
cv2.fillConvexPoly(cur_canvas, polygon, color)
|
||||
canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0)
|
||||
image = Image.fromarray(canvas.astype('uint8'))
|
||||
plt.close()
|
||||
return image
|
||||
|
||||
|
||||
def draw_pose3d(image,
|
||||
pose3d,
|
||||
pose2d=None,
|
||||
visual_thread=0.6,
|
||||
save_name='pose3d.jpg',
|
||||
returnimg=True):
|
||||
try:
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib
|
||||
plt.switch_backend('agg')
|
||||
except Exception as e:
|
||||
logger.error('Matplotlib not found, please install matplotlib.'
|
||||
'for example: `pip install matplotlib`.')
|
||||
raise e
|
||||
|
||||
if pose3d.shape[0] == 24:
|
||||
joints_connectivity_dict = [
|
||||
[0, 1, 0], [1, 2, 0], [5, 4, 1], [4, 3, 1], [2, 3, 0], [2, 14, 1],
|
||||
[3, 14, 1], [14, 16, 1], [15, 16, 1], [15, 12, 1], [6, 7, 0],
|
||||
[7, 8, 0], [11, 10, 1], [10, 9, 1], [8, 12, 0], [9, 12, 1],
|
||||
[12, 19, 1], [19, 18, 1], [19, 20, 0], [19, 21, 1], [22, 20, 0],
|
||||
[23, 21, 1]
|
||||
]
|
||||
elif pose3d.shape[0] == 14:
|
||||
joints_connectivity_dict = [
|
||||
[0, 1, 0], [1, 2, 0], [5, 4, 1], [4, 3, 1], [2, 3, 0], [2, 12, 0],
|
||||
[3, 12, 1], [6, 7, 0], [7, 8, 0], [11, 10, 1], [10, 9, 1],
|
||||
[8, 12, 0], [9, 12, 1], [12, 13, 1]
|
||||
]
|
||||
else:
|
||||
print(
|
||||
"not defined joints number :{}, cannot visualize because unknown of joint connectivity".
|
||||
format(pose.shape[0]))
|
||||
return
|
||||
|
||||
def draw3Dpose(pose3d,
|
||||
ax,
|
||||
lcolor="#3498db",
|
||||
rcolor="#e74c3c",
|
||||
add_labels=False):
|
||||
# pose3d = orthographic_projection(pose3d, cam)
|
||||
for i in joints_connectivity_dict:
|
||||
x, y, z = [
|
||||
np.array([pose3d[i[0], j], pose3d[i[1], j]]) for j in range(3)
|
||||
]
|
||||
ax.plot(-x, -z, -y, lw=2, c=lcolor if i[2] else rcolor)
|
||||
|
||||
RADIUS = 1000
|
||||
center_xy = 2 if pose3d.shape[0] == 14 else 14
|
||||
x, y, z = pose3d[center_xy, 0], pose3d[center_xy, 1], pose3d[center_xy,
|
||||
2]
|
||||
ax.set_xlim3d([-RADIUS + x, RADIUS + x])
|
||||
ax.set_ylim3d([-RADIUS + y, RADIUS + y])
|
||||
ax.set_zlim3d([-RADIUS + z, RADIUS + z])
|
||||
|
||||
ax.set_xlabel("x")
|
||||
ax.set_ylabel("y")
|
||||
ax.set_zlabel("z")
|
||||
|
||||
def draw2Dpose(pose2d,
|
||||
ax,
|
||||
lcolor="#3498db",
|
||||
rcolor="#e74c3c",
|
||||
add_labels=False):
|
||||
for i in joints_connectivity_dict:
|
||||
if pose2d[i[0], 2] and pose2d[i[1], 2]:
|
||||
x, y = [
|
||||
np.array([pose2d[i[0], j], pose2d[i[1], j]])
|
||||
for j in range(2)
|
||||
]
|
||||
ax.plot(x, y, 0, lw=2, c=lcolor if i[2] else rcolor)
|
||||
|
||||
def draw_img_pose(pose3d,
|
||||
pose2d=None,
|
||||
frame=None,
|
||||
figsize=(12, 12),
|
||||
savepath=None):
|
||||
fig = plt.figure(figsize=figsize, dpi=80)
|
||||
# fig.clear()
|
||||
fig.tight_layout()
|
||||
|
||||
ax = fig.add_subplot(221)
|
||||
if frame is not None:
|
||||
ax.imshow(frame, interpolation='nearest')
|
||||
if pose2d is not None:
|
||||
draw2Dpose(pose2d, ax)
|
||||
|
||||
ax = fig.add_subplot(222, projection='3d')
|
||||
ax.view_init(45, 45)
|
||||
draw3Dpose(pose3d, ax)
|
||||
ax = fig.add_subplot(223, projection='3d')
|
||||
ax.view_init(0, 0)
|
||||
draw3Dpose(pose3d, ax)
|
||||
ax = fig.add_subplot(224, projection='3d')
|
||||
ax.view_init(0, 90)
|
||||
draw3Dpose(pose3d, ax)
|
||||
|
||||
if savepath is not None:
|
||||
plt.savefig(savepath)
|
||||
plt.close()
|
||||
else:
|
||||
return fig
|
||||
|
||||
def fig2data(fig):
|
||||
"""
|
||||
fig = plt.figure()
|
||||
image = fig2data(fig)
|
||||
@brief Convert a Matplotlib figure to a 4D numpy array with RGBA channels and return it
|
||||
@param fig a matplotlib figure
|
||||
@return a numpy 3D array of RGBA values
|
||||
"""
|
||||
# draw the renderer
|
||||
fig.canvas.draw()
|
||||
|
||||
# Get the RGBA buffer from the figure
|
||||
w, h = fig.canvas.get_width_height()
|
||||
buf = np.fromstring(fig.canvas.tostring_argb(), dtype=np.uint8)
|
||||
buf.shape = (w, h, 4)
|
||||
|
||||
# canvas.tostring_argb give pixmap in ARGB mode. Roll the ALPHA channel to have it in RGBA mode
|
||||
buf = np.roll(buf, 3, axis=2)
|
||||
image = Image.frombytes("RGBA", (w, h), buf.tostring())
|
||||
return image.convert("RGB")
|
||||
|
||||
fig = draw_img_pose(pose3d, pose2d, frame=image)
|
||||
data = fig2data(fig)
|
||||
if returnimg is False:
|
||||
data.save(save_name)
|
||||
else:
|
||||
return data
|
||||
@@ -0,0 +1,86 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import os.path as osp
|
||||
import re
|
||||
import random
|
||||
|
||||
__all__ = ['create_list']
|
||||
|
||||
|
||||
def create_list(devkit_dir, years, output_dir):
|
||||
"""
|
||||
create following list:
|
||||
1. trainval.txt
|
||||
2. test.txt
|
||||
"""
|
||||
trainval_list = []
|
||||
test_list = []
|
||||
for year in years:
|
||||
trainval, test = _walk_voc_dir(devkit_dir, year, output_dir)
|
||||
trainval_list.extend(trainval)
|
||||
test_list.extend(test)
|
||||
|
||||
random.shuffle(trainval_list)
|
||||
with open(osp.join(output_dir, 'trainval.txt'), 'w') as ftrainval:
|
||||
for item in trainval_list:
|
||||
ftrainval.write(item[0] + ' ' + item[1] + '\n')
|
||||
|
||||
with open(osp.join(output_dir, 'test.txt'), 'w') as fval:
|
||||
ct = 0
|
||||
for item in test_list:
|
||||
ct += 1
|
||||
fval.write(item[0] + ' ' + item[1] + '\n')
|
||||
|
||||
|
||||
def _get_voc_dir(devkit_dir, year, type):
|
||||
return osp.join(devkit_dir, 'VOC' + year, type)
|
||||
|
||||
|
||||
def _walk_voc_dir(devkit_dir, year, output_dir):
|
||||
filelist_dir = _get_voc_dir(devkit_dir, year, 'ImageSets/Main')
|
||||
annotation_dir = _get_voc_dir(devkit_dir, year, 'Annotations')
|
||||
img_dir = _get_voc_dir(devkit_dir, year, 'JPEGImages')
|
||||
trainval_list = []
|
||||
test_list = []
|
||||
added = set()
|
||||
|
||||
for _, _, files in os.walk(filelist_dir):
|
||||
for fname in files:
|
||||
img_ann_list = []
|
||||
if re.match(r'[a-z]+_trainval\.txt', fname):
|
||||
img_ann_list = trainval_list
|
||||
elif re.match(r'[a-z]+_test\.txt', fname):
|
||||
img_ann_list = test_list
|
||||
else:
|
||||
continue
|
||||
fpath = osp.join(filelist_dir, fname)
|
||||
for line in open(fpath):
|
||||
name_prefix = line.strip().split()[0]
|
||||
if name_prefix in added:
|
||||
continue
|
||||
added.add(name_prefix)
|
||||
ann_path = osp.join(
|
||||
osp.relpath(annotation_dir, output_dir),
|
||||
name_prefix + '.xml')
|
||||
img_path = osp.join(
|
||||
osp.relpath(img_dir, output_dir), name_prefix + '.jpg')
|
||||
img_ann_list.append((img_path, ann_path))
|
||||
|
||||
return trainval_list, test_list
|
||||
Reference in New Issue
Block a user