更换文档检测模型

This commit is contained in:
2024-08-27 14:42:45 +08:00
parent aea6f19951
commit 1514e09c40
2072 changed files with 254336 additions and 4967 deletions

View File

@@ -0,0 +1,33 @@
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import coco
from . import voc
from . import widerface
from . import category
from . import keypoint_coco
from . import mot
from . import sniper_coco
from . import culane
from .coco import *
from .voc import *
from .widerface import *
from .category import *
from .keypoint_coco import *
from .mot import *
from .sniper_coco import SniperCOCODataSet
from .dataset import ImageFolder
from .pose3d_cmb import *
from .culane import *

View File

@@ -0,0 +1,942 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
from ppdet.data.source.voc import pascalvoc_label
from ppdet.data.source.widerface import widerface_label
from ppdet.utils.logger import setup_logger
logger = setup_logger(__name__)
__all__ = ['get_categories']
def get_categories(metric_type, anno_file=None, arch=None):
"""
Get class id to category id map and category id
to category name map from annotation file.
Args:
metric_type (str): metric type, currently support 'coco', 'voc', 'oid'
and 'widerface'.
anno_file (str): annotation file path
"""
if arch == 'keypoint_arch':
return (None, {'id': 'keypoint'})
if anno_file == None or (not os.path.isfile(anno_file)):
logger.warning(
"anno_file '{}' is None or not set or not exist, "
"please recheck TrainDataset/EvalDataset/TestDataset.anno_path, "
"otherwise the default categories will be used by metric_type.".
format(anno_file))
if metric_type.lower() == 'coco' or metric_type.lower(
) == 'rbox' or metric_type.lower() == 'snipercoco':
if anno_file and os.path.isfile(anno_file):
if anno_file.endswith('json'):
# lazy import pycocotools here
from pycocotools.coco import COCO
coco = COCO(anno_file)
cats = coco.loadCats(coco.getCatIds())
clsid2catid = {i: cat['id'] for i, cat in enumerate(cats)}
catid2name = {cat['id']: cat['name'] for cat in cats}
elif anno_file.endswith('txt'):
cats = []
with open(anno_file) as f:
for line in f.readlines():
cats.append(line.strip())
if cats[0] == 'background': cats = cats[1:]
clsid2catid = {i: i for i in range(len(cats))}
catid2name = {i: name for i, name in enumerate(cats)}
else:
raise ValueError("anno_file {} should be json or txt.".format(
anno_file))
return clsid2catid, catid2name
# anno file not exist, load default categories of COCO17
else:
if metric_type.lower() == 'rbox':
logger.warning(
"metric_type: {}, load default categories of DOTA.".format(
metric_type))
return _dota_category()
logger.warning("metric_type: {}, load default categories of COCO.".
format(metric_type))
return _coco17_category()
elif metric_type.lower() == 'voc':
if anno_file and os.path.isfile(anno_file):
cats = []
with open(anno_file) as f:
for line in f.readlines():
cats.append(line.strip())
if cats[0] == 'background':
cats = cats[1:]
clsid2catid = {i: i for i in range(len(cats))}
catid2name = {i: name for i, name in enumerate(cats)}
return clsid2catid, catid2name
# anno file not exist, load default categories of
# VOC all 20 categories
else:
logger.warning("metric_type: {}, load default categories of VOC.".
format(metric_type))
return _vocall_category()
elif metric_type.lower() == 'oid':
if anno_file and os.path.isfile(anno_file):
logger.warning("only default categories support for OID19")
return _oid19_category()
elif metric_type.lower() == 'widerface':
return _widerface_category()
elif metric_type.lower() in [
'keypointtopdowncocoeval', 'keypointtopdownmpiieval',
'keypointtopdowncocowholebadyhandeval'
]:
return (None, {'id': 'keypoint'})
elif metric_type.lower() == 'pose3deval':
return (None, {'id': 'pose3d'})
elif metric_type.lower() in ['mot', 'motdet', 'reid']:
if anno_file and os.path.isfile(anno_file):
cats = []
with open(anno_file) as f:
for line in f.readlines():
cats.append(line.strip())
if cats[0] == 'background':
cats = cats[1:]
clsid2catid = {i: i for i in range(len(cats))}
catid2name = {i: name for i, name in enumerate(cats)}
return clsid2catid, catid2name
# anno file not exist, load default category 'pedestrian'.
else:
logger.warning(
"metric_type: {}, load default categories of pedestrian MOT.".
format(metric_type))
return _mot_category(category='pedestrian')
elif metric_type.lower() in ['kitti', 'bdd100kmot']:
return _mot_category(category='vehicle')
elif metric_type.lower() in ['mcmot']:
if anno_file and os.path.isfile(anno_file):
cats = []
with open(anno_file) as f:
for line in f.readlines():
cats.append(line.strip())
if cats[0] == 'background':
cats = cats[1:]
clsid2catid = {i: i for i in range(len(cats))}
catid2name = {i: name for i, name in enumerate(cats)}
return clsid2catid, catid2name
# anno file not exist, load default categories of visdrone all 10 categories
else:
logger.warning(
"metric_type: {}, load default categories of VisDrone.".format(
metric_type))
return _visdrone_category()
else:
raise ValueError("unknown metric type {}".format(metric_type))
def _mot_category(category='pedestrian'):
"""
Get class id to category id map and category id
to category name map of mot dataset
"""
label_map = {category: 0}
label_map = sorted(label_map.items(), key=lambda x: x[1])
cats = [l[0] for l in label_map]
clsid2catid = {i: i for i in range(len(cats))}
catid2name = {i: name for i, name in enumerate(cats)}
return clsid2catid, catid2name
def _coco17_category():
"""
Get class id to category id map and category id
to category name map of COCO2017 dataset
"""
clsid2catid = {
1: 1,
2: 2,
3: 3,
4: 4,
5: 5,
6: 6,
7: 7,
8: 8,
9: 9,
10: 10,
11: 11,
12: 13,
13: 14,
14: 15,
15: 16,
16: 17,
17: 18,
18: 19,
19: 20,
20: 21,
21: 22,
22: 23,
23: 24,
24: 25,
25: 27,
26: 28,
27: 31,
28: 32,
29: 33,
30: 34,
31: 35,
32: 36,
33: 37,
34: 38,
35: 39,
36: 40,
37: 41,
38: 42,
39: 43,
40: 44,
41: 46,
42: 47,
43: 48,
44: 49,
45: 50,
46: 51,
47: 52,
48: 53,
49: 54,
50: 55,
51: 56,
52: 57,
53: 58,
54: 59,
55: 60,
56: 61,
57: 62,
58: 63,
59: 64,
60: 65,
61: 67,
62: 70,
63: 72,
64: 73,
65: 74,
66: 75,
67: 76,
68: 77,
69: 78,
70: 79,
71: 80,
72: 81,
73: 82,
74: 84,
75: 85,
76: 86,
77: 87,
78: 88,
79: 89,
80: 90
}
catid2name = {
0: 'background',
1: 'person',
2: 'bicycle',
3: 'car',
4: 'motorcycle',
5: 'airplane',
6: 'bus',
7: 'train',
8: 'truck',
9: 'boat',
10: 'traffic light',
11: 'fire hydrant',
13: 'stop sign',
14: 'parking meter',
15: 'bench',
16: 'bird',
17: 'cat',
18: 'dog',
19: 'horse',
20: 'sheep',
21: 'cow',
22: 'elephant',
23: 'bear',
24: 'zebra',
25: 'giraffe',
27: 'backpack',
28: 'umbrella',
31: 'handbag',
32: 'tie',
33: 'suitcase',
34: 'frisbee',
35: 'skis',
36: 'snowboard',
37: 'sports ball',
38: 'kite',
39: 'baseball bat',
40: 'baseball glove',
41: 'skateboard',
42: 'surfboard',
43: 'tennis racket',
44: 'bottle',
46: 'wine glass',
47: 'cup',
48: 'fork',
49: 'knife',
50: 'spoon',
51: 'bowl',
52: 'banana',
53: 'apple',
54: 'sandwich',
55: 'orange',
56: 'broccoli',
57: 'carrot',
58: 'hot dog',
59: 'pizza',
60: 'donut',
61: 'cake',
62: 'chair',
63: 'couch',
64: 'potted plant',
65: 'bed',
67: 'dining table',
70: 'toilet',
72: 'tv',
73: 'laptop',
74: 'mouse',
75: 'remote',
76: 'keyboard',
77: 'cell phone',
78: 'microwave',
79: 'oven',
80: 'toaster',
81: 'sink',
82: 'refrigerator',
84: 'book',
85: 'clock',
86: 'vase',
87: 'scissors',
88: 'teddy bear',
89: 'hair drier',
90: 'toothbrush'
}
clsid2catid = {k - 1: v for k, v in clsid2catid.items()}
catid2name.pop(0)
return clsid2catid, catid2name
def _dota_category():
"""
Get class id to category id map and category id
to category name map of dota dataset
"""
catid2name = {
0: 'background',
1: 'plane',
2: 'baseball-diamond',
3: 'bridge',
4: 'ground-track-field',
5: 'small-vehicle',
6: 'large-vehicle',
7: 'ship',
8: 'tennis-court',
9: 'basketball-court',
10: 'storage-tank',
11: 'soccer-ball-field',
12: 'roundabout',
13: 'harbor',
14: 'swimming-pool',
15: 'helicopter'
}
catid2name.pop(0)
clsid2catid = {i: i + 1 for i in range(len(catid2name))}
return clsid2catid, catid2name
def _vocall_category():
"""
Get class id to category id map and category id
to category name map of mixup voc dataset
"""
label_map = pascalvoc_label()
label_map = sorted(label_map.items(), key=lambda x: x[1])
cats = [l[0] for l in label_map]
clsid2catid = {i: i for i in range(len(cats))}
catid2name = {i: name for i, name in enumerate(cats)}
return clsid2catid, catid2name
def _widerface_category():
label_map = widerface_label()
label_map = sorted(label_map.items(), key=lambda x: x[1])
cats = [l[0] for l in label_map]
clsid2catid = {i: i for i in range(len(cats))}
catid2name = {i: name for i, name in enumerate(cats)}
return clsid2catid, catid2name
def _oid19_category():
clsid2catid = {k: k + 1 for k in range(500)}
catid2name = {
0: "background",
1: "Infant bed",
2: "Rose",
3: "Flag",
4: "Flashlight",
5: "Sea turtle",
6: "Camera",
7: "Animal",
8: "Glove",
9: "Crocodile",
10: "Cattle",
11: "House",
12: "Guacamole",
13: "Penguin",
14: "Vehicle registration plate",
15: "Bench",
16: "Ladybug",
17: "Human nose",
18: "Watermelon",
19: "Flute",
20: "Butterfly",
21: "Washing machine",
22: "Raccoon",
23: "Segway",
24: "Taco",
25: "Jellyfish",
26: "Cake",
27: "Pen",
28: "Cannon",
29: "Bread",
30: "Tree",
31: "Shellfish",
32: "Bed",
33: "Hamster",
34: "Hat",
35: "Toaster",
36: "Sombrero",
37: "Tiara",
38: "Bowl",
39: "Dragonfly",
40: "Moths and butterflies",
41: "Antelope",
42: "Vegetable",
43: "Torch",
44: "Building",
45: "Power plugs and sockets",
46: "Blender",
47: "Billiard table",
48: "Cutting board",
49: "Bronze sculpture",
50: "Turtle",
51: "Broccoli",
52: "Tiger",
53: "Mirror",
54: "Bear",
55: "Zucchini",
56: "Dress",
57: "Volleyball",
58: "Guitar",
59: "Reptile",
60: "Golf cart",
61: "Tart",
62: "Fedora",
63: "Carnivore",
64: "Car",
65: "Lighthouse",
66: "Coffeemaker",
67: "Food processor",
68: "Truck",
69: "Bookcase",
70: "Surfboard",
71: "Footwear",
72: "Bench",
73: "Necklace",
74: "Flower",
75: "Radish",
76: "Marine mammal",
77: "Frying pan",
78: "Tap",
79: "Peach",
80: "Knife",
81: "Handbag",
82: "Laptop",
83: "Tent",
84: "Ambulance",
85: "Christmas tree",
86: "Eagle",
87: "Limousine",
88: "Kitchen & dining room table",
89: "Polar bear",
90: "Tower",
91: "Football",
92: "Willow",
93: "Human head",
94: "Stop sign",
95: "Banana",
96: "Mixer",
97: "Binoculars",
98: "Dessert",
99: "Bee",
100: "Chair",
101: "Wood-burning stove",
102: "Flowerpot",
103: "Beaker",
104: "Oyster",
105: "Woodpecker",
106: "Harp",
107: "Bathtub",
108: "Wall clock",
109: "Sports uniform",
110: "Rhinoceros",
111: "Beehive",
112: "Cupboard",
113: "Chicken",
114: "Man",
115: "Blue jay",
116: "Cucumber",
117: "Balloon",
118: "Kite",
119: "Fireplace",
120: "Lantern",
121: "Missile",
122: "Book",
123: "Spoon",
124: "Grapefruit",
125: "Squirrel",
126: "Orange",
127: "Coat",
128: "Punching bag",
129: "Zebra",
130: "Billboard",
131: "Bicycle",
132: "Door handle",
133: "Mechanical fan",
134: "Ring binder",
135: "Table",
136: "Parrot",
137: "Sock",
138: "Vase",
139: "Weapon",
140: "Shotgun",
141: "Glasses",
142: "Seahorse",
143: "Belt",
144: "Watercraft",
145: "Window",
146: "Giraffe",
147: "Lion",
148: "Tire",
149: "Vehicle",
150: "Canoe",
151: "Tie",
152: "Shelf",
153: "Picture frame",
154: "Printer",
155: "Human leg",
156: "Boat",
157: "Slow cooker",
158: "Croissant",
159: "Candle",
160: "Pancake",
161: "Pillow",
162: "Coin",
163: "Stretcher",
164: "Sandal",
165: "Woman",
166: "Stairs",
167: "Harpsichord",
168: "Stool",
169: "Bus",
170: "Suitcase",
171: "Human mouth",
172: "Juice",
173: "Skull",
174: "Door",
175: "Violin",
176: "Chopsticks",
177: "Digital clock",
178: "Sunflower",
179: "Leopard",
180: "Bell pepper",
181: "Harbor seal",
182: "Snake",
183: "Sewing machine",
184: "Goose",
185: "Helicopter",
186: "Seat belt",
187: "Coffee cup",
188: "Microwave oven",
189: "Hot dog",
190: "Countertop",
191: "Serving tray",
192: "Dog bed",
193: "Beer",
194: "Sunglasses",
195: "Golf ball",
196: "Waffle",
197: "Palm tree",
198: "Trumpet",
199: "Ruler",
200: "Helmet",
201: "Ladder",
202: "Office building",
203: "Tablet computer",
204: "Toilet paper",
205: "Pomegranate",
206: "Skirt",
207: "Gas stove",
208: "Cookie",
209: "Cart",
210: "Raven",
211: "Egg",
212: "Burrito",
213: "Goat",
214: "Kitchen knife",
215: "Skateboard",
216: "Salt and pepper shakers",
217: "Lynx",
218: "Boot",
219: "Platter",
220: "Ski",
221: "Swimwear",
222: "Swimming pool",
223: "Drinking straw",
224: "Wrench",
225: "Drum",
226: "Ant",
227: "Human ear",
228: "Headphones",
229: "Fountain",
230: "Bird",
231: "Jeans",
232: "Television",
233: "Crab",
234: "Microphone",
235: "Home appliance",
236: "Snowplow",
237: "Beetle",
238: "Artichoke",
239: "Jet ski",
240: "Stationary bicycle",
241: "Human hair",
242: "Brown bear",
243: "Starfish",
244: "Fork",
245: "Lobster",
246: "Corded phone",
247: "Drink",
248: "Saucer",
249: "Carrot",
250: "Insect",
251: "Clock",
252: "Castle",
253: "Tennis racket",
254: "Ceiling fan",
255: "Asparagus",
256: "Jaguar",
257: "Musical instrument",
258: "Train",
259: "Cat",
260: "Rifle",
261: "Dumbbell",
262: "Mobile phone",
263: "Taxi",
264: "Shower",
265: "Pitcher",
266: "Lemon",
267: "Invertebrate",
268: "Turkey",
269: "High heels",
270: "Bust",
271: "Elephant",
272: "Scarf",
273: "Barrel",
274: "Trombone",
275: "Pumpkin",
276: "Box",
277: "Tomato",
278: "Frog",
279: "Bidet",
280: "Human face",
281: "Houseplant",
282: "Van",
283: "Shark",
284: "Ice cream",
285: "Swim cap",
286: "Falcon",
287: "Ostrich",
288: "Handgun",
289: "Whiteboard",
290: "Lizard",
291: "Pasta",
292: "Snowmobile",
293: "Light bulb",
294: "Window blind",
295: "Muffin",
296: "Pretzel",
297: "Computer monitor",
298: "Horn",
299: "Furniture",
300: "Sandwich",
301: "Fox",
302: "Convenience store",
303: "Fish",
304: "Fruit",
305: "Earrings",
306: "Curtain",
307: "Grape",
308: "Sofa bed",
309: "Horse",
310: "Luggage and bags",
311: "Desk",
312: "Crutch",
313: "Bicycle helmet",
314: "Tick",
315: "Airplane",
316: "Canary",
317: "Spatula",
318: "Watch",
319: "Lily",
320: "Kitchen appliance",
321: "Filing cabinet",
322: "Aircraft",
323: "Cake stand",
324: "Candy",
325: "Sink",
326: "Mouse",
327: "Wine",
328: "Wheelchair",
329: "Goldfish",
330: "Refrigerator",
331: "French fries",
332: "Drawer",
333: "Treadmill",
334: "Picnic basket",
335: "Dice",
336: "Cabbage",
337: "Football helmet",
338: "Pig",
339: "Person",
340: "Shorts",
341: "Gondola",
342: "Honeycomb",
343: "Doughnut",
344: "Chest of drawers",
345: "Land vehicle",
346: "Bat",
347: "Monkey",
348: "Dagger",
349: "Tableware",
350: "Human foot",
351: "Mug",
352: "Alarm clock",
353: "Pressure cooker",
354: "Human hand",
355: "Tortoise",
356: "Baseball glove",
357: "Sword",
358: "Pear",
359: "Miniskirt",
360: "Traffic sign",
361: "Girl",
362: "Roller skates",
363: "Dinosaur",
364: "Porch",
365: "Human beard",
366: "Submarine sandwich",
367: "Screwdriver",
368: "Strawberry",
369: "Wine glass",
370: "Seafood",
371: "Racket",
372: "Wheel",
373: "Sea lion",
374: "Toy",
375: "Tea",
376: "Tennis ball",
377: "Waste container",
378: "Mule",
379: "Cricket ball",
380: "Pineapple",
381: "Coconut",
382: "Doll",
383: "Coffee table",
384: "Snowman",
385: "Lavender",
386: "Shrimp",
387: "Maple",
388: "Cowboy hat",
389: "Goggles",
390: "Rugby ball",
391: "Caterpillar",
392: "Poster",
393: "Rocket",
394: "Organ",
395: "Saxophone",
396: "Traffic light",
397: "Cocktail",
398: "Plastic bag",
399: "Squash",
400: "Mushroom",
401: "Hamburger",
402: "Light switch",
403: "Parachute",
404: "Teddy bear",
405: "Winter melon",
406: "Deer",
407: "Musical keyboard",
408: "Plumbing fixture",
409: "Scoreboard",
410: "Baseball bat",
411: "Envelope",
412: "Adhesive tape",
413: "Briefcase",
414: "Paddle",
415: "Bow and arrow",
416: "Telephone",
417: "Sheep",
418: "Jacket",
419: "Boy",
420: "Pizza",
421: "Otter",
422: "Office supplies",
423: "Couch",
424: "Cello",
425: "Bull",
426: "Camel",
427: "Ball",
428: "Duck",
429: "Whale",
430: "Shirt",
431: "Tank",
432: "Motorcycle",
433: "Accordion",
434: "Owl",
435: "Porcupine",
436: "Sun hat",
437: "Nail",
438: "Scissors",
439: "Swan",
440: "Lamp",
441: "Crown",
442: "Piano",
443: "Sculpture",
444: "Cheetah",
445: "Oboe",
446: "Tin can",
447: "Mango",
448: "Tripod",
449: "Oven",
450: "Mouse",
451: "Barge",
452: "Coffee",
453: "Snowboard",
454: "Common fig",
455: "Salad",
456: "Marine invertebrates",
457: "Umbrella",
458: "Kangaroo",
459: "Human arm",
460: "Measuring cup",
461: "Snail",
462: "Loveseat",
463: "Suit",
464: "Teapot",
465: "Bottle",
466: "Alpaca",
467: "Kettle",
468: "Trousers",
469: "Popcorn",
470: "Centipede",
471: "Spider",
472: "Sparrow",
473: "Plate",
474: "Bagel",
475: "Personal care",
476: "Apple",
477: "Brassiere",
478: "Bathroom cabinet",
479: "studio couch",
480: "Computer keyboard",
481: "Table tennis racket",
482: "Sushi",
483: "Cabinetry",
484: "Street light",
485: "Towel",
486: "Nightstand",
487: "Rabbit",
488: "Dolphin",
489: "Dog",
490: "Jug",
491: "Wok",
492: "Fire hydrant",
493: "Human eye",
494: "Skyscraper",
495: "Backpack",
496: "Potato",
497: "Paper towel",
498: "Lifejacket",
499: "Bicycle wheel",
500: "Toilet",
}
return clsid2catid, catid2name
def _visdrone_category():
clsid2catid = {i: i for i in range(10)}
catid2name = {
0: 'pedestrian',
1: 'people',
2: 'bicycle',
3: 'car',
4: 'van',
5: 'truck',
6: 'tricycle',
7: 'awning-tricycle',
8: 'bus',
9: 'motor'
}
return clsid2catid, catid2name

