更换文档检测模型
This commit is contained in:
33
paddle_detection/ppdet/data/source/__init__.py
Normal file
33
paddle_detection/ppdet/data/source/__init__.py
Normal file
@@ -0,0 +1,33 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from . import coco
|
||||
from . import voc
|
||||
from . import widerface
|
||||
from . import category
|
||||
from . import keypoint_coco
|
||||
from . import mot
|
||||
from . import sniper_coco
|
||||
from . import culane
|
||||
|
||||
from .coco import *
|
||||
from .voc import *
|
||||
from .widerface import *
|
||||
from .category import *
|
||||
from .keypoint_coco import *
|
||||
from .mot import *
|
||||
from .sniper_coco import SniperCOCODataSet
|
||||
from .dataset import ImageFolder
|
||||
from .pose3d_cmb import *
|
||||
from .culane import *
|
||||
942
paddle_detection/ppdet/data/source/category.py
Normal file
942
paddle_detection/ppdet/data/source/category.py
Normal file
@@ -0,0 +1,942 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
|
||||
from ppdet.data.source.voc import pascalvoc_label
|
||||
from ppdet.data.source.widerface import widerface_label
|
||||
from ppdet.utils.logger import setup_logger
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
__all__ = ['get_categories']
|
||||
|
||||
|
||||
def get_categories(metric_type, anno_file=None, arch=None):
|
||||
"""
|
||||
Get class id to category id map and category id
|
||||
to category name map from annotation file.
|
||||
|
||||
Args:
|
||||
metric_type (str): metric type, currently support 'coco', 'voc', 'oid'
|
||||
and 'widerface'.
|
||||
anno_file (str): annotation file path
|
||||
"""
|
||||
if arch == 'keypoint_arch':
|
||||
return (None, {'id': 'keypoint'})
|
||||
|
||||
if anno_file == None or (not os.path.isfile(anno_file)):
|
||||
logger.warning(
|
||||
"anno_file '{}' is None or not set or not exist, "
|
||||
"please recheck TrainDataset/EvalDataset/TestDataset.anno_path, "
|
||||
"otherwise the default categories will be used by metric_type.".
|
||||
format(anno_file))
|
||||
|
||||
if metric_type.lower() == 'coco' or metric_type.lower(
|
||||
) == 'rbox' or metric_type.lower() == 'snipercoco':
|
||||
if anno_file and os.path.isfile(anno_file):
|
||||
if anno_file.endswith('json'):
|
||||
# lazy import pycocotools here
|
||||
from pycocotools.coco import COCO
|
||||
coco = COCO(anno_file)
|
||||
cats = coco.loadCats(coco.getCatIds())
|
||||
|
||||
clsid2catid = {i: cat['id'] for i, cat in enumerate(cats)}
|
||||
catid2name = {cat['id']: cat['name'] for cat in cats}
|
||||
|
||||
elif anno_file.endswith('txt'):
|
||||
cats = []
|
||||
with open(anno_file) as f:
|
||||
for line in f.readlines():
|
||||
cats.append(line.strip())
|
||||
if cats[0] == 'background': cats = cats[1:]
|
||||
|
||||
clsid2catid = {i: i for i in range(len(cats))}
|
||||
catid2name = {i: name for i, name in enumerate(cats)}
|
||||
|
||||
else:
|
||||
raise ValueError("anno_file {} should be json or txt.".format(
|
||||
anno_file))
|
||||
return clsid2catid, catid2name
|
||||
|
||||
# anno file not exist, load default categories of COCO17
|
||||
else:
|
||||
if metric_type.lower() == 'rbox':
|
||||
logger.warning(
|
||||
"metric_type: {}, load default categories of DOTA.".format(
|
||||
metric_type))
|
||||
return _dota_category()
|
||||
logger.warning("metric_type: {}, load default categories of COCO.".
|
||||
format(metric_type))
|
||||
return _coco17_category()
|
||||
|
||||
elif metric_type.lower() == 'voc':
|
||||
if anno_file and os.path.isfile(anno_file):
|
||||
cats = []
|
||||
with open(anno_file) as f:
|
||||
for line in f.readlines():
|
||||
cats.append(line.strip())
|
||||
|
||||
if cats[0] == 'background':
|
||||
cats = cats[1:]
|
||||
|
||||
clsid2catid = {i: i for i in range(len(cats))}
|
||||
catid2name = {i: name for i, name in enumerate(cats)}
|
||||
|
||||
return clsid2catid, catid2name
|
||||
|
||||
# anno file not exist, load default categories of
|
||||
# VOC all 20 categories
|
||||
else:
|
||||
logger.warning("metric_type: {}, load default categories of VOC.".
|
||||
format(metric_type))
|
||||
return _vocall_category()
|
||||
|
||||
elif metric_type.lower() == 'oid':
|
||||
if anno_file and os.path.isfile(anno_file):
|
||||
logger.warning("only default categories support for OID19")
|
||||
return _oid19_category()
|
||||
|
||||
elif metric_type.lower() == 'widerface':
|
||||
return _widerface_category()
|
||||
|
||||
elif metric_type.lower() in [
|
||||
'keypointtopdowncocoeval', 'keypointtopdownmpiieval',
|
||||
'keypointtopdowncocowholebadyhandeval'
|
||||
]:
|
||||
return (None, {'id': 'keypoint'})
|
||||
|
||||
elif metric_type.lower() == 'pose3deval':
|
||||
return (None, {'id': 'pose3d'})
|
||||
|
||||
elif metric_type.lower() in ['mot', 'motdet', 'reid']:
|
||||
if anno_file and os.path.isfile(anno_file):
|
||||
cats = []
|
||||
with open(anno_file) as f:
|
||||
for line in f.readlines():
|
||||
cats.append(line.strip())
|
||||
if cats[0] == 'background':
|
||||
cats = cats[1:]
|
||||
clsid2catid = {i: i for i in range(len(cats))}
|
||||
catid2name = {i: name for i, name in enumerate(cats)}
|
||||
return clsid2catid, catid2name
|
||||
# anno file not exist, load default category 'pedestrian'.
|
||||
else:
|
||||
logger.warning(
|
||||
"metric_type: {}, load default categories of pedestrian MOT.".
|
||||
format(metric_type))
|
||||
return _mot_category(category='pedestrian')
|
||||
|
||||
elif metric_type.lower() in ['kitti', 'bdd100kmot']:
|
||||
return _mot_category(category='vehicle')
|
||||
|
||||
elif metric_type.lower() in ['mcmot']:
|
||||
if anno_file and os.path.isfile(anno_file):
|
||||
cats = []
|
||||
with open(anno_file) as f:
|
||||
for line in f.readlines():
|
||||
cats.append(line.strip())
|
||||
if cats[0] == 'background':
|
||||
cats = cats[1:]
|
||||
clsid2catid = {i: i for i in range(len(cats))}
|
||||
catid2name = {i: name for i, name in enumerate(cats)}
|
||||
return clsid2catid, catid2name
|
||||
# anno file not exist, load default categories of visdrone all 10 categories
|
||||
else:
|
||||
logger.warning(
|
||||
"metric_type: {}, load default categories of VisDrone.".format(
|
||||
metric_type))
|
||||
return _visdrone_category()
|
||||
|
||||
else:
|
||||
raise ValueError("unknown metric type {}".format(metric_type))
|
||||
|
||||
|
||||
def _mot_category(category='pedestrian'):
|
||||
"""
|
||||
Get class id to category id map and category id
|
||||
to category name map of mot dataset
|
||||
"""
|
||||
label_map = {category: 0}
|
||||
label_map = sorted(label_map.items(), key=lambda x: x[1])
|
||||
cats = [l[0] for l in label_map]
|
||||
|
||||
clsid2catid = {i: i for i in range(len(cats))}
|
||||
catid2name = {i: name for i, name in enumerate(cats)}
|
||||
|
||||
return clsid2catid, catid2name
|
||||
|
||||
|
||||
def _coco17_category():
|
||||
"""
|
||||
Get class id to category id map and category id
|
||||
to category name map of COCO2017 dataset
|
||||
|
||||
"""
|
||||
clsid2catid = {
|
||||
1: 1,
|
||||
2: 2,
|
||||
3: 3,
|
||||
4: 4,
|
||||
5: 5,
|
||||
6: 6,
|
||||
7: 7,
|
||||
8: 8,
|
||||
9: 9,
|
||||
10: 10,
|
||||
11: 11,
|
||||
12: 13,
|
||||
13: 14,
|
||||
14: 15,
|
||||
15: 16,
|
||||
16: 17,
|
||||
17: 18,
|
||||
18: 19,
|
||||
19: 20,
|
||||
20: 21,
|
||||
21: 22,
|
||||
22: 23,
|
||||
23: 24,
|
||||
24: 25,
|
||||
25: 27,
|
||||
26: 28,
|
||||
27: 31,
|
||||
28: 32,
|
||||
29: 33,
|
||||
30: 34,
|
||||
31: 35,
|
||||
32: 36,
|
||||
33: 37,
|
||||
34: 38,
|
||||
35: 39,
|
||||
36: 40,
|
||||
37: 41,
|
||||
38: 42,
|
||||
39: 43,
|
||||
40: 44,
|
||||
41: 46,
|
||||
42: 47,
|
||||
43: 48,
|
||||
44: 49,
|
||||
45: 50,
|
||||
46: 51,
|
||||
47: 52,
|
||||
48: 53,
|
||||
49: 54,
|
||||
50: 55,
|
||||
51: 56,
|
||||
52: 57,
|
||||
53: 58,
|
||||
54: 59,
|
||||
55: 60,
|
||||
56: 61,
|
||||
57: 62,
|
||||
58: 63,
|
||||
59: 64,
|
||||
60: 65,
|
||||
61: 67,
|
||||
62: 70,
|
||||
63: 72,
|
||||
64: 73,
|
||||
65: 74,
|
||||
66: 75,
|
||||
67: 76,
|
||||
68: 77,
|
||||
69: 78,
|
||||
70: 79,
|
||||
71: 80,
|
||||
72: 81,
|
||||
73: 82,
|
||||
74: 84,
|
||||
75: 85,
|
||||
76: 86,
|
||||
77: 87,
|
||||
78: 88,
|
||||
79: 89,
|
||||
80: 90
|
||||
}
|
||||
|
||||
catid2name = {
|
||||
0: 'background',
|
||||
1: 'person',
|
||||
2: 'bicycle',
|
||||
3: 'car',
|
||||
4: 'motorcycle',
|
||||
5: 'airplane',
|
||||
6: 'bus',
|
||||
7: 'train',
|
||||
8: 'truck',
|
||||
9: 'boat',
|
||||
10: 'traffic light',
|
||||
11: 'fire hydrant',
|
||||
13: 'stop sign',
|
||||
14: 'parking meter',
|
||||
15: 'bench',
|
||||
16: 'bird',
|
||||
17: 'cat',
|
||||
18: 'dog',
|
||||
19: 'horse',
|
||||
20: 'sheep',
|
||||
21: 'cow',
|
||||
22: 'elephant',
|
||||
23: 'bear',
|
||||
24: 'zebra',
|
||||
25: 'giraffe',
|
||||
27: 'backpack',
|
||||
28: 'umbrella',
|
||||
31: 'handbag',
|
||||
32: 'tie',
|
||||
33: 'suitcase',
|
||||
34: 'frisbee',
|
||||
35: 'skis',
|
||||
36: 'snowboard',
|
||||
37: 'sports ball',
|
||||
38: 'kite',
|
||||
39: 'baseball bat',
|
||||
40: 'baseball glove',
|
||||
41: 'skateboard',
|
||||
42: 'surfboard',
|
||||
43: 'tennis racket',
|
||||
44: 'bottle',
|
||||
46: 'wine glass',
|
||||
47: 'cup',
|
||||
48: 'fork',
|
||||
49: 'knife',
|
||||
50: 'spoon',
|
||||
51: 'bowl',
|
||||
52: 'banana',
|
||||
53: 'apple',
|
||||
54: 'sandwich',
|
||||
55: 'orange',
|
||||
56: 'broccoli',
|
||||
57: 'carrot',
|
||||
58: 'hot dog',
|
||||
59: 'pizza',
|
||||
60: 'donut',
|
||||
61: 'cake',
|
||||
62: 'chair',
|
||||
63: 'couch',
|
||||
64: 'potted plant',
|
||||
65: 'bed',
|
||||
67: 'dining table',
|
||||
70: 'toilet',
|
||||
72: 'tv',
|
||||
73: 'laptop',
|
||||
74: 'mouse',
|
||||
75: 'remote',
|
||||
76: 'keyboard',
|
||||
77: 'cell phone',
|
||||
78: 'microwave',
|
||||
79: 'oven',
|
||||
80: 'toaster',
|
||||
81: 'sink',
|
||||
82: 'refrigerator',
|
||||
84: 'book',
|
||||
85: 'clock',
|
||||
86: 'vase',
|
||||
87: 'scissors',
|
||||
88: 'teddy bear',
|
||||
89: 'hair drier',
|
||||
90: 'toothbrush'
|
||||
}
|
||||
|
||||
clsid2catid = {k - 1: v for k, v in clsid2catid.items()}
|
||||
catid2name.pop(0)
|
||||
|
||||
return clsid2catid, catid2name
|
||||
|
||||
|
||||
def _dota_category():
|
||||
"""
|
||||
Get class id to category id map and category id
|
||||
to category name map of dota dataset
|
||||
"""
|
||||
catid2name = {
|
||||
0: 'background',
|
||||
1: 'plane',
|
||||
2: 'baseball-diamond',
|
||||
3: 'bridge',
|
||||
4: 'ground-track-field',
|
||||
5: 'small-vehicle',
|
||||
6: 'large-vehicle',
|
||||
7: 'ship',
|
||||
8: 'tennis-court',
|
||||
9: 'basketball-court',
|
||||
10: 'storage-tank',
|
||||
11: 'soccer-ball-field',
|
||||
12: 'roundabout',
|
||||
13: 'harbor',
|
||||
14: 'swimming-pool',
|
||||
15: 'helicopter'
|
||||
}
|
||||
catid2name.pop(0)
|
||||
clsid2catid = {i: i + 1 for i in range(len(catid2name))}
|
||||
return clsid2catid, catid2name
|
||||
|
||||
|
||||
def _vocall_category():
|
||||
"""
|
||||
Get class id to category id map and category id
|
||||
to category name map of mixup voc dataset
|
||||
|
||||
"""
|
||||
label_map = pascalvoc_label()
|
||||
label_map = sorted(label_map.items(), key=lambda x: x[1])
|
||||
cats = [l[0] for l in label_map]
|
||||
|
||||
clsid2catid = {i: i for i in range(len(cats))}
|
||||
catid2name = {i: name for i, name in enumerate(cats)}
|
||||
|
||||
return clsid2catid, catid2name
|
||||
|
||||
|
||||
def _widerface_category():
|
||||
label_map = widerface_label()
|
||||
label_map = sorted(label_map.items(), key=lambda x: x[1])
|
||||
cats = [l[0] for l in label_map]
|
||||
clsid2catid = {i: i for i in range(len(cats))}
|
||||
catid2name = {i: name for i, name in enumerate(cats)}
|
||||
|
||||
return clsid2catid, catid2name
|
||||
|
||||
|
||||
def _oid19_category():
|
||||
clsid2catid = {k: k + 1 for k in range(500)}
|
||||
|
||||
catid2name = {
|
||||
0: "background",
|
||||
1: "Infant bed",
|
||||
2: "Rose",
|
||||
3: "Flag",
|
||||
4: "Flashlight",
|
||||
5: "Sea turtle",
|
||||
6: "Camera",
|
||||
7: "Animal",
|
||||
8: "Glove",
|
||||
9: "Crocodile",
|
||||
10: "Cattle",
|
||||
11: "House",
|
||||
12: "Guacamole",
|
||||
13: "Penguin",
|
||||
14: "Vehicle registration plate",
|
||||
15: "Bench",
|
||||
16: "Ladybug",
|
||||
17: "Human nose",
|
||||
18: "Watermelon",
|
||||
19: "Flute",
|
||||
20: "Butterfly",
|
||||
21: "Washing machine",
|
||||
22: "Raccoon",
|
||||
23: "Segway",
|
||||
24: "Taco",
|
||||
25: "Jellyfish",
|
||||
26: "Cake",
|
||||
27: "Pen",
|
||||
28: "Cannon",
|
||||
29: "Bread",
|
||||
30: "Tree",
|
||||
31: "Shellfish",
|
||||
32: "Bed",
|
||||
33: "Hamster",
|
||||
34: "Hat",
|
||||
35: "Toaster",
|
||||
36: "Sombrero",
|
||||
37: "Tiara",
|
||||
38: "Bowl",
|
||||
39: "Dragonfly",
|
||||
40: "Moths and butterflies",
|
||||
41: "Antelope",
|
||||
42: "Vegetable",
|
||||
43: "Torch",
|
||||
44: "Building",
|
||||
45: "Power plugs and sockets",
|
||||
46: "Blender",
|
||||
47: "Billiard table",
|
||||
48: "Cutting board",
|
||||
49: "Bronze sculpture",
|
||||
50: "Turtle",
|
||||
51: "Broccoli",
|
||||
52: "Tiger",
|
||||
53: "Mirror",
|
||||
54: "Bear",
|
||||
55: "Zucchini",
|
||||
56: "Dress",
|
||||
57: "Volleyball",
|
||||
58: "Guitar",
|
||||
59: "Reptile",
|
||||
60: "Golf cart",
|
||||
61: "Tart",
|
||||
62: "Fedora",
|
||||
63: "Carnivore",
|
||||
64: "Car",
|
||||
65: "Lighthouse",
|
||||
66: "Coffeemaker",
|
||||
67: "Food processor",
|
||||
68: "Truck",
|
||||
69: "Bookcase",
|
||||
70: "Surfboard",
|
||||
71: "Footwear",
|
||||
72: "Bench",
|
||||
73: "Necklace",
|
||||
74: "Flower",
|
||||
75: "Radish",
|
||||
76: "Marine mammal",
|
||||
77: "Frying pan",
|
||||
78: "Tap",
|
||||
79: "Peach",
|
||||
80: "Knife",
|
||||
81: "Handbag",
|
||||
82: "Laptop",
|
||||
83: "Tent",
|
||||
84: "Ambulance",
|
||||
85: "Christmas tree",
|
||||
86: "Eagle",
|
||||
87: "Limousine",
|
||||
88: "Kitchen & dining room table",
|
||||
89: "Polar bear",
|
||||
90: "Tower",
|
||||
91: "Football",
|
||||
92: "Willow",
|
||||
93: "Human head",
|
||||
94: "Stop sign",
|
||||
95: "Banana",
|
||||
96: "Mixer",
|
||||
97: "Binoculars",
|
||||
98: "Dessert",
|
||||
99: "Bee",
|
||||
100: "Chair",
|
||||
101: "Wood-burning stove",
|
||||
102: "Flowerpot",
|
||||
103: "Beaker",
|
||||
104: "Oyster",
|
||||
105: "Woodpecker",
|
||||
106: "Harp",
|
||||
107: "Bathtub",
|
||||
108: "Wall clock",
|
||||
109: "Sports uniform",
|
||||
110: "Rhinoceros",
|
||||
111: "Beehive",
|
||||
112: "Cupboard",
|
||||
113: "Chicken",
|
||||
114: "Man",
|
||||
115: "Blue jay",
|
||||
116: "Cucumber",
|
||||
117: "Balloon",
|
||||
118: "Kite",
|
||||
119: "Fireplace",
|
||||
120: "Lantern",
|
||||
121: "Missile",
|
||||
122: "Book",
|
||||
123: "Spoon",
|
||||
124: "Grapefruit",
|
||||
125: "Squirrel",
|
||||
126: "Orange",
|
||||
127: "Coat",
|
||||
128: "Punching bag",
|
||||
129: "Zebra",
|
||||
130: "Billboard",
|
||||
131: "Bicycle",
|
||||
132: "Door handle",
|
||||
133: "Mechanical fan",
|
||||
134: "Ring binder",
|
||||
135: "Table",
|
||||
136: "Parrot",
|
||||
137: "Sock",
|
||||
138: "Vase",
|
||||
139: "Weapon",
|
||||
140: "Shotgun",
|
||||
141: "Glasses",
|
||||
142: "Seahorse",
|
||||
143: "Belt",
|
||||
144: "Watercraft",
|
||||
145: "Window",
|
||||
146: "Giraffe",
|
||||
147: "Lion",
|
||||
148: "Tire",
|
||||
149: "Vehicle",
|
||||
150: "Canoe",
|
||||
151: "Tie",
|
||||
152: "Shelf",
|
||||
153: "Picture frame",
|
||||
154: "Printer",
|
||||
155: "Human leg",
|
||||
156: "Boat",
|
||||
157: "Slow cooker",
|
||||
158: "Croissant",
|
||||
159: "Candle",
|
||||
160: "Pancake",
|
||||
161: "Pillow",
|
||||
162: "Coin",
|
||||
163: "Stretcher",
|
||||
164: "Sandal",
|
||||
165: "Woman",
|
||||
166: "Stairs",
|
||||
167: "Harpsichord",
|
||||
168: "Stool",
|
||||
169: "Bus",
|
||||
170: "Suitcase",
|
||||
171: "Human mouth",
|
||||
172: "Juice",
|
||||
173: "Skull",
|
||||
174: "Door",
|
||||
175: "Violin",
|
||||
176: "Chopsticks",
|
||||
177: "Digital clock",
|
||||
178: "Sunflower",
|
||||
179: "Leopard",
|
||||
180: "Bell pepper",
|
||||
181: "Harbor seal",
|
||||
182: "Snake",
|
||||
183: "Sewing machine",
|
||||
184: "Goose",
|
||||
185: "Helicopter",
|
||||
186: "Seat belt",
|
||||
187: "Coffee cup",
|
||||
188: "Microwave oven",
|
||||
189: "Hot dog",
|
||||
190: "Countertop",
|
||||
191: "Serving tray",
|
||||
192: "Dog bed",
|
||||
193: "Beer",
|
||||
194: "Sunglasses",
|
||||
195: "Golf ball",
|
||||
196: "Waffle",
|
||||
197: "Palm tree",
|
||||
198: "Trumpet",
|
||||
199: "Ruler",
|
||||
200: "Helmet",
|
||||
201: "Ladder",
|
||||
202: "Office building",
|
||||
203: "Tablet computer",
|
||||
204: "Toilet paper",
|
||||
205: "Pomegranate",
|
||||
206: "Skirt",
|
||||
207: "Gas stove",
|
||||
208: "Cookie",
|
||||
209: "Cart",
|
||||
210: "Raven",
|
||||
211: "Egg",
|
||||
212: "Burrito",
|
||||
213: "Goat",
|
||||
214: "Kitchen knife",
|
||||
215: "Skateboard",
|
||||
216: "Salt and pepper shakers",
|
||||
217: "Lynx",
|
||||
218: "Boot",
|
||||
219: "Platter",
|
||||
220: "Ski",
|
||||
221: "Swimwear",
|
||||
222: "Swimming pool",
|
||||
223: "Drinking straw",
|
||||
224: "Wrench",
|
||||
225: "Drum",
|
||||
226: "Ant",
|
||||
227: "Human ear",
|
||||
228: "Headphones",
|
||||
229: "Fountain",
|
||||
230: "Bird",
|
||||
231: "Jeans",
|
||||
232: "Television",
|
||||
233: "Crab",
|
||||
234: "Microphone",
|
||||
235: "Home appliance",
|
||||
236: "Snowplow",
|
||||
237: "Beetle",
|
||||
238: "Artichoke",
|
||||
239: "Jet ski",
|
||||
240: "Stationary bicycle",
|
||||
241: "Human hair",
|
||||
242: "Brown bear",
|
||||
243: "Starfish",
|
||||
244: "Fork",
|
||||
245: "Lobster",
|
||||
246: "Corded phone",
|
||||
247: "Drink",
|
||||
248: "Saucer",
|
||||
249: "Carrot",
|
||||
250: "Insect",
|
||||
251: "Clock",
|
||||
252: "Castle",
|
||||
253: "Tennis racket",
|
||||
254: "Ceiling fan",
|
||||
255: "Asparagus",
|
||||
256: "Jaguar",
|
||||
257: "Musical instrument",
|
||||
258: "Train",
|
||||
259: "Cat",
|
||||
260: "Rifle",
|
||||
261: "Dumbbell",
|
||||
262: "Mobile phone",
|
||||
263: "Taxi",
|
||||
264: "Shower",
|
||||
265: "Pitcher",
|
||||
266: "Lemon",
|
||||
267: "Invertebrate",
|
||||
268: "Turkey",
|
||||
269: "High heels",
|
||||
270: "Bust",
|
||||
271: "Elephant",
|
||||
272: "Scarf",
|
||||
273: "Barrel",
|
||||
274: "Trombone",
|
||||
275: "Pumpkin",
|
||||
276: "Box",
|
||||
277: "Tomato",
|
||||
278: "Frog",
|
||||
279: "Bidet",
|
||||
280: "Human face",
|
||||
281: "Houseplant",
|
||||
282: "Van",
|
||||
283: "Shark",
|
||||
284: "Ice cream",
|
||||
285: "Swim cap",
|
||||
286: "Falcon",
|
||||
287: "Ostrich",
|
||||
288: "Handgun",
|
||||
289: "Whiteboard",
|
||||
290: "Lizard",
|
||||
291: "Pasta",
|
||||
292: "Snowmobile",
|
||||
293: "Light bulb",
|
||||
294: "Window blind",
|
||||
295: "Muffin",
|
||||
296: "Pretzel",
|
||||
297: "Computer monitor",
|
||||
298: "Horn",
|
||||
299: "Furniture",
|
||||
300: "Sandwich",
|
||||
301: "Fox",
|
||||
302: "Convenience store",
|
||||
303: "Fish",
|
||||
304: "Fruit",
|
||||
305: "Earrings",
|
||||
306: "Curtain",
|
||||
307: "Grape",
|
||||
308: "Sofa bed",
|
||||
309: "Horse",
|
||||
310: "Luggage and bags",
|
||||
311: "Desk",
|
||||
312: "Crutch",
|
||||
313: "Bicycle helmet",
|
||||
314: "Tick",
|
||||
315: "Airplane",
|
||||
316: "Canary",
|
||||
317: "Spatula",
|
||||
318: "Watch",
|
||||
319: "Lily",
|
||||
320: "Kitchen appliance",
|
||||
321: "Filing cabinet",
|
||||
322: "Aircraft",
|
||||
323: "Cake stand",
|
||||
324: "Candy",
|
||||
325: "Sink",
|
||||
326: "Mouse",
|
||||
327: "Wine",
|
||||
328: "Wheelchair",
|
||||
329: "Goldfish",
|
||||
330: "Refrigerator",
|
||||
331: "French fries",
|
||||
332: "Drawer",
|
||||
333: "Treadmill",
|
||||
334: "Picnic basket",
|
||||
335: "Dice",
|
||||
336: "Cabbage",
|
||||
337: "Football helmet",
|
||||
338: "Pig",
|
||||
339: "Person",
|
||||
340: "Shorts",
|
||||
341: "Gondola",
|
||||
342: "Honeycomb",
|
||||
343: "Doughnut",
|
||||
344: "Chest of drawers",
|
||||
345: "Land vehicle",
|
||||
346: "Bat",
|
||||
347: "Monkey",
|
||||
348: "Dagger",
|
||||
349: "Tableware",
|
||||
350: "Human foot",
|
||||
351: "Mug",
|
||||
352: "Alarm clock",
|
||||
353: "Pressure cooker",
|
||||
354: "Human hand",
|
||||
355: "Tortoise",
|
||||
356: "Baseball glove",
|
||||
357: "Sword",
|
||||
358: "Pear",
|
||||
359: "Miniskirt",
|
||||
360: "Traffic sign",
|
||||
361: "Girl",
|
||||
362: "Roller skates",
|
||||
363: "Dinosaur",
|
||||
364: "Porch",
|
||||
365: "Human beard",
|
||||
366: "Submarine sandwich",
|
||||
367: "Screwdriver",
|
||||
368: "Strawberry",
|
||||
369: "Wine glass",
|
||||
370: "Seafood",
|
||||
371: "Racket",
|
||||
372: "Wheel",
|
||||
373: "Sea lion",
|
||||
374: "Toy",
|
||||
375: "Tea",
|
||||
376: "Tennis ball",
|
||||
377: "Waste container",
|
||||
378: "Mule",
|
||||
379: "Cricket ball",
|
||||
380: "Pineapple",
|
||||
381: "Coconut",
|
||||
382: "Doll",
|
||||
383: "Coffee table",
|
||||
384: "Snowman",
|
||||
385: "Lavender",
|
||||
386: "Shrimp",
|
||||
387: "Maple",
|
||||
388: "Cowboy hat",
|
||||
389: "Goggles",
|
||||
390: "Rugby ball",
|
||||
391: "Caterpillar",
|
||||
392: "Poster",
|
||||
393: "Rocket",
|
||||
394: "Organ",
|
||||
395: "Saxophone",
|
||||
396: "Traffic light",
|
||||
397: "Cocktail",
|
||||
398: "Plastic bag",
|
||||
399: "Squash",
|
||||
400: "Mushroom",
|
||||
401: "Hamburger",
|
||||
402: "Light switch",
|
||||
403: "Parachute",
|
||||
404: "Teddy bear",
|
||||
405: "Winter melon",
|
||||
406: "Deer",
|
||||
407: "Musical keyboard",
|
||||
408: "Plumbing fixture",
|
||||
409: "Scoreboard",
|
||||
410: "Baseball bat",
|
||||
411: "Envelope",
|
||||
412: "Adhesive tape",
|
||||
413: "Briefcase",
|
||||
414: "Paddle",
|
||||
415: "Bow and arrow",
|
||||
416: "Telephone",
|
||||
417: "Sheep",
|
||||
418: "Jacket",
|
||||
419: "Boy",
|
||||
420: "Pizza",
|
||||
421: "Otter",
|
||||
422: "Office supplies",
|
||||
423: "Couch",
|
||||
424: "Cello",
|
||||
425: "Bull",
|
||||
426: "Camel",
|
||||
427: "Ball",
|
||||
428: "Duck",
|
||||
429: "Whale",
|
||||
430: "Shirt",
|
||||
431: "Tank",
|
||||
432: "Motorcycle",
|
||||
433: "Accordion",
|
||||
434: "Owl",
|
||||
435: "Porcupine",
|
||||
436: "Sun hat",
|
||||
437: "Nail",
|
||||
438: "Scissors",
|
||||
439: "Swan",
|
||||
440: "Lamp",
|
||||
441: "Crown",
|
||||
442: "Piano",
|
||||
443: "Sculpture",
|
||||
444: "Cheetah",
|
||||
445: "Oboe",
|
||||
446: "Tin can",
|
||||
447: "Mango",
|
||||
448: "Tripod",
|
||||
449: "Oven",
|
||||
450: "Mouse",
|
||||
451: "Barge",
|
||||
452: "Coffee",
|
||||
453: "Snowboard",
|
||||
454: "Common fig",
|
||||
455: "Salad",
|
||||
456: "Marine invertebrates",
|
||||
457: "Umbrella",
|
||||
458: "Kangaroo",
|
||||
459: "Human arm",
|
||||
460: "Measuring cup",
|
||||
461: "Snail",
|
||||
462: "Loveseat",
|
||||
463: "Suit",
|
||||
464: "Teapot",
|
||||
465: "Bottle",
|
||||
466: "Alpaca",
|
||||
467: "Kettle",
|
||||
468: "Trousers",
|
||||
469: "Popcorn",
|
||||
470: "Centipede",
|
||||
471: "Spider",
|
||||
472: "Sparrow",
|
||||
473: "Plate",
|
||||
474: "Bagel",
|
||||
475: "Personal care",
|
||||
476: "Apple",
|
||||
477: "Brassiere",
|
||||
478: "Bathroom cabinet",
|
||||
479: "studio couch",
|
||||
480: "Computer keyboard",
|
||||
481: "Table tennis racket",
|
||||
482: "Sushi",
|
||||
483: "Cabinetry",
|
||||
484: "Street light",
|
||||
485: "Towel",
|
||||
486: "Nightstand",
|
||||
487: "Rabbit",
|
||||
488: "Dolphin",
|
||||
489: "Dog",
|
||||
490: "Jug",
|
||||
491: "Wok",
|
||||
492: "Fire hydrant",
|
||||
493: "Human eye",
|
||||
494: "Skyscraper",
|
||||
495: "Backpack",
|
||||
496: "Potato",
|
||||
497: "Paper towel",
|
||||
498: "Lifejacket",
|
||||
499: "Bicycle wheel",
|
||||
500: "Toilet",
|
||||
}
|
||||
|
||||
return clsid2catid, catid2name
|
||||
|
||||
|
||||
def _visdrone_category():
|
||||
clsid2catid = {i: i for i in range(10)}
|
||||
|
||||
catid2name = {
|
||||
0: 'pedestrian',
|
||||
1: 'people',
|
||||
2: 'bicycle',
|
||||
3: 'car',
|
||||
4: 'van',
|
||||
5: 'truck',
|
||||
6: 'tricycle',
|
||||
7: 'awning-tricycle',
|
||||
8: 'bus',
|
||||
9: 'motor'
|
||||
}
|
||||
return clsid2catid, catid2name
|
||||
596
paddle_detection/ppdet/data/source/coco.py
Normal file
596
paddle_detection/ppdet/data/source/coco.py
Normal file
@@ -0,0 +1,596 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import copy
|
||||
try:
|
||||
from collections.abc import Sequence
|
||||
except Exception:
|
||||
from collections import Sequence
|
||||
import numpy as np
|
||||
from ppdet.core.workspace import register, serializable
|
||||
from .dataset import DetDataset
|
||||
|
||||
from ppdet.utils.logger import setup_logger
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
__all__ = [
|
||||
'COCODataSet', 'SlicedCOCODataSet', 'SemiCOCODataSet', 'COCODetDataset'
|
||||
]
|
||||
|
||||
|
||||
@register
|
||||
@serializable
|
||||
class COCODataSet(DetDataset):
|
||||
"""
|
||||
Load dataset with COCO format.
|
||||
|
||||
Args:
|
||||
dataset_dir (str): root directory for dataset.
|
||||
image_dir (str): directory for images.
|
||||
anno_path (str): coco annotation file path.
|
||||
data_fields (list): key name of data dictionary, at least have 'image'.
|
||||
sample_num (int): number of samples to load, -1 means all.
|
||||
load_crowd (bool): whether to load crowded ground-truth.
|
||||
False as default
|
||||
allow_empty (bool): whether to load empty entry. False as default
|
||||
empty_ratio (float): the ratio of empty record number to total
|
||||
record's, if empty_ratio is out of [0. ,1.), do not sample the
|
||||
records and use all the empty entries. 1. as default
|
||||
repeat (int): repeat times for dataset, use in benchmark.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dataset_dir=None,
|
||||
image_dir=None,
|
||||
anno_path=None,
|
||||
data_fields=['image'],
|
||||
sample_num=-1,
|
||||
load_crowd=False,
|
||||
allow_empty=False,
|
||||
empty_ratio=1.,
|
||||
repeat=1):
|
||||
super(COCODataSet, self).__init__(
|
||||
dataset_dir,
|
||||
image_dir,
|
||||
anno_path,
|
||||
data_fields,
|
||||
sample_num,
|
||||
repeat=repeat)
|
||||
self.load_image_only = False
|
||||
self.load_semantic = False
|
||||
self.load_crowd = load_crowd
|
||||
self.allow_empty = allow_empty
|
||||
self.empty_ratio = empty_ratio
|
||||
|
||||
def _sample_empty(self, records, num):
|
||||
# if empty_ratio is out of [0. ,1.), do not sample the records
|
||||
if self.empty_ratio < 0. or self.empty_ratio >= 1.:
|
||||
return records
|
||||
import random
|
||||
sample_num = min(
|
||||
int(num * self.empty_ratio / (1 - self.empty_ratio)), len(records))
|
||||
records = random.sample(records, sample_num)
|
||||
return records
|
||||
|
||||
def parse_dataset(self):
|
||||
anno_path = os.path.join(self.dataset_dir, self.anno_path)
|
||||
image_dir = os.path.join(self.dataset_dir, self.image_dir)
|
||||
|
||||
assert anno_path.endswith('.json'), \
|
||||
'invalid coco annotation file: ' + anno_path
|
||||
from pycocotools.coco import COCO
|
||||
coco = COCO(anno_path)
|
||||
img_ids = coco.getImgIds()
|
||||
img_ids.sort()
|
||||
cat_ids = coco.getCatIds()
|
||||
records = []
|
||||
empty_records = []
|
||||
ct = 0
|
||||
|
||||
self.catid2clsid = dict({catid: i for i, catid in enumerate(cat_ids)})
|
||||
self.cname2cid = dict({
|
||||
coco.loadCats(catid)[0]['name']: clsid
|
||||
for catid, clsid in self.catid2clsid.items()
|
||||
})
|
||||
|
||||
if 'annotations' not in coco.dataset:
|
||||
self.load_image_only = True
|
||||
logger.warning('Annotation file: {} does not contains ground truth '
|
||||
'and load image information only.'.format(anno_path))
|
||||
|
||||
for img_id in img_ids:
|
||||
img_anno = coco.loadImgs([img_id])[0]
|
||||
im_fname = img_anno['file_name']
|
||||
im_w = float(img_anno['width'])
|
||||
im_h = float(img_anno['height'])
|
||||
|
||||
im_path = os.path.join(image_dir,
|
||||
im_fname) if image_dir else im_fname
|
||||
is_empty = False
|
||||
if not os.path.exists(im_path):
|
||||
logger.warning('Illegal image file: {}, and it will be '
|
||||
'ignored'.format(im_path))
|
||||
continue
|
||||
|
||||
if im_w < 0 or im_h < 0:
|
||||
logger.warning('Illegal width: {} or height: {} in annotation, '
|
||||
'and im_id: {} will be ignored'.format(
|
||||
im_w, im_h, img_id))
|
||||
continue
|
||||
|
||||
coco_rec = {
|
||||
'im_file': im_path,
|
||||
'im_id': np.array([img_id]),
|
||||
'h': im_h,
|
||||
'w': im_w,
|
||||
} if 'image' in self.data_fields else {}
|
||||
|
||||
if not self.load_image_only:
|
||||
ins_anno_ids = coco.getAnnIds(
|
||||
imgIds=[img_id], iscrowd=None if self.load_crowd else False)
|
||||
instances = coco.loadAnns(ins_anno_ids)
|
||||
|
||||
bboxes = []
|
||||
is_rbox_anno = False
|
||||
for inst in instances:
|
||||
# check gt bbox
|
||||
if inst.get('ignore', False):
|
||||
continue
|
||||
if 'bbox' not in inst.keys():
|
||||
continue
|
||||
else:
|
||||
if not any(np.array(inst['bbox'])):
|
||||
continue
|
||||
|
||||
x1, y1, box_w, box_h = inst['bbox']
|
||||
x2 = x1 + box_w
|
||||
y2 = y1 + box_h
|
||||
eps = 1e-5
|
||||
if inst['area'] > 0 and x2 - x1 > eps and y2 - y1 > eps:
|
||||
inst['clean_bbox'] = [
|
||||
round(float(x), 3) for x in [x1, y1, x2, y2]
|
||||
]
|
||||
bboxes.append(inst)
|
||||
else:
|
||||
logger.warning(
|
||||
'Found an invalid bbox in annotations: im_id: {}, '
|
||||
'area: {} x1: {}, y1: {}, x2: {}, y2: {}.'.format(
|
||||
img_id, float(inst['area']), x1, y1, x2, y2))
|
||||
|
||||
num_bbox = len(bboxes)
|
||||
if num_bbox <= 0 and not self.allow_empty:
|
||||
continue
|
||||
elif num_bbox <= 0:
|
||||
is_empty = True
|
||||
|
||||
gt_bbox = np.zeros((num_bbox, 4), dtype=np.float32)
|
||||
gt_class = np.zeros((num_bbox, 1), dtype=np.int32)
|
||||
is_crowd = np.zeros((num_bbox, 1), dtype=np.int32)
|
||||
gt_poly = [None] * num_bbox
|
||||
gt_track_id = -np.ones((num_bbox, 1), dtype=np.int32)
|
||||
|
||||
has_segmentation = False
|
||||
has_track_id = False
|
||||
for i, box in enumerate(bboxes):
|
||||
catid = box['category_id']
|
||||
gt_class[i][0] = self.catid2clsid[catid]
|
||||
gt_bbox[i, :] = box['clean_bbox']
|
||||
is_crowd[i][0] = box['iscrowd']
|
||||
# check RLE format
|
||||
if 'segmentation' in box and box['iscrowd'] == 1:
|
||||
gt_poly[i] = [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]
|
||||
elif 'segmentation' in box and box['segmentation']:
|
||||
if not np.array(
|
||||
box['segmentation'],
|
||||
dtype=object).size > 0 and not self.allow_empty:
|
||||
bboxes.pop(i)
|
||||
gt_poly.pop(i)
|
||||
np.delete(is_crowd, i)
|
||||
np.delete(gt_class, i)
|
||||
np.delete(gt_bbox, i)
|
||||
else:
|
||||
gt_poly[i] = box['segmentation']
|
||||
has_segmentation = True
|
||||
|
||||
if 'track_id' in box:
|
||||
gt_track_id[i][0] = box['track_id']
|
||||
has_track_id = True
|
||||
|
||||
if has_segmentation and not any(
|
||||
gt_poly) and not self.allow_empty:
|
||||
continue
|
||||
|
||||
gt_rec = {
|
||||
'is_crowd': is_crowd,
|
||||
'gt_class': gt_class,
|
||||
'gt_bbox': gt_bbox,
|
||||
'gt_poly': gt_poly,
|
||||
}
|
||||
if has_track_id:
|
||||
gt_rec.update({'gt_track_id': gt_track_id})
|
||||
|
||||
for k, v in gt_rec.items():
|
||||
if k in self.data_fields:
|
||||
coco_rec[k] = v
|
||||
|
||||
# TODO: remove load_semantic
|
||||
if self.load_semantic and 'semantic' in self.data_fields:
|
||||
seg_path = os.path.join(self.dataset_dir, 'stuffthingmaps',
|
||||
'train2017', im_fname[:-3] + 'png')
|
||||
coco_rec.update({'semantic': seg_path})
|
||||
|
||||
logger.debug('Load file: {}, im_id: {}, h: {}, w: {}.'.format(
|
||||
im_path, img_id, im_h, im_w))
|
||||
if is_empty:
|
||||
empty_records.append(coco_rec)
|
||||
else:
|
||||
records.append(coco_rec)
|
||||
ct += 1
|
||||
if self.sample_num > 0 and ct >= self.sample_num:
|
||||
break
|
||||
assert ct > 0, 'not found any coco record in %s' % (anno_path)
|
||||
logger.info('Load [{} samples valid, {} samples invalid] in file {}.'.
|
||||
format(ct, len(img_ids) - ct, anno_path))
|
||||
if self.allow_empty and len(empty_records) > 0:
|
||||
empty_records = self._sample_empty(empty_records, len(records))
|
||||
records += empty_records
|
||||
self.roidbs = records
|
||||
|
||||
|
||||
@register
|
||||
@serializable
|
||||
class SlicedCOCODataSet(COCODataSet):
|
||||
"""Sliced COCODataSet"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_dir=None,
|
||||
image_dir=None,
|
||||
anno_path=None,
|
||||
data_fields=['image'],
|
||||
sample_num=-1,
|
||||
load_crowd=False,
|
||||
allow_empty=False,
|
||||
empty_ratio=1.,
|
||||
repeat=1,
|
||||
sliced_size=[640, 640],
|
||||
overlap_ratio=[0.25, 0.25], ):
|
||||
super(SlicedCOCODataSet, self).__init__(
|
||||
dataset_dir=dataset_dir,
|
||||
image_dir=image_dir,
|
||||
anno_path=anno_path,
|
||||
data_fields=data_fields,
|
||||
sample_num=sample_num,
|
||||
load_crowd=load_crowd,
|
||||
allow_empty=allow_empty,
|
||||
empty_ratio=empty_ratio,
|
||||
repeat=repeat, )
|
||||
self.sliced_size = sliced_size
|
||||
self.overlap_ratio = overlap_ratio
|
||||
|
||||
def parse_dataset(self):
|
||||
anno_path = os.path.join(self.dataset_dir, self.anno_path)
|
||||
image_dir = os.path.join(self.dataset_dir, self.image_dir)
|
||||
|
||||
assert anno_path.endswith('.json'), \
|
||||
'invalid coco annotation file: ' + anno_path
|
||||
from pycocotools.coco import COCO
|
||||
coco = COCO(anno_path)
|
||||
img_ids = coco.getImgIds()
|
||||
img_ids.sort()
|
||||
cat_ids = coco.getCatIds()
|
||||
records = []
|
||||
empty_records = []
|
||||
ct = 0
|
||||
ct_sub = 0
|
||||
|
||||
self.catid2clsid = dict({catid: i for i, catid in enumerate(cat_ids)})
|
||||
self.cname2cid = dict({
|
||||
coco.loadCats(catid)[0]['name']: clsid
|
||||
for catid, clsid in self.catid2clsid.items()
|
||||
})
|
||||
|
||||
if 'annotations' not in coco.dataset:
|
||||
self.load_image_only = True
|
||||
logger.warning('Annotation file: {} does not contains ground truth '
|
||||
'and load image information only.'.format(anno_path))
|
||||
try:
|
||||
import sahi
|
||||
from sahi.slicing import slice_image
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'sahi not found, plaese install sahi. '
|
||||
'for example: `pip install sahi`, see https://github.com/obss/sahi.'
|
||||
)
|
||||
raise e
|
||||
|
||||
sub_img_ids = 0
|
||||
for img_id in img_ids:
|
||||
img_anno = coco.loadImgs([img_id])[0]
|
||||
im_fname = img_anno['file_name']
|
||||
im_w = float(img_anno['width'])
|
||||
im_h = float(img_anno['height'])
|
||||
|
||||
im_path = os.path.join(image_dir,
|
||||
im_fname) if image_dir else im_fname
|
||||
is_empty = False
|
||||
if not os.path.exists(im_path):
|
||||
logger.warning('Illegal image file: {}, and it will be '
|
||||
'ignored'.format(im_path))
|
||||
continue
|
||||
|
||||
if im_w < 0 or im_h < 0:
|
||||
logger.warning('Illegal width: {} or height: {} in annotation, '
|
||||
'and im_id: {} will be ignored'.format(
|
||||
im_w, im_h, img_id))
|
||||
continue
|
||||
|
||||
slice_image_result = sahi.slicing.slice_image(
|
||||
image=im_path,
|
||||
slice_height=self.sliced_size[0],
|
||||
slice_width=self.sliced_size[1],
|
||||
overlap_height_ratio=self.overlap_ratio[0],
|
||||
overlap_width_ratio=self.overlap_ratio[1])
|
||||
|
||||
sub_img_num = len(slice_image_result)
|
||||
for _ind in range(sub_img_num):
|
||||
im = slice_image_result.images[_ind]
|
||||
coco_rec = {
|
||||
'image': im,
|
||||
'im_id': np.array([sub_img_ids + _ind]),
|
||||
'h': im.shape[0],
|
||||
'w': im.shape[1],
|
||||
'ori_im_id': np.array([img_id]),
|
||||
'st_pix': np.array(
|
||||
slice_image_result.starting_pixels[_ind],
|
||||
dtype=np.float32),
|
||||
'is_last': 1 if _ind == sub_img_num - 1 else 0,
|
||||
} if 'image' in self.data_fields else {}
|
||||
records.append(coco_rec)
|
||||
ct_sub += sub_img_num
|
||||
ct += 1
|
||||
if self.sample_num > 0 and ct >= self.sample_num:
|
||||
break
|
||||
assert ct > 0, 'not found any coco record in %s' % (anno_path)
|
||||
logger.info('{} samples and slice to {} sub_samples in file {}'.format(
|
||||
ct, ct_sub, anno_path))
|
||||
if self.allow_empty and len(empty_records) > 0:
|
||||
empty_records = self._sample_empty(empty_records, len(records))
|
||||
records += empty_records
|
||||
self.roidbs = records
|
||||
|
||||
|
||||
@register
|
||||
@serializable
|
||||
class SemiCOCODataSet(COCODataSet):
|
||||
"""Semi-COCODataSet used for supervised and unsupervised dataSet"""
|
||||
|
||||
def __init__(self,
|
||||
dataset_dir=None,
|
||||
image_dir=None,
|
||||
anno_path=None,
|
||||
data_fields=['image'],
|
||||
sample_num=-1,
|
||||
load_crowd=False,
|
||||
allow_empty=False,
|
||||
empty_ratio=1.,
|
||||
repeat=1,
|
||||
supervised=True):
|
||||
super(SemiCOCODataSet, self).__init__(
|
||||
dataset_dir, image_dir, anno_path, data_fields, sample_num,
|
||||
load_crowd, allow_empty, empty_ratio, repeat)
|
||||
self.supervised = supervised
|
||||
self.length = -1 # defalut -1 means all
|
||||
|
||||
def parse_dataset(self):
|
||||
anno_path = os.path.join(self.dataset_dir, self.anno_path)
|
||||
image_dir = os.path.join(self.dataset_dir, self.image_dir)
|
||||
|
||||
assert anno_path.endswith('.json'), \
|
||||
'invalid coco annotation file: ' + anno_path
|
||||
from pycocotools.coco import COCO
|
||||
coco = COCO(anno_path)
|
||||
img_ids = coco.getImgIds()
|
||||
img_ids.sort()
|
||||
cat_ids = coco.getCatIds()
|
||||
records = []
|
||||
empty_records = []
|
||||
ct = 0
|
||||
|
||||
self.catid2clsid = dict({catid: i for i, catid in enumerate(cat_ids)})
|
||||
self.cname2cid = dict({
|
||||
coco.loadCats(catid)[0]['name']: clsid
|
||||
for catid, clsid in self.catid2clsid.items()
|
||||
})
|
||||
|
||||
if 'annotations' not in coco.dataset or self.supervised == False:
|
||||
self.load_image_only = True
|
||||
logger.warning('Annotation file: {} does not contains ground truth '
|
||||
'and load image information only.'.format(anno_path))
|
||||
|
||||
for img_id in img_ids:
|
||||
img_anno = coco.loadImgs([img_id])[0]
|
||||
im_fname = img_anno['file_name']
|
||||
im_w = float(img_anno['width'])
|
||||
im_h = float(img_anno['height'])
|
||||
|
||||
im_path = os.path.join(image_dir,
|
||||
im_fname) if image_dir else im_fname
|
||||
is_empty = False
|
||||
if not os.path.exists(im_path):
|
||||
logger.warning('Illegal image file: {}, and it will be '
|
||||
'ignored'.format(im_path))
|
||||
continue
|
||||
|
||||
if im_w < 0 or im_h < 0:
|
||||
logger.warning('Illegal width: {} or height: {} in annotation, '
|
||||
'and im_id: {} will be ignored'.format(
|
||||
im_w, im_h, img_id))
|
||||
continue
|
||||
|
||||
coco_rec = {
|
||||
'im_file': im_path,
|
||||
'im_id': np.array([img_id]),
|
||||
'h': im_h,
|
||||
'w': im_w,
|
||||
} if 'image' in self.data_fields else {}
|
||||
|
||||
if not self.load_image_only:
|
||||
ins_anno_ids = coco.getAnnIds(
|
||||
imgIds=[img_id], iscrowd=None if self.load_crowd else False)
|
||||
instances = coco.loadAnns(ins_anno_ids)
|
||||
|
||||
bboxes = []
|
||||
is_rbox_anno = False
|
||||
for inst in instances:
|
||||
# check gt bbox
|
||||
if inst.get('ignore', False):
|
||||
continue
|
||||
if 'bbox' not in inst.keys():
|
||||
continue
|
||||
else:
|
||||
if not any(np.array(inst['bbox'])):
|
||||
continue
|
||||
|
||||
x1, y1, box_w, box_h = inst['bbox']
|
||||
x2 = x1 + box_w
|
||||
y2 = y1 + box_h
|
||||
eps = 1e-5
|
||||
if inst['area'] > 0 and x2 - x1 > eps and y2 - y1 > eps:
|
||||
inst['clean_bbox'] = [
|
||||
round(float(x), 3) for x in [x1, y1, x2, y2]
|
||||
]
|
||||
bboxes.append(inst)
|
||||
else:
|
||||
logger.warning(
|
||||
'Found an invalid bbox in annotations: im_id: {}, '
|
||||
'area: {} x1: {}, y1: {}, x2: {}, y2: {}.'.format(
|
||||
img_id, float(inst['area']), x1, y1, x2, y2))
|
||||
|
||||
num_bbox = len(bboxes)
|
||||
if num_bbox <= 0 and not self.allow_empty:
|
||||
continue
|
||||
elif num_bbox <= 0:
|
||||
is_empty = True
|
||||
|
||||
gt_bbox = np.zeros((num_bbox, 4), dtype=np.float32)
|
||||
gt_class = np.zeros((num_bbox, 1), dtype=np.int32)
|
||||
is_crowd = np.zeros((num_bbox, 1), dtype=np.int32)
|
||||
gt_poly = [None] * num_bbox
|
||||
|
||||
has_segmentation = False
|
||||
for i, box in enumerate(bboxes):
|
||||
catid = box['category_id']
|
||||
gt_class[i][0] = self.catid2clsid[catid]
|
||||
gt_bbox[i, :] = box['clean_bbox']
|
||||
is_crowd[i][0] = box['iscrowd']
|
||||
# check RLE format
|
||||
if 'segmentation' in box and box['iscrowd'] == 1:
|
||||
gt_poly[i] = [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]
|
||||
elif 'segmentation' in box and box['segmentation']:
|
||||
if not np.array(box['segmentation']
|
||||
).size > 0 and not self.allow_empty:
|
||||
bboxes.pop(i)
|
||||
gt_poly.pop(i)
|
||||
np.delete(is_crowd, i)
|
||||
np.delete(gt_class, i)
|
||||
np.delete(gt_bbox, i)
|
||||
else:
|
||||
gt_poly[i] = box['segmentation']
|
||||
has_segmentation = True
|
||||
|
||||
if has_segmentation and not any(
|
||||
gt_poly) and not self.allow_empty:
|
||||
continue
|
||||
|
||||
gt_rec = {
|
||||
'is_crowd': is_crowd,
|
||||
'gt_class': gt_class,
|
||||
'gt_bbox': gt_bbox,
|
||||
'gt_poly': gt_poly,
|
||||
}
|
||||
|
||||
for k, v in gt_rec.items():
|
||||
if k in self.data_fields:
|
||||
coco_rec[k] = v
|
||||
|
||||
# TODO: remove load_semantic
|
||||
if self.load_semantic and 'semantic' in self.data_fields:
|
||||
seg_path = os.path.join(self.dataset_dir, 'stuffthingmaps',
|
||||
'train2017', im_fname[:-3] + 'png')
|
||||
coco_rec.update({'semantic': seg_path})
|
||||
|
||||
logger.debug('Load file: {}, im_id: {}, h: {}, w: {}.'.format(
|
||||
im_path, img_id, im_h, im_w))
|
||||
if is_empty:
|
||||
empty_records.append(coco_rec)
|
||||
else:
|
||||
records.append(coco_rec)
|
||||
ct += 1
|
||||
if self.sample_num > 0 and ct >= self.sample_num:
|
||||
break
|
||||
assert ct > 0, 'not found any coco record in %s' % (anno_path)
|
||||
logger.info('Load [{} samples valid, {} samples invalid] in file {}.'.
|
||||
format(ct, len(img_ids) - ct, anno_path))
|
||||
if self.allow_empty and len(empty_records) > 0:
|
||||
empty_records = self._sample_empty(empty_records, len(records))
|
||||
records += empty_records
|
||||
self.roidbs = records
|
||||
|
||||
if self.supervised:
|
||||
logger.info(f'Use {len(self.roidbs)} sup_samples data as LABELED')
|
||||
else:
|
||||
if self.length > 0: # unsup length will be decide by sup length
|
||||
all_roidbs = self.roidbs.copy()
|
||||
selected_idxs = [
|
||||
np.random.choice(len(all_roidbs))
|
||||
for _ in range(self.length)
|
||||
]
|
||||
self.roidbs = [all_roidbs[i] for i in selected_idxs]
|
||||
logger.info(
|
||||
f'Use {len(self.roidbs)} unsup_samples data as UNLABELED')
|
||||
|
||||
def __getitem__(self, idx):
|
||||
n = len(self.roidbs)
|
||||
if self.repeat > 1:
|
||||
idx %= n
|
||||
# data batch
|
||||
roidb = copy.deepcopy(self.roidbs[idx])
|
||||
if self.mixup_epoch == 0 or self._epoch < self.mixup_epoch:
|
||||
idx = np.random.randint(n)
|
||||
roidb = [roidb, copy.deepcopy(self.roidbs[idx])]
|
||||
elif self.cutmix_epoch == 0 or self._epoch < self.cutmix_epoch:
|
||||
idx = np.random.randint(n)
|
||||
roidb = [roidb, copy.deepcopy(self.roidbs[idx])]
|
||||
elif self.mosaic_epoch == 0 or self._epoch < self.mosaic_epoch:
|
||||
roidb = [roidb, ] + [
|
||||
copy.deepcopy(self.roidbs[np.random.randint(n)])
|
||||
for _ in range(4)
|
||||
]
|
||||
if isinstance(roidb, Sequence):
|
||||
for r in roidb:
|
||||
r['curr_iter'] = self._curr_iter
|
||||
else:
|
||||
roidb['curr_iter'] = self._curr_iter
|
||||
self._curr_iter += 1
|
||||
|
||||
return self.transform(roidb)
|
||||
|
||||
|
||||
# for PaddleX
|
||||
@register
|
||||
@serializable
|
||||
class COCODetDataset(COCODataSet):
|
||||
pass
|
||||
206
paddle_detection/ppdet/data/source/culane.py
Normal file
206
paddle_detection/ppdet/data/source/culane.py
Normal file
@@ -0,0 +1,206 @@
|
||||
from ppdet.core.workspace import register, serializable
|
||||
import cv2
|
||||
import os
|
||||
import tarfile
|
||||
import numpy as np
|
||||
import os.path as osp
|
||||
from ppdet.data.source.dataset import DetDataset
|
||||
from imgaug.augmentables.lines import LineStringsOnImage
|
||||
from imgaug.augmentables.segmaps import SegmentationMapsOnImage
|
||||
from ppdet.data.culane_utils import lane_to_linestrings
|
||||
import pickle as pkl
|
||||
from ppdet.utils.logger import setup_logger
|
||||
try:
|
||||
from collections.abc import Sequence
|
||||
except Exception:
|
||||
from collections import Sequence
|
||||
from .dataset import DetDataset, _make_dataset, _is_valid_file
|
||||
from ppdet.utils.download import download_dataset
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
|
||||
@register
|
||||
@serializable
|
||||
class CULaneDataSet(DetDataset):
|
||||
def __init__(
|
||||
self,
|
||||
dataset_dir,
|
||||
cut_height,
|
||||
list_path,
|
||||
split='train',
|
||||
data_fields=['image'],
|
||||
video_file=None,
|
||||
frame_rate=-1, ):
|
||||
super(CULaneDataSet, self).__init__(
|
||||
dataset_dir=dataset_dir,
|
||||
cut_height=cut_height,
|
||||
split=split,
|
||||
data_fields=data_fields)
|
||||
self.dataset_dir = dataset_dir
|
||||
self.list_path = osp.join(dataset_dir, list_path)
|
||||
self.cut_height = cut_height
|
||||
self.data_fields = data_fields
|
||||
self.split = split
|
||||
self.training = 'train' in split
|
||||
self.data_infos = []
|
||||
self.video_file = video_file
|
||||
self.frame_rate = frame_rate
|
||||
self._imid2path = {}
|
||||
self.predict_dir = None
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data_infos)
|
||||
|
||||
def check_or_download_dataset(self):
|
||||
if not osp.exists(self.dataset_dir):
|
||||
download_dataset("dataset", dataset="culane")
|
||||
# extract .tar files in self.dataset_dir
|
||||
for fname in os.listdir(self.dataset_dir):
|
||||
logger.info("Decompressing {}...".format(fname))
|
||||
# ignore .* files
|
||||
if fname.startswith('.'):
|
||||
continue
|
||||
if fname.find('.tar.gz') >= 0:
|
||||
with tarfile.open(osp.join(self.dataset_dir, fname)) as tf:
|
||||
tf.extractall(path=self.dataset_dir)
|
||||
logger.info("Dataset files are ready.")
|
||||
|
||||
def parse_dataset(self):
|
||||
logger.info('Loading CULane annotations...')
|
||||
if self.predict_dir is not None:
|
||||
logger.info('switch to predict mode')
|
||||
return
|
||||
# Waiting for the dataset to load is tedious, let's cache it
|
||||
os.makedirs('cache', exist_ok=True)
|
||||
cache_path = 'cache/culane_paddle_{}.pkl'.format(self.split)
|
||||
if os.path.exists(cache_path):
|
||||
with open(cache_path, 'rb') as cache_file:
|
||||
self.data_infos = pkl.load(cache_file)
|
||||
self.max_lanes = max(
|
||||
len(anno['lanes']) for anno in self.data_infos)
|
||||
return
|
||||
|
||||
with open(self.list_path) as list_file:
|
||||
for line in list_file:
|
||||
infos = self.load_annotation(line.split())
|
||||
self.data_infos.append(infos)
|
||||
|
||||
# cache data infos to file
|
||||
with open(cache_path, 'wb') as cache_file:
|
||||
pkl.dump(self.data_infos, cache_file)
|
||||
|
||||
def load_annotation(self, line):
|
||||
infos = {}
|
||||
img_line = line[0]
|
||||
img_line = img_line[1 if img_line[0] == '/' else 0::]
|
||||
img_path = os.path.join(self.dataset_dir, img_line)
|
||||
infos['img_name'] = img_line
|
||||
infos['img_path'] = img_path
|
||||
if len(line) > 1:
|
||||
mask_line = line[1]
|
||||
mask_line = mask_line[1 if mask_line[0] == '/' else 0::]
|
||||
mask_path = os.path.join(self.dataset_dir, mask_line)
|
||||
infos['mask_path'] = mask_path
|
||||
|
||||
if len(line) > 2:
|
||||
exist_list = [int(l) for l in line[2:]]
|
||||
infos['lane_exist'] = np.array(exist_list)
|
||||
|
||||
anno_path = img_path[:
|
||||
-3] + 'lines.txt' # remove sufix jpg and add lines.txt
|
||||
with open(anno_path, 'r') as anno_file:
|
||||
data = [
|
||||
list(map(float, line.split())) for line in anno_file.readlines()
|
||||
]
|
||||
lanes = [[(lane[i], lane[i + 1]) for i in range(0, len(lane), 2)
|
||||
if lane[i] >= 0 and lane[i + 1] >= 0] for lane in data]
|
||||
lanes = [list(set(lane)) for lane in lanes] # remove duplicated points
|
||||
lanes = [lane for lane in lanes
|
||||
if len(lane) > 2] # remove lanes with less than 2 points
|
||||
|
||||
lanes = [sorted(
|
||||
lane, key=lambda x: x[1]) for lane in lanes] # sort by y
|
||||
infos['lanes'] = lanes
|
||||
|
||||
return infos
|
||||
|
||||
def set_images(self, images):
|
||||
self.predict_dir = images
|
||||
self.data_infos = self._load_images()
|
||||
|
||||
def _find_images(self):
|
||||
predict_dir = self.predict_dir
|
||||
if not isinstance(predict_dir, Sequence):
|
||||
predict_dir = [predict_dir]
|
||||
images = []
|
||||
for im_dir in predict_dir:
|
||||
if os.path.isdir(im_dir):
|
||||
im_dir = os.path.join(self.predict_dir, im_dir)
|
||||
images.extend(_make_dataset(im_dir))
|
||||
elif os.path.isfile(im_dir) and _is_valid_file(im_dir):
|
||||
images.append(im_dir)
|
||||
return images
|
||||
|
||||
def _load_images(self):
|
||||
images = self._find_images()
|
||||
ct = 0
|
||||
records = []
|
||||
for image in images:
|
||||
assert image != '' and os.path.isfile(image), \
|
||||
"Image {} not found".format(image)
|
||||
if self.sample_num > 0 and ct >= self.sample_num:
|
||||
break
|
||||
rec = {
|
||||
'im_id': np.array([ct]),
|
||||
"img_path": os.path.abspath(image),
|
||||
"img_name": os.path.basename(image),
|
||||
"lanes": []
|
||||
}
|
||||
self._imid2path[ct] = image
|
||||
ct += 1
|
||||
records.append(rec)
|
||||
assert len(records) > 0, "No image file found"
|
||||
return records
|
||||
|
||||
def get_imid2path(self):
|
||||
return self._imid2path
|
||||
|
||||
def __getitem__(self, idx):
|
||||
data_info = self.data_infos[idx]
|
||||
img = cv2.imread(data_info['img_path'])
|
||||
img = img[self.cut_height:, :, :]
|
||||
sample = data_info.copy()
|
||||
sample.update({'image': img})
|
||||
img_org = sample['image']
|
||||
|
||||
if self.training:
|
||||
label = cv2.imread(sample['mask_path'], cv2.IMREAD_UNCHANGED)
|
||||
if len(label.shape) > 2:
|
||||
label = label[:, :, 0]
|
||||
label = label.squeeze()
|
||||
label = label[self.cut_height:, :]
|
||||
sample.update({'mask': label})
|
||||
if self.cut_height != 0:
|
||||
new_lanes = []
|
||||
for i in sample['lanes']:
|
||||
lanes = []
|
||||
for p in i:
|
||||
lanes.append((p[0], p[1] - self.cut_height))
|
||||
new_lanes.append(lanes)
|
||||
sample.update({'lanes': new_lanes})
|
||||
|
||||
sample['mask'] = SegmentationMapsOnImage(
|
||||
sample['mask'], shape=img_org.shape)
|
||||
|
||||
sample['full_img_path'] = data_info['img_path']
|
||||
sample['img_name'] = data_info['img_name']
|
||||
sample['im_id'] = np.array([idx])
|
||||
|
||||
sample['image'] = sample['image'].copy().astype(np.uint8)
|
||||
sample['lanes'] = lane_to_linestrings(sample['lanes'])
|
||||
sample['lanes'] = LineStringsOnImage(
|
||||
sample['lanes'], shape=img_org.shape)
|
||||
sample['seg'] = np.zeros(img_org.shape)
|
||||
|
||||
return sample
|
||||
307
paddle_detection/ppdet/data/source/dataset.py
Normal file
307
paddle_detection/ppdet/data/source/dataset.py
Normal file
@@ -0,0 +1,307 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import copy
|
||||
import numpy as np
|
||||
try:
|
||||
from collections.abc import Sequence
|
||||
except Exception:
|
||||
from collections import Sequence
|
||||
from paddle.io import Dataset
|
||||
from ppdet.core.workspace import register, serializable
|
||||
from ppdet.utils.download import get_dataset_path
|
||||
from ppdet.data import source
|
||||
|
||||
from ppdet.utils.logger import setup_logger
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
|
||||
@serializable
|
||||
class DetDataset(Dataset):
|
||||
"""
|
||||
Load detection dataset.
|
||||
|
||||
Args:
|
||||
dataset_dir (str): root directory for dataset.
|
||||
image_dir (str): directory for images.
|
||||
anno_path (str): annotation file path.
|
||||
data_fields (list): key name of data dictionary, at least have 'image'.
|
||||
sample_num (int): number of samples to load, -1 means all.
|
||||
use_default_label (bool): whether to load default label list.
|
||||
repeat (int): repeat times for dataset, use in benchmark.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dataset_dir=None,
|
||||
image_dir=None,
|
||||
anno_path=None,
|
||||
data_fields=['image'],
|
||||
sample_num=-1,
|
||||
use_default_label=None,
|
||||
repeat=1,
|
||||
**kwargs):
|
||||
super(DetDataset, self).__init__()
|
||||
self.dataset_dir = dataset_dir if dataset_dir is not None else ''
|
||||
self.anno_path = anno_path
|
||||
self.image_dir = image_dir if image_dir is not None else ''
|
||||
self.data_fields = data_fields
|
||||
self.sample_num = sample_num
|
||||
self.use_default_label = use_default_label
|
||||
self.repeat = repeat
|
||||
self._epoch = 0
|
||||
self._curr_iter = 0
|
||||
|
||||
def __len__(self, ):
|
||||
return len(self.roidbs) * self.repeat
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def __getitem__(self, idx):
|
||||
n = len(self.roidbs)
|
||||
if self.repeat > 1:
|
||||
idx %= n
|
||||
# data batch
|
||||
roidb = copy.deepcopy(self.roidbs[idx])
|
||||
if self.mixup_epoch == 0 or self._epoch < self.mixup_epoch:
|
||||
idx = np.random.randint(n)
|
||||
roidb = [roidb, copy.deepcopy(self.roidbs[idx])]
|
||||
elif self.cutmix_epoch == 0 or self._epoch < self.cutmix_epoch:
|
||||
idx = np.random.randint(n)
|
||||
roidb = [roidb, copy.deepcopy(self.roidbs[idx])]
|
||||
elif self.mosaic_epoch == 0 or self._epoch < self.mosaic_epoch:
|
||||
roidb = [roidb, ] + [
|
||||
copy.deepcopy(self.roidbs[np.random.randint(n)])
|
||||
for _ in range(4)
|
||||
]
|
||||
elif self.pre_img_epoch == 0 or self._epoch < self.pre_img_epoch:
|
||||
# Add previous image as input, only used in CenterTrack
|
||||
idx_pre_img = idx - 1
|
||||
if idx_pre_img < 0:
|
||||
idx_pre_img = idx + 1
|
||||
roidb = [roidb, ] + [copy.deepcopy(self.roidbs[idx_pre_img])]
|
||||
if isinstance(roidb, Sequence):
|
||||
for r in roidb:
|
||||
r['curr_iter'] = self._curr_iter
|
||||
else:
|
||||
roidb['curr_iter'] = self._curr_iter
|
||||
self._curr_iter += 1
|
||||
|
||||
return self.transform(roidb)
|
||||
|
||||
def check_or_download_dataset(self):
|
||||
self.dataset_dir = get_dataset_path(self.dataset_dir, self.anno_path,
|
||||
self.image_dir)
|
||||
|
||||
def set_kwargs(self, **kwargs):
|
||||
self.mixup_epoch = kwargs.get('mixup_epoch', -1)
|
||||
self.cutmix_epoch = kwargs.get('cutmix_epoch', -1)
|
||||
self.mosaic_epoch = kwargs.get('mosaic_epoch', -1)
|
||||
self.pre_img_epoch = kwargs.get('pre_img_epoch', -1)
|
||||
|
||||
def set_transform(self, transform):
|
||||
self.transform = transform
|
||||
|
||||
def set_epoch(self, epoch_id):
|
||||
self._epoch = epoch_id
|
||||
|
||||
def parse_dataset(self, ):
|
||||
raise NotImplementedError(
|
||||
"Need to implement parse_dataset method of Dataset")
|
||||
|
||||
def get_anno(self):
|
||||
if self.anno_path is None:
|
||||
return
|
||||
return os.path.join(self.dataset_dir, self.anno_path)
|
||||
|
||||
|
||||
def _is_valid_file(f, extensions=('.jpg', '.jpeg', '.png', '.bmp')):
|
||||
return f.lower().endswith(extensions)
|
||||
|
||||
|
||||
def _make_dataset(dir):
|
||||
dir = os.path.expanduser(dir)
|
||||
if not os.path.isdir(dir):
|
||||
raise ('{} should be a dir'.format(dir))
|
||||
images = []
|
||||
for root, _, fnames in sorted(os.walk(dir, followlinks=True)):
|
||||
for fname in sorted(fnames):
|
||||
path = os.path.join(root, fname)
|
||||
if _is_valid_file(path):
|
||||
images.append(path)
|
||||
return images
|
||||
|
||||
|
||||
@register
|
||||
@serializable
|
||||
class ImageFolder(DetDataset):
|
||||
def __init__(self,
|
||||
dataset_dir=None,
|
||||
image_dir=None,
|
||||
anno_path=None,
|
||||
sample_num=-1,
|
||||
use_default_label=None,
|
||||
**kwargs):
|
||||
super(ImageFolder, self).__init__(
|
||||
dataset_dir,
|
||||
image_dir,
|
||||
anno_path,
|
||||
sample_num=sample_num,
|
||||
use_default_label=use_default_label)
|
||||
self._imid2path = {}
|
||||
self.roidbs = None
|
||||
self.sample_num = sample_num
|
||||
|
||||
def check_or_download_dataset(self):
|
||||
return
|
||||
|
||||
def get_anno(self):
|
||||
if self.anno_path is None:
|
||||
return
|
||||
if self.dataset_dir:
|
||||
return os.path.join(self.dataset_dir, self.anno_path)
|
||||
else:
|
||||
return self.anno_path
|
||||
|
||||
def parse_dataset(self, ):
|
||||
if not self.roidbs:
|
||||
self.roidbs = self._load_images()
|
||||
|
||||
def _parse(self):
|
||||
image_dir = self.image_dir
|
||||
if not isinstance(image_dir, Sequence):
|
||||
image_dir = [image_dir]
|
||||
images = []
|
||||
for im_dir in image_dir:
|
||||
if os.path.isdir(im_dir):
|
||||
im_dir = os.path.join(self.dataset_dir, im_dir)
|
||||
images.extend(_make_dataset(im_dir))
|
||||
elif os.path.isfile(im_dir) and _is_valid_file(im_dir):
|
||||
images.append(im_dir)
|
||||
return images
|
||||
|
||||
def _load_images(self):
|
||||
images = self._parse()
|
||||
ct = 0
|
||||
records = []
|
||||
for image in images:
|
||||
assert image != '' and os.path.isfile(image), \
|
||||
"Image {} not found".format(image)
|
||||
if self.sample_num > 0 and ct >= self.sample_num:
|
||||
break
|
||||
rec = {'im_id': np.array([ct]), 'im_file': image}
|
||||
self._imid2path[ct] = image
|
||||
ct += 1
|
||||
records.append(rec)
|
||||
assert len(records) > 0, "No image file found"
|
||||
return records
|
||||
|
||||
def get_imid2path(self):
|
||||
return self._imid2path
|
||||
|
||||
def set_images(self, images):
|
||||
self.image_dir = images
|
||||
self.roidbs = self._load_images()
|
||||
|
||||
def set_slice_images(self,
|
||||
images,
|
||||
slice_size=[640, 640],
|
||||
overlap_ratio=[0.25, 0.25]):
|
||||
self.image_dir = images
|
||||
ori_records = self._load_images()
|
||||
try:
|
||||
import sahi
|
||||
from sahi.slicing import slice_image
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
'sahi not found, plaese install sahi. '
|
||||
'for example: `pip install sahi`, see https://github.com/obss/sahi.'
|
||||
)
|
||||
raise e
|
||||
|
||||
sub_img_ids = 0
|
||||
ct = 0
|
||||
ct_sub = 0
|
||||
records = []
|
||||
for i, ori_rec in enumerate(ori_records):
|
||||
im_path = ori_rec['im_file']
|
||||
slice_image_result = sahi.slicing.slice_image(
|
||||
image=im_path,
|
||||
slice_height=slice_size[0],
|
||||
slice_width=slice_size[1],
|
||||
overlap_height_ratio=overlap_ratio[0],
|
||||
overlap_width_ratio=overlap_ratio[1])
|
||||
|
||||
sub_img_num = len(slice_image_result)
|
||||
for _ind in range(sub_img_num):
|
||||
im = slice_image_result.images[_ind]
|
||||
rec = {
|
||||
'image': im,
|
||||
'im_id': np.array([sub_img_ids + _ind]),
|
||||
'h': im.shape[0],
|
||||
'w': im.shape[1],
|
||||
'ori_im_id': np.array([ori_rec['im_id'][0]]),
|
||||
'st_pix': np.array(
|
||||
slice_image_result.starting_pixels[_ind],
|
||||
dtype=np.float32),
|
||||
'is_last': 1 if _ind == sub_img_num - 1 else 0,
|
||||
} if 'image' in self.data_fields else {}
|
||||
records.append(rec)
|
||||
ct_sub += sub_img_num
|
||||
ct += 1
|
||||
logger.info('{} samples and slice to {} sub_samples.'.format(ct,
|
||||
ct_sub))
|
||||
self.roidbs = records
|
||||
|
||||
def get_label_list(self):
|
||||
# Only VOC dataset needs label list in ImageFold
|
||||
return self.anno_path
|
||||
|
||||
|
||||
@register
|
||||
class CommonDataset(object):
|
||||
def __init__(self, **dataset_args):
|
||||
super(CommonDataset, self).__init__()
|
||||
dataset_args = copy.deepcopy(dataset_args)
|
||||
type = dataset_args.pop("name")
|
||||
self.dataset = getattr(source, type)(**dataset_args)
|
||||
|
||||
def __call__(self):
|
||||
return self.dataset
|
||||
|
||||
|
||||
@register
|
||||
class TrainDataset(CommonDataset):
|
||||
pass
|
||||
|
||||
|
||||
@register
|
||||
class EvalMOTDataset(CommonDataset):
|
||||
pass
|
||||
|
||||
|
||||
@register
|
||||
class TestMOTDataset(CommonDataset):
|
||||
pass
|
||||
|
||||
|
||||
@register
|
||||
class EvalDataset(CommonDataset):
|
||||
pass
|
||||
|
||||
|
||||
@register
|
||||
class TestDataset(CommonDataset):
|
||||
pass
|
||||
845
paddle_detection/ppdet/data/source/keypoint_coco.py
Normal file
845
paddle_detection/ppdet/data/source/keypoint_coco.py
Normal file
@@ -0,0 +1,845 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
this code is base on https://github.com/open-mmlab/mmpose
|
||||
"""
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
import json
|
||||
import copy
|
||||
import pycocotools
|
||||
from pycocotools.coco import COCO
|
||||
from .dataset import DetDataset
|
||||
from ppdet.core.workspace import register, serializable
|
||||
|
||||
|
||||
@serializable
|
||||
class KeypointBottomUpBaseDataset(DetDataset):
|
||||
"""Base class for bottom-up datasets.
|
||||
|
||||
All datasets should subclass it.
|
||||
All subclasses should overwrite:
|
||||
Methods:`_get_imganno`
|
||||
|
||||
Args:
|
||||
dataset_dir (str): Root path to the dataset.
|
||||
anno_path (str): Relative path to the annotation file.
|
||||
image_dir (str): Path to a directory where images are held.
|
||||
Default: None.
|
||||
num_joints (int): keypoint numbers
|
||||
transform (composed(operators)): A sequence of data transforms.
|
||||
shard (list): [rank, worldsize], the distributed env params
|
||||
test_mode (bool): Store True when building test or
|
||||
validation dataset. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dataset_dir,
|
||||
image_dir,
|
||||
anno_path,
|
||||
num_joints,
|
||||
transform=[],
|
||||
shard=[0, 1],
|
||||
test_mode=False):
|
||||
super().__init__(dataset_dir, image_dir, anno_path)
|
||||
self.image_info = {}
|
||||
self.ann_info = {}
|
||||
|
||||
self.img_prefix = os.path.join(dataset_dir, image_dir)
|
||||
self.transform = transform
|
||||
self.test_mode = test_mode
|
||||
|
||||
self.ann_info['num_joints'] = num_joints
|
||||
self.img_ids = []
|
||||
|
||||
def parse_dataset(self):
|
||||
pass
|
||||
|
||||
def __len__(self):
|
||||
"""Get dataset length."""
|
||||
return len(self.img_ids)
|
||||
|
||||
def _get_imganno(self, idx):
|
||||
"""Get anno for a single image."""
|
||||
raise NotImplementedError
|
||||
|
||||
def __getitem__(self, idx):
|
||||
"""Prepare image for training given the index."""
|
||||
records = copy.deepcopy(self._get_imganno(idx))
|
||||
records['image'] = cv2.imread(records['image_file'])
|
||||
records['image'] = cv2.cvtColor(records['image'], cv2.COLOR_BGR2RGB)
|
||||
if 'mask' in records:
|
||||
records['mask'] = (records['mask'] + 0).astype('uint8')
|
||||
records = self.transform(records)
|
||||
return records
|
||||
|
||||
def parse_dataset(self):
|
||||
return
|
||||
|
||||
|
||||
@register
|
||||
@serializable
|
||||
class KeypointBottomUpCocoDataset(KeypointBottomUpBaseDataset):
|
||||
"""COCO dataset for bottom-up pose estimation.
|
||||
|
||||
The dataset loads raw features and apply specified transforms
|
||||
to return a dict containing the image tensors and other information.
|
||||
|
||||
COCO keypoint indexes::
|
||||
|
||||
0: 'nose',
|
||||
1: 'left_eye',
|
||||
2: 'right_eye',
|
||||
3: 'left_ear',
|
||||
4: 'right_ear',
|
||||
5: 'left_shoulder',
|
||||
6: 'right_shoulder',
|
||||
7: 'left_elbow',
|
||||
8: 'right_elbow',
|
||||
9: 'left_wrist',
|
||||
10: 'right_wrist',
|
||||
11: 'left_hip',
|
||||
12: 'right_hip',
|
||||
13: 'left_knee',
|
||||
14: 'right_knee',
|
||||
15: 'left_ankle',
|
||||
16: 'right_ankle'
|
||||
|
||||
Args:
|
||||
dataset_dir (str): Root path to the dataset.
|
||||
anno_path (str): Relative path to the annotation file.
|
||||
image_dir (str): Path to a directory where images are held.
|
||||
Default: None.
|
||||
num_joints (int): keypoint numbers
|
||||
transform (composed(operators)): A sequence of data transforms.
|
||||
shard (list): [rank, worldsize], the distributed env params
|
||||
test_mode (bool): Store True when building test or
|
||||
validation dataset. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dataset_dir,
|
||||
image_dir,
|
||||
anno_path,
|
||||
num_joints,
|
||||
transform=[],
|
||||
shard=[0, 1],
|
||||
test_mode=False,
|
||||
return_mask=True,
|
||||
return_bbox=True,
|
||||
return_area=True,
|
||||
return_class=True):
|
||||
super().__init__(dataset_dir, image_dir, anno_path, num_joints,
|
||||
transform, shard, test_mode)
|
||||
|
||||
self.ann_file = os.path.join(dataset_dir, anno_path)
|
||||
self.shard = shard
|
||||
self.test_mode = test_mode
|
||||
self.return_mask = return_mask
|
||||
self.return_bbox = return_bbox
|
||||
self.return_area = return_area
|
||||
self.return_class = return_class
|
||||
|
||||
def parse_dataset(self):
|
||||
self.coco = COCO(self.ann_file)
|
||||
|
||||
self.img_ids = self.coco.getImgIds()
|
||||
if not self.test_mode:
|
||||
self.img_ids_tmp = []
|
||||
for img_id in self.img_ids:
|
||||
ann_ids = self.coco.getAnnIds(imgIds=img_id)
|
||||
anno = self.coco.loadAnns(ann_ids)
|
||||
anno = [obj for obj in anno if obj['iscrowd'] == 0]
|
||||
if len(anno) == 0:
|
||||
continue
|
||||
self.img_ids_tmp.append(img_id)
|
||||
self.img_ids = self.img_ids_tmp
|
||||
|
||||
blocknum = int(len(self.img_ids) / self.shard[1])
|
||||
self.img_ids = self.img_ids[(blocknum * self.shard[0]):(blocknum * (
|
||||
self.shard[0] + 1))]
|
||||
self.num_images = len(self.img_ids)
|
||||
self.id2name, self.name2id = self._get_mapping_id_name(self.coco.imgs)
|
||||
self.dataset_name = 'coco'
|
||||
|
||||
cat_ids = self.coco.getCatIds()
|
||||
self.catid2clsid = dict({catid: i for i, catid in enumerate(cat_ids)})
|
||||
print('=> num_images: {}'.format(self.num_images))
|
||||
|
||||
@staticmethod
|
||||
def _get_mapping_id_name(imgs):
|
||||
"""
|
||||
Args:
|
||||
imgs (dict): dict of image info.
|
||||
|
||||
Returns:
|
||||
tuple: Image name & id mapping dicts.
|
||||
|
||||
- id2name (dict): Mapping image id to name.
|
||||
- name2id (dict): Mapping image name to id.
|
||||
"""
|
||||
id2name = {}
|
||||
name2id = {}
|
||||
for image_id, image in imgs.items():
|
||||
file_name = image['file_name']
|
||||
id2name[image_id] = file_name
|
||||
name2id[file_name] = image_id
|
||||
|
||||
return id2name, name2id
|
||||
|
||||
def _get_imganno(self, idx):
|
||||
"""Get anno for a single image.
|
||||
|
||||
Args:
|
||||
idx (int): image idx
|
||||
|
||||
Returns:
|
||||
dict: info for model training
|
||||
"""
|
||||
coco = self.coco
|
||||
img_id = self.img_ids[idx]
|
||||
ann_ids = coco.getAnnIds(imgIds=img_id)
|
||||
anno = coco.loadAnns(ann_ids)
|
||||
|
||||
anno = [
|
||||
obj for obj in anno
|
||||
if obj['iscrowd'] == 0 and obj['num_keypoints'] > 0
|
||||
]
|
||||
|
||||
db_rec = {}
|
||||
joints, orgsize = self._get_joints(anno, idx)
|
||||
db_rec['gt_joints'] = joints
|
||||
db_rec['im_shape'] = orgsize
|
||||
|
||||
if self.return_bbox:
|
||||
db_rec['gt_bbox'] = self._get_bboxs(anno, idx)
|
||||
|
||||
if self.return_class:
|
||||
db_rec['gt_class'] = self._get_labels(anno, idx)
|
||||
|
||||
if self.return_area:
|
||||
db_rec['gt_areas'] = self._get_areas(anno, idx)
|
||||
|
||||
if self.return_mask:
|
||||
db_rec['mask'] = self._get_mask(anno, idx)
|
||||
|
||||
db_rec['im_id'] = img_id
|
||||
db_rec['image_file'] = os.path.join(self.img_prefix,
|
||||
self.id2name[img_id])
|
||||
|
||||
return db_rec
|
||||
|
||||
def _get_joints(self, anno, idx):
|
||||
"""Get joints for all people in an image."""
|
||||
num_people = len(anno)
|
||||
|
||||
joints = np.zeros(
|
||||
(num_people, self.ann_info['num_joints'], 3), dtype=np.float32)
|
||||
|
||||
for i, obj in enumerate(anno):
|
||||
joints[i, :self.ann_info['num_joints'], :3] = \
|
||||
np.array(obj['keypoints']).reshape([-1, 3])
|
||||
|
||||
img_info = self.coco.loadImgs(self.img_ids[idx])[0]
|
||||
orgsize = np.array([img_info['height'], img_info['width'], 1])
|
||||
|
||||
return joints, orgsize
|
||||
|
||||
def _get_bboxs(self, anno, idx):
|
||||
num_people = len(anno)
|
||||
gt_bboxes = np.zeros((num_people, 4), dtype=np.float32)
|
||||
|
||||
for idx, obj in enumerate(anno):
|
||||
if 'bbox' in obj:
|
||||
gt_bboxes[idx, :] = obj['bbox']
|
||||
|
||||
gt_bboxes[:, 2] += gt_bboxes[:, 0]
|
||||
gt_bboxes[:, 3] += gt_bboxes[:, 1]
|
||||
return gt_bboxes
|
||||
|
||||
def _get_labels(self, anno, idx):
|
||||
num_people = len(anno)
|
||||
gt_labels = np.zeros((num_people, 1), dtype=np.float32)
|
||||
|
||||
for idx, obj in enumerate(anno):
|
||||
if 'category_id' in obj:
|
||||
catid = obj['category_id']
|
||||
gt_labels[idx, 0] = self.catid2clsid[catid]
|
||||
return gt_labels
|
||||
|
||||
def _get_areas(self, anno, idx):
|
||||
num_people = len(anno)
|
||||
gt_areas = np.zeros((num_people, ), dtype=np.float32)
|
||||
|
||||
for idx, obj in enumerate(anno):
|
||||
if 'area' in obj:
|
||||
gt_areas[idx, ] = obj['area']
|
||||
return gt_areas
|
||||
|
||||
def _get_mask(self, anno, idx):
|
||||
"""Get ignore masks to mask out losses."""
|
||||
coco = self.coco
|
||||
img_info = coco.loadImgs(self.img_ids[idx])[0]
|
||||
|
||||
m = np.zeros((img_info['height'], img_info['width']), dtype=np.float32)
|
||||
|
||||
for obj in anno:
|
||||
if 'segmentation' in obj:
|
||||
if obj['iscrowd']:
|
||||
rle = pycocotools.mask.frPyObjects(obj['segmentation'],
|
||||
img_info['height'],
|
||||
img_info['width'])
|
||||
m += pycocotools.mask.decode(rle)
|
||||
elif obj['num_keypoints'] == 0:
|
||||
rles = pycocotools.mask.frPyObjects(obj['segmentation'],
|
||||
img_info['height'],
|
||||
img_info['width'])
|
||||
for rle in rles:
|
||||
m += pycocotools.mask.decode(rle)
|
||||
|
||||
return m < 0.5
|
||||
|
||||
|
||||
@register
|
||||
@serializable
|
||||
class KeypointBottomUpCrowdPoseDataset(KeypointBottomUpCocoDataset):
|
||||
"""CrowdPose dataset for bottom-up pose estimation.
|
||||
|
||||
The dataset loads raw features and apply specified transforms
|
||||
to return a dict containing the image tensors and other information.
|
||||
|
||||
CrowdPose keypoint indexes::
|
||||
|
||||
0: 'left_shoulder',
|
||||
1: 'right_shoulder',
|
||||
2: 'left_elbow',
|
||||
3: 'right_elbow',
|
||||
4: 'left_wrist',
|
||||
5: 'right_wrist',
|
||||
6: 'left_hip',
|
||||
7: 'right_hip',
|
||||
8: 'left_knee',
|
||||
9: 'right_knee',
|
||||
10: 'left_ankle',
|
||||
11: 'right_ankle',
|
||||
12: 'top_head',
|
||||
13: 'neck'
|
||||
|
||||
Args:
|
||||
dataset_dir (str): Root path to the dataset.
|
||||
anno_path (str): Relative path to the annotation file.
|
||||
image_dir (str): Path to a directory where images are held.
|
||||
Default: None.
|
||||
num_joints (int): keypoint numbers
|
||||
transform (composed(operators)): A sequence of data transforms.
|
||||
shard (list): [rank, worldsize], the distributed env params
|
||||
test_mode (bool): Store True when building test or
|
||||
validation dataset. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dataset_dir,
|
||||
image_dir,
|
||||
anno_path,
|
||||
num_joints,
|
||||
transform=[],
|
||||
shard=[0, 1],
|
||||
test_mode=False):
|
||||
super().__init__(dataset_dir, image_dir, anno_path, num_joints,
|
||||
transform, shard, test_mode)
|
||||
|
||||
self.ann_file = os.path.join(dataset_dir, anno_path)
|
||||
self.shard = shard
|
||||
self.test_mode = test_mode
|
||||
|
||||
def parse_dataset(self):
|
||||
self.coco = COCO(self.ann_file)
|
||||
|
||||
self.img_ids = self.coco.getImgIds()
|
||||
if not self.test_mode:
|
||||
self.img_ids = [
|
||||
img_id for img_id in self.img_ids
|
||||
if len(self.coco.getAnnIds(
|
||||
imgIds=img_id, iscrowd=None)) > 0
|
||||
]
|
||||
blocknum = int(len(self.img_ids) / self.shard[1])
|
||||
self.img_ids = self.img_ids[(blocknum * self.shard[0]):(blocknum * (
|
||||
self.shard[0] + 1))]
|
||||
self.num_images = len(self.img_ids)
|
||||
self.id2name, self.name2id = self._get_mapping_id_name(self.coco.imgs)
|
||||
|
||||
self.dataset_name = 'crowdpose'
|
||||
print('=> num_images: {}'.format(self.num_images))
|
||||
|
||||
|
||||
@serializable
|
||||
class KeypointTopDownBaseDataset(DetDataset):
|
||||
"""Base class for top_down datasets.
|
||||
|
||||
All datasets should subclass it.
|
||||
All subclasses should overwrite:
|
||||
Methods:`_get_db`
|
||||
|
||||
Args:
|
||||
dataset_dir (str): Root path to the dataset.
|
||||
image_dir (str): Path to a directory where images are held.
|
||||
anno_path (str): Relative path to the annotation file.
|
||||
num_joints (int): keypoint numbers
|
||||
transform (composed(operators)): A sequence of data transforms.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dataset_dir,
|
||||
image_dir,
|
||||
anno_path,
|
||||
num_joints,
|
||||
transform=[]):
|
||||
super().__init__(dataset_dir, image_dir, anno_path)
|
||||
self.image_info = {}
|
||||
self.ann_info = {}
|
||||
|
||||
self.img_prefix = os.path.join(dataset_dir, image_dir)
|
||||
self.transform = transform
|
||||
|
||||
self.ann_info['num_joints'] = num_joints
|
||||
self.db = []
|
||||
|
||||
def __len__(self):
|
||||
"""Get dataset length."""
|
||||
return len(self.db)
|
||||
|
||||
def _get_db(self):
|
||||
"""Get a sample"""
|
||||
raise NotImplementedError
|
||||
|
||||
def __getitem__(self, idx):
|
||||
"""Prepare sample for training given the index."""
|
||||
records = copy.deepcopy(self.db[idx])
|
||||
records['image'] = cv2.imread(records['image_file'], cv2.IMREAD_COLOR |
|
||||
cv2.IMREAD_IGNORE_ORIENTATION)
|
||||
records['image'] = cv2.cvtColor(records['image'], cv2.COLOR_BGR2RGB)
|
||||
records['score'] = records['score'] if 'score' in records else 1
|
||||
records = self.transform(records)
|
||||
# print('records', records)
|
||||
return records
|
||||
|
||||
|
||||
@register
|
||||
@serializable
|
||||
class KeypointTopDownCocoDataset(KeypointTopDownBaseDataset):
|
||||
"""COCO dataset for top-down pose estimation.
|
||||
|
||||
The dataset loads raw features and apply specified transforms
|
||||
to return a dict containing the image tensors and other information.
|
||||
|
||||
COCO keypoint indexes:
|
||||
|
||||
0: 'nose',
|
||||
1: 'left_eye',
|
||||
2: 'right_eye',
|
||||
3: 'left_ear',
|
||||
4: 'right_ear',
|
||||
5: 'left_shoulder',
|
||||
6: 'right_shoulder',
|
||||
7: 'left_elbow',
|
||||
8: 'right_elbow',
|
||||
9: 'left_wrist',
|
||||
10: 'right_wrist',
|
||||
11: 'left_hip',
|
||||
12: 'right_hip',
|
||||
13: 'left_knee',
|
||||
14: 'right_knee',
|
||||
15: 'left_ankle',
|
||||
16: 'right_ankle'
|
||||
|
||||
Args:
|
||||
dataset_dir (str): Root path to the dataset.
|
||||
image_dir (str): Path to a directory where images are held.
|
||||
anno_path (str): Relative path to the annotation file.
|
||||
num_joints (int): Keypoint numbers
|
||||
trainsize (list):[w, h] Image target size
|
||||
transform (composed(operators)): A sequence of data transforms.
|
||||
bbox_file (str): Path to a detection bbox file
|
||||
Default: None.
|
||||
use_gt_bbox (bool): Whether to use ground truth bbox
|
||||
Default: True.
|
||||
pixel_std (int): The pixel std of the scale
|
||||
Default: 200.
|
||||
image_thre (float): The threshold to filter the detection box
|
||||
Default: 0.0.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dataset_dir,
|
||||
image_dir,
|
||||
anno_path,
|
||||
num_joints,
|
||||
trainsize,
|
||||
transform=[],
|
||||
bbox_file=None,
|
||||
use_gt_bbox=True,
|
||||
pixel_std=200,
|
||||
image_thre=0.0,
|
||||
center_scale=None):
|
||||
super().__init__(dataset_dir, image_dir, anno_path, num_joints,
|
||||
transform)
|
||||
|
||||
self.bbox_file = bbox_file
|
||||
self.use_gt_bbox = use_gt_bbox
|
||||
self.trainsize = trainsize
|
||||
self.pixel_std = pixel_std
|
||||
self.image_thre = image_thre
|
||||
self.center_scale = center_scale
|
||||
self.dataset_name = 'coco'
|
||||
|
||||
def parse_dataset(self):
|
||||
if self.use_gt_bbox:
|
||||
self.db = self._load_coco_keypoint_annotations()
|
||||
else:
|
||||
self.db = self._load_coco_person_detection_results()
|
||||
|
||||
def _load_coco_keypoint_annotations(self):
|
||||
coco = COCO(self.get_anno())
|
||||
img_ids = coco.getImgIds()
|
||||
gt_db = []
|
||||
for index in img_ids:
|
||||
im_ann = coco.loadImgs(index)[0]
|
||||
width = im_ann['width']
|
||||
height = im_ann['height']
|
||||
file_name = im_ann['file_name']
|
||||
im_id = int(im_ann["id"])
|
||||
|
||||
annIds = coco.getAnnIds(imgIds=index, iscrowd=False)
|
||||
objs = coco.loadAnns(annIds)
|
||||
|
||||
valid_objs = []
|
||||
for obj in objs:
|
||||
x, y, w, h = obj['bbox']
|
||||
x1 = np.max((0, x))
|
||||
y1 = np.max((0, y))
|
||||
x2 = np.min((width - 1, x1 + np.max((0, w - 1))))
|
||||
y2 = np.min((height - 1, y1 + np.max((0, h - 1))))
|
||||
if obj['area'] > 0 and x2 >= x1 and y2 >= y1:
|
||||
obj['clean_bbox'] = [x1, y1, x2 - x1, y2 - y1]
|
||||
valid_objs.append(obj)
|
||||
objs = valid_objs
|
||||
|
||||
rec = []
|
||||
for obj in objs:
|
||||
if max(obj['keypoints']) == 0:
|
||||
continue
|
||||
|
||||
joints = np.zeros(
|
||||
(self.ann_info['num_joints'], 3), dtype=np.float32)
|
||||
joints_vis = np.zeros(
|
||||
(self.ann_info['num_joints'], 3), dtype=np.float32)
|
||||
for ipt in range(self.ann_info['num_joints']):
|
||||
joints[ipt, 0] = obj['keypoints'][ipt * 3 + 0]
|
||||
joints[ipt, 1] = obj['keypoints'][ipt * 3 + 1]
|
||||
joints[ipt, 2] = 0
|
||||
t_vis = obj['keypoints'][ipt * 3 + 2]
|
||||
if t_vis > 1:
|
||||
t_vis = 1
|
||||
joints_vis[ipt, 0] = t_vis
|
||||
joints_vis[ipt, 1] = t_vis
|
||||
joints_vis[ipt, 2] = 0
|
||||
|
||||
center, scale = self._box2cs(obj['clean_bbox'][:4])
|
||||
rec.append({
|
||||
'image_file': os.path.join(self.img_prefix, file_name),
|
||||
'center': center,
|
||||
'scale': scale,
|
||||
'gt_joints': joints,
|
||||
'joints_vis': joints_vis,
|
||||
'im_id': im_id,
|
||||
})
|
||||
gt_db.extend(rec)
|
||||
|
||||
return gt_db
|
||||
|
||||
def _box2cs(self, box):
|
||||
x, y, w, h = box[:4]
|
||||
center = np.zeros((2), dtype=np.float32)
|
||||
center[0] = x + w * 0.5
|
||||
center[1] = y + h * 0.5
|
||||
aspect_ratio = self.trainsize[0] * 1.0 / self.trainsize[1]
|
||||
|
||||
if self.center_scale is not None and np.random.rand() < 0.3:
|
||||
center += self.center_scale * (np.random.rand(2) - 0.5) * [w, h]
|
||||
|
||||
if w > aspect_ratio * h:
|
||||
h = w * 1.0 / aspect_ratio
|
||||
elif w < aspect_ratio * h:
|
||||
w = h * aspect_ratio
|
||||
scale = np.array(
|
||||
[w * 1.0 / self.pixel_std, h * 1.0 / self.pixel_std],
|
||||
dtype=np.float32)
|
||||
if center[0] != -1:
|
||||
scale = scale * 1.25
|
||||
|
||||
return center, scale
|
||||
|
||||
def _load_coco_person_detection_results(self):
|
||||
all_boxes = None
|
||||
bbox_file_path = os.path.join(self.dataset_dir, self.bbox_file)
|
||||
with open(bbox_file_path, 'r') as f:
|
||||
all_boxes = json.load(f)
|
||||
|
||||
if not all_boxes:
|
||||
print('=> Load %s fail!' % bbox_file_path)
|
||||
return None
|
||||
|
||||
kpt_db = []
|
||||
for n_img in range(0, len(all_boxes)):
|
||||
det_res = all_boxes[n_img]
|
||||
if det_res['category_id'] != 1:
|
||||
continue
|
||||
file_name = det_res[
|
||||
'filename'] if 'filename' in det_res else '%012d.jpg' % det_res[
|
||||
'image_id']
|
||||
img_name = os.path.join(self.img_prefix, file_name)
|
||||
box = det_res['bbox']
|
||||
score = det_res['score']
|
||||
im_id = int(det_res['image_id'])
|
||||
|
||||
if score < self.image_thre:
|
||||
continue
|
||||
|
||||
center, scale = self._box2cs(box)
|
||||
joints = np.zeros(
|
||||
(self.ann_info['num_joints'], 3), dtype=np.float32)
|
||||
joints_vis = np.ones(
|
||||
(self.ann_info['num_joints'], 3), dtype=np.float32)
|
||||
kpt_db.append({
|
||||
'image_file': img_name,
|
||||
'im_id': im_id,
|
||||
'center': center,
|
||||
'scale': scale,
|
||||
'score': score,
|
||||
'gt_joints': joints,
|
||||
'joints_vis': joints_vis,
|
||||
})
|
||||
|
||||
return kpt_db
|
||||
|
||||
|
||||
@register
|
||||
@serializable
|
||||
class KeypointTopDownCocoWholeBodyHandDataset(KeypointTopDownBaseDataset):
|
||||
"""CocoWholeBody dataset for top-down hand pose estimation.
|
||||
|
||||
The dataset loads raw features and apply specified transforms
|
||||
to return a dict containing the image tensors and other information.
|
||||
|
||||
COCO-WholeBody Hand keypoint indexes:
|
||||
|
||||
0: 'wrist',
|
||||
1: 'thumb1',
|
||||
2: 'thumb2',
|
||||
3: 'thumb3',
|
||||
4: 'thumb4',
|
||||
5: 'forefinger1',
|
||||
6: 'forefinger2',
|
||||
7: 'forefinger3',
|
||||
8: 'forefinger4',
|
||||
9: 'middle_finger1',
|
||||
10: 'middle_finger2',
|
||||
11: 'middle_finger3',
|
||||
12: 'middle_finger4',
|
||||
13: 'ring_finger1',
|
||||
14: 'ring_finger2',
|
||||
15: 'ring_finger3',
|
||||
16: 'ring_finger4',
|
||||
17: 'pinky_finger1',
|
||||
18: 'pinky_finger2',
|
||||
19: 'pinky_finger3',
|
||||
20: 'pinky_finger4'
|
||||
|
||||
Args:
|
||||
dataset_dir (str): Root path to the dataset.
|
||||
image_dir (str): Path to a directory where images are held.
|
||||
anno_path (str): Relative path to the annotation file.
|
||||
num_joints (int): Keypoint numbers
|
||||
trainsize (list):[w, h] Image target size
|
||||
transform (composed(operators)): A sequence of data transforms.
|
||||
pixel_std (int): The pixel std of the scale
|
||||
Default: 200.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dataset_dir,
|
||||
image_dir,
|
||||
anno_path,
|
||||
num_joints,
|
||||
trainsize,
|
||||
transform=[],
|
||||
pixel_std=200):
|
||||
super().__init__(dataset_dir, image_dir, anno_path, num_joints,
|
||||
transform)
|
||||
|
||||
self.trainsize = trainsize
|
||||
self.pixel_std = pixel_std
|
||||
self.dataset_name = 'coco_wholebady_hand'
|
||||
|
||||
def _box2cs(self, box):
|
||||
x, y, w, h = box[:4]
|
||||
center = np.zeros((2), dtype=np.float32)
|
||||
center[0] = x + w * 0.5
|
||||
center[1] = y + h * 0.5
|
||||
aspect_ratio = self.trainsize[0] * 1.0 / self.trainsize[1]
|
||||
|
||||
if w > aspect_ratio * h:
|
||||
h = w * 1.0 / aspect_ratio
|
||||
elif w < aspect_ratio * h:
|
||||
w = h * aspect_ratio
|
||||
scale = np.array(
|
||||
[w * 1.0 / self.pixel_std, h * 1.0 / self.pixel_std],
|
||||
dtype=np.float32)
|
||||
if center[0] != -1:
|
||||
scale = scale * 1.25
|
||||
|
||||
return center, scale
|
||||
|
||||
def parse_dataset(self):
|
||||
gt_db = []
|
||||
num_joints = self.ann_info['num_joints']
|
||||
coco = COCO(self.get_anno())
|
||||
img_ids = list(coco.imgs.keys())
|
||||
for img_id in img_ids:
|
||||
im_ann = coco.loadImgs(img_id)[0]
|
||||
image_file = os.path.join(self.img_prefix, im_ann['file_name'])
|
||||
im_id = int(im_ann["id"])
|
||||
|
||||
ann_ids = coco.getAnnIds(imgIds=img_id, iscrowd=False)
|
||||
objs = coco.loadAnns(ann_ids)
|
||||
|
||||
for obj in objs:
|
||||
for type in ['left', 'right']:
|
||||
if (obj[f'{type}hand_valid'] and
|
||||
max(obj[f'{type}hand_kpts']) > 0):
|
||||
|
||||
joints = np.zeros((num_joints, 3), dtype=np.float32)
|
||||
joints_vis = np.zeros((num_joints, 3), dtype=np.float32)
|
||||
|
||||
keypoints = np.array(obj[f'{type}hand_kpts'])
|
||||
keypoints = keypoints.reshape(-1, 3)
|
||||
joints[:, :2] = keypoints[:, :2]
|
||||
joints_vis[:, :2] = np.minimum(1, keypoints[:, 2:3])
|
||||
|
||||
center, scale = self._box2cs(obj[f'{type}hand_box'][:4])
|
||||
gt_db.append({
|
||||
'image_file': image_file,
|
||||
'center': center,
|
||||
'scale': scale,
|
||||
'gt_joints': joints,
|
||||
'joints_vis': joints_vis,
|
||||
'im_id': im_id,
|
||||
})
|
||||
|
||||
self.db = gt_db
|
||||
|
||||
|
||||
@register
|
||||
@serializable
|
||||
class KeypointTopDownMPIIDataset(KeypointTopDownBaseDataset):
|
||||
"""MPII dataset for topdown pose estimation.
|
||||
|
||||
The dataset loads raw features and apply specified transforms
|
||||
to return a dict containing the image tensors and other information.
|
||||
|
||||
MPII keypoint indexes::
|
||||
|
||||
0: 'right_ankle',
|
||||
1: 'right_knee',
|
||||
2: 'right_hip',
|
||||
3: 'left_hip',
|
||||
4: 'left_knee',
|
||||
5: 'left_ankle',
|
||||
6: 'pelvis',
|
||||
7: 'thorax',
|
||||
8: 'upper_neck',
|
||||
9: 'head_top',
|
||||
10: 'right_wrist',
|
||||
11: 'right_elbow',
|
||||
12: 'right_shoulder',
|
||||
13: 'left_shoulder',
|
||||
14: 'left_elbow',
|
||||
15: 'left_wrist',
|
||||
|
||||
Args:
|
||||
dataset_dir (str): Root path to the dataset.
|
||||
image_dir (str): Path to a directory where images are held.
|
||||
anno_path (str): Relative path to the annotation file.
|
||||
num_joints (int): Keypoint numbers
|
||||
trainsize (list):[w, h] Image target size
|
||||
transform (composed(operators)): A sequence of data transforms.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dataset_dir,
|
||||
image_dir,
|
||||
anno_path,
|
||||
num_joints,
|
||||
transform=[]):
|
||||
super().__init__(dataset_dir, image_dir, anno_path, num_joints,
|
||||
transform)
|
||||
|
||||
self.dataset_name = 'mpii'
|
||||
|
||||
def parse_dataset(self):
|
||||
with open(self.get_anno()) as anno_file:
|
||||
anno = json.load(anno_file)
|
||||
|
||||
gt_db = []
|
||||
for a in anno:
|
||||
image_name = a['image']
|
||||
im_id = a['image_id'] if 'image_id' in a else int(
|
||||
os.path.splitext(image_name)[0])
|
||||
|
||||
c = np.array(a['center'], dtype=np.float32)
|
||||
s = np.array([a['scale'], a['scale']], dtype=np.float32)
|
||||
|
||||
# Adjust center/scale slightly to avoid cropping limbs
|
||||
if c[0] != -1:
|
||||
c[1] = c[1] + 15 * s[1]
|
||||
s = s * 1.25
|
||||
c = c - 1
|
||||
|
||||
joints = np.zeros(
|
||||
(self.ann_info['num_joints'], 3), dtype=np.float32)
|
||||
joints_vis = np.zeros(
|
||||
(self.ann_info['num_joints'], 3), dtype=np.float32)
|
||||
if 'gt_joints' in a:
|
||||
joints_ = np.array(a['gt_joints'])
|
||||
joints_[:, 0:2] = joints_[:, 0:2] - 1
|
||||
joints_vis_ = np.array(a['joints_vis'])
|
||||
assert len(joints_) == self.ann_info[
|
||||
'num_joints'], 'joint num diff: {} vs {}'.format(
|
||||
len(joints_), self.ann_info['num_joints'])
|
||||
|
||||
joints[:, 0:2] = joints_[:, 0:2]
|
||||
joints_vis[:, 0] = joints_vis_[:]
|
||||
joints_vis[:, 1] = joints_vis_[:]
|
||||
|
||||
gt_db.append({
|
||||
'image_file': os.path.join(self.img_prefix, image_name),
|
||||
'im_id': im_id,
|
||||
'center': c,
|
||||
'scale': s,
|
||||
'gt_joints': joints,
|
||||
'joints_vis': joints_vis
|
||||
})
|
||||
print("number length: {}".format(len(gt_db)))
|
||||
self.db = gt_db
|
||||
638
paddle_detection/ppdet/data/source/mot.py
Normal file
638
paddle_detection/ppdet/data/source/mot.py
Normal file
@@ -0,0 +1,638 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import sys
|
||||
import cv2
|
||||
import glob
|
||||
import numpy as np
|
||||
from collections import OrderedDict, defaultdict
|
||||
try:
|
||||
from collections.abc import Sequence
|
||||
except Exception:
|
||||
from collections import Sequence
|
||||
from .dataset import DetDataset, _make_dataset, _is_valid_file
|
||||
from ppdet.core.workspace import register, serializable
|
||||
from ppdet.utils.logger import setup_logger
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
|
||||
@register
|
||||
@serializable
|
||||
class MOTDataSet(DetDataset):
|
||||
"""
|
||||
Load dataset with MOT format, only support single class MOT.
|
||||
|
||||
Args:
|
||||
dataset_dir (str): root directory for dataset.
|
||||
image_lists (str|list): mot data image lists, muiti-source mot dataset.
|
||||
data_fields (list): key name of data dictionary, at least have 'image'.
|
||||
sample_num (int): number of samples to load, -1 means all.
|
||||
repeat (int): repeat times for dataset, use in benchmark.
|
||||
|
||||
Notes:
|
||||
MOT datasets root directory following this:
|
||||
dataset/mot
|
||||
|——————image_lists
|
||||
| |——————caltech.train
|
||||
| |——————caltech.val
|
||||
| |——————mot16.train
|
||||
| |——————mot17.train
|
||||
| ......
|
||||
|——————Caltech
|
||||
|——————MOT17
|
||||
|——————......
|
||||
|
||||
All the MOT datasets have the following structure:
|
||||
Caltech
|
||||
|——————images
|
||||
| └——————00001.jpg
|
||||
| |—————— ...
|
||||
| └——————0000N.jpg
|
||||
└——————labels_with_ids
|
||||
└——————00001.txt
|
||||
|—————— ...
|
||||
└——————0000N.txt
|
||||
or
|
||||
|
||||
MOT17
|
||||
|——————images
|
||||
| └——————train
|
||||
| └——————test
|
||||
└——————labels_with_ids
|
||||
└——————train
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dataset_dir=None,
|
||||
image_lists=[],
|
||||
data_fields=['image'],
|
||||
sample_num=-1,
|
||||
repeat=1):
|
||||
super(MOTDataSet, self).__init__(
|
||||
dataset_dir=dataset_dir,
|
||||
data_fields=data_fields,
|
||||
sample_num=sample_num,
|
||||
repeat=repeat)
|
||||
self.dataset_dir = dataset_dir
|
||||
self.image_lists = image_lists
|
||||
if isinstance(self.image_lists, str):
|
||||
self.image_lists = [self.image_lists]
|
||||
self.roidbs = None
|
||||
self.cname2cid = None
|
||||
|
||||
def get_anno(self):
|
||||
if self.image_lists == []:
|
||||
return
|
||||
# only used to get categories and metric
|
||||
# only check first data, but the label_list of all data should be same.
|
||||
first_mot_data = self.image_lists[0].split('.')[0]
|
||||
anno_file = os.path.join(self.dataset_dir, first_mot_data,
|
||||
'label_list.txt')
|
||||
return anno_file
|
||||
|
||||
def parse_dataset(self):
|
||||
self.img_files = OrderedDict()
|
||||
self.img_start_index = OrderedDict()
|
||||
self.label_files = OrderedDict()
|
||||
self.tid_num = OrderedDict()
|
||||
self.tid_start_index = OrderedDict()
|
||||
|
||||
img_index = 0
|
||||
for data_name in self.image_lists:
|
||||
# check every data image list
|
||||
image_lists_dir = os.path.join(self.dataset_dir, 'image_lists')
|
||||
assert os.path.isdir(image_lists_dir), \
|
||||
"The {} is not a directory.".format(image_lists_dir)
|
||||
|
||||
list_path = os.path.join(image_lists_dir, data_name)
|
||||
assert os.path.exists(list_path), \
|
||||
"The list path {} does not exist.".format(list_path)
|
||||
|
||||
# record img_files, filter out empty ones
|
||||
with open(list_path, 'r') as file:
|
||||
self.img_files[data_name] = file.readlines()
|
||||
self.img_files[data_name] = [
|
||||
os.path.join(self.dataset_dir, x.strip())
|
||||
for x in self.img_files[data_name]
|
||||
]
|
||||
self.img_files[data_name] = list(
|
||||
filter(lambda x: len(x) > 0, self.img_files[data_name]))
|
||||
|
||||
self.img_start_index[data_name] = img_index
|
||||
img_index += len(self.img_files[data_name])
|
||||
|
||||
# record label_files
|
||||
self.label_files[data_name] = [
|
||||
x.replace('images', 'labels_with_ids').replace(
|
||||
'.png', '.txt').replace('.jpg', '.txt')
|
||||
for x in self.img_files[data_name]
|
||||
]
|
||||
|
||||
for data_name, label_paths in self.label_files.items():
|
||||
max_index = -1
|
||||
for lp in label_paths:
|
||||
lb = np.loadtxt(lp)
|
||||
if len(lb) < 1:
|
||||
continue
|
||||
if len(lb.shape) < 2:
|
||||
img_max = lb[1]
|
||||
else:
|
||||
img_max = np.max(lb[:, 1])
|
||||
if img_max > max_index:
|
||||
max_index = img_max
|
||||
self.tid_num[data_name] = int(max_index + 1)
|
||||
|
||||
last_index = 0
|
||||
for i, (k, v) in enumerate(self.tid_num.items()):
|
||||
self.tid_start_index[k] = last_index
|
||||
last_index += v
|
||||
|
||||
self.num_identities_dict = defaultdict(int)
|
||||
self.num_identities_dict[0] = int(last_index + 1) # single class
|
||||
self.num_imgs_each_data = [len(x) for x in self.img_files.values()]
|
||||
self.total_imgs = sum(self.num_imgs_each_data)
|
||||
|
||||
logger.info('MOT dataset summary: ')
|
||||
logger.info(self.tid_num)
|
||||
logger.info('Total images: {}'.format(self.total_imgs))
|
||||
logger.info('Image start index: {}'.format(self.img_start_index))
|
||||
logger.info('Total identities: {}'.format(self.num_identities_dict[0]))
|
||||
logger.info('Identity start index: {}'.format(self.tid_start_index))
|
||||
|
||||
records = []
|
||||
cname2cid = mot_label()
|
||||
|
||||
for img_index in range(self.total_imgs):
|
||||
for i, (k, v) in enumerate(self.img_start_index.items()):
|
||||
if img_index >= v:
|
||||
data_name = list(self.label_files.keys())[i]
|
||||
start_index = v
|
||||
img_file = self.img_files[data_name][img_index - start_index]
|
||||
lbl_file = self.label_files[data_name][img_index - start_index]
|
||||
|
||||
if not os.path.exists(img_file):
|
||||
logger.warning('Illegal image file: {}, and it will be ignored'.
|
||||
format(img_file))
|
||||
continue
|
||||
if not os.path.isfile(lbl_file):
|
||||
logger.warning('Illegal label file: {}, and it will be ignored'.
|
||||
format(lbl_file))
|
||||
continue
|
||||
|
||||
labels = np.loadtxt(lbl_file, dtype=np.float32).reshape(-1, 6)
|
||||
# each row in labels (N, 6) is [gt_class, gt_identity, cx, cy, w, h]
|
||||
|
||||
cx, cy = labels[:, 2], labels[:, 3]
|
||||
w, h = labels[:, 4], labels[:, 5]
|
||||
gt_bbox = np.stack((cx, cy, w, h)).T.astype('float32')
|
||||
gt_class = labels[:, 0:1].astype('int32')
|
||||
gt_score = np.ones((len(labels), 1)).astype('float32')
|
||||
gt_ide = labels[:, 1:2].astype('int32')
|
||||
for i, _ in enumerate(gt_ide):
|
||||
if gt_ide[i] > -1:
|
||||
gt_ide[i] += self.tid_start_index[data_name]
|
||||
|
||||
mot_rec = {
|
||||
'im_file': img_file,
|
||||
'im_id': img_index,
|
||||
} if 'image' in self.data_fields else {}
|
||||
|
||||
gt_rec = {
|
||||
'gt_class': gt_class,
|
||||
'gt_score': gt_score,
|
||||
'gt_bbox': gt_bbox,
|
||||
'gt_ide': gt_ide,
|
||||
}
|
||||
|
||||
for k, v in gt_rec.items():
|
||||
if k in self.data_fields:
|
||||
mot_rec[k] = v
|
||||
|
||||
records.append(mot_rec)
|
||||
if self.sample_num > 0 and img_index >= self.sample_num:
|
||||
break
|
||||
assert len(records) > 0, 'not found any mot record in %s' % (
|
||||
self.image_lists)
|
||||
self.roidbs, self.cname2cid = records, cname2cid
|
||||
|
||||
|
||||
@register
|
||||
@serializable
|
||||
class MCMOTDataSet(DetDataset):
|
||||
"""
|
||||
Load dataset with MOT format, support multi-class MOT.
|
||||
|
||||
Args:
|
||||
dataset_dir (str): root directory for dataset.
|
||||
image_lists (list(str)): mcmot data image lists, muiti-source mcmot dataset.
|
||||
data_fields (list): key name of data dictionary, at least have 'image'.
|
||||
label_list (str): if use_default_label is False, will load
|
||||
mapping between category and class index.
|
||||
sample_num (int): number of samples to load, -1 means all.
|
||||
|
||||
Notes:
|
||||
MCMOT datasets root directory following this:
|
||||
dataset/mot
|
||||
|——————image_lists
|
||||
| |——————visdrone_mcmot.train
|
||||
| |——————visdrone_mcmot.val
|
||||
visdrone_mcmot
|
||||
|——————images
|
||||
| └——————train
|
||||
| └——————val
|
||||
└——————labels_with_ids
|
||||
└——————train
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dataset_dir=None,
|
||||
image_lists=[],
|
||||
data_fields=['image'],
|
||||
label_list=None,
|
||||
sample_num=-1):
|
||||
super(MCMOTDataSet, self).__init__(
|
||||
dataset_dir=dataset_dir,
|
||||
data_fields=data_fields,
|
||||
sample_num=sample_num)
|
||||
self.dataset_dir = dataset_dir
|
||||
self.image_lists = image_lists
|
||||
if isinstance(self.image_lists, str):
|
||||
self.image_lists = [self.image_lists]
|
||||
self.label_list = label_list
|
||||
self.roidbs = None
|
||||
self.cname2cid = None
|
||||
|
||||
def get_anno(self):
|
||||
if self.image_lists == []:
|
||||
return
|
||||
# only used to get categories and metric
|
||||
# only check first data, but the label_list of all data should be same.
|
||||
first_mot_data = self.image_lists[0].split('.')[0]
|
||||
anno_file = os.path.join(self.dataset_dir, first_mot_data,
|
||||
'label_list.txt')
|
||||
return anno_file
|
||||
|
||||
def parse_dataset(self):
|
||||
self.img_files = OrderedDict()
|
||||
self.img_start_index = OrderedDict()
|
||||
self.label_files = OrderedDict()
|
||||
self.tid_num = OrderedDict()
|
||||
self.tid_start_idx_of_cls_ids = defaultdict(dict) # for MCMOT
|
||||
|
||||
img_index = 0
|
||||
for data_name in self.image_lists:
|
||||
# check every data image list
|
||||
image_lists_dir = os.path.join(self.dataset_dir, 'image_lists')
|
||||
assert os.path.isdir(image_lists_dir), \
|
||||
"The {} is not a directory.".format(image_lists_dir)
|
||||
|
||||
list_path = os.path.join(image_lists_dir, data_name)
|
||||
assert os.path.exists(list_path), \
|
||||
"The list path {} does not exist.".format(list_path)
|
||||
|
||||
# record img_files, filter out empty ones
|
||||
with open(list_path, 'r') as file:
|
||||
self.img_files[data_name] = file.readlines()
|
||||
self.img_files[data_name] = [
|
||||
os.path.join(self.dataset_dir, x.strip())
|
||||
for x in self.img_files[data_name]
|
||||
]
|
||||
self.img_files[data_name] = list(
|
||||
filter(lambda x: len(x) > 0, self.img_files[data_name]))
|
||||
|
||||
self.img_start_index[data_name] = img_index
|
||||
img_index += len(self.img_files[data_name])
|
||||
|
||||
# record label_files
|
||||
self.label_files[data_name] = [
|
||||
x.replace('images', 'labels_with_ids').replace(
|
||||
'.png', '.txt').replace('.jpg', '.txt')
|
||||
for x in self.img_files[data_name]
|
||||
]
|
||||
|
||||
for data_name, label_paths in self.label_files.items():
|
||||
# using max_ids_dict rather than max_index
|
||||
max_ids_dict = defaultdict(int)
|
||||
for lp in label_paths:
|
||||
lb = np.loadtxt(lp)
|
||||
if len(lb) < 1:
|
||||
continue
|
||||
lb = lb.reshape(-1, 6)
|
||||
for item in lb:
|
||||
if item[1] > max_ids_dict[int(item[0])]:
|
||||
# item[0]: cls_id
|
||||
# item[1]: track id
|
||||
max_ids_dict[int(item[0])] = int(item[1])
|
||||
# track id number
|
||||
self.tid_num[data_name] = max_ids_dict
|
||||
|
||||
last_idx_dict = defaultdict(int)
|
||||
for i, (k, v) in enumerate(self.tid_num.items()): # each sub dataset
|
||||
for cls_id, id_num in v.items(): # v is a max_ids_dict
|
||||
self.tid_start_idx_of_cls_ids[k][cls_id] = last_idx_dict[cls_id]
|
||||
last_idx_dict[cls_id] += id_num
|
||||
|
||||
self.num_identities_dict = defaultdict(int)
|
||||
for k, v in last_idx_dict.items():
|
||||
self.num_identities_dict[k] = int(v) # total ids of each category
|
||||
|
||||
self.num_imgs_each_data = [len(x) for x in self.img_files.values()]
|
||||
self.total_imgs = sum(self.num_imgs_each_data)
|
||||
|
||||
# cname2cid and cid2cname
|
||||
cname2cid = {}
|
||||
if self.label_list is not None:
|
||||
# if use label_list for multi source mix dataset,
|
||||
# please make sure label_list in the first sub_dataset at least.
|
||||
sub_dataset = self.image_lists[0].split('.')[0]
|
||||
label_path = os.path.join(self.dataset_dir, sub_dataset,
|
||||
self.label_list)
|
||||
if not os.path.exists(label_path):
|
||||
logger.info(
|
||||
"Note: label_list {} does not exists, use VisDrone 10 classes labels as default.".
|
||||
format(label_path))
|
||||
cname2cid = visdrone_mcmot_label()
|
||||
else:
|
||||
with open(label_path, 'r') as fr:
|
||||
label_id = 0
|
||||
for line in fr.readlines():
|
||||
cname2cid[line.strip()] = label_id
|
||||
label_id += 1
|
||||
else:
|
||||
cname2cid = visdrone_mcmot_label()
|
||||
|
||||
cid2cname = dict([(v, k) for (k, v) in cname2cid.items()])
|
||||
|
||||
logger.info('MCMOT dataset summary: ')
|
||||
logger.info(self.tid_num)
|
||||
logger.info('Total images: {}'.format(self.total_imgs))
|
||||
logger.info('Image start index: {}'.format(self.img_start_index))
|
||||
|
||||
logger.info('Total identities of each category: ')
|
||||
num_identities_dict = sorted(
|
||||
self.num_identities_dict.items(), key=lambda x: x[0])
|
||||
total_IDs_all_cats = 0
|
||||
for (k, v) in num_identities_dict:
|
||||
logger.info('Category {} [{}] has {} IDs.'.format(k, cid2cname[k],
|
||||
v))
|
||||
total_IDs_all_cats += v
|
||||
logger.info('Total identities of all categories: {}'.format(
|
||||
total_IDs_all_cats))
|
||||
|
||||
logger.info('Identity start index of each category: ')
|
||||
for k, v in self.tid_start_idx_of_cls_ids.items():
|
||||
sorted_v = sorted(v.items(), key=lambda x: x[0])
|
||||
for (cls_id, start_idx) in sorted_v:
|
||||
logger.info('Start index of dataset {} category {:d} is {:d}'
|
||||
.format(k, cls_id, start_idx))
|
||||
|
||||
records = []
|
||||
for img_index in range(self.total_imgs):
|
||||
for i, (k, v) in enumerate(self.img_start_index.items()):
|
||||
if img_index >= v:
|
||||
data_name = list(self.label_files.keys())[i]
|
||||
start_index = v
|
||||
img_file = self.img_files[data_name][img_index - start_index]
|
||||
lbl_file = self.label_files[data_name][img_index - start_index]
|
||||
|
||||
if not os.path.exists(img_file):
|
||||
logger.warning('Illegal image file: {}, and it will be ignored'.
|
||||
format(img_file))
|
||||
continue
|
||||
if not os.path.isfile(lbl_file):
|
||||
logger.warning('Illegal label file: {}, and it will be ignored'.
|
||||
format(lbl_file))
|
||||
continue
|
||||
|
||||
labels = np.loadtxt(lbl_file, dtype=np.float32).reshape(-1, 6)
|
||||
# each row in labels (N, 6) is [gt_class, gt_identity, cx, cy, w, h]
|
||||
|
||||
cx, cy = labels[:, 2], labels[:, 3]
|
||||
w, h = labels[:, 4], labels[:, 5]
|
||||
gt_bbox = np.stack((cx, cy, w, h)).T.astype('float32')
|
||||
gt_class = labels[:, 0:1].astype('int32')
|
||||
gt_score = np.ones((len(labels), 1)).astype('float32')
|
||||
gt_ide = labels[:, 1:2].astype('int32')
|
||||
for i, _ in enumerate(gt_ide):
|
||||
if gt_ide[i] > -1:
|
||||
cls_id = int(gt_class[i])
|
||||
start_idx = self.tid_start_idx_of_cls_ids[data_name][cls_id]
|
||||
gt_ide[i] += start_idx
|
||||
|
||||
mot_rec = {
|
||||
'im_file': img_file,
|
||||
'im_id': img_index,
|
||||
} if 'image' in self.data_fields else {}
|
||||
|
||||
gt_rec = {
|
||||
'gt_class': gt_class,
|
||||
'gt_score': gt_score,
|
||||
'gt_bbox': gt_bbox,
|
||||
'gt_ide': gt_ide,
|
||||
}
|
||||
|
||||
for k, v in gt_rec.items():
|
||||
if k in self.data_fields:
|
||||
mot_rec[k] = v
|
||||
|
||||
records.append(mot_rec)
|
||||
if self.sample_num > 0 and img_index >= self.sample_num:
|
||||
break
|
||||
assert len(records) > 0, 'not found any mot record in %s' % (
|
||||
self.image_lists)
|
||||
self.roidbs, self.cname2cid = records, cname2cid
|
||||
|
||||
|
||||
@register
|
||||
@serializable
|
||||
class MOTImageFolder(DetDataset):
|
||||
"""
|
||||
Load MOT dataset with MOT format from image folder or video .
|
||||
Args:
|
||||
video_file (str): path of the video file, default ''.
|
||||
frame_rate (int): frame rate of the video, use cv2 VideoCapture if not set.
|
||||
dataset_dir (str): root directory for dataset.
|
||||
keep_ori_im (bool): whether to keep original image, default False.
|
||||
Set True when used during MOT model inference while saving
|
||||
images or video, or used in DeepSORT.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
video_file=None,
|
||||
frame_rate=-1,
|
||||
dataset_dir=None,
|
||||
data_root=None,
|
||||
image_dir=None,
|
||||
sample_num=-1,
|
||||
keep_ori_im=False,
|
||||
anno_path=None,
|
||||
**kwargs):
|
||||
super(MOTImageFolder, self).__init__(
|
||||
dataset_dir, image_dir, sample_num=sample_num)
|
||||
self.video_file = video_file
|
||||
self.data_root = data_root
|
||||
self.keep_ori_im = keep_ori_im
|
||||
self._imid2path = {}
|
||||
self.roidbs = None
|
||||
self.frame_rate = frame_rate
|
||||
self.anno_path = anno_path
|
||||
|
||||
def check_or_download_dataset(self):
|
||||
return
|
||||
|
||||
def parse_dataset(self, ):
|
||||
if not self.roidbs:
|
||||
if self.video_file is None:
|
||||
self.frame_rate = 30 # set as default if infer image folder
|
||||
self.roidbs = self._load_images()
|
||||
else:
|
||||
self.roidbs = self._load_video_images()
|
||||
|
||||
def _load_video_images(self):
|
||||
if self.frame_rate == -1:
|
||||
# if frame_rate is not set for video, use cv2.VideoCapture
|
||||
cap = cv2.VideoCapture(self.video_file)
|
||||
self.frame_rate = int(cap.get(cv2.CAP_PROP_FPS))
|
||||
|
||||
extension = self.video_file.split('.')[-1]
|
||||
output_path = self.video_file.replace('.{}'.format(extension), '')
|
||||
frames_path = video2frames(self.video_file, output_path,
|
||||
self.frame_rate)
|
||||
self.video_frames = sorted(
|
||||
glob.glob(os.path.join(frames_path, '*.png')))
|
||||
|
||||
self.video_length = len(self.video_frames)
|
||||
logger.info('Length of the video: {:d} frames.'.format(
|
||||
self.video_length))
|
||||
ct = 0
|
||||
records = []
|
||||
for image in self.video_frames:
|
||||
assert image != '' and os.path.isfile(image), \
|
||||
"Image {} not found".format(image)
|
||||
if self.sample_num > 0 and ct >= self.sample_num:
|
||||
break
|
||||
rec = {'im_id': np.array([ct]), 'im_file': image}
|
||||
if self.keep_ori_im:
|
||||
rec.update({'keep_ori_im': 1})
|
||||
self._imid2path[ct] = image
|
||||
ct += 1
|
||||
records.append(rec)
|
||||
assert len(records) > 0, "No image file found"
|
||||
return records
|
||||
|
||||
def _find_images(self):
|
||||
image_dir = self.image_dir
|
||||
if not isinstance(image_dir, Sequence):
|
||||
image_dir = [image_dir]
|
||||
images = []
|
||||
for im_dir in image_dir:
|
||||
if os.path.isdir(im_dir):
|
||||
im_dir = os.path.join(self.dataset_dir, im_dir)
|
||||
images.extend(_make_dataset(im_dir))
|
||||
elif os.path.isfile(im_dir) and _is_valid_file(im_dir):
|
||||
images.append(im_dir)
|
||||
return images
|
||||
|
||||
def _load_images(self):
|
||||
images = self._find_images()
|
||||
ct = 0
|
||||
records = []
|
||||
for image in images:
|
||||
assert image != '' and os.path.isfile(image), \
|
||||
"Image {} not found".format(image)
|
||||
if self.sample_num > 0 and ct >= self.sample_num:
|
||||
break
|
||||
rec = {'im_id': np.array([ct]), 'im_file': image}
|
||||
if self.keep_ori_im:
|
||||
rec.update({'keep_ori_im': 1})
|
||||
self._imid2path[ct] = image
|
||||
ct += 1
|
||||
records.append(rec)
|
||||
assert len(records) > 0, "No image file found"
|
||||
return records
|
||||
|
||||
def get_imid2path(self):
|
||||
return self._imid2path
|
||||
|
||||
def set_images(self, images):
|
||||
self.image_dir = images
|
||||
self.roidbs = self._load_images()
|
||||
|
||||
def set_video(self, video_file, frame_rate):
|
||||
# update video_file and frame_rate by command line of tools/infer_mot.py
|
||||
self.video_file = video_file
|
||||
self.frame_rate = frame_rate
|
||||
assert os.path.isfile(self.video_file) and _is_valid_video(self.video_file), \
|
||||
"wrong or unsupported file format: {}".format(self.video_file)
|
||||
self.roidbs = self._load_video_images()
|
||||
|
||||
def get_anno(self):
|
||||
return self.anno_path
|
||||
|
||||
|
||||
def _is_valid_video(f, extensions=('.mp4', '.avi', '.mov', '.rmvb', 'flv')):
|
||||
return f.lower().endswith(extensions)
|
||||
|
||||
|
||||
def video2frames(video_path, outpath, frame_rate, **kargs):
|
||||
def _dict2str(kargs):
|
||||
cmd_str = ''
|
||||
for k, v in kargs.items():
|
||||
cmd_str += (' ' + str(k) + ' ' + str(v))
|
||||
return cmd_str
|
||||
|
||||
ffmpeg = ['ffmpeg ', ' -y -loglevel ', ' error ']
|
||||
vid_name = os.path.basename(video_path).split('.')[0]
|
||||
out_full_path = os.path.join(outpath, vid_name)
|
||||
|
||||
if not os.path.exists(out_full_path):
|
||||
os.makedirs(out_full_path)
|
||||
|
||||
# video file name
|
||||
outformat = os.path.join(out_full_path, '%08d.png')
|
||||
|
||||
cmd = ffmpeg
|
||||
cmd = ffmpeg + [
|
||||
' -i ', video_path, ' -r ', str(frame_rate), ' -f image2 ', outformat
|
||||
]
|
||||
cmd = ''.join(cmd) + _dict2str(kargs)
|
||||
|
||||
if os.system(cmd) != 0:
|
||||
raise RuntimeError('ffmpeg process video: {} error'.format(video_path))
|
||||
sys.exit(-1)
|
||||
|
||||
sys.stdout.flush()
|
||||
return out_full_path
|
||||
|
||||
|
||||
def mot_label():
|
||||
labels_map = {'person': 0}
|
||||
return labels_map
|
||||
|
||||
|
||||
def visdrone_mcmot_label():
|
||||
labels_map = {
|
||||
'pedestrian': 0,
|
||||
'people': 1,
|
||||
'bicycle': 2,
|
||||
'car': 3,
|
||||
'van': 4,
|
||||
'truck': 5,
|
||||
'tricycle': 6,
|
||||
'awning-tricycle': 7,
|
||||
'bus': 8,
|
||||
'motor': 9,
|
||||
}
|
||||
return labels_map
|
||||
380
paddle_detection/ppdet/data/source/pose3d_cmb.py
Normal file
380
paddle_detection/ppdet/data/source/pose3d_cmb.py
Normal file
@@ -0,0 +1,380 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
import json
|
||||
import copy
|
||||
import pycocotools
|
||||
from pycocotools.coco import COCO
|
||||
from .dataset import DetDataset
|
||||
from ppdet.core.workspace import register, serializable
|
||||
from paddle.io import Dataset
|
||||
|
||||
|
||||
@serializable
|
||||
class Pose3DDataset(DetDataset):
|
||||
"""Pose3D Dataset class.
|
||||
|
||||
Args:
|
||||
dataset_dir (str): Root path to the dataset.
|
||||
anno_list (list of str): each of the element is a relative path to the annotation file.
|
||||
image_dirs (list of str): each of path is a relative path where images are held.
|
||||
transform (composed(operators)): A sequence of data transforms.
|
||||
test_mode (bool): Store True when building test or
|
||||
validation dataset. Default: False.
|
||||
24 joints order:
|
||||
0-2: 'R_Ankle', 'R_Knee', 'R_Hip',
|
||||
3-5:'L_Hip', 'L_Knee', 'L_Ankle',
|
||||
6-8:'R_Wrist', 'R_Elbow', 'R_Shoulder',
|
||||
9-11:'L_Shoulder','L_Elbow','L_Wrist',
|
||||
12-14:'Neck','Top_of_Head','Pelvis',
|
||||
15-18:'Thorax','Spine','Jaw','Head',
|
||||
19-23:'Nose','L_Eye','R_Eye','L_Ear','R_Ear'
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dataset_dir,
|
||||
image_dirs,
|
||||
anno_list,
|
||||
transform=[],
|
||||
num_joints=24,
|
||||
test_mode=False):
|
||||
super().__init__(dataset_dir, image_dirs, anno_list)
|
||||
self.image_info = {}
|
||||
self.ann_info = {}
|
||||
self.num_joints = num_joints
|
||||
|
||||
self.transform = transform
|
||||
self.test_mode = test_mode
|
||||
|
||||
self.img_ids = []
|
||||
self.dataset_dir = dataset_dir
|
||||
self.image_dirs = image_dirs
|
||||
self.anno_list = anno_list
|
||||
|
||||
def get_mask(self, mvm_percent=0.3):
|
||||
num_joints = self.num_joints
|
||||
mjm_mask = np.ones((num_joints, 1)).astype(np.float32)
|
||||
if self.test_mode == False:
|
||||
pb = np.random.random_sample()
|
||||
masked_num = int(
|
||||
pb * mvm_percent *
|
||||
num_joints) # at most x% of the joints could be masked
|
||||
indices = np.random.choice(
|
||||
np.arange(num_joints), replace=False, size=masked_num)
|
||||
mjm_mask[indices, :] = 0.0
|
||||
# return mjm_mask
|
||||
|
||||
num_joints = 10
|
||||
mvm_mask = np.ones((num_joints, 1)).astype(np.float)
|
||||
if self.test_mode == False:
|
||||
num_vertices = num_joints
|
||||
pb = np.random.random_sample()
|
||||
masked_num = int(
|
||||
pb * mvm_percent *
|
||||
num_vertices) # at most x% of the vertices could be masked
|
||||
indices = np.random.choice(
|
||||
np.arange(num_vertices), replace=False, size=masked_num)
|
||||
mvm_mask[indices, :] = 0.0
|
||||
|
||||
mjm_mask = np.concatenate([mjm_mask, mvm_mask], axis=0)
|
||||
return mjm_mask
|
||||
|
||||
def filterjoints(self, x):
|
||||
if self.num_joints == 24:
|
||||
return x
|
||||
elif self.num_joints == 14:
|
||||
return x[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 18], :]
|
||||
elif self.num_joints == 17:
|
||||
return x[
|
||||
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 18, 19], :]
|
||||
else:
|
||||
raise ValueError(
|
||||
"unsupported joint numbers, only [24 or 17 or 14] is supported!")
|
||||
|
||||
def parse_dataset(self):
|
||||
print("Loading annotations..., please wait")
|
||||
self.annos = []
|
||||
im_id = 0
|
||||
self.human36m_num = 0
|
||||
for idx, annof in enumerate(self.anno_list):
|
||||
img_prefix = os.path.join(self.dataset_dir, self.image_dirs[idx])
|
||||
dataf = os.path.join(self.dataset_dir, annof)
|
||||
with open(dataf, 'r') as rf:
|
||||
anno_data = json.load(rf)
|
||||
annos = anno_data['data']
|
||||
new_annos = []
|
||||
print("{} has annos numbers: {}".format(dataf, len(annos)))
|
||||
for anno in annos:
|
||||
new_anno = {}
|
||||
new_anno['im_id'] = im_id
|
||||
im_id += 1
|
||||
imagename = anno['imageName']
|
||||
if imagename.startswith("COCO_train2014_"):
|
||||
imagename = imagename[len("COCO_train2014_"):]
|
||||
elif imagename.startswith("COCO_val2014_"):
|
||||
imagename = imagename[len("COCO_val2014_"):]
|
||||
imagename = os.path.join(img_prefix, imagename)
|
||||
if not os.path.exists(imagename):
|
||||
if "train2017" in imagename:
|
||||
imagename = imagename.replace("train2017",
|
||||
"val2017")
|
||||
if not os.path.exists(imagename):
|
||||
print("cannot find imagepath:{}".format(
|
||||
imagename))
|
||||
continue
|
||||
else:
|
||||
print("cannot find imagepath:{}".format(imagename))
|
||||
continue
|
||||
new_anno['imageName'] = imagename
|
||||
if 'human3.6m' in imagename:
|
||||
self.human36m_num += 1
|
||||
new_anno['bbox_center'] = anno['bbox_center']
|
||||
new_anno['bbox_scale'] = anno['bbox_scale']
|
||||
new_anno['joints_2d'] = np.array(anno[
|
||||
'gt_keypoint_2d']).astype(np.float32)
|
||||
if new_anno['joints_2d'].shape[0] == 49:
|
||||
#if the joints_2d is in SPIN format(which generated by eft), choose the last 24 public joints
|
||||
#for detail please refer: https://github.com/nkolot/SPIN/blob/master/constants.py
|
||||
new_anno['joints_2d'] = new_anno['joints_2d'][25:]
|
||||
new_anno['joints_3d'] = np.array(anno[
|
||||
'pose3d'])[:, :3].astype(np.float32)
|
||||
new_anno['mjm_mask'] = self.get_mask()
|
||||
if not 'has_3d_joints' in anno:
|
||||
new_anno['has_3d_joints'] = int(1)
|
||||
new_anno['has_2d_joints'] = int(1)
|
||||
else:
|
||||
new_anno['has_3d_joints'] = int(anno['has_3d_joints'])
|
||||
new_anno['has_2d_joints'] = int(anno['has_2d_joints'])
|
||||
new_anno['joints_2d'] = self.filterjoints(new_anno[
|
||||
'joints_2d'])
|
||||
self.annos.append(new_anno)
|
||||
del annos
|
||||
|
||||
def get_temp_num(self):
|
||||
"""get temporal data number, like human3.6m"""
|
||||
return self.human36m_num
|
||||
|
||||
def __len__(self):
|
||||
"""Get dataset length."""
|
||||
return len(self.annos)
|
||||
|
||||
def _get_imganno(self, idx):
|
||||
"""Get anno for a single image."""
|
||||
return self.annos[idx]
|
||||
|
||||
def __getitem__(self, idx):
|
||||
"""Prepare image for training given the index."""
|
||||
records = copy.deepcopy(self._get_imganno(idx))
|
||||
imgpath = records['imageName']
|
||||
assert os.path.exists(imgpath), "cannot find image {}".format(imgpath)
|
||||
records['image'] = cv2.imread(imgpath)
|
||||
records['image'] = cv2.cvtColor(records['image'], cv2.COLOR_BGR2RGB)
|
||||
records = self.transform(records)
|
||||
return records
|
||||
|
||||
def check_or_download_dataset(self):
|
||||
alldatafind = True
|
||||
for image_dir in self.image_dirs:
|
||||
image_dir = os.path.join(self.dataset_dir, image_dir)
|
||||
if not os.path.isdir(image_dir):
|
||||
print("dataset [{}] is not found".format(image_dir))
|
||||
alldatafind = False
|
||||
if not alldatafind:
|
||||
raise ValueError(
|
||||
"Some dataset is not valid and cannot download automatically now, please prepare the dataset first"
|
||||
)
|
||||
|
||||
|
||||
@register
|
||||
@serializable
|
||||
class Keypoint3DMultiFramesDataset(Dataset):
|
||||
"""24 keypoints 3D dataset for pose estimation.
|
||||
|
||||
each item is a list of images
|
||||
|
||||
The dataset loads raw features and apply specified transforms
|
||||
to return a dict containing the image tensors and other information.
|
||||
|
||||
Args:
|
||||
dataset_dir (str): Root path to the dataset.
|
||||
image_dir (str): Path to a directory where images are held.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_dir, # 数据集根目录
|
||||
image_dir, # 图像文件夹
|
||||
p3d_dir, # 3D关键点文件夹
|
||||
json_path,
|
||||
img_size, #图像resize大小
|
||||
num_frames, # 帧序列长度
|
||||
anno_path=None, ):
|
||||
|
||||
self.dataset_dir = dataset_dir
|
||||
self.image_dir = image_dir
|
||||
self.p3d_dir = p3d_dir
|
||||
self.json_path = json_path
|
||||
self.img_size = img_size
|
||||
self.num_frames = num_frames
|
||||
self.anno_path = anno_path
|
||||
|
||||
self.data_labels, self.mf_inds = self._generate_multi_frames_list()
|
||||
|
||||
def _generate_multi_frames_list(self):
|
||||
act_list = os.listdir(self.dataset_dir) # 动作列表
|
||||
count = 0
|
||||
mf_list = []
|
||||
annos_dict = {'images': [], 'annotations': [], 'act_inds': []}
|
||||
for act in act_list: #对每个动作,生成帧序列
|
||||
if '.' in act:
|
||||
continue
|
||||
|
||||
json_path = os.path.join(self.dataset_dir, act, self.json_path)
|
||||
with open(json_path, 'r') as j:
|
||||
annos = json.load(j)
|
||||
length = len(annos['images'])
|
||||
for k, v in annos.items():
|
||||
if k in annos_dict:
|
||||
annos_dict[k].extend(v)
|
||||
annos_dict['act_inds'].extend([act] * length)
|
||||
|
||||
mf = [[i + j + count for j in range(self.num_frames)]
|
||||
for i in range(0, length - self.num_frames + 1)]
|
||||
mf_list.extend(mf)
|
||||
count += length
|
||||
|
||||
print("total data number:", len(mf_list))
|
||||
return annos_dict, mf_list
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def __getitem__(self, index): # 拿一个连续的序列
|
||||
inds = self.mf_inds[
|
||||
index] # 如[568, 569, 570, 571, 572, 573],长度为num_frames
|
||||
|
||||
images = self.data_labels['images'] # all images
|
||||
annots = self.data_labels['annotations'] # all annots
|
||||
|
||||
act = self.data_labels['act_inds'][inds[0]] # 动作名(文件夹名)
|
||||
|
||||
kps3d_list = []
|
||||
kps3d_vis_list = []
|
||||
names = []
|
||||
|
||||
h, w = 0, 0
|
||||
for ind in inds: # one image
|
||||
height = float(images[ind]['height'])
|
||||
width = float(images[ind]['width'])
|
||||
name = images[ind]['file_name'] # 图像名称,带有后缀
|
||||
|
||||
kps3d_name = name.split('.')[0] + '.obj'
|
||||
kps3d_path = os.path.join(self.dataset_dir, act, self.p3d_dir,
|
||||
kps3d_name)
|
||||
|
||||
joints, joints_vis = self.kps3d_process(kps3d_path)
|
||||
joints_vis = np.array(joints_vis, dtype=np.float32)
|
||||
|
||||
kps3d_list.append(joints)
|
||||
kps3d_vis_list.append(joints_vis)
|
||||
names.append(name)
|
||||
|
||||
kps3d = np.array(kps3d_list) # (6, 24, 3),(num_frames, joints_num, 3)
|
||||
kps3d_vis = np.array(kps3d_vis_list)
|
||||
|
||||
# read image
|
||||
imgs = []
|
||||
for name in names:
|
||||
img_path = os.path.join(self.dataset_dir, act, self.image_dir, name)
|
||||
|
||||
image = cv2.imread(img_path, cv2.IMREAD_COLOR |
|
||||
cv2.IMREAD_IGNORE_ORIENTATION)
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
|
||||
imgs.append(np.expand_dims(image, axis=0))
|
||||
|
||||
imgs = np.concatenate(imgs, axis=0)
|
||||
imgs = imgs.astype(
|
||||
np.float32) # (6, 1080, 1920, 3),(num_frames, h, w, c)
|
||||
|
||||
# attention: 此时图像和标注是镜像的
|
||||
records = {
|
||||
'kps3d': kps3d,
|
||||
'kps3d_vis': kps3d_vis,
|
||||
"image": imgs,
|
||||
'act': act,
|
||||
'names': names,
|
||||
'im_id': index
|
||||
}
|
||||
|
||||
return self.transform(records)
|
||||
|
||||
def kps3d_process(self, kps3d_path):
|
||||
count = 0
|
||||
kps = []
|
||||
kps_vis = []
|
||||
|
||||
with open(kps3d_path, 'r') as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
if line[0] == 'v':
|
||||
kps.append([])
|
||||
line = line.strip('\n').split(' ')[1:]
|
||||
for kp in line:
|
||||
kps[-1].append(float(kp))
|
||||
count += 1
|
||||
|
||||
kps_vis.append([1, 1, 1])
|
||||
|
||||
kps = np.array(kps) # 52,3
|
||||
kps_vis = np.array(kps_vis)
|
||||
|
||||
kps *= 10 # scale points
|
||||
kps -= kps[[0], :] # set root point to zero
|
||||
|
||||
kps = np.concatenate((kps[0:23], kps[[37]]), axis=0) # 24,3
|
||||
|
||||
kps *= 10
|
||||
|
||||
kps_vis = np.concatenate((kps_vis[0:23], kps_vis[[37]]), axis=0) # 24,3
|
||||
|
||||
return kps, kps_vis
|
||||
|
||||
def __len__(self):
|
||||
return len(self.mf_inds)
|
||||
|
||||
def get_anno(self):
|
||||
if self.anno_path is None:
|
||||
return
|
||||
return os.path.join(self.dataset_dir, self.anno_path)
|
||||
|
||||
def check_or_download_dataset(self):
|
||||
return
|
||||
|
||||
def parse_dataset(self, ):
|
||||
return
|
||||
|
||||
def set_transform(self, transform):
|
||||
self.transform = transform
|
||||
|
||||
def set_epoch(self, epoch_id):
|
||||
self._epoch = epoch_id
|
||||
|
||||
def set_kwargs(self, **kwargs):
|
||||
self.mixup_epoch = kwargs.get('mixup_epoch', -1)
|
||||
self.cutmix_epoch = kwargs.get('cutmix_epoch', -1)
|
||||
self.mosaic_epoch = kwargs.get('mosaic_epoch', -1)
|
||||
194
paddle_detection/ppdet/data/source/sniper_coco.py
Normal file
194
paddle_detection/ppdet/data/source/sniper_coco.py
Normal file
@@ -0,0 +1,194 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import cv2
|
||||
import json
|
||||
import copy
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
from collections.abc import Sequence
|
||||
except Exception:
|
||||
from collections import Sequence
|
||||
|
||||
from ppdet.core.workspace import register, serializable
|
||||
from ppdet.data.crop_utils.annotation_cropper import AnnoCropper
|
||||
from .coco import COCODataSet
|
||||
from .dataset import _make_dataset, _is_valid_file
|
||||
from ppdet.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger('sniper_coco_dataset')
|
||||
|
||||
|
||||
@register
|
||||
@serializable
|
||||
class SniperCOCODataSet(COCODataSet):
|
||||
"""SniperCOCODataSet"""
|
||||
|
||||
def __init__(self,
|
||||
dataset_dir=None,
|
||||
image_dir=None,
|
||||
anno_path=None,
|
||||
proposals_file=None,
|
||||
data_fields=['image'],
|
||||
sample_num=-1,
|
||||
load_crowd=False,
|
||||
allow_empty=True,
|
||||
empty_ratio=1.,
|
||||
is_trainset=True,
|
||||
image_target_sizes=[2000, 1000],
|
||||
valid_box_ratio_ranges=[[-1, 0.1],[0.08, -1]],
|
||||
chip_target_size=500,
|
||||
chip_target_stride=200,
|
||||
use_neg_chip=False,
|
||||
max_neg_num_per_im=8,
|
||||
max_per_img=-1,
|
||||
nms_thresh=0.5):
|
||||
super(SniperCOCODataSet, self).__init__(
|
||||
dataset_dir=dataset_dir,
|
||||
image_dir=image_dir,
|
||||
anno_path=anno_path,
|
||||
data_fields=data_fields,
|
||||
sample_num=sample_num,
|
||||
load_crowd=load_crowd,
|
||||
allow_empty=allow_empty,
|
||||
empty_ratio=empty_ratio
|
||||
)
|
||||
self.proposals_file = proposals_file
|
||||
self.proposals = None
|
||||
self.anno_cropper = None
|
||||
self.is_trainset = is_trainset
|
||||
self.image_target_sizes = image_target_sizes
|
||||
self.valid_box_ratio_ranges = valid_box_ratio_ranges
|
||||
self.chip_target_size = chip_target_size
|
||||
self.chip_target_stride = chip_target_stride
|
||||
self.use_neg_chip = use_neg_chip
|
||||
self.max_neg_num_per_im = max_neg_num_per_im
|
||||
self.max_per_img = max_per_img
|
||||
self.nms_thresh = nms_thresh
|
||||
|
||||
|
||||
def parse_dataset(self):
|
||||
if not hasattr(self, "roidbs"):
|
||||
super(SniperCOCODataSet, self).parse_dataset()
|
||||
if self.is_trainset:
|
||||
self._parse_proposals()
|
||||
self._merge_anno_proposals()
|
||||
self.ori_roidbs = copy.deepcopy(self.roidbs)
|
||||
self.init_anno_cropper()
|
||||
self.roidbs = self.generate_chips_roidbs(self.roidbs, self.is_trainset)
|
||||
|
||||
def set_proposals_file(self, file_path):
|
||||
self.proposals_file = file_path
|
||||
|
||||
def init_anno_cropper(self):
|
||||
logger.info("Init AnnoCropper...")
|
||||
self.anno_cropper = AnnoCropper(
|
||||
image_target_sizes=self.image_target_sizes,
|
||||
valid_box_ratio_ranges=self.valid_box_ratio_ranges,
|
||||
chip_target_size=self.chip_target_size,
|
||||
chip_target_stride=self.chip_target_stride,
|
||||
use_neg_chip=self.use_neg_chip,
|
||||
max_neg_num_per_im=self.max_neg_num_per_im,
|
||||
max_per_img=self.max_per_img,
|
||||
nms_thresh=self.nms_thresh
|
||||
)
|
||||
|
||||
def generate_chips_roidbs(self, roidbs, is_trainset):
|
||||
if is_trainset:
|
||||
roidbs = self.anno_cropper.crop_anno_records(roidbs)
|
||||
else:
|
||||
roidbs = self.anno_cropper.crop_infer_anno_records(roidbs)
|
||||
return roidbs
|
||||
|
||||
def _parse_proposals(self):
|
||||
if self.proposals_file:
|
||||
self.proposals = {}
|
||||
logger.info("Parse proposals file:{}".format(self.proposals_file))
|
||||
with open(self.proposals_file, 'r') as f:
|
||||
proposals = json.load(f)
|
||||
for prop in proposals:
|
||||
image_id = prop["image_id"]
|
||||
if image_id not in self.proposals:
|
||||
self.proposals[image_id] = []
|
||||
x, y, w, h = prop["bbox"]
|
||||
self.proposals[image_id].append([x, y, x + w, y + h])
|
||||
|
||||
def _merge_anno_proposals(self):
|
||||
assert self.roidbs
|
||||
if self.proposals and len(self.proposals.keys()) > 0:
|
||||
logger.info("merge proposals to annos")
|
||||
for id, record in enumerate(self.roidbs):
|
||||
image_id = int(record["im_id"])
|
||||
if image_id not in self.proposals.keys():
|
||||
logger.info("image id :{} no proposals".format(image_id))
|
||||
record["proposals"] = np.array(self.proposals.get(image_id, []), dtype=np.float32)
|
||||
self.roidbs[id] = record
|
||||
|
||||
def get_ori_roidbs(self):
|
||||
if not hasattr(self, "ori_roidbs"):
|
||||
return None
|
||||
return self.ori_roidbs
|
||||
|
||||
def get_roidbs(self):
|
||||
if not hasattr(self, "roidbs"):
|
||||
self.parse_dataset()
|
||||
return self.roidbs
|
||||
|
||||
def set_roidbs(self, roidbs):
|
||||
self.roidbs = roidbs
|
||||
|
||||
def check_or_download_dataset(self):
|
||||
return
|
||||
|
||||
def _parse(self):
|
||||
image_dir = self.image_dir
|
||||
if not isinstance(image_dir, Sequence):
|
||||
image_dir = [image_dir]
|
||||
images = []
|
||||
for im_dir in image_dir:
|
||||
if os.path.isdir(im_dir):
|
||||
im_dir = os.path.join(self.dataset_dir, im_dir)
|
||||
images.extend(_make_dataset(im_dir))
|
||||
elif os.path.isfile(im_dir) and _is_valid_file(im_dir):
|
||||
images.append(im_dir)
|
||||
return images
|
||||
|
||||
def _load_images(self):
|
||||
images = self._parse()
|
||||
ct = 0
|
||||
records = []
|
||||
for image in images:
|
||||
assert image != '' and os.path.isfile(image), \
|
||||
"Image {} not found".format(image)
|
||||
if self.sample_num > 0 and ct >= self.sample_num:
|
||||
break
|
||||
im = cv2.imread(image)
|
||||
h, w, c = im.shape
|
||||
rec = {'im_id': np.array([ct]), 'im_file': image, "h": h, "w": w}
|
||||
self._imid2path[ct] = image
|
||||
ct += 1
|
||||
records.append(rec)
|
||||
assert len(records) > 0, "No image file found"
|
||||
return records
|
||||
|
||||
def get_imid2path(self):
|
||||
return self._imid2path
|
||||
|
||||
def set_images(self, images):
|
||||
self._imid2path = {}
|
||||
self.image_dir = images
|
||||
self.roidbs = self._load_images()
|
||||
|
||||
234
paddle_detection/ppdet/data/source/voc.py
Normal file
234
paddle_detection/ppdet/data/source/voc.py
Normal file
@@ -0,0 +1,234 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
from ppdet.core.workspace import register, serializable
|
||||
|
||||
from .dataset import DetDataset
|
||||
|
||||
from ppdet.utils.logger import setup_logger
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
|
||||
@register
|
||||
@serializable
|
||||
class VOCDataSet(DetDataset):
|
||||
"""
|
||||
Load dataset with PascalVOC format.
|
||||
|
||||
Notes:
|
||||
`anno_path` must contains xml file and image file path for annotations.
|
||||
|
||||
Args:
|
||||
dataset_dir (str): root directory for dataset.
|
||||
image_dir (str): directory for images.
|
||||
anno_path (str): voc annotation file path.
|
||||
data_fields (list): key name of data dictionary, at least have 'image'.
|
||||
sample_num (int): number of samples to load, -1 means all.
|
||||
label_list (str): if use_default_label is False, will load
|
||||
mapping between category and class index.
|
||||
allow_empty (bool): whether to load empty entry. False as default
|
||||
empty_ratio (float): the ratio of empty record number to total
|
||||
record's, if empty_ratio is out of [0. ,1.), do not sample the
|
||||
records and use all the empty entries. 1. as default
|
||||
repeat (int): repeat times for dataset, use in benchmark.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dataset_dir=None,
|
||||
image_dir=None,
|
||||
anno_path=None,
|
||||
data_fields=['image'],
|
||||
sample_num=-1,
|
||||
label_list=None,
|
||||
allow_empty=False,
|
||||
empty_ratio=1.,
|
||||
repeat=1):
|
||||
super(VOCDataSet, self).__init__(
|
||||
dataset_dir=dataset_dir,
|
||||
image_dir=image_dir,
|
||||
anno_path=anno_path,
|
||||
data_fields=data_fields,
|
||||
sample_num=sample_num,
|
||||
repeat=repeat)
|
||||
self.label_list = label_list
|
||||
self.allow_empty = allow_empty
|
||||
self.empty_ratio = empty_ratio
|
||||
|
||||
def _sample_empty(self, records, num):
|
||||
# if empty_ratio is out of [0. ,1.), do not sample the records
|
||||
if self.empty_ratio < 0. or self.empty_ratio >= 1.:
|
||||
return records
|
||||
import random
|
||||
sample_num = min(
|
||||
int(num * self.empty_ratio / (1 - self.empty_ratio)), len(records))
|
||||
records = random.sample(records, sample_num)
|
||||
return records
|
||||
|
||||
def parse_dataset(self, ):
|
||||
anno_path = os.path.join(self.dataset_dir, self.anno_path)
|
||||
image_dir = os.path.join(self.dataset_dir, self.image_dir)
|
||||
|
||||
# mapping category name to class id
|
||||
# first_class:0, second_class:1, ...
|
||||
records = []
|
||||
empty_records = []
|
||||
ct = 0
|
||||
cname2cid = {}
|
||||
if self.label_list:
|
||||
label_path = os.path.join(self.dataset_dir, self.label_list)
|
||||
if not os.path.exists(label_path):
|
||||
raise ValueError("label_list {} does not exists".format(
|
||||
label_path))
|
||||
with open(label_path, 'r') as fr:
|
||||
label_id = 0
|
||||
for line in fr.readlines():
|
||||
cname2cid[line.strip()] = label_id
|
||||
label_id += 1
|
||||
else:
|
||||
cname2cid = pascalvoc_label()
|
||||
|
||||
with open(anno_path, 'r') as fr:
|
||||
while True:
|
||||
line = fr.readline()
|
||||
if not line:
|
||||
break
|
||||
img_file, xml_file = [os.path.join(image_dir, x) \
|
||||
for x in line.strip().split()[:2]]
|
||||
if not os.path.exists(img_file):
|
||||
logger.warning(
|
||||
'Illegal image file: {}, and it will be ignored'.format(
|
||||
img_file))
|
||||
continue
|
||||
if not os.path.isfile(xml_file):
|
||||
logger.warning(
|
||||
'Illegal xml file: {}, and it will be ignored'.format(
|
||||
xml_file))
|
||||
continue
|
||||
tree = ET.parse(xml_file)
|
||||
if tree.find('id') is None:
|
||||
im_id = np.array([ct])
|
||||
else:
|
||||
im_id = np.array([int(tree.find('id').text)])
|
||||
|
||||
objs = tree.findall('object')
|
||||
im_w = float(tree.find('size').find('width').text)
|
||||
im_h = float(tree.find('size').find('height').text)
|
||||
if im_w < 0 or im_h < 0:
|
||||
logger.warning(
|
||||
'Illegal width: {} or height: {} in annotation, '
|
||||
'and {} will be ignored'.format(im_w, im_h, xml_file))
|
||||
continue
|
||||
|
||||
num_bbox, i = len(objs), 0
|
||||
gt_bbox = np.zeros((num_bbox, 4), dtype=np.float32)
|
||||
gt_class = np.zeros((num_bbox, 1), dtype=np.int32)
|
||||
gt_score = np.zeros((num_bbox, 1), dtype=np.float32)
|
||||
difficult = np.zeros((num_bbox, 1), dtype=np.int32)
|
||||
for obj in objs:
|
||||
cname = obj.find('name').text
|
||||
|
||||
# user dataset may not contain difficult field
|
||||
_difficult = obj.find('difficult')
|
||||
_difficult = int(
|
||||
_difficult.text) if _difficult is not None else 0
|
||||
|
||||
x1 = float(obj.find('bndbox').find('xmin').text)
|
||||
y1 = float(obj.find('bndbox').find('ymin').text)
|
||||
x2 = float(obj.find('bndbox').find('xmax').text)
|
||||
y2 = float(obj.find('bndbox').find('ymax').text)
|
||||
x1 = max(0, x1)
|
||||
y1 = max(0, y1)
|
||||
x2 = min(im_w - 1, x2)
|
||||
y2 = min(im_h - 1, y2)
|
||||
if x2 > x1 and y2 > y1:
|
||||
gt_bbox[i, :] = [x1, y1, x2, y2]
|
||||
gt_class[i, 0] = cname2cid[cname]
|
||||
gt_score[i, 0] = 1.
|
||||
difficult[i, 0] = _difficult
|
||||
i += 1
|
||||
else:
|
||||
logger.warning(
|
||||
'Found an invalid bbox in annotations: xml_file: {}'
|
||||
', x1: {}, y1: {}, x2: {}, y2: {}.'.format(
|
||||
xml_file, x1, y1, x2, y2))
|
||||
gt_bbox = gt_bbox[:i, :]
|
||||
gt_class = gt_class[:i, :]
|
||||
gt_score = gt_score[:i, :]
|
||||
difficult = difficult[:i, :]
|
||||
|
||||
voc_rec = {
|
||||
'im_file': img_file,
|
||||
'im_id': im_id,
|
||||
'h': im_h,
|
||||
'w': im_w
|
||||
} if 'image' in self.data_fields else {}
|
||||
|
||||
gt_rec = {
|
||||
'gt_class': gt_class,
|
||||
'gt_score': gt_score,
|
||||
'gt_bbox': gt_bbox,
|
||||
'difficult': difficult
|
||||
}
|
||||
for k, v in gt_rec.items():
|
||||
if k in self.data_fields:
|
||||
voc_rec[k] = v
|
||||
|
||||
if len(objs) == 0:
|
||||
empty_records.append(voc_rec)
|
||||
else:
|
||||
records.append(voc_rec)
|
||||
|
||||
ct += 1
|
||||
if self.sample_num > 0 and ct >= self.sample_num:
|
||||
break
|
||||
assert ct > 0, 'not found any voc record in %s' % (self.anno_path)
|
||||
logger.debug('{} samples in file {}'.format(ct, anno_path))
|
||||
if self.allow_empty and len(empty_records) > 0:
|
||||
empty_records = self._sample_empty(empty_records, len(records))
|
||||
records += empty_records
|
||||
self.roidbs, self.cname2cid = records, cname2cid
|
||||
|
||||
def get_label_list(self):
|
||||
return os.path.join(self.dataset_dir, self.label_list)
|
||||
|
||||
|
||||
def pascalvoc_label():
|
||||
labels_map = {
|
||||
'aeroplane': 0,
|
||||
'bicycle': 1,
|
||||
'bird': 2,
|
||||
'boat': 3,
|
||||
'bottle': 4,
|
||||
'bus': 5,
|
||||
'car': 6,
|
||||
'cat': 7,
|
||||
'chair': 8,
|
||||
'cow': 9,
|
||||
'diningtable': 10,
|
||||
'dog': 11,
|
||||
'horse': 12,
|
||||
'motorbike': 13,
|
||||
'person': 14,
|
||||
'pottedplant': 15,
|
||||
'sheep': 16,
|
||||
'sofa': 17,
|
||||
'train': 18,
|
||||
'tvmonitor': 19
|
||||
}
|
||||
return labels_map
|
||||
180
paddle_detection/ppdet/data/source/widerface.py
Normal file
180
paddle_detection/ppdet/data/source/widerface.py
Normal file
@@ -0,0 +1,180 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
from ppdet.core.workspace import register, serializable
|
||||
from .dataset import DetDataset
|
||||
|
||||
from ppdet.utils.logger import setup_logger
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
|
||||
@register
|
||||
@serializable
|
||||
class WIDERFaceDataSet(DetDataset):
|
||||
"""
|
||||
Load WiderFace records with 'anno_path'
|
||||
|
||||
Args:
|
||||
dataset_dir (str): root directory for dataset.
|
||||
image_dir (str): directory for images.
|
||||
anno_path (str): WiderFace annotation data.
|
||||
data_fields (list): key name of data dictionary, at least have 'image'.
|
||||
sample_num (int): number of samples to load, -1 means all.
|
||||
with_lmk (bool): whether to load face landmark keypoint labels.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dataset_dir=None,
|
||||
image_dir=None,
|
||||
anno_path=None,
|
||||
data_fields=['image'],
|
||||
sample_num=-1,
|
||||
with_lmk=False):
|
||||
super(WIDERFaceDataSet, self).__init__(
|
||||
dataset_dir=dataset_dir,
|
||||
image_dir=image_dir,
|
||||
anno_path=anno_path,
|
||||
data_fields=data_fields,
|
||||
sample_num=sample_num,
|
||||
with_lmk=with_lmk)
|
||||
self.anno_path = anno_path
|
||||
self.sample_num = sample_num
|
||||
self.roidbs = None
|
||||
self.cname2cid = None
|
||||
self.with_lmk = with_lmk
|
||||
|
||||
def parse_dataset(self):
|
||||
anno_path = os.path.join(self.dataset_dir, self.anno_path)
|
||||
image_dir = os.path.join(self.dataset_dir, self.image_dir)
|
||||
|
||||
txt_file = anno_path
|
||||
|
||||
records = []
|
||||
ct = 0
|
||||
file_lists = self._load_file_list(txt_file)
|
||||
cname2cid = widerface_label()
|
||||
|
||||
for item in file_lists:
|
||||
im_fname = item[0]
|
||||
im_id = np.array([ct])
|
||||
gt_bbox = np.zeros((len(item) - 1, 4), dtype=np.float32)
|
||||
gt_class = np.zeros((len(item) - 1, 1), dtype=np.int32)
|
||||
gt_lmk_labels = np.zeros((len(item) - 1, 10), dtype=np.float32)
|
||||
lmk_ignore_flag = np.zeros((len(item) - 1, 1), dtype=np.int32)
|
||||
for index_box in range(len(item)):
|
||||
if index_box < 1:
|
||||
continue
|
||||
gt_bbox[index_box - 1] = item[index_box][0]
|
||||
if self.with_lmk:
|
||||
gt_lmk_labels[index_box - 1] = item[index_box][1]
|
||||
lmk_ignore_flag[index_box - 1] = item[index_box][2]
|
||||
im_fname = os.path.join(image_dir,
|
||||
im_fname) if image_dir else im_fname
|
||||
widerface_rec = {
|
||||
'im_file': im_fname,
|
||||
'im_id': im_id,
|
||||
} if 'image' in self.data_fields else {}
|
||||
gt_rec = {
|
||||
'gt_bbox': gt_bbox,
|
||||
'gt_class': gt_class,
|
||||
}
|
||||
for k, v in gt_rec.items():
|
||||
if k in self.data_fields:
|
||||
widerface_rec[k] = v
|
||||
if self.with_lmk:
|
||||
widerface_rec['gt_keypoint'] = gt_lmk_labels
|
||||
widerface_rec['keypoint_ignore'] = lmk_ignore_flag
|
||||
|
||||
if len(item) != 0:
|
||||
records.append(widerface_rec)
|
||||
|
||||
ct += 1
|
||||
if self.sample_num > 0 and ct >= self.sample_num:
|
||||
break
|
||||
assert len(records) > 0, 'not found any widerface in %s' % (anno_path)
|
||||
logger.debug('{} samples in file {}'.format(ct, anno_path))
|
||||
self.roidbs, self.cname2cid = records, cname2cid
|
||||
|
||||
def _load_file_list(self, input_txt):
|
||||
with open(input_txt, 'r') as f_dir:
|
||||
lines_input_txt = f_dir.readlines()
|
||||
|
||||
file_dict = {}
|
||||
num_class = 0
|
||||
exts = ['jpg', 'jpeg', 'png', 'bmp']
|
||||
exts += [ext.upper() for ext in exts]
|
||||
for i in range(len(lines_input_txt)):
|
||||
line_txt = lines_input_txt[i].strip('\n\t\r')
|
||||
split_str = line_txt.split(' ')
|
||||
if len(split_str) == 1:
|
||||
img_file_name = os.path.split(split_str[0])[1]
|
||||
split_txt = img_file_name.split('.')
|
||||
if len(split_txt) < 2:
|
||||
continue
|
||||
elif split_txt[-1] in exts:
|
||||
if i != 0:
|
||||
num_class += 1
|
||||
file_dict[num_class] = [line_txt]
|
||||
else:
|
||||
if len(line_txt) <= 6:
|
||||
continue
|
||||
result_boxs = []
|
||||
xmin = float(split_str[0])
|
||||
ymin = float(split_str[1])
|
||||
w = float(split_str[2])
|
||||
h = float(split_str[3])
|
||||
# Filter out wrong labels
|
||||
if w < 0 or h < 0:
|
||||
logger.warning('Illegal box with w: {}, h: {} in '
|
||||
'img: {}, and it will be ignored'.format(
|
||||
w, h, file_dict[num_class][0]))
|
||||
continue
|
||||
xmin = max(0, xmin)
|
||||
ymin = max(0, ymin)
|
||||
xmax = xmin + w
|
||||
ymax = ymin + h
|
||||
gt_bbox = [xmin, ymin, xmax, ymax]
|
||||
result_boxs.append(gt_bbox)
|
||||
if self.with_lmk:
|
||||
assert len(split_str) > 18, 'When `with_lmk=True`, the number' \
|
||||
'of characters per line in the annotation file should' \
|
||||
'exceed 18.'
|
||||
lmk0_x = float(split_str[5])
|
||||
lmk0_y = float(split_str[6])
|
||||
lmk1_x = float(split_str[8])
|
||||
lmk1_y = float(split_str[9])
|
||||
lmk2_x = float(split_str[11])
|
||||
lmk2_y = float(split_str[12])
|
||||
lmk3_x = float(split_str[14])
|
||||
lmk3_y = float(split_str[15])
|
||||
lmk4_x = float(split_str[17])
|
||||
lmk4_y = float(split_str[18])
|
||||
lmk_ignore_flag = 0 if lmk0_x == -1 else 1
|
||||
gt_lmk_label = [
|
||||
lmk0_x, lmk0_y, lmk1_x, lmk1_y, lmk2_x, lmk2_y, lmk3_x,
|
||||
lmk3_y, lmk4_x, lmk4_y
|
||||
]
|
||||
result_boxs.append(gt_lmk_label)
|
||||
result_boxs.append(lmk_ignore_flag)
|
||||
file_dict[num_class].append(result_boxs)
|
||||
|
||||
return list(file_dict.values())
|
||||
|
||||
|
||||
def widerface_label():
|
||||
labels_map = {'face': 0}
|
||||
return labels_map
|
||||
Reference in New Issue
Block a user