View File

@@ -0,0 +1,596 @@
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import copy
try:
from collections.abc import Sequence
except Exception:
from collections import Sequence
import numpy as np
from ppdet.core.workspace import register, serializable
from .dataset import DetDataset
from ppdet.utils.logger import setup_logger
logger = setup_logger(__name__)
__all__ = [
'COCODataSet', 'SlicedCOCODataSet', 'SemiCOCODataSet', 'COCODetDataset'
]
@register
@serializable
class COCODataSet(DetDataset):
"""
Load dataset with COCO format.
Args:
dataset_dir (str): root directory for dataset.
image_dir (str): directory for images.
anno_path (str): coco annotation file path.
data_fields (list): key name of data dictionary, at least have 'image'.
sample_num (int): number of samples to load, -1 means all.
load_crowd (bool): whether to load crowded ground-truth.
False as default
allow_empty (bool): whether to load empty entry. False as default
empty_ratio (float): the ratio of empty record number to total
record's, if empty_ratio is out of [0. ,1.), do not sample the
records and use all the empty entries. 1. as default
repeat (int): repeat times for dataset, use in benchmark.
"""
def __init__(self,
dataset_dir=None,
image_dir=None,
anno_path=None,
data_fields=['image'],
sample_num=-1,
load_crowd=False,
allow_empty=False,
empty_ratio=1.,
repeat=1):
super(COCODataSet, self).__init__(
dataset_dir,
image_dir,
anno_path,
data_fields,
sample_num,
repeat=repeat)
self.load_image_only = False
self.load_semantic = False
self.load_crowd = load_crowd
self.allow_empty = allow_empty
self.empty_ratio = empty_ratio
def _sample_empty(self, records, num):
# if empty_ratio is out of [0. ,1.), do not sample the records
if self.empty_ratio < 0. or self.empty_ratio >= 1.:
return records
import random
sample_num = min(
int(num * self.empty_ratio / (1 - self.empty_ratio)), len(records))
records = random.sample(records, sample_num)
return records
def parse_dataset(self):
anno_path = os.path.join(self.dataset_dir, self.anno_path)
image_dir = os.path.join(self.dataset_dir, self.image_dir)
assert anno_path.endswith('.json'), \
'invalid coco annotation file: ' + anno_path
from pycocotools.coco import COCO
coco = COCO(anno_path)
img_ids = coco.getImgIds()
img_ids.sort()
cat_ids = coco.getCatIds()
records = []
empty_records = []
ct = 0
self.catid2clsid = dict({catid: i for i, catid in enumerate(cat_ids)})
self.cname2cid = dict({
coco.loadCats(catid)[0]['name']: clsid
for catid, clsid in self.catid2clsid.items()
})
if 'annotations' not in coco.dataset:
self.load_image_only = True
logger.warning('Annotation file: {} does not contains ground truth '
'and load image information only.'.format(anno_path))
for img_id in img_ids:
img_anno = coco.loadImgs([img_id])[0]
im_fname = img_anno['file_name']
im_w = float(img_anno['width'])
im_h = float(img_anno['height'])
im_path = os.path.join(image_dir,
im_fname) if image_dir else im_fname
is_empty = False
if not os.path.exists(im_path):
logger.warning('Illegal image file: {}, and it will be '
'ignored'.format(im_path))
continue
if im_w < 0 or im_h < 0:
logger.warning('Illegal width: {} or height: {} in annotation, '
'and im_id: {} will be ignored'.format(
im_w, im_h, img_id))
continue
coco_rec = {
'im_file': im_path,
'im_id': np.array([img_id]),
'h': im_h,
'w': im_w,
} if 'image' in self.data_fields else {}
if not self.load_image_only:
ins_anno_ids = coco.getAnnIds(
imgIds=[img_id], iscrowd=None if self.load_crowd else False)
instances = coco.loadAnns(ins_anno_ids)
bboxes = []
is_rbox_anno = False
for inst in instances:
# check gt bbox
if inst.get('ignore', False):
continue
if 'bbox' not in inst.keys():
continue
else:
if not any(np.array(inst['bbox'])):
continue
x1, y1, box_w, box_h = inst['bbox']
x2 = x1 + box_w
y2 = y1 + box_h
eps = 1e-5
if inst['area'] > 0 and x2 - x1 > eps and y2 - y1 > eps:
inst['clean_bbox'] = [
round(float(x), 3) for x in [x1, y1, x2, y2]
]
bboxes.append(inst)
else:
logger.warning(
'Found an invalid bbox in annotations: im_id: {}, '
'area: {} x1: {}, y1: {}, x2: {}, y2: {}.'.format(
img_id, float(inst['area']), x1, y1, x2, y2))
num_bbox = len(bboxes)
if num_bbox <= 0 and not self.allow_empty:
continue
elif num_bbox <= 0:
is_empty = True
gt_bbox = np.zeros((num_bbox, 4), dtype=np.float32)
gt_class = np.zeros((num_bbox, 1), dtype=np.int32)
is_crowd = np.zeros((num_bbox, 1), dtype=np.int32)
gt_poly = [None] * num_bbox
gt_track_id = -np.ones((num_bbox, 1), dtype=np.int32)
has_segmentation = False
has_track_id = False
for i, box in enumerate(bboxes):
catid = box['category_id']
gt_class[i][0] = self.catid2clsid[catid]
gt_bbox[i, :] = box['clean_bbox']
is_crowd[i][0] = box['iscrowd']
# check RLE format
if 'segmentation' in box and box['iscrowd'] == 1:
gt_poly[i] = [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]
elif 'segmentation' in box and box['segmentation']:
if not np.array(
box['segmentation'],
dtype=object).size > 0 and not self.allow_empty:
bboxes.pop(i)
gt_poly.pop(i)
np.delete(is_crowd, i)
np.delete(gt_class, i)
np.delete(gt_bbox, i)
else:
gt_poly[i] = box['segmentation']
has_segmentation = True
if 'track_id' in box:
gt_track_id[i][0] = box['track_id']
has_track_id = True
if has_segmentation and not any(
gt_poly) and not self.allow_empty:
continue
gt_rec = {
'is_crowd': is_crowd,
'gt_class': gt_class,
'gt_bbox': gt_bbox,
'gt_poly': gt_poly,
}
if has_track_id:
gt_rec.update({'gt_track_id': gt_track_id})
for k, v in gt_rec.items():
if k in self.data_fields:
coco_rec[k] = v
# TODO: remove load_semantic
if self.load_semantic and 'semantic' in self.data_fields:
seg_path = os.path.join(self.dataset_dir, 'stuffthingmaps',
'train2017', im_fname[:-3] + 'png')
coco_rec.update({'semantic': seg_path})
logger.debug('Load file: {}, im_id: {}, h: {}, w: {}.'.format(
im_path, img_id, im_h, im_w))
if is_empty:
empty_records.append(coco_rec)
else:
records.append(coco_rec)
ct += 1
if self.sample_num > 0 and ct >= self.sample_num:
break
assert ct > 0, 'not found any coco record in %s' % (anno_path)
logger.info('Load [{} samples valid, {} samples invalid] in file {}.'.
format(ct, len(img_ids) - ct, anno_path))
if self.allow_empty and len(empty_records) > 0:
empty_records = self._sample_empty(empty_records, len(records))
records += empty_records
self.roidbs = records
@register
@serializable
class SlicedCOCODataSet(COCODataSet):
"""Sliced COCODataSet"""
def __init__(
self,
dataset_dir=None,
image_dir=None,
anno_path=None,
data_fields=['image'],
sample_num=-1,
load_crowd=False,
allow_empty=False,
empty_ratio=1.,
repeat=1,
sliced_size=[640, 640],
overlap_ratio=[0.25, 0.25], ):
super(SlicedCOCODataSet, self).__init__(
dataset_dir=dataset_dir,
image_dir=image_dir,
anno_path=anno_path,
data_fields=data_fields,
sample_num=sample_num,
load_crowd=load_crowd,
allow_empty=allow_empty,
empty_ratio=empty_ratio,
repeat=repeat, )
self.sliced_size = sliced_size
self.overlap_ratio = overlap_ratio
def parse_dataset(self):
anno_path = os.path.join(self.dataset_dir, self.anno_path)
image_dir = os.path.join(self.dataset_dir, self.image_dir)
assert anno_path.endswith('.json'), \
'invalid coco annotation file: ' + anno_path
from pycocotools.coco import COCO
coco = COCO(anno_path)
img_ids = coco.getImgIds()
img_ids.sort()
cat_ids = coco.getCatIds()
records = []
empty_records = []
ct = 0
ct_sub = 0
self.catid2clsid = dict({catid: i for i, catid in enumerate(cat_ids)})
self.cname2cid = dict({
coco.loadCats(catid)[0]['name']: clsid
for catid, clsid in self.catid2clsid.items()
})
if 'annotations' not in coco.dataset:
self.load_image_only = True
logger.warning('Annotation file: {} does not contains ground truth '
'and load image information only.'.format(anno_path))
try:
import sahi
from sahi.slicing import slice_image
except Exception as e:
logger.error(
'sahi not found, plaese install sahi. '
'for example: `pip install sahi`, see https://github.com/obss/sahi.'
)
raise e
sub_img_ids = 0
for img_id in img_ids:
img_anno = coco.loadImgs([img_id])[0]
im_fname = img_anno['file_name']
im_w = float(img_anno['width'])
im_h = float(img_anno['height'])
im_path = os.path.join(image_dir,
im_fname) if image_dir else im_fname
is_empty = False
if not os.path.exists(im_path):
logger.warning('Illegal image file: {}, and it will be '
'ignored'.format(im_path))
continue
if im_w < 0 or im_h < 0:
logger.warning('Illegal width: {} or height: {} in annotation, '
'and im_id: {} will be ignored'.format(
im_w, im_h, img_id))
continue
slice_image_result = sahi.slicing.slice_image(
image=im_path,
slice_height=self.sliced_size[0],
slice_width=self.sliced_size[1],
overlap_height_ratio=self.overlap_ratio[0],
overlap_width_ratio=self.overlap_ratio[1])
sub_img_num = len(slice_image_result)
for _ind in range(sub_img_num):
im = slice_image_result.images[_ind]
coco_rec = {
'image': im,
'im_id': np.array([sub_img_ids + _ind]),
'h': im.shape[0],
'w': im.shape[1],
'ori_im_id': np.array([img_id]),
'st_pix': np.array(
slice_image_result.starting_pixels[_ind],
dtype=np.float32),
'is_last': 1 if _ind == sub_img_num - 1 else 0,
} if 'image' in self.data_fields else {}
records.append(coco_rec)
ct_sub += sub_img_num
ct += 1
if self.sample_num > 0 and ct >= self.sample_num:
break
assert ct > 0, 'not found any coco record in %s' % (anno_path)
logger.info('{} samples and slice to {} sub_samples in file {}'.format(
ct, ct_sub, anno_path))
if self.allow_empty and len(empty_records) > 0:
empty_records = self._sample_empty(empty_records, len(records))
records += empty_records
self.roidbs = records
@register
@serializable
class SemiCOCODataSet(COCODataSet):
"""Semi-COCODataSet used for supervised and unsupervised dataSet"""
def __init__(self,
dataset_dir=None,
image_dir=None,
anno_path=None,
data_fields=['image'],
sample_num=-1,
load_crowd=False,
allow_empty=False,
empty_ratio=1.,
repeat=1,
supervised=True):
super(SemiCOCODataSet, self).__init__(
dataset_dir, image_dir, anno_path, data_fields, sample_num,
load_crowd, allow_empty, empty_ratio, repeat)
self.supervised = supervised
self.length = -1 # defalut -1 means all
def parse_dataset(self):
anno_path = os.path.join(self.dataset_dir, self.anno_path)
image_dir = os.path.join(self.dataset_dir, self.image_dir)
assert anno_path.endswith('.json'), \
'invalid coco annotation file: ' + anno_path
from pycocotools.coco import COCO
coco = COCO(anno_path)
img_ids = coco.getImgIds()
img_ids.sort()
cat_ids = coco.getCatIds()
records = []
empty_records = []
ct = 0
self.catid2clsid = dict({catid: i for i, catid in enumerate(cat_ids)})
self.cname2cid = dict({
coco.loadCats(catid)[0]['name']: clsid
for catid, clsid in self.catid2clsid.items()
})
if 'annotations' not in coco.dataset or self.supervised == False:
self.load_image_only = True
logger.warning('Annotation file: {} does not contains ground truth '
'and load image information only.'.format(anno_path))
for img_id in img_ids:
img_anno = coco.loadImgs([img_id])[0]
im_fname = img_anno['file_name']
im_w = float(img_anno['width'])
im_h = float(img_anno['height'])
im_path = os.path.join(image_dir,
im_fname) if image_dir else im_fname
is_empty = False
if not os.path.exists(im_path):
logger.warning('Illegal image file: {}, and it will be '
'ignored'.format(im_path))
continue
if im_w < 0 or im_h < 0:
logger.warning('Illegal width: {} or height: {} in annotation, '
'and im_id: {} will be ignored'.format(
im_w, im_h, img_id))
continue
coco_rec = {
'im_file': im_path,
'im_id': np.array([img_id]),
'h': im_h,
'w': im_w,
} if 'image' in self.data_fields else {}
if not self.load_image_only:
ins_anno_ids = coco.getAnnIds(
imgIds=[img_id], iscrowd=None if self.load_crowd else False)
instances = coco.loadAnns(ins_anno_ids)
bboxes = []
is_rbox_anno = False
for inst in instances:
# check gt bbox
if inst.get('ignore', False):
continue
if 'bbox' not in inst.keys():
continue
else:
if not any(np.array(inst['bbox'])):
continue
x1, y1, box_w, box_h = inst['bbox']
x2 = x1 + box_w
y2 = y1 + box_h
eps = 1e-5
if inst['area'] > 0 and x2 - x1 > eps and y2 - y1 > eps:
inst['clean_bbox'] = [
round(float(x), 3) for x in [x1, y1, x2, y2]
]
bboxes.append(inst)
else:
logger.warning(
'Found an invalid bbox in annotations: im_id: {}, '
'area: {} x1: {}, y1: {}, x2: {}, y2: {}.'.format(
img_id, float(inst['area']), x1, y1, x2, y2))
num_bbox = len(bboxes)
if num_bbox <= 0 and not self.allow_empty:
continue
elif num_bbox <= 0:
is_empty = True
gt_bbox = np.zeros((num_bbox, 4), dtype=np.float32)
gt_class = np.zeros((num_bbox, 1), dtype=np.int32)
is_crowd = np.zeros((num_bbox, 1), dtype=np.int32)
gt_poly = [None] * num_bbox
has_segmentation = False
for i, box in enumerate(bboxes):
catid = box['category_id']
gt_class[i][0] = self.catid2clsid[catid]
gt_bbox[i, :] = box['clean_bbox']
is_crowd[i][0] = box['iscrowd']
# check RLE format
if 'segmentation' in box and box['iscrowd'] == 1:
gt_poly[i] = [[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]
elif 'segmentation' in box and box['segmentation']:
if not np.array(box['segmentation']
).size > 0 and not self.allow_empty:
bboxes.pop(i)
gt_poly.pop(i)
np.delete(is_crowd, i)
np.delete(gt_class, i)
np.delete(gt_bbox, i)
else:
gt_poly[i] = box['segmentation']
has_segmentation = True
if has_segmentation and not any(
gt_poly) and not self.allow_empty:
continue
gt_rec = {
'is_crowd': is_crowd,
'gt_class': gt_class,
'gt_bbox': gt_bbox,
'gt_poly': gt_poly,
}
for k, v in gt_rec.items():
if k in self.data_fields:
coco_rec[k] = v
# TODO: remove load_semantic
if self.load_semantic and 'semantic' in self.data_fields:
seg_path = os.path.join(self.dataset_dir, 'stuffthingmaps',
'train2017', im_fname[:-3] + 'png')
coco_rec.update({'semantic': seg_path})
logger.debug('Load file: {}, im_id: {}, h: {}, w: {}.'.format(
im_path, img_id, im_h, im_w))
if is_empty:
empty_records.append(coco_rec)
else:
records.append(coco_rec)
ct += 1
if self.sample_num > 0 and ct >= self.sample_num:
break
assert ct > 0, 'not found any coco record in %s' % (anno_path)
logger.info('Load [{} samples valid, {} samples invalid] in file {}.'.
format(ct, len(img_ids) - ct, anno_path))
if self.allow_empty and len(empty_records) > 0:
empty_records = self._sample_empty(empty_records, len(records))
records += empty_records
self.roidbs = records
if self.supervised:
logger.info(f'Use {len(self.roidbs)} sup_samples data as LABELED')
else:
if self.length > 0: # unsup length will be decide by sup length
all_roidbs = self.roidbs.copy()
selected_idxs = [
np.random.choice(len(all_roidbs))
for _ in range(self.length)
]
self.roidbs = [all_roidbs[i] for i in selected_idxs]
logger.info(
f'Use {len(self.roidbs)} unsup_samples data as UNLABELED')
def __getitem__(self, idx):
n = len(self.roidbs)
if self.repeat > 1:
idx %= n
# data batch
roidb = copy.deepcopy(self.roidbs[idx])
if self.mixup_epoch == 0 or self._epoch < self.mixup_epoch:
idx = np.random.randint(n)
roidb = [roidb, copy.deepcopy(self.roidbs[idx])]
elif self.cutmix_epoch == 0 or self._epoch < self.cutmix_epoch:
idx = np.random.randint(n)
roidb = [roidb, copy.deepcopy(self.roidbs[idx])]
elif self.mosaic_epoch == 0 or self._epoch < self.mosaic_epoch:
roidb = [roidb, ] + [
copy.deepcopy(self.roidbs[np.random.randint(n)])
for _ in range(4)
]
if isinstance(roidb, Sequence):
for r in roidb:
r['curr_iter'] = self._curr_iter
else:
roidb['curr_iter'] = self._curr_iter
self._curr_iter += 1
return self.transform(roidb)
# for PaddleX
@register
@serializable
class COCODetDataset(COCODataSet):
pass

View File

@@ -0,0 +1,206 @@
from ppdet.core.workspace import register, serializable
import cv2
import os
import tarfile
import numpy as np
import os.path as osp
from ppdet.data.source.dataset import DetDataset
from imgaug.augmentables.lines import LineStringsOnImage
from imgaug.augmentables.segmaps import SegmentationMapsOnImage
from ppdet.data.culane_utils import lane_to_linestrings
import pickle as pkl
from ppdet.utils.logger import setup_logger
try:
from collections.abc import Sequence
except Exception:
from collections import Sequence
from .dataset import DetDataset, _make_dataset, _is_valid_file
from ppdet.utils.download import download_dataset
logger = setup_logger(__name__)
@register
@serializable
class CULaneDataSet(DetDataset):
def __init__(
self,
dataset_dir,
cut_height,
list_path,
split='train',
data_fields=['image'],
video_file=None,
frame_rate=-1, ):
super(CULaneDataSet, self).__init__(
dataset_dir=dataset_dir,
cut_height=cut_height,
split=split,
data_fields=data_fields)
self.dataset_dir = dataset_dir
self.list_path = osp.join(dataset_dir, list_path)
self.cut_height = cut_height
self.data_fields = data_fields
self.split = split
self.training = 'train' in split
self.data_infos = []
self.video_file = video_file
self.frame_rate = frame_rate
self._imid2path = {}
self.predict_dir = None
def __len__(self):
return len(self.data_infos)
def check_or_download_dataset(self):
if not osp.exists(self.dataset_dir):
download_dataset("dataset", dataset="culane")
# extract .tar files in self.dataset_dir
for fname in os.listdir(self.dataset_dir):
logger.info("Decompressing {}...".format(fname))
# ignore .* files
if fname.startswith('.'):
continue
if fname.find('.tar.gz') >= 0:
with tarfile.open(osp.join(self.dataset_dir, fname)) as tf:
tf.extractall(path=self.dataset_dir)
logger.info("Dataset files are ready.")
def parse_dataset(self):
logger.info('Loading CULane annotations...')
if self.predict_dir is not None:
logger.info('switch to predict mode')
return
# Waiting for the dataset to load is tedious, let's cache it
os.makedirs('cache', exist_ok=True)
cache_path = 'cache/culane_paddle_{}.pkl'.format(self.split)
if os.path.exists(cache_path):
with open(cache_path, 'rb') as cache_file:
self.data_infos = pkl.load(cache_file)
self.max_lanes = max(
len(anno['lanes']) for anno in self.data_infos)
return
with open(self.list_path) as list_file:
for line in list_file:
infos = self.load_annotation(line.split())
self.data_infos.append(infos)
# cache data infos to file
with open(cache_path, 'wb') as cache_file:
pkl.dump(self.data_infos, cache_file)
def load_annotation(self, line):
infos = {}
img_line = line[0]
img_line = img_line[1 if img_line[0] == '/' else 0::]
img_path = os.path.join(self.dataset_dir, img_line)
infos['img_name'] = img_line
infos['img_path'] = img_path
if len(line) > 1:
mask_line = line[1]
mask_line = mask_line[1 if mask_line[0] == '/' else 0::]
mask_path = os.path.join(self.dataset_dir, mask_line)
infos['mask_path'] = mask_path
if len(line) > 2:
exist_list = [int(l) for l in line[2:]]
infos['lane_exist'] = np.array(exist_list)
anno_path = img_path[:
-3] + 'lines.txt' # remove sufix jpg and add lines.txt
with open(anno_path, 'r') as anno_file:
data = [
list(map(float, line.split())) for line in anno_file.readlines()
]
lanes = [[(lane[i], lane[i + 1]) for i in range(0, len(lane), 2)
if lane[i] >= 0 and lane[i + 1] >= 0] for lane in data]
lanes = [list(set(lane)) for lane in lanes] # remove duplicated points
lanes = [lane for lane in lanes
if len(lane) > 2] # remove lanes with less than 2 points
lanes = [sorted(
lane, key=lambda x: x[1]) for lane in lanes] # sort by y
infos['lanes'] = lanes
return infos
def set_images(self, images):
self.predict_dir = images
self.data_infos = self._load_images()
def _find_images(self):
predict_dir = self.predict_dir
if not isinstance(predict_dir, Sequence):
predict_dir = [predict_dir]
images = []
for im_dir in predict_dir:
if os.path.isdir(im_dir):
im_dir = os.path.join(self.predict_dir, im_dir)
images.extend(_make_dataset(im_dir))
elif os.path.isfile(im_dir) and _is_valid_file(im_dir):
images.append(im_dir)
return images
def _load_images(self):
images = self._find_images()
ct = 0
records = []
for image in images:
assert image != '' and os.path.isfile(image), \
"Image {} not found".format(image)
if self.sample_num > 0 and ct >= self.sample_num:
break
rec = {
'im_id': np.array([ct]),
"img_path": os.path.abspath(image),
"img_name": os.path.basename(image),
"lanes": []
}
self._imid2path[ct] = image
ct += 1
records.append(rec)
assert len(records) > 0, "No image file found"
return records
def get_imid2path(self):
return self._imid2path
def __getitem__(self, idx):
data_info = self.data_infos[idx]
img = cv2.imread(data_info['img_path'])
img = img[self.cut_height:, :, :]
sample = data_info.copy()
sample.update({'image': img})
img_org = sample['image']
if self.training:
label = cv2.imread(sample['mask_path'], cv2.IMREAD_UNCHANGED)
if len(label.shape) > 2:
label = label[:, :, 0]
label = label.squeeze()
label = label[self.cut_height:, :]
sample.update({'mask': label})
if self.cut_height != 0:
new_lanes = []
for i in sample['lanes']:
lanes = []
for p in i:
lanes.append((p[0], p[1] - self.cut_height))
new_lanes.append(lanes)
sample.update({'lanes': new_lanes})
sample['mask'] = SegmentationMapsOnImage(
sample['mask'], shape=img_org.shape)
sample['full_img_path'] = data_info['img_path']
sample['img_name'] = data_info['img_name']
sample['im_id'] = np.array([idx])
sample['image'] = sample['image'].copy().astype(np.uint8)
sample['lanes'] = lane_to_linestrings(sample['lanes'])
sample['lanes'] = LineStringsOnImage(
sample['lanes'], shape=img_org.shape)
sample['seg'] = np.zeros(img_org.shape)
return sample

View File

@@ -0,0 +1,307 @@
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import copy
import numpy as np
try:
from collections.abc import Sequence
except Exception:
from collections import Sequence
from paddle.io import Dataset
from ppdet.core.workspace import register, serializable
from ppdet.utils.download import get_dataset_path
from ppdet.data import source
from ppdet.utils.logger import setup_logger
logger = setup_logger(__name__)
@serializable
class DetDataset(Dataset):
"""
Load detection dataset.
Args:
dataset_dir (str): root directory for dataset.
image_dir (str): directory for images.
anno_path (str): annotation file path.
data_fields (list): key name of data dictionary, at least have 'image'.
sample_num (int): number of samples to load, -1 means all.
use_default_label (bool): whether to load default label list.
repeat (int): repeat times for dataset, use in benchmark.
"""
def __init__(self,
dataset_dir=None,
image_dir=None,
anno_path=None,
data_fields=['image'],
sample_num=-1,
use_default_label=None,
repeat=1,
**kwargs):
super(DetDataset, self).__init__()
self.dataset_dir = dataset_dir if dataset_dir is not None else ''
self.anno_path = anno_path
self.image_dir = image_dir if image_dir is not None else ''
self.data_fields = data_fields
self.sample_num = sample_num
self.use_default_label = use_default_label
self.repeat = repeat
self._epoch = 0
self._curr_iter = 0
def __len__(self, ):
return len(self.roidbs) * self.repeat
def __call__(self, *args, **kwargs):
return self
def __getitem__(self, idx):
n = len(self.roidbs)
if self.repeat > 1:
idx %= n
# data batch
roidb = copy.deepcopy(self.roidbs[idx])
if self.mixup_epoch == 0 or self._epoch < self.mixup_epoch:
idx = np.random.randint(n)
roidb = [roidb, copy.deepcopy(self.roidbs[idx])]
elif self.cutmix_epoch == 0 or self._epoch < self.cutmix_epoch:
idx = np.random.randint(n)
roidb = [roidb, copy.deepcopy(self.roidbs[idx])]
elif self.mosaic_epoch == 0 or self._epoch < self.mosaic_epoch:
roidb = [roidb, ] + [
copy.deepcopy(self.roidbs[np.random.randint(n)])
for _ in range(4)
]
elif self.pre_img_epoch == 0 or self._epoch < self.pre_img_epoch:
# Add previous image as input, only used in CenterTrack
idx_pre_img = idx - 1
if idx_pre_img < 0:
idx_pre_img = idx + 1
roidb = [roidb, ] + [copy.deepcopy(self.roidbs[idx_pre_img])]
if isinstance(roidb, Sequence):
for r in roidb:
r['curr_iter'] = self._curr_iter
else:
roidb['curr_iter'] = self._curr_iter
self._curr_iter += 1
return self.transform(roidb)
def check_or_download_dataset(self):
self.dataset_dir = get_dataset_path(self.dataset_dir, self.anno_path,
self.image_dir)
def set_kwargs(self, **kwargs):
self.mixup_epoch = kwargs.get('mixup_epoch', -1)
self.cutmix_epoch = kwargs.get('cutmix_epoch', -1)
self.mosaic_epoch = kwargs.get('mosaic_epoch', -1)
self.pre_img_epoch = kwargs.get('pre_img_epoch', -1)
def set_transform(self, transform):
self.transform = transform
def set_epoch(self, epoch_id):
self._epoch = epoch_id
def parse_dataset(self, ):
raise NotImplementedError(
"Need to implement parse_dataset method of Dataset")
def get_anno(self):
if self.anno_path is None:
return
return os.path.join(self.dataset_dir, self.anno_path)
def _is_valid_file(f, extensions=('.jpg', '.jpeg', '.png', '.bmp')):
return f.lower().endswith(extensions)
def _make_dataset(dir):
dir = os.path.expanduser(dir)
if not os.path.isdir(dir):
raise ('{} should be a dir'.format(dir))
images = []
for root, _, fnames in sorted(os.walk(dir, followlinks=True)):
for fname in sorted(fnames):
path = os.path.join(root, fname)
if _is_valid_file(path):
images.append(path)
return images
@register
@serializable
class ImageFolder(DetDataset):
def __init__(self,
dataset_dir=None,
image_dir=None,
anno_path=None,
sample_num=-1,
use_default_label=None,
**kwargs):
super(ImageFolder, self).__init__(
dataset_dir,
image_dir,
anno_path,
sample_num=sample_num,
use_default_label=use_default_label)
self._imid2path = {}
self.roidbs = None
self.sample_num = sample_num
def check_or_download_dataset(self):
return
def get_anno(self):
if self.anno_path is None:
return
if self.dataset_dir:
return os.path.join(self.dataset_dir, self.anno_path)
else:
return self.anno_path
def parse_dataset(self, ):
if not self.roidbs:
self.roidbs = self._load_images()
def _parse(self):
image_dir = self.image_dir
if not isinstance(image_dir, Sequence):
image_dir = [image_dir]
images = []
for im_dir in image_dir:
if os.path.isdir(im_dir):
im_dir = os.path.join(self.dataset_dir, im_dir)
images.extend(_make_dataset(im_dir))
elif os.path.isfile(im_dir) and _is_valid_file(im_dir):
images.append(im_dir)
return images
def _load_images(self):
images = self._parse()
ct = 0
records = []
for image in images:
assert image != '' and os.path.isfile(image), \
"Image {} not found".format(image)
if self.sample_num > 0 and ct >= self.sample_num:
break
rec = {'im_id': np.array([ct]), 'im_file': image}
self._imid2path[ct] = image
ct += 1
records.append(rec)
assert len(records) > 0, "No image file found"
return records
def get_imid2path(self):
return self._imid2path
def set_images(self, images):
self.image_dir = images
self.roidbs = self._load_images()
def set_slice_images(self,
images,
slice_size=[640, 640],
overlap_ratio=[0.25, 0.25]):
self.image_dir = images
ori_records = self._load_images()
try:
import sahi
from sahi.slicing import slice_image
except Exception as e:
logger.error(
'sahi not found, plaese install sahi. '
'for example: `pip install sahi`, see https://github.com/obss/sahi.'
)
raise e
sub_img_ids = 0
ct = 0
ct_sub = 0
records = []
for i, ori_rec in enumerate(ori_records):
im_path = ori_rec['im_file']
slice_image_result = sahi.slicing.slice_image(
image=im_path,
slice_height=slice_size[0],
slice_width=slice_size[1],
overlap_height_ratio=overlap_ratio[0],
overlap_width_ratio=overlap_ratio[1])
sub_img_num = len(slice_image_result)
for _ind in range(sub_img_num):
im = slice_image_result.images[_ind]
rec = {
'image': im,
'im_id': np.array([sub_img_ids + _ind]),
'h': im.shape[0],
'w': im.shape[1],
'ori_im_id': np.array([ori_rec['im_id'][0]]),
'st_pix': np.array(
slice_image_result.starting_pixels[_ind],
dtype=np.float32),
'is_last': 1 if _ind == sub_img_num - 1 else 0,
} if 'image' in self.data_fields else {}
records.append(rec)
ct_sub += sub_img_num
ct += 1
logger.info('{} samples and slice to {} sub_samples.'.format(ct,
ct_sub))
self.roidbs = records
def get_label_list(self):
# Only VOC dataset needs label list in ImageFold
return self.anno_path
@register
class CommonDataset(object):
def __init__(self, **dataset_args):
super(CommonDataset, self).__init__()
dataset_args = copy.deepcopy(dataset_args)
type = dataset_args.pop("name")
self.dataset = getattr(source, type)(**dataset_args)
def __call__(self):
return self.dataset
@register
class TrainDataset(CommonDataset):
pass
@register
class EvalMOTDataset(CommonDataset):
pass
@register
class TestMOTDataset(CommonDataset):
pass
@register
class EvalDataset(CommonDataset):
pass
@register
class TestDataset(CommonDataset):
pass

View File

@@ -0,0 +1,845 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
this code is base on https://github.com/open-mmlab/mmpose
"""
import os
import cv2
import numpy as np
import json
import copy
import pycocotools
from pycocotools.coco import COCO
from .dataset import DetDataset
from ppdet.core.workspace import register, serializable
@serializable
class KeypointBottomUpBaseDataset(DetDataset):
"""Base class for bottom-up datasets.
All datasets should subclass it.
All subclasses should overwrite:
Methods:`_get_imganno`
Args:
dataset_dir (str): Root path to the dataset.
anno_path (str): Relative path to the annotation file.
image_dir (str): Path to a directory where images are held.
Default: None.
num_joints (int): keypoint numbers
transform (composed(operators)): A sequence of data transforms.
shard (list): [rank, worldsize], the distributed env params
test_mode (bool): Store True when building test or
validation dataset. Default: False.
"""
def __init__(self,
dataset_dir,
image_dir,
anno_path,
num_joints,
transform=[],
shard=[0, 1],
test_mode=False):
super().__init__(dataset_dir, image_dir, anno_path)
self.image_info = {}
self.ann_info = {}
self.img_prefix = os.path.join(dataset_dir, image_dir)
self.transform = transform
self.test_mode = test_mode
self.ann_info['num_joints'] = num_joints
self.img_ids = []
def parse_dataset(self):
pass
def __len__(self):
"""Get dataset length."""
return len(self.img_ids)
def _get_imganno(self, idx):
"""Get anno for a single image."""
raise NotImplementedError
def __getitem__(self, idx):
"""Prepare image for training given the index."""
records = copy.deepcopy(self._get_imganno(idx))
records['image'] = cv2.imread(records['image_file'])
records['image'] = cv2.cvtColor(records['image'], cv2.COLOR_BGR2RGB)
if 'mask' in records:
records['mask'] = (records['mask'] + 0).astype('uint8')
records = self.transform(records)
return records
def parse_dataset(self):
return
@register
@serializable
class KeypointBottomUpCocoDataset(KeypointBottomUpBaseDataset):
"""COCO dataset for bottom-up pose estimation.
The dataset loads raw features and apply specified transforms
to return a dict containing the image tensors and other information.
COCO keypoint indexes::
0: 'nose',
1: 'left_eye',
2: 'right_eye',
3: 'left_ear',
4: 'right_ear',
5: 'left_shoulder',
6: 'right_shoulder',
7: 'left_elbow',
8: 'right_elbow',
9: 'left_wrist',
10: 'right_wrist',
11: 'left_hip',
12: 'right_hip',
13: 'left_knee',
14: 'right_knee',
15: 'left_ankle',
16: 'right_ankle'
Args:
dataset_dir (str): Root path to the dataset.
anno_path (str): Relative path to the annotation file.
image_dir (str): Path to a directory where images are held.
Default: None.
num_joints (int): keypoint numbers
transform (composed(operators)): A sequence of data transforms.
shard (list): [rank, worldsize], the distributed env params
test_mode (bool): Store True when building test or
validation dataset. Default: False.
"""
def __init__(self,
dataset_dir,
image_dir,
anno_path,
num_joints,
transform=[],
shard=[0, 1],
test_mode=False,
return_mask=True,
return_bbox=True,
return_area=True,
return_class=True):
super().__init__(dataset_dir, image_dir, anno_path, num_joints,
transform, shard, test_mode)
self.ann_file = os.path.join(dataset_dir, anno_path)
self.shard = shard
self.test_mode = test_mode
self.return_mask = return_mask
self.return_bbox = return_bbox
self.return_area = return_area
self.return_class = return_class
def parse_dataset(self):
self.coco = COCO(self.ann_file)
self.img_ids = self.coco.getImgIds()
if not self.test_mode:
self.img_ids_tmp = []
for img_id in self.img_ids:
ann_ids = self.coco.getAnnIds(imgIds=img_id)
anno = self.coco.loadAnns(ann_ids)
anno = [obj for obj in anno if obj['iscrowd'] == 0]
if len(anno) == 0:
continue
self.img_ids_tmp.append(img_id)
self.img_ids = self.img_ids_tmp
blocknum = int(len(self.img_ids) / self.shard[1])
self.img_ids = self.img_ids[(blocknum * self.shard[0]):(blocknum * (
self.shard[0] + 1))]
self.num_images = len(self.img_ids)
self.id2name, self.name2id = self._get_mapping_id_name(self.coco.imgs)
self.dataset_name = 'coco'
cat_ids = self.coco.getCatIds()
self.catid2clsid = dict({catid: i for i, catid in enumerate(cat_ids)})
print('=> num_images: {}'.format(self.num_images))
@staticmethod
def _get_mapping_id_name(imgs):
"""
Args:
imgs (dict): dict of image info.
Returns:
tuple: Image name & id mapping dicts.
- id2name (dict): Mapping image id to name.
- name2id (dict): Mapping image name to id.
"""
id2name = {}
name2id = {}
for image_id, image in imgs.items():
file_name = image['file_name']
id2name[image_id] = file_name
name2id[file_name] = image_id
return id2name, name2id
def _get_imganno(self, idx):
"""Get anno for a single image.
Args:
idx (int): image idx
Returns:
dict: info for model training
"""
coco = self.coco
img_id = self.img_ids[idx]
ann_ids = coco.getAnnIds(imgIds=img_id)
anno = coco.loadAnns(ann_ids)
anno = [
obj for obj in anno
if obj['iscrowd'] == 0 and obj['num_keypoints'] > 0
]
db_rec = {}
joints, orgsize = self._get_joints(anno, idx)
db_rec['gt_joints'] = joints
db_rec['im_shape'] = orgsize
if self.return_bbox:
db_rec['gt_bbox'] = self._get_bboxs(anno, idx)
if self.return_class:
db_rec['gt_class'] = self._get_labels(anno, idx)
if self.return_area:
db_rec['gt_areas'] = self._get_areas(anno, idx)
if self.return_mask:
db_rec['mask'] = self._get_mask(anno, idx)
db_rec['im_id'] = img_id
db_rec['image_file'] = os.path.join(self.img_prefix,
self.id2name[img_id])
return db_rec
def _get_joints(self, anno, idx):
"""Get joints for all people in an image."""
num_people = len(anno)
joints = np.zeros(
(num_people, self.ann_info['num_joints'], 3), dtype=np.float32)
for i, obj in enumerate(anno):
joints[i, :self.ann_info['num_joints'], :3] = \
np.array(obj['keypoints']).reshape([-1, 3])
img_info = self.coco.loadImgs(self.img_ids[idx])[0]
orgsize = np.array([img_info['height'], img_info['width'], 1])
return joints, orgsize
def _get_bboxs(self, anno, idx):
num_people = len(anno)
gt_bboxes = np.zeros((num_people, 4), dtype=np.float32)
for idx, obj in enumerate(anno):
if 'bbox' in obj:
gt_bboxes[idx, :] = obj['bbox']
gt_bboxes[:, 2] += gt_bboxes[:, 0]
gt_bboxes[:, 3] += gt_bboxes[:, 1]
return gt_bboxes
def _get_labels(self, anno, idx):
num_people = len(anno)
gt_labels = np.zeros((num_people, 1), dtype=np.float32)
for idx, obj in enumerate(anno):
if 'category_id' in obj:
catid = obj['category_id']
gt_labels[idx, 0] = self.catid2clsid[catid]
return gt_labels
def _get_areas(self, anno, idx):
num_people = len(anno)
gt_areas = np.zeros((num_people, ), dtype=np.float32)
for idx, obj in enumerate(anno):
if 'area' in obj:
gt_areas[idx, ] = obj['area']
return gt_areas
def _get_mask(self, anno, idx):
"""Get ignore masks to mask out losses."""
coco = self.coco
img_info = coco.loadImgs(self.img_ids[idx])[0]
m = np.zeros((img_info['height'], img_info['width']), dtype=np.float32)
for obj in anno:
if 'segmentation' in obj:
if obj['iscrowd']:
rle = pycocotools.mask.frPyObjects(obj['segmentation'],
img_info['height'],
img_info['width'])
m += pycocotools.mask.decode(rle)
elif obj['num_keypoints'] == 0:
rles = pycocotools.mask.frPyObjects(obj['segmentation'],
img_info['height'],
img_info['width'])
for rle in rles:
m += pycocotools.mask.decode(rle)
return m < 0.5
@register
@serializable
class KeypointBottomUpCrowdPoseDataset(KeypointBottomUpCocoDataset):
"""CrowdPose dataset for bottom-up pose estimation.
The dataset loads raw features and apply specified transforms
to return a dict containing the image tensors and other information.
CrowdPose keypoint indexes::
0: 'left_shoulder',
1: 'right_shoulder',
2: 'left_elbow',
3: 'right_elbow',
4: 'left_wrist',
5: 'right_wrist',
6: 'left_hip',
7: 'right_hip',
8: 'left_knee',
9: 'right_knee',
10: 'left_ankle',
11: 'right_ankle',
12: 'top_head',
13: 'neck'
Args:
dataset_dir (str): Root path to the dataset.
anno_path (str): Relative path to the annotation file.
image_dir (str): Path to a directory where images are held.
Default: None.
num_joints (int): keypoint numbers
transform (composed(operators)): A sequence of data transforms.
shard (list): [rank, worldsize], the distributed env params
test_mode (bool): Store True when building test or
validation dataset. Default: False.
"""
def __init__(self,
dataset_dir,
image_dir,
anno_path,
num_joints,
transform=[],
shard=[0, 1],
test_mode=False):
super().__init__(dataset_dir, image_dir, anno_path, num_joints,
transform, shard, test_mode)
self.ann_file = os.path.join(dataset_dir, anno_path)
self.shard = shard
self.test_mode = test_mode
def parse_dataset(self):
self.coco = COCO(self.ann_file)
self.img_ids = self.coco.getImgIds()
if not self.test_mode:
self.img_ids = [
img_id for img_id in self.img_ids
if len(self.coco.getAnnIds(
imgIds=img_id, iscrowd=None)) > 0
]
blocknum = int(len(self.img_ids) / self.shard[1])
self.img_ids = self.img_ids[(blocknum * self.shard[0]):(blocknum * (
self.shard[0] + 1))]
self.num_images = len(self.img_ids)
self.id2name, self.name2id = self._get_mapping_id_name(self.coco.imgs)
self.dataset_name = 'crowdpose'
print('=> num_images: {}'.format(self.num_images))
@serializable
class KeypointTopDownBaseDataset(DetDataset):
"""Base class for top_down datasets.
All datasets should subclass it.
All subclasses should overwrite:
Methods:`_get_db`
Args:
dataset_dir (str): Root path to the dataset.
image_dir (str): Path to a directory where images are held.
anno_path (str): Relative path to the annotation file.
num_joints (int): keypoint numbers
transform (composed(operators)): A sequence of data transforms.
"""
def __init__(self,
dataset_dir,
image_dir,
anno_path,
num_joints,
transform=[]):
super().__init__(dataset_dir, image_dir, anno_path)
self.image_info = {}
self.ann_info = {}
self.img_prefix = os.path.join(dataset_dir, image_dir)
self.transform = transform
self.ann_info['num_joints'] = num_joints
self.db = []
def __len__(self):
"""Get dataset length."""
return len(self.db)
def _get_db(self):
"""Get a sample"""
raise NotImplementedError
def __getitem__(self, idx):
"""Prepare sample for training given the index."""
records = copy.deepcopy(self.db[idx])
records['image'] = cv2.imread(records['image_file'], cv2.IMREAD_COLOR |
cv2.IMREAD_IGNORE_ORIENTATION)
records['image'] = cv2.cvtColor(records['image'], cv2.COLOR_BGR2RGB)
records['score'] = records['score'] if 'score' in records else 1
records = self.transform(records)
# print('records', records)
return records
@register
@serializable
class KeypointTopDownCocoDataset(KeypointTopDownBaseDataset):
"""COCO dataset for top-down pose estimation.
The dataset loads raw features and apply specified transforms
to return a dict containing the image tensors and other information.
COCO keypoint indexes:
0: 'nose',
1: 'left_eye',
2: 'right_eye',
3: 'left_ear',
4: 'right_ear',
5: 'left_shoulder',
6: 'right_shoulder',
7: 'left_elbow',
8: 'right_elbow',
9: 'left_wrist',
10: 'right_wrist',
11: 'left_hip',
12: 'right_hip',
13: 'left_knee',
14: 'right_knee',
15: 'left_ankle',
16: 'right_ankle'
Args:
dataset_dir (str): Root path to the dataset.
image_dir (str): Path to a directory where images are held.
anno_path (str): Relative path to the annotation file.
num_joints (int): Keypoint numbers
trainsize (list):[w, h] Image target size
transform (composed(operators)): A sequence of data transforms.
bbox_file (str): Path to a detection bbox file
Default: None.
use_gt_bbox (bool): Whether to use ground truth bbox
Default: True.
pixel_std (int): The pixel std of the scale
Default: 200.
image_thre (float): The threshold to filter the detection box
Default: 0.0.
"""
def __init__(self,
dataset_dir,
image_dir,
anno_path,
num_joints,
trainsize,
transform=[],
bbox_file=None,
use_gt_bbox=True,
pixel_std=200,
image_thre=0.0,
center_scale=None):
super().__init__(dataset_dir, image_dir, anno_path, num_joints,
transform)
self.bbox_file = bbox_file
self.use_gt_bbox = use_gt_bbox
self.trainsize = trainsize
self.pixel_std = pixel_std
self.image_thre = image_thre
self.center_scale = center_scale
self.dataset_name = 'coco'
def parse_dataset(self):
if self.use_gt_bbox:
self.db = self._load_coco_keypoint_annotations()
else:
self.db = self._load_coco_person_detection_results()
def _load_coco_keypoint_annotations(self):
coco = COCO(self.get_anno())
img_ids = coco.getImgIds()
gt_db = []
for index in img_ids:
im_ann = coco.loadImgs(index)[0]
width = im_ann['width']
height = im_ann['height']
file_name = im_ann['file_name']
im_id = int(im_ann["id"])
annIds = coco.getAnnIds(imgIds=index, iscrowd=False)
objs = coco.loadAnns(annIds)
valid_objs = []
for obj in objs:
x, y, w, h = obj['bbox']
x1 = np.max((0, x))
y1 = np.max((0, y))
x2 = np.min((width - 1, x1 + np.max((0, w - 1))))
y2 = np.min((height - 1, y1 + np.max((0, h - 1))))
if obj['area'] > 0 and x2 >= x1 and y2 >= y1:
obj['clean_bbox'] = [x1, y1, x2 - x1, y2 - y1]
valid_objs.append(obj)
objs = valid_objs
rec = []
for obj in objs:
if max(obj['keypoints']) == 0:
continue
joints = np.zeros(
(self.ann_info['num_joints'], 3), dtype=np.float32)
joints_vis = np.zeros(
(self.ann_info['num_joints'], 3), dtype=np.float32)
for ipt in range(self.ann_info['num_joints']):
joints[ipt, 0] = obj['keypoints'][ipt * 3 + 0]
joints[ipt, 1] = obj['keypoints'][ipt * 3 + 1]
joints[ipt, 2] = 0
t_vis = obj['keypoints'][ipt * 3 + 2]
if t_vis > 1:
t_vis = 1
joints_vis[ipt, 0] = t_vis
joints_vis[ipt, 1] = t_vis
joints_vis[ipt, 2] = 0
center, scale = self._box2cs(obj['clean_bbox'][:4])
rec.append({
'image_file': os.path.join(self.img_prefix, file_name),
'center': center,
'scale': scale,
'gt_joints': joints,
'joints_vis': joints_vis,
'im_id': im_id,
})
gt_db.extend(rec)
return gt_db
def _box2cs(self, box):
x, y, w, h = box[:4]
center = np.zeros((2), dtype=np.float32)
center[0] = x + w * 0.5
center[1] = y + h * 0.5
aspect_ratio = self.trainsize[0] * 1.0 / self.trainsize[1]
if self.center_scale is not None and np.random.rand() < 0.3:
center += self.center_scale * (np.random.rand(2) - 0.5) * [w, h]
if w > aspect_ratio * h:
h = w * 1.0 / aspect_ratio
elif w < aspect_ratio * h:
w = h * aspect_ratio
scale = np.array(
[w * 1.0 / self.pixel_std, h * 1.0 / self.pixel_std],
dtype=np.float32)
if center[0] != -1:
scale = scale * 1.25
return center, scale
def _load_coco_person_detection_results(self):
all_boxes = None
bbox_file_path = os.path.join(self.dataset_dir, self.bbox_file)
with open(bbox_file_path, 'r') as f:
all_boxes = json.load(f)
if not all_boxes:
print('=> Load %s fail!' % bbox_file_path)
return None
kpt_db = []
for n_img in range(0, len(all_boxes)):
det_res = all_boxes[n_img]
if det_res['category_id'] != 1:
continue
file_name = det_res[
'filename'] if 'filename' in det_res else '%012d.jpg' % det_res[
'image_id']
img_name = os.path.join(self.img_prefix, file_name)
box = det_res['bbox']
score = det_res['score']
im_id = int(det_res['image_id'])
if score < self.image_thre:
continue
center, scale = self._box2cs(box)
joints = np.zeros(
(self.ann_info['num_joints'], 3), dtype=np.float32)
joints_vis = np.ones(
(self.ann_info['num_joints'], 3), dtype=np.float32)
kpt_db.append({
'image_file': img_name,
'im_id': im_id,
'center': center,
'scale': scale,
'score': score,
'gt_joints': joints,
'joints_vis': joints_vis,
})
return kpt_db
@register
@serializable
class KeypointTopDownCocoWholeBodyHandDataset(KeypointTopDownBaseDataset):
"""CocoWholeBody dataset for top-down hand pose estimation.
The dataset loads raw features and apply specified transforms
to return a dict containing the image tensors and other information.
COCO-WholeBody Hand keypoint indexes:
0: 'wrist',
1: 'thumb1',
2: 'thumb2',
3: 'thumb3',
4: 'thumb4',
5: 'forefinger1',
6: 'forefinger2',
7: 'forefinger3',
8: 'forefinger4',
9: 'middle_finger1',
10: 'middle_finger2',
11: 'middle_finger3',
12: 'middle_finger4',
13: 'ring_finger1',
14: 'ring_finger2',
15: 'ring_finger3',
16: 'ring_finger4',
17: 'pinky_finger1',
18: 'pinky_finger2',
19: 'pinky_finger3',
20: 'pinky_finger4'
Args:
dataset_dir (str): Root path to the dataset.
image_dir (str): Path to a directory where images are held.
anno_path (str): Relative path to the annotation file.
num_joints (int): Keypoint numbers
trainsize (list):[w, h] Image target size
transform (composed(operators)): A sequence of data transforms.
pixel_std (int): The pixel std of the scale
Default: 200.
"""
def __init__(self,
dataset_dir,
image_dir,
anno_path,
num_joints,
trainsize,
transform=[],
pixel_std=200):
super().__init__(dataset_dir, image_dir, anno_path, num_joints,
transform)
self.trainsize = trainsize
self.pixel_std = pixel_std
self.dataset_name = 'coco_wholebady_hand'
def _box2cs(self, box):
x, y, w, h = box[:4]
center = np.zeros((2), dtype=np.float32)
center[0] = x + w * 0.5
center[1] = y + h * 0.5
aspect_ratio = self.trainsize[0] * 1.0 / self.trainsize[1]
if w > aspect_ratio * h:
h = w * 1.0 / aspect_ratio
elif w < aspect_ratio * h:
w = h * aspect_ratio
scale = np.array(
[w * 1.0 / self.pixel_std, h * 1.0 / self.pixel_std],
dtype=np.float32)
if center[0] != -1:
scale = scale * 1.25
return center, scale
def parse_dataset(self):
gt_db = []
num_joints = self.ann_info['num_joints']
coco = COCO(self.get_anno())
img_ids = list(coco.imgs.keys())
for img_id in img_ids:
im_ann = coco.loadImgs(img_id)[0]
image_file = os.path.join(self.img_prefix, im_ann['file_name'])
im_id = int(im_ann["id"])
ann_ids = coco.getAnnIds(imgIds=img_id, iscrowd=False)
objs = coco.loadAnns(ann_ids)
for obj in objs:
for type in ['left', 'right']:
if (obj[f'{type}hand_valid'] and
max(obj[f'{type}hand_kpts']) > 0):
joints = np.zeros((num_joints, 3), dtype=np.float32)
joints_vis = np.zeros((num_joints, 3), dtype=np.float32)
keypoints = np.array(obj[f'{type}hand_kpts'])
keypoints = keypoints.reshape(-1, 3)
joints[:, :2] = keypoints[:, :2]
joints_vis[:, :2] = np.minimum(1, keypoints[:, 2:3])
center, scale = self._box2cs(obj[f'{type}hand_box'][:4])
gt_db.append({
'image_file': image_file,
'center': center,
'scale': scale,
'gt_joints': joints,
'joints_vis': joints_vis,
'im_id': im_id,
})
self.db = gt_db
@register
@serializable
class KeypointTopDownMPIIDataset(KeypointTopDownBaseDataset):
"""MPII dataset for topdown pose estimation.
The dataset loads raw features and apply specified transforms
to return a dict containing the image tensors and other information.
MPII keypoint indexes::
0: 'right_ankle',
1: 'right_knee',
2: 'right_hip',
3: 'left_hip',
4: 'left_knee',
5: 'left_ankle',
6: 'pelvis',
7: 'thorax',
8: 'upper_neck',
9: 'head_top',
10: 'right_wrist',
11: 'right_elbow',
12: 'right_shoulder',
13: 'left_shoulder',
14: 'left_elbow',
15: 'left_wrist',
Args:
dataset_dir (str): Root path to the dataset.
image_dir (str): Path to a directory where images are held.
anno_path (str): Relative path to the annotation file.
num_joints (int): Keypoint numbers
trainsize (list):[w, h] Image target size
transform (composed(operators)): A sequence of data transforms.
"""
def __init__(self,
dataset_dir,
image_dir,
anno_path,
num_joints,
transform=[]):
super().__init__(dataset_dir, image_dir, anno_path, num_joints,
transform)
self.dataset_name = 'mpii'
def parse_dataset(self):
with open(self.get_anno()) as anno_file:
anno = json.load(anno_file)
gt_db = []
for a in anno:
image_name = a['image']
im_id = a['image_id'] if 'image_id' in a else int(
os.path.splitext(image_name)[0])
c = np.array(a['center'], dtype=np.float32)
s = np.array([a['scale'], a['scale']], dtype=np.float32)
# Adjust center/scale slightly to avoid cropping limbs
if c[0] != -1:
c[1] = c[1] + 15 * s[1]
s = s * 1.25
c = c - 1
joints = np.zeros(
(self.ann_info['num_joints'], 3), dtype=np.float32)
joints_vis = np.zeros(
(self.ann_info['num_joints'], 3), dtype=np.float32)
if 'gt_joints' in a:
joints_ = np.array(a['gt_joints'])
joints_[:, 0:2] = joints_[:, 0:2] - 1
joints_vis_ = np.array(a['joints_vis'])
assert len(joints_) == self.ann_info[
'num_joints'], 'joint num diff: {} vs {}'.format(
len(joints_), self.ann_info['num_joints'])
joints[:, 0:2] = joints_[:, 0:2]
joints_vis[:, 0] = joints_vis_[:]
joints_vis[:, 1] = joints_vis_[:]
gt_db.append({
'image_file': os.path.join(self.img_prefix, image_name),
'im_id': im_id,
'center': c,
'scale': s,
'gt_joints': joints,
'joints_vis': joints_vis
})
print("number length: {}".format(len(gt_db)))
self.db = gt_db

View File

@@ -0,0 +1,638 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
import cv2
import glob
import numpy as np
from collections import OrderedDict, defaultdict
try:
from collections.abc import Sequence
except Exception:
from collections import Sequence
from .dataset import DetDataset, _make_dataset, _is_valid_file
from ppdet.core.workspace import register, serializable
from ppdet.utils.logger import setup_logger
logger = setup_logger(__name__)
@register
@serializable
class MOTDataSet(DetDataset):
"""
Load dataset with MOT format, only support single class MOT.
Args:
dataset_dir (str): root directory for dataset.
image_lists (str|list): mot data image lists, muiti-source mot dataset.
data_fields (list): key name of data dictionary, at least have 'image'.
sample_num (int): number of samples to load, -1 means all.
repeat (int): repeat times for dataset, use in benchmark.
Notes:
MOT datasets root directory following this:
dataset/mot
|——————image_lists
| |——————caltech.train
| |——————caltech.val
| |——————mot16.train
| |——————mot17.train
| ......
|——————Caltech
|——————MOT17
|——————......
All the MOT datasets have the following structure:
Caltech
|——————images
| └——————00001.jpg
| |—————— ...
| └——————0000N.jpg
└——————labels_with_ids
└——————00001.txt
|—————— ...
└——————0000N.txt
or
MOT17
|——————images
| └——————train
| └——————test
└——————labels_with_ids
└——————train
"""
def __init__(self,
dataset_dir=None,
image_lists=[],
data_fields=['image'],
sample_num=-1,
repeat=1):
super(MOTDataSet, self).__init__(
dataset_dir=dataset_dir,
data_fields=data_fields,
sample_num=sample_num,
repeat=repeat)
self.dataset_dir = dataset_dir
self.image_lists = image_lists
if isinstance(self.image_lists, str):
self.image_lists = [self.image_lists]
self.roidbs = None
self.cname2cid = None
def get_anno(self):
if self.image_lists == []:
return
# only used to get categories and metric
# only check first data, but the label_list of all data should be same.
first_mot_data = self.image_lists[0].split('.')[0]
anno_file = os.path.join(self.dataset_dir, first_mot_data,
'label_list.txt')
return anno_file
def parse_dataset(self):
self.img_files = OrderedDict()
self.img_start_index = OrderedDict()
self.label_files = OrderedDict()
self.tid_num = OrderedDict()
self.tid_start_index = OrderedDict()
img_index = 0
for data_name in self.image_lists:
# check every data image list
image_lists_dir = os.path.join(self.dataset_dir, 'image_lists')
assert os.path.isdir(image_lists_dir), \
"The {} is not a directory.".format(image_lists_dir)
list_path = os.path.join(image_lists_dir, data_name)
assert os.path.exists(list_path), \
"The list path {} does not exist.".format(list_path)
# record img_files, filter out empty ones
with open(list_path, 'r') as file:
self.img_files[data_name] = file.readlines()
self.img_files[data_name] = [
os.path.join(self.dataset_dir, x.strip())
for x in self.img_files[data_name]
]
self.img_files[data_name] = list(
filter(lambda x: len(x) > 0, self.img_files[data_name]))
self.img_start_index[data_name] = img_index
img_index += len(self.img_files[data_name])
# record label_files
self.label_files[data_name] = [
x.replace('images', 'labels_with_ids').replace(
'.png', '.txt').replace('.jpg', '.txt')
for x in self.img_files[data_name]
]
for data_name, label_paths in self.label_files.items():
max_index = -1
for lp in label_paths:
lb = np.loadtxt(lp)
if len(lb) < 1:
continue
if len(lb.shape) < 2:
img_max = lb[1]
else:
img_max = np.max(lb[:, 1])
if img_max > max_index:
max_index = img_max
self.tid_num[data_name] = int(max_index + 1)
last_index = 0
for i, (k, v) in enumerate(self.tid_num.items()):
self.tid_start_index[k] = last_index
last_index += v
self.num_identities_dict = defaultdict(int)
self.num_identities_dict[0] = int(last_index + 1) # single class
self.num_imgs_each_data = [len(x) for x in self.img_files.values()]
self.total_imgs = sum(self.num_imgs_each_data)
logger.info('MOT dataset summary: ')
logger.info(self.tid_num)
logger.info('Total images: {}'.format(self.total_imgs))
logger.info('Image start index: {}'.format(self.img_start_index))
logger.info('Total identities: {}'.format(self.num_identities_dict[0]))
logger.info('Identity start index: {}'.format(self.tid_start_index))
records = []
cname2cid = mot_label()
for img_index in range(self.total_imgs):
for i, (k, v) in enumerate(self.img_start_index.items()):
if img_index >= v:
data_name = list(self.label_files.keys())[i]
start_index = v
img_file = self.img_files[data_name][img_index - start_index]
lbl_file = self.label_files[data_name][img_index - start_index]
if not os.path.exists(img_file):
logger.warning('Illegal image file: {}, and it will be ignored'.
format(img_file))
continue
if not os.path.isfile(lbl_file):
logger.warning('Illegal label file: {}, and it will be ignored'.
format(lbl_file))
continue
labels = np.loadtxt(lbl_file, dtype=np.float32).reshape(-1, 6)
# each row in labels (N, 6) is [gt_class, gt_identity, cx, cy, w, h]
cx, cy = labels[:, 2], labels[:, 3]
w, h = labels[:, 4], labels[:, 5]
gt_bbox = np.stack((cx, cy, w, h)).T.astype('float32')
gt_class = labels[:, 0:1].astype('int32')
gt_score = np.ones((len(labels), 1)).astype('float32')
gt_ide = labels[:, 1:2].astype('int32')
for i, _ in enumerate(gt_ide):
if gt_ide[i] > -1:
gt_ide[i] += self.tid_start_index[data_name]
mot_rec = {
'im_file': img_file,
'im_id': img_index,
} if 'image' in self.data_fields else {}
gt_rec = {
'gt_class': gt_class,
'gt_score': gt_score,
'gt_bbox': gt_bbox,
'gt_ide': gt_ide,
}
for k, v in gt_rec.items():
if k in self.data_fields:
mot_rec[k] = v
records.append(mot_rec)
if self.sample_num > 0 and img_index >= self.sample_num:
break
assert len(records) > 0, 'not found any mot record in %s' % (
self.image_lists)
self.roidbs, self.cname2cid = records, cname2cid
@register
@serializable
class MCMOTDataSet(DetDataset):
"""
Load dataset with MOT format, support multi-class MOT.
Args:
dataset_dir (str): root directory for dataset.
image_lists (list(str)): mcmot data image lists, muiti-source mcmot dataset.
data_fields (list): key name of data dictionary, at least have 'image'.
label_list (str): if use_default_label is False, will load
mapping between category and class index.
sample_num (int): number of samples to load, -1 means all.
Notes:
MCMOT datasets root directory following this:
dataset/mot
|——————image_lists
| |——————visdrone_mcmot.train
| |——————visdrone_mcmot.val
visdrone_mcmot
|——————images
| └——————train
| └——————val
└——————labels_with_ids
└——————train
"""
def __init__(self,
dataset_dir=None,
image_lists=[],
data_fields=['image'],
label_list=None,
sample_num=-1):
super(MCMOTDataSet, self).__init__(
dataset_dir=dataset_dir,
data_fields=data_fields,
sample_num=sample_num)
self.dataset_dir = dataset_dir
self.image_lists = image_lists
if isinstance(self.image_lists, str):
self.image_lists = [self.image_lists]
self.label_list = label_list
self.roidbs = None
self.cname2cid = None
def get_anno(self):
if self.image_lists == []:
return
# only used to get categories and metric
# only check first data, but the label_list of all data should be same.
first_mot_data = self.image_lists[0].split('.')[0]
anno_file = os.path.join(self.dataset_dir, first_mot_data,
'label_list.txt')
return anno_file
def parse_dataset(self):
self.img_files = OrderedDict()
self.img_start_index = OrderedDict()
self.label_files = OrderedDict()
self.tid_num = OrderedDict()
self.tid_start_idx_of_cls_ids = defaultdict(dict) # for MCMOT
img_index = 0
for data_name in self.image_lists:
# check every data image list
image_lists_dir = os.path.join(self.dataset_dir, 'image_lists')
assert os.path.isdir(image_lists_dir), \
"The {} is not a directory.".format(image_lists_dir)
list_path = os.path.join(image_lists_dir, data_name)
assert os.path.exists(list_path), \
"The list path {} does not exist.".format(list_path)
# record img_files, filter out empty ones
with open(list_path, 'r') as file:
self.img_files[data_name] = file.readlines()
self.img_files[data_name] = [
os.path.join(self.dataset_dir, x.strip())
for x in self.img_files[data_name]
]
self.img_files[data_name] = list(
filter(lambda x: len(x) > 0, self.img_files[data_name]))
self.img_start_index[data_name] = img_index
img_index += len(self.img_files[data_name])
# record label_files
self.label_files[data_name] = [
x.replace('images', 'labels_with_ids').replace(
'.png', '.txt').replace('.jpg', '.txt')
for x in self.img_files[data_name]
]
for data_name, label_paths in self.label_files.items():
# using max_ids_dict rather than max_index
max_ids_dict = defaultdict(int)
for lp in label_paths:
lb = np.loadtxt(lp)
if len(lb) < 1:
continue
lb = lb.reshape(-1, 6)
for item in lb:
if item[1] > max_ids_dict[int(item[0])]:
# item[0]: cls_id
# item[1]: track id
max_ids_dict[int(item[0])] = int(item[1])
# track id number
self.tid_num[data_name] = max_ids_dict
last_idx_dict = defaultdict(int)
for i, (k, v) in enumerate(self.tid_num.items()): # each sub dataset
for cls_id, id_num in v.items(): # v is a max_ids_dict
self.tid_start_idx_of_cls_ids[k][cls_id] = last_idx_dict[cls_id]
last_idx_dict[cls_id] += id_num
self.num_identities_dict = defaultdict(int)
for k, v in last_idx_dict.items():
self.num_identities_dict[k] = int(v) # total ids of each category
self.num_imgs_each_data = [len(x) for x in self.img_files.values()]
self.total_imgs = sum(self.num_imgs_each_data)
# cname2cid and cid2cname
cname2cid = {}
if self.label_list is not None:
# if use label_list for multi source mix dataset,
# please make sure label_list in the first sub_dataset at least.
sub_dataset = self.image_lists[0].split('.')[0]
label_path = os.path.join(self.dataset_dir, sub_dataset,
self.label_list)
if not os.path.exists(label_path):
logger.info(
"Note: label_list {} does not exists, use VisDrone 10 classes labels as default.".
format(label_path))
cname2cid = visdrone_mcmot_label()
else:
with open(label_path, 'r') as fr:
label_id = 0
for line in fr.readlines():
cname2cid[line.strip()] = label_id
label_id += 1
else:
cname2cid = visdrone_mcmot_label()
cid2cname = dict([(v, k) for (k, v) in cname2cid.items()])
logger.info('MCMOT dataset summary: ')
logger.info(self.tid_num)
logger.info('Total images: {}'.format(self.total_imgs))
logger.info('Image start index: {}'.format(self.img_start_index))
logger.info('Total identities of each category: ')
num_identities_dict = sorted(
self.num_identities_dict.items(), key=lambda x: x[0])
total_IDs_all_cats = 0
for (k, v) in num_identities_dict:
logger.info('Category {} [{}] has {} IDs.'.format(k, cid2cname[k],
v))
total_IDs_all_cats += v
logger.info('Total identities of all categories: {}'.format(
total_IDs_all_cats))
logger.info('Identity start index of each category: ')
for k, v in self.tid_start_idx_of_cls_ids.items():
sorted_v = sorted(v.items(), key=lambda x: x[0])
for (cls_id, start_idx) in sorted_v:
logger.info('Start index of dataset {} category {:d} is {:d}'
.format(k, cls_id, start_idx))
records = []
for img_index in range(self.total_imgs):
for i, (k, v) in enumerate(self.img_start_index.items()):
if img_index >= v:
data_name = list(self.label_files.keys())[i]
start_index = v
img_file = self.img_files[data_name][img_index - start_index]
lbl_file = self.label_files[data_name][img_index - start_index]
if not os.path.exists(img_file):
logger.warning('Illegal image file: {}, and it will be ignored'.
format(img_file))
continue
if not os.path.isfile(lbl_file):
logger.warning('Illegal label file: {}, and it will be ignored'.
format(lbl_file))
continue
labels = np.loadtxt(lbl_file, dtype=np.float32).reshape(-1, 6)
# each row in labels (N, 6) is [gt_class, gt_identity, cx, cy, w, h]
cx, cy = labels[:, 2], labels[:, 3]
w, h = labels[:, 4], labels[:, 5]
gt_bbox = np.stack((cx, cy, w, h)).T.astype('float32')
gt_class = labels[:, 0:1].astype('int32')
gt_score = np.ones((len(labels), 1)).astype('float32')
gt_ide = labels[:, 1:2].astype('int32')
for i, _ in enumerate(gt_ide):
if gt_ide[i] > -1:
cls_id = int(gt_class[i])
start_idx = self.tid_start_idx_of_cls_ids[data_name][cls_id]
gt_ide[i] += start_idx
mot_rec = {
'im_file': img_file,
'im_id': img_index,
} if 'image' in self.data_fields else {}
gt_rec = {
'gt_class': gt_class,
'gt_score': gt_score,
'gt_bbox': gt_bbox,
'gt_ide': gt_ide,
}
for k, v in gt_rec.items():
if k in self.data_fields:
mot_rec[k] = v
records.append(mot_rec)
if self.sample_num > 0 and img_index >= self.sample_num:
break
assert len(records) > 0, 'not found any mot record in %s' % (
self.image_lists)
self.roidbs, self.cname2cid = records, cname2cid
@register
@serializable
class MOTImageFolder(DetDataset):
"""
Load MOT dataset with MOT format from image folder or video .
Args:
video_file (str): path of the video file, default ''.
frame_rate (int): frame rate of the video, use cv2 VideoCapture if not set.
dataset_dir (str): root directory for dataset.
keep_ori_im (bool): whether to keep original image, default False.
Set True when used during MOT model inference while saving
images or video, or used in DeepSORT.
"""
def __init__(self,
video_file=None,
frame_rate=-1,
dataset_dir=None,
data_root=None,
image_dir=None,
sample_num=-1,
keep_ori_im=False,
anno_path=None,
**kwargs):
super(MOTImageFolder, self).__init__(
dataset_dir, image_dir, sample_num=sample_num)
self.video_file = video_file
self.data_root = data_root
self.keep_ori_im = keep_ori_im
self._imid2path = {}
self.roidbs = None
self.frame_rate = frame_rate
self.anno_path = anno_path
def check_or_download_dataset(self):
return
def parse_dataset(self, ):
if not self.roidbs:
if self.video_file is None:
self.frame_rate = 30 # set as default if infer image folder
self.roidbs = self._load_images()
else:
self.roidbs = self._load_video_images()
def _load_video_images(self):
if self.frame_rate == -1:
# if frame_rate is not set for video, use cv2.VideoCapture
cap = cv2.VideoCapture(self.video_file)
self.frame_rate = int(cap.get(cv2.CAP_PROP_FPS))
extension = self.video_file.split('.')[-1]
output_path = self.video_file.replace('.{}'.format(extension), '')
frames_path = video2frames(self.video_file, output_path,
self.frame_rate)
self.video_frames = sorted(
glob.glob(os.path.join(frames_path, '*.png')))
self.video_length = len(self.video_frames)
logger.info('Length of the video: {:d} frames.'.format(
self.video_length))
ct = 0
records = []
for image in self.video_frames:
assert image != '' and os.path.isfile(image), \
"Image {} not found".format(image)
if self.sample_num > 0 and ct >= self.sample_num:
break
rec = {'im_id': np.array([ct]), 'im_file': image}
if self.keep_ori_im:
rec.update({'keep_ori_im': 1})
self._imid2path[ct] = image
ct += 1
records.append(rec)
assert len(records) > 0, "No image file found"
return records
def _find_images(self):
image_dir = self.image_dir
if not isinstance(image_dir, Sequence):
image_dir = [image_dir]
images = []
for im_dir in image_dir:
if os.path.isdir(im_dir):
im_dir = os.path.join(self.dataset_dir, im_dir)
images.extend(_make_dataset(im_dir))
elif os.path.isfile(im_dir) and _is_valid_file(im_dir):
images.append(im_dir)
return images
def _load_images(self):
images = self._find_images()
ct = 0
records = []
for image in images:
assert image != '' and os.path.isfile(image), \
"Image {} not found".format(image)
if self.sample_num > 0 and ct >= self.sample_num:
break
rec = {'im_id': np.array([ct]), 'im_file': image}
if self.keep_ori_im:
rec.update({'keep_ori_im': 1})
self._imid2path[ct] = image
ct += 1
records.append(rec)
assert len(records) > 0, "No image file found"
return records
def get_imid2path(self):
return self._imid2path
def set_images(self, images):
self.image_dir = images
self.roidbs = self._load_images()
def set_video(self, video_file, frame_rate):
# update video_file and frame_rate by command line of tools/infer_mot.py
self.video_file = video_file
self.frame_rate = frame_rate
assert os.path.isfile(self.video_file) and _is_valid_video(self.video_file), \
"wrong or unsupported file format: {}".format(self.video_file)
self.roidbs = self._load_video_images()
def get_anno(self):
return self.anno_path
def _is_valid_video(f, extensions=('.mp4', '.avi', '.mov', '.rmvb', 'flv')):
return f.lower().endswith(extensions)
def video2frames(video_path, outpath, frame_rate, **kargs):
def _dict2str(kargs):
cmd_str = ''
for k, v in kargs.items():
cmd_str += (' ' + str(k) + ' ' + str(v))
return cmd_str
ffmpeg = ['ffmpeg ', ' -y -loglevel ', ' error ']
vid_name = os.path.basename(video_path).split('.')[0]
out_full_path = os.path.join(outpath, vid_name)
if not os.path.exists(out_full_path):
os.makedirs(out_full_path)
# video file name
outformat = os.path.join(out_full_path, '%08d.png')
cmd = ffmpeg
cmd = ffmpeg + [
' -i ', video_path, ' -r ', str(frame_rate), ' -f image2 ', outformat
]
cmd = ''.join(cmd) + _dict2str(kargs)
if os.system(cmd) != 0:
raise RuntimeError('ffmpeg process video: {} error'.format(video_path))
sys.exit(-1)
sys.stdout.flush()
return out_full_path
def mot_label():
labels_map = {'person': 0}
return labels_map
def visdrone_mcmot_label():
labels_map = {
'pedestrian': 0,
'people': 1,
'bicycle': 2,
'car': 3,
'van': 4,
'truck': 5,
'tricycle': 6,
'awning-tricycle': 7,
'bus': 8,
'motor': 9,
}
return labels_map

View File

@@ -0,0 +1,380 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import cv2
import numpy as np
import json
import copy
import pycocotools
from pycocotools.coco import COCO
from .dataset import DetDataset
from ppdet.core.workspace import register, serializable
from paddle.io import Dataset
@serializable
class Pose3DDataset(DetDataset):
"""Pose3D Dataset class.
Args:
dataset_dir (str): Root path to the dataset.
anno_list (list of str): each of the element is a relative path to the annotation file.
image_dirs (list of str): each of path is a relative path where images are held.
transform (composed(operators)): A sequence of data transforms.
test_mode (bool): Store True when building test or
validation dataset. Default: False.
24 joints order:
0-2: 'R_Ankle', 'R_Knee', 'R_Hip',
3-5:'L_Hip', 'L_Knee', 'L_Ankle',
6-8:'R_Wrist', 'R_Elbow', 'R_Shoulder',
9-11:'L_Shoulder','L_Elbow','L_Wrist',
12-14:'Neck','Top_of_Head','Pelvis',
15-18:'Thorax','Spine','Jaw','Head',
19-23:'Nose','L_Eye','R_Eye','L_Ear','R_Ear'
"""
def __init__(self,
dataset_dir,
image_dirs,
anno_list,
transform=[],
num_joints=24,
test_mode=False):
super().__init__(dataset_dir, image_dirs, anno_list)
self.image_info = {}
self.ann_info = {}
self.num_joints = num_joints
self.transform = transform
self.test_mode = test_mode
self.img_ids = []
self.dataset_dir = dataset_dir
self.image_dirs = image_dirs
self.anno_list = anno_list
def get_mask(self, mvm_percent=0.3):
num_joints = self.num_joints
mjm_mask = np.ones((num_joints, 1)).astype(np.float32)
if self.test_mode == False:
pb = np.random.random_sample()
masked_num = int(
pb * mvm_percent *
num_joints) # at most x% of the joints could be masked
indices = np.random.choice(
np.arange(num_joints), replace=False, size=masked_num)
mjm_mask[indices, :] = 0.0
# return mjm_mask
num_joints = 10
mvm_mask = np.ones((num_joints, 1)).astype(np.float)
if self.test_mode == False:
num_vertices = num_joints
pb = np.random.random_sample()
masked_num = int(
pb * mvm_percent *
num_vertices) # at most x% of the vertices could be masked
indices = np.random.choice(
np.arange(num_vertices), replace=False, size=masked_num)
mvm_mask[indices, :] = 0.0
mjm_mask = np.concatenate([mjm_mask, mvm_mask], axis=0)
return mjm_mask
def filterjoints(self, x):
if self.num_joints == 24:
return x
elif self.num_joints == 14:
return x[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 18], :]
elif self.num_joints == 17:
return x[
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 18, 19], :]
else:
raise ValueError(
"unsupported joint numbers, only [24 or 17 or 14] is supported!")
def parse_dataset(self):
print("Loading annotations..., please wait")
self.annos = []
im_id = 0
self.human36m_num = 0
for idx, annof in enumerate(self.anno_list):
img_prefix = os.path.join(self.dataset_dir, self.image_dirs[idx])
dataf = os.path.join(self.dataset_dir, annof)
with open(dataf, 'r') as rf:
anno_data = json.load(rf)
annos = anno_data['data']
new_annos = []
print("{} has annos numbers: {}".format(dataf, len(annos)))
for anno in annos:
new_anno = {}
new_anno['im_id'] = im_id
im_id += 1
imagename = anno['imageName']
if imagename.startswith("COCO_train2014_"):
imagename = imagename[len("COCO_train2014_"):]
elif imagename.startswith("COCO_val2014_"):
imagename = imagename[len("COCO_val2014_"):]
imagename = os.path.join(img_prefix, imagename)
if not os.path.exists(imagename):
if "train2017" in imagename:
imagename = imagename.replace("train2017",
"val2017")
if not os.path.exists(imagename):
print("cannot find imagepath:{}".format(
imagename))
continue
else:
print("cannot find imagepath:{}".format(imagename))
continue
new_anno['imageName'] = imagename
if 'human3.6m' in imagename:
self.human36m_num += 1
new_anno['bbox_center'] = anno['bbox_center']
new_anno['bbox_scale'] = anno['bbox_scale']
new_anno['joints_2d'] = np.array(anno[
'gt_keypoint_2d']).astype(np.float32)
if new_anno['joints_2d'].shape[0] == 49:
#if the joints_2d is in SPIN format(which generated by eft), choose the last 24 public joints
#for detail please refer: https://github.com/nkolot/SPIN/blob/master/constants.py
new_anno['joints_2d'] = new_anno['joints_2d'][25:]
new_anno['joints_3d'] = np.array(anno[
'pose3d'])[:, :3].astype(np.float32)
new_anno['mjm_mask'] = self.get_mask()
if not 'has_3d_joints' in anno:
new_anno['has_3d_joints'] = int(1)
new_anno['has_2d_joints'] = int(1)
else:
new_anno['has_3d_joints'] = int(anno['has_3d_joints'])
new_anno['has_2d_joints'] = int(anno['has_2d_joints'])
new_anno['joints_2d'] = self.filterjoints(new_anno[
'joints_2d'])
self.annos.append(new_anno)
del annos
def get_temp_num(self):
"""get temporal data number, like human3.6m"""
return self.human36m_num
def __len__(self):
"""Get dataset length."""
return len(self.annos)
def _get_imganno(self, idx):
"""Get anno for a single image."""
return self.annos[idx]
def __getitem__(self, idx):
"""Prepare image for training given the index."""
records = copy.deepcopy(self._get_imganno(idx))
imgpath = records['imageName']
assert os.path.exists(imgpath), "cannot find image {}".format(imgpath)
records['image'] = cv2.imread(imgpath)
records['image'] = cv2.cvtColor(records['image'], cv2.COLOR_BGR2RGB)
records = self.transform(records)
return records
def check_or_download_dataset(self):
alldatafind = True
for image_dir in self.image_dirs:
image_dir = os.path.join(self.dataset_dir, image_dir)
if not os.path.isdir(image_dir):
print("dataset [{}] is not found".format(image_dir))
alldatafind = False
if not alldatafind:
raise ValueError(
"Some dataset is not valid and cannot download automatically now, please prepare the dataset first"
)
@register
@serializable
class Keypoint3DMultiFramesDataset(Dataset):
"""24 keypoints 3D dataset for pose estimation.
each item is a list of images
The dataset loads raw features and apply specified transforms
to return a dict containing the image tensors and other information.
Args:
dataset_dir (str): Root path to the dataset.
image_dir (str): Path to a directory where images are held.
"""
def __init__(
self,
dataset_dir, # 数据集根目录
image_dir, # 图像文件夹
p3d_dir, # 3D关键点文件夹
json_path,
img_size, #图像resize大小
num_frames, # 帧序列长度
anno_path=None, ):
self.dataset_dir = dataset_dir
self.image_dir = image_dir
self.p3d_dir = p3d_dir
self.json_path = json_path
self.img_size = img_size
self.num_frames = num_frames
self.anno_path = anno_path
self.data_labels, self.mf_inds = self._generate_multi_frames_list()
def _generate_multi_frames_list(self):
act_list = os.listdir(self.dataset_dir) # 动作列表
count = 0
mf_list = []
annos_dict = {'images': [], 'annotations': [], 'act_inds': []}
for act in act_list: #对每个动作,生成帧序列
if '.' in act:
continue
json_path = os.path.join(self.dataset_dir, act, self.json_path)
with open(json_path, 'r') as j:
annos = json.load(j)
length = len(annos['images'])
for k, v in annos.items():
if k in annos_dict:
annos_dict[k].extend(v)
annos_dict['act_inds'].extend([act] * length)
mf = [[i + j + count for j in range(self.num_frames)]
for i in range(0, length - self.num_frames + 1)]
mf_list.extend(mf)
count += length
print("total data number:", len(mf_list))
return annos_dict, mf_list
def __call__(self, *args, **kwargs):
return self
def __getitem__(self, index): # 拿一个连续的序列
inds = self.mf_inds[
index] # 如[568, 569, 570, 571, 572, 573]长度为num_frames
images = self.data_labels['images'] # all images
annots = self.data_labels['annotations'] # all annots
act = self.data_labels['act_inds'][inds[0]] # 动作名(文件夹名)
kps3d_list = []
kps3d_vis_list = []
names = []
h, w = 0, 0
for ind in inds: # one image
height = float(images[ind]['height'])
width = float(images[ind]['width'])
name = images[ind]['file_name'] # 图像名称,带有后缀
kps3d_name = name.split('.')[0] + '.obj'
kps3d_path = os.path.join(self.dataset_dir, act, self.p3d_dir,
kps3d_name)
joints, joints_vis = self.kps3d_process(kps3d_path)
joints_vis = np.array(joints_vis, dtype=np.float32)
kps3d_list.append(joints)
kps3d_vis_list.append(joints_vis)
names.append(name)
kps3d = np.array(kps3d_list) # (6, 24, 3),(num_frames, joints_num, 3)
kps3d_vis = np.array(kps3d_vis_list)
# read image
imgs = []
for name in names:
img_path = os.path.join(self.dataset_dir, act, self.image_dir, name)
image = cv2.imread(img_path, cv2.IMREAD_COLOR |
cv2.IMREAD_IGNORE_ORIENTATION)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
imgs.append(np.expand_dims(image, axis=0))
imgs = np.concatenate(imgs, axis=0)
imgs = imgs.astype(
np.float32) # (6, 1080, 1920, 3),(num_frames, h, w, c)
# attention: 此时图像和标注是镜像的
records = {
'kps3d': kps3d,
'kps3d_vis': kps3d_vis,
"image": imgs,
'act': act,
'names': names,
'im_id': index
}
return self.transform(records)
def kps3d_process(self, kps3d_path):
count = 0
kps = []
kps_vis = []
with open(kps3d_path, 'r') as f:
lines = f.readlines()
for line in lines:
if line[0] == 'v':
kps.append([])
line = line.strip('\n').split(' ')[1:]
for kp in line:
kps[-1].append(float(kp))
count += 1
kps_vis.append([1, 1, 1])
kps = np.array(kps) # 523
kps_vis = np.array(kps_vis)
kps *= 10 # scale points
kps -= kps[[0], :] # set root point to zero
kps = np.concatenate((kps[0:23], kps[[37]]), axis=0) # 24,3
kps *= 10
kps_vis = np.concatenate((kps_vis[0:23], kps_vis[[37]]), axis=0) # 24,3
return kps, kps_vis
def __len__(self):
return len(self.mf_inds)
def get_anno(self):
if self.anno_path is None:
return
return os.path.join(self.dataset_dir, self.anno_path)
def check_or_download_dataset(self):
return
def parse_dataset(self, ):
return
def set_transform(self, transform):
self.transform = transform
def set_epoch(self, epoch_id):
self._epoch = epoch_id
def set_kwargs(self, **kwargs):
self.mixup_epoch = kwargs.get('mixup_epoch', -1)
self.cutmix_epoch = kwargs.get('cutmix_epoch', -1)
self.mosaic_epoch = kwargs.get('mosaic_epoch', -1)

View File

@@ -0,0 +1,194 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import cv2
import json
import copy
import numpy as np
try:
from collections.abc import Sequence
except Exception:
from collections import Sequence
from ppdet.core.workspace import register, serializable
from ppdet.data.crop_utils.annotation_cropper import AnnoCropper
from .coco import COCODataSet
from .dataset import _make_dataset, _is_valid_file
from ppdet.utils.logger import setup_logger
logger = setup_logger('sniper_coco_dataset')
@register
@serializable
class SniperCOCODataSet(COCODataSet):
"""SniperCOCODataSet"""
def __init__(self,
dataset_dir=None,
image_dir=None,
anno_path=None,
proposals_file=None,
data_fields=['image'],
sample_num=-1,
load_crowd=False,
allow_empty=True,
empty_ratio=1.,
is_trainset=True,
image_target_sizes=[2000, 1000],
valid_box_ratio_ranges=[[-1, 0.1],[0.08, -1]],
chip_target_size=500,
chip_target_stride=200,
use_neg_chip=False,
max_neg_num_per_im=8,
max_per_img=-1,
nms_thresh=0.5):
super(SniperCOCODataSet, self).__init__(
dataset_dir=dataset_dir,
image_dir=image_dir,
anno_path=anno_path,
data_fields=data_fields,
sample_num=sample_num,
load_crowd=load_crowd,
allow_empty=allow_empty,
empty_ratio=empty_ratio
)
self.proposals_file = proposals_file
self.proposals = None
self.anno_cropper = None
self.is_trainset = is_trainset
self.image_target_sizes = image_target_sizes
self.valid_box_ratio_ranges = valid_box_ratio_ranges
self.chip_target_size = chip_target_size
self.chip_target_stride = chip_target_stride
self.use_neg_chip = use_neg_chip
self.max_neg_num_per_im = max_neg_num_per_im
self.max_per_img = max_per_img
self.nms_thresh = nms_thresh
def parse_dataset(self):
if not hasattr(self, "roidbs"):
super(SniperCOCODataSet, self).parse_dataset()
if self.is_trainset:
self._parse_proposals()
self._merge_anno_proposals()
self.ori_roidbs = copy.deepcopy(self.roidbs)
self.init_anno_cropper()
self.roidbs = self.generate_chips_roidbs(self.roidbs, self.is_trainset)
def set_proposals_file(self, file_path):
self.proposals_file = file_path
def init_anno_cropper(self):
logger.info("Init AnnoCropper...")
self.anno_cropper = AnnoCropper(
image_target_sizes=self.image_target_sizes,
valid_box_ratio_ranges=self.valid_box_ratio_ranges,
chip_target_size=self.chip_target_size,
chip_target_stride=self.chip_target_stride,
use_neg_chip=self.use_neg_chip,
max_neg_num_per_im=self.max_neg_num_per_im,
max_per_img=self.max_per_img,
nms_thresh=self.nms_thresh
)
def generate_chips_roidbs(self, roidbs, is_trainset):
if is_trainset:
roidbs = self.anno_cropper.crop_anno_records(roidbs)
else:
roidbs = self.anno_cropper.crop_infer_anno_records(roidbs)
return roidbs
def _parse_proposals(self):
if self.proposals_file:
self.proposals = {}
logger.info("Parse proposals file:{}".format(self.proposals_file))
with open(self.proposals_file, 'r') as f:
proposals = json.load(f)
for prop in proposals:
image_id = prop["image_id"]
if image_id not in self.proposals:
self.proposals[image_id] = []
x, y, w, h = prop["bbox"]
self.proposals[image_id].append([x, y, x + w, y + h])
def _merge_anno_proposals(self):
assert self.roidbs
if self.proposals and len(self.proposals.keys()) > 0:
logger.info("merge proposals to annos")
for id, record in enumerate(self.roidbs):
image_id = int(record["im_id"])
if image_id not in self.proposals.keys():
logger.info("image id :{} no proposals".format(image_id))
record["proposals"] = np.array(self.proposals.get(image_id, []), dtype=np.float32)
self.roidbs[id] = record
def get_ori_roidbs(self):
if not hasattr(self, "ori_roidbs"):
return None
return self.ori_roidbs
def get_roidbs(self):
if not hasattr(self, "roidbs"):
self.parse_dataset()
return self.roidbs
def set_roidbs(self, roidbs):
self.roidbs = roidbs
def check_or_download_dataset(self):
return
def _parse(self):
image_dir = self.image_dir
if not isinstance(image_dir, Sequence):
image_dir = [image_dir]
images = []
for im_dir in image_dir:
if os.path.isdir(im_dir):
im_dir = os.path.join(self.dataset_dir, im_dir)
images.extend(_make_dataset(im_dir))
elif os.path.isfile(im_dir) and _is_valid_file(im_dir):
images.append(im_dir)
return images
def _load_images(self):
images = self._parse()
ct = 0
records = []
for image in images:
assert image != '' and os.path.isfile(image), \
"Image {} not found".format(image)
if self.sample_num > 0 and ct >= self.sample_num:
break
im = cv2.imread(image)
h, w, c = im.shape
rec = {'im_id': np.array([ct]), 'im_file': image, "h": h, "w": w}
self._imid2path[ct] = image
ct += 1
records.append(rec)
assert len(records) > 0, "No image file found"
return records
def get_imid2path(self):
return self._imid2path
def set_images(self, images):
self._imid2path = {}
self.image_dir = images
self.roidbs = self._load_images()

View File

@@ -0,0 +1,234 @@
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import numpy as np
import xml.etree.ElementTree as ET
from ppdet.core.workspace import register, serializable
from .dataset import DetDataset
from ppdet.utils.logger import setup_logger
logger = setup_logger(__name__)
@register
@serializable
class VOCDataSet(DetDataset):
"""
Load dataset with PascalVOC format.
Notes:
`anno_path` must contains xml file and image file path for annotations.
Args:
dataset_dir (str): root directory for dataset.
image_dir (str): directory for images.
anno_path (str): voc annotation file path.
data_fields (list): key name of data dictionary, at least have 'image'.
sample_num (int): number of samples to load, -1 means all.
label_list (str): if use_default_label is False, will load
mapping between category and class index.
allow_empty (bool): whether to load empty entry. False as default
empty_ratio (float): the ratio of empty record number to total
record's, if empty_ratio is out of [0. ,1.), do not sample the
records and use all the empty entries. 1. as default
repeat (int): repeat times for dataset, use in benchmark.
"""
def __init__(self,
dataset_dir=None,
image_dir=None,
anno_path=None,
data_fields=['image'],
sample_num=-1,
label_list=None,
allow_empty=False,
empty_ratio=1.,
repeat=1):
super(VOCDataSet, self).__init__(
dataset_dir=dataset_dir,
image_dir=image_dir,
anno_path=anno_path,
data_fields=data_fields,
sample_num=sample_num,
repeat=repeat)
self.label_list = label_list
self.allow_empty = allow_empty
self.empty_ratio = empty_ratio
def _sample_empty(self, records, num):
# if empty_ratio is out of [0. ,1.), do not sample the records
if self.empty_ratio < 0. or self.empty_ratio >= 1.:
return records
import random
sample_num = min(
int(num * self.empty_ratio / (1 - self.empty_ratio)), len(records))
records = random.sample(records, sample_num)
return records
def parse_dataset(self, ):
anno_path = os.path.join(self.dataset_dir, self.anno_path)
image_dir = os.path.join(self.dataset_dir, self.image_dir)
# mapping category name to class id
# first_class:0, second_class:1, ...
records = []
empty_records = []
ct = 0
cname2cid = {}
if self.label_list:
label_path = os.path.join(self.dataset_dir, self.label_list)
if not os.path.exists(label_path):
raise ValueError("label_list {} does not exists".format(
label_path))
with open(label_path, 'r') as fr:
label_id = 0
for line in fr.readlines():
cname2cid[line.strip()] = label_id
label_id += 1
else:
cname2cid = pascalvoc_label()
with open(anno_path, 'r') as fr:
while True:
line = fr.readline()
if not line:
break
img_file, xml_file = [os.path.join(image_dir, x) \
for x in line.strip().split()[:2]]
if not os.path.exists(img_file):
logger.warning(
'Illegal image file: {}, and it will be ignored'.format(
img_file))
continue
if not os.path.isfile(xml_file):
logger.warning(
'Illegal xml file: {}, and it will be ignored'.format(
xml_file))
continue
tree = ET.parse(xml_file)
if tree.find('id') is None:
im_id = np.array([ct])
else:
im_id = np.array([int(tree.find('id').text)])
objs = tree.findall('object')
im_w = float(tree.find('size').find('width').text)
im_h = float(tree.find('size').find('height').text)
if im_w < 0 or im_h < 0:
logger.warning(
'Illegal width: {} or height: {} in annotation, '
'and {} will be ignored'.format(im_w, im_h, xml_file))
continue
num_bbox, i = len(objs), 0
gt_bbox = np.zeros((num_bbox, 4), dtype=np.float32)
gt_class = np.zeros((num_bbox, 1), dtype=np.int32)
gt_score = np.zeros((num_bbox, 1), dtype=np.float32)
difficult = np.zeros((num_bbox, 1), dtype=np.int32)
for obj in objs:
cname = obj.find('name').text
# user dataset may not contain difficult field
_difficult = obj.find('difficult')
_difficult = int(
_difficult.text) if _difficult is not None else 0
x1 = float(obj.find('bndbox').find('xmin').text)
y1 = float(obj.find('bndbox').find('ymin').text)
x2 = float(obj.find('bndbox').find('xmax').text)
y2 = float(obj.find('bndbox').find('ymax').text)
x1 = max(0, x1)
y1 = max(0, y1)
x2 = min(im_w - 1, x2)
y2 = min(im_h - 1, y2)
if x2 > x1 and y2 > y1:
gt_bbox[i, :] = [x1, y1, x2, y2]
gt_class[i, 0] = cname2cid[cname]
gt_score[i, 0] = 1.
difficult[i, 0] = _difficult
i += 1
else:
logger.warning(
'Found an invalid bbox in annotations: xml_file: {}'
', x1: {}, y1: {}, x2: {}, y2: {}.'.format(
xml_file, x1, y1, x2, y2))
gt_bbox = gt_bbox[:i, :]
gt_class = gt_class[:i, :]
gt_score = gt_score[:i, :]
difficult = difficult[:i, :]
voc_rec = {
'im_file': img_file,
'im_id': im_id,
'h': im_h,
'w': im_w
} if 'image' in self.data_fields else {}
gt_rec = {
'gt_class': gt_class,
'gt_score': gt_score,
'gt_bbox': gt_bbox,
'difficult': difficult
}
for k, v in gt_rec.items():
if k in self.data_fields:
voc_rec[k] = v
if len(objs) == 0:
empty_records.append(voc_rec)
else:
records.append(voc_rec)
ct += 1
if self.sample_num > 0 and ct >= self.sample_num:
break
assert ct > 0, 'not found any voc record in %s' % (self.anno_path)
logger.debug('{} samples in file {}'.format(ct, anno_path))
if self.allow_empty and len(empty_records) > 0:
empty_records = self._sample_empty(empty_records, len(records))
records += empty_records
self.roidbs, self.cname2cid = records, cname2cid
def get_label_list(self):
return os.path.join(self.dataset_dir, self.label_list)
def pascalvoc_label():
labels_map = {
'aeroplane': 0,
'bicycle': 1,
'bird': 2,
'boat': 3,
'bottle': 4,
'bus': 5,
'car': 6,
'cat': 7,
'chair': 8,
'cow': 9,
'diningtable': 10,
'dog': 11,
'horse': 12,
'motorbike': 13,
'person': 14,
'pottedplant': 15,
'sheep': 16,
'sofa': 17,
'train': 18,
'tvmonitor': 19
}
return labels_map

View File

@@ -0,0 +1,180 @@
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import numpy as np
from ppdet.core.workspace import register, serializable
from .dataset import DetDataset
from ppdet.utils.logger import setup_logger
logger = setup_logger(__name__)
@register
@serializable
class WIDERFaceDataSet(DetDataset):
"""
Load WiderFace records with 'anno_path'
Args:
dataset_dir (str): root directory for dataset.
image_dir (str): directory for images.
anno_path (str): WiderFace annotation data.
data_fields (list): key name of data dictionary, at least have 'image'.
sample_num (int): number of samples to load, -1 means all.
with_lmk (bool): whether to load face landmark keypoint labels.
"""
def __init__(self,
dataset_dir=None,
image_dir=None,
anno_path=None,
data_fields=['image'],
sample_num=-1,
with_lmk=False):
super(WIDERFaceDataSet, self).__init__(
dataset_dir=dataset_dir,
image_dir=image_dir,
anno_path=anno_path,
data_fields=data_fields,
sample_num=sample_num,
with_lmk=with_lmk)
self.anno_path = anno_path
self.sample_num = sample_num
self.roidbs = None
self.cname2cid = None
self.with_lmk = with_lmk
def parse_dataset(self):
anno_path = os.path.join(self.dataset_dir, self.anno_path)
image_dir = os.path.join(self.dataset_dir, self.image_dir)
txt_file = anno_path
records = []
ct = 0
file_lists = self._load_file_list(txt_file)
cname2cid = widerface_label()
for item in file_lists:
im_fname = item[0]
im_id = np.array([ct])
gt_bbox = np.zeros((len(item) - 1, 4), dtype=np.float32)
gt_class = np.zeros((len(item) - 1, 1), dtype=np.int32)
gt_lmk_labels = np.zeros((len(item) - 1, 10), dtype=np.float32)
lmk_ignore_flag = np.zeros((len(item) - 1, 1), dtype=np.int32)
for index_box in range(len(item)):
if index_box < 1:
continue
gt_bbox[index_box - 1] = item[index_box][0]
if self.with_lmk:
gt_lmk_labels[index_box - 1] = item[index_box][1]
lmk_ignore_flag[index_box - 1] = item[index_box][2]
im_fname = os.path.join(image_dir,
im_fname) if image_dir else im_fname
widerface_rec = {
'im_file': im_fname,
'im_id': im_id,
} if 'image' in self.data_fields else {}
gt_rec = {
'gt_bbox': gt_bbox,
'gt_class': gt_class,
}
for k, v in gt_rec.items():
if k in self.data_fields:
widerface_rec[k] = v
if self.with_lmk:
widerface_rec['gt_keypoint'] = gt_lmk_labels
widerface_rec['keypoint_ignore'] = lmk_ignore_flag
if len(item) != 0:
records.append(widerface_rec)
ct += 1
if self.sample_num > 0 and ct >= self.sample_num:
break
assert len(records) > 0, 'not found any widerface in %s' % (anno_path)
logger.debug('{} samples in file {}'.format(ct, anno_path))
self.roidbs, self.cname2cid = records, cname2cid
def _load_file_list(self, input_txt):
with open(input_txt, 'r') as f_dir:
lines_input_txt = f_dir.readlines()
file_dict = {}
num_class = 0
exts = ['jpg', 'jpeg', 'png', 'bmp']
exts += [ext.upper() for ext in exts]
for i in range(len(lines_input_txt)):
line_txt = lines_input_txt[i].strip('\n\t\r')
split_str = line_txt.split(' ')
if len(split_str) == 1:
img_file_name = os.path.split(split_str[0])[1]
split_txt = img_file_name.split('.')
if len(split_txt) < 2:
continue
elif split_txt[-1] in exts:
if i != 0:
num_class += 1
file_dict[num_class] = [line_txt]
else:
if len(line_txt) <= 6:
continue
result_boxs = []
xmin = float(split_str[0])
ymin = float(split_str[1])
w = float(split_str[2])
h = float(split_str[3])
# Filter out wrong labels
if w < 0 or h < 0:
logger.warning('Illegal box with w: {}, h: {} in '
'img: {}, and it will be ignored'.format(
w, h, file_dict[num_class][0]))
continue
xmin = max(0, xmin)
ymin = max(0, ymin)
xmax = xmin + w
ymax = ymin + h
gt_bbox = [xmin, ymin, xmax, ymax]
result_boxs.append(gt_bbox)
if self.with_lmk:
assert len(split_str) > 18, 'When `with_lmk=True`, the number' \
'of characters per line in the annotation file should' \
'exceed 18.'
lmk0_x = float(split_str[5])
lmk0_y = float(split_str[6])
lmk1_x = float(split_str[8])
lmk1_y = float(split_str[9])
lmk2_x = float(split_str[11])
lmk2_y = float(split_str[12])
lmk3_x = float(split_str[14])
lmk3_y = float(split_str[15])
lmk4_x = float(split_str[17])
lmk4_y = float(split_str[18])
lmk_ignore_flag = 0 if lmk0_x == -1 else 1
gt_lmk_label = [
lmk0_x, lmk0_y, lmk1_x, lmk1_y, lmk2_x, lmk2_y, lmk3_x,
lmk3_y, lmk4_x, lmk4_y
]
result_boxs.append(gt_lmk_label)
result_boxs.append(lmk_ignore_flag)
file_dict[num_class].append(result_boxs)
return list(file_dict.values())
def widerface_label():
labels_map = {'face': 0}
return labels_map