更换文档检测模型
This commit is contained in:
206
paddle_detection/ppdet/data/source/culane.py
Normal file
206
paddle_detection/ppdet/data/source/culane.py
Normal file
@@ -0,0 +1,206 @@
|
||||
from ppdet.core.workspace import register, serializable
|
||||
import cv2
|
||||
import os
|
||||
import tarfile
|
||||
import numpy as np
|
||||
import os.path as osp
|
||||
from ppdet.data.source.dataset import DetDataset
|
||||
from imgaug.augmentables.lines import LineStringsOnImage
|
||||
from imgaug.augmentables.segmaps import SegmentationMapsOnImage
|
||||
from ppdet.data.culane_utils import lane_to_linestrings
|
||||
import pickle as pkl
|
||||
from ppdet.utils.logger import setup_logger
|
||||
try:
|
||||
from collections.abc import Sequence
|
||||
except Exception:
|
||||
from collections import Sequence
|
||||
from .dataset import DetDataset, _make_dataset, _is_valid_file
|
||||
from ppdet.utils.download import download_dataset
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
|
||||
@register
|
||||
@serializable
|
||||
class CULaneDataSet(DetDataset):
|
||||
def __init__(
|
||||
self,
|
||||
dataset_dir,
|
||||
cut_height,
|
||||
list_path,
|
||||
split='train',
|
||||
data_fields=['image'],
|
||||
video_file=None,
|
||||
frame_rate=-1, ):
|
||||
super(CULaneDataSet, self).__init__(
|
||||
dataset_dir=dataset_dir,
|
||||
cut_height=cut_height,
|
||||
split=split,
|
||||
data_fields=data_fields)
|
||||
self.dataset_dir = dataset_dir
|
||||
self.list_path = osp.join(dataset_dir, list_path)
|
||||
self.cut_height = cut_height
|
||||
self.data_fields = data_fields
|
||||
self.split = split
|
||||
self.training = 'train' in split
|
||||
self.data_infos = []
|
||||
self.video_file = video_file
|
||||
self.frame_rate = frame_rate
|
||||
self._imid2path = {}
|
||||
self.predict_dir = None
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data_infos)
|
||||
|
||||
def check_or_download_dataset(self):
|
||||
if not osp.exists(self.dataset_dir):
|
||||
download_dataset("dataset", dataset="culane")
|
||||
# extract .tar files in self.dataset_dir
|
||||
for fname in os.listdir(self.dataset_dir):
|
||||
logger.info("Decompressing {}...".format(fname))
|
||||
# ignore .* files
|
||||
if fname.startswith('.'):
|
||||
continue
|
||||
if fname.find('.tar.gz') >= 0:
|
||||
with tarfile.open(osp.join(self.dataset_dir, fname)) as tf:
|
||||
tf.extractall(path=self.dataset_dir)
|
||||
logger.info("Dataset files are ready.")
|
||||
|
||||
def parse_dataset(self):
|
||||
logger.info('Loading CULane annotations...')
|
||||
if self.predict_dir is not None:
|
||||
logger.info('switch to predict mode')
|
||||
return
|
||||
# Waiting for the dataset to load is tedious, let's cache it
|
||||
os.makedirs('cache', exist_ok=True)
|
||||
cache_path = 'cache/culane_paddle_{}.pkl'.format(self.split)
|
||||
if os.path.exists(cache_path):
|
||||
with open(cache_path, 'rb') as cache_file:
|
||||
self.data_infos = pkl.load(cache_file)
|
||||
self.max_lanes = max(
|
||||
len(anno['lanes']) for anno in self.data_infos)
|
||||
return
|
||||
|
||||
with open(self.list_path) as list_file:
|
||||
for line in list_file:
|
||||
infos = self.load_annotation(line.split())
|
||||
self.data_infos.append(infos)
|
||||
|
||||
# cache data infos to file
|
||||
with open(cache_path, 'wb') as cache_file:
|
||||
pkl.dump(self.data_infos, cache_file)
|
||||
|
||||
def load_annotation(self, line):
|
||||
infos = {}
|
||||
img_line = line[0]
|
||||
img_line = img_line[1 if img_line[0] == '/' else 0::]
|
||||
img_path = os.path.join(self.dataset_dir, img_line)
|
||||
infos['img_name'] = img_line
|
||||
infos['img_path'] = img_path
|
||||
if len(line) > 1:
|
||||
mask_line = line[1]
|
||||
mask_line = mask_line[1 if mask_line[0] == '/' else 0::]
|
||||
mask_path = os.path.join(self.dataset_dir, mask_line)
|
||||
infos['mask_path'] = mask_path
|
||||
|
||||
if len(line) > 2:
|
||||
exist_list = [int(l) for l in line[2:]]
|
||||
infos['lane_exist'] = np.array(exist_list)
|
||||
|
||||
anno_path = img_path[:
|
||||
-3] + 'lines.txt' # remove sufix jpg and add lines.txt
|
||||
with open(anno_path, 'r') as anno_file:
|
||||
data = [
|
||||
list(map(float, line.split())) for line in anno_file.readlines()
|
||||
]
|
||||
lanes = [[(lane[i], lane[i + 1]) for i in range(0, len(lane), 2)
|
||||
if lane[i] >= 0 and lane[i + 1] >= 0] for lane in data]
|
||||
lanes = [list(set(lane)) for lane in lanes] # remove duplicated points
|
||||
lanes = [lane for lane in lanes
|
||||
if len(lane) > 2] # remove lanes with less than 2 points
|
||||
|
||||
lanes = [sorted(
|
||||
lane, key=lambda x: x[1]) for lane in lanes] # sort by y
|
||||
infos['lanes'] = lanes
|
||||
|
||||
return infos
|
||||
|
||||
def set_images(self, images):
|
||||
self.predict_dir = images
|
||||
self.data_infos = self._load_images()
|
||||
|
||||
def _find_images(self):
|
||||
predict_dir = self.predict_dir
|
||||
if not isinstance(predict_dir, Sequence):
|
||||
predict_dir = [predict_dir]
|
||||
images = []
|
||||
for im_dir in predict_dir:
|
||||
if os.path.isdir(im_dir):
|
||||
im_dir = os.path.join(self.predict_dir, im_dir)
|
||||
images.extend(_make_dataset(im_dir))
|
||||
elif os.path.isfile(im_dir) and _is_valid_file(im_dir):
|
||||
images.append(im_dir)
|
||||
return images
|
||||
|
||||
def _load_images(self):
|
||||
images = self._find_images()
|
||||
ct = 0
|
||||
records = []
|
||||
for image in images:
|
||||
assert image != '' and os.path.isfile(image), \
|
||||
"Image {} not found".format(image)
|
||||
if self.sample_num > 0 and ct >= self.sample_num:
|
||||
break
|
||||
rec = {
|
||||
'im_id': np.array([ct]),
|
||||
"img_path": os.path.abspath(image),
|
||||
"img_name": os.path.basename(image),
|
||||
"lanes": []
|
||||
}
|
||||
self._imid2path[ct] = image
|
||||
ct += 1
|
||||
records.append(rec)
|
||||
assert len(records) > 0, "No image file found"
|
||||
return records
|
||||
|
||||
def get_imid2path(self):
|
||||
return self._imid2path
|
||||
|
||||
def __getitem__(self, idx):
|
||||
data_info = self.data_infos[idx]
|
||||
img = cv2.imread(data_info['img_path'])
|
||||
img = img[self.cut_height:, :, :]
|
||||
sample = data_info.copy()
|
||||
sample.update({'image': img})
|
||||
img_org = sample['image']
|
||||
|
||||
if self.training:
|
||||
label = cv2.imread(sample['mask_path'], cv2.IMREAD_UNCHANGED)
|
||||
if len(label.shape) > 2:
|
||||
label = label[:, :, 0]
|
||||
label = label.squeeze()
|
||||
label = label[self.cut_height:, :]
|
||||
sample.update({'mask': label})
|
||||
if self.cut_height != 0:
|
||||
new_lanes = []
|
||||
for i in sample['lanes']:
|
||||
lanes = []
|
||||
for p in i:
|
||||
lanes.append((p[0], p[1] - self.cut_height))
|
||||
new_lanes.append(lanes)
|
||||
sample.update({'lanes': new_lanes})
|
||||
|
||||
sample['mask'] = SegmentationMapsOnImage(
|
||||
sample['mask'], shape=img_org.shape)
|
||||
|
||||
sample['full_img_path'] = data_info['img_path']
|
||||
sample['img_name'] = data_info['img_name']
|
||||
sample['im_id'] = np.array([idx])
|
||||
|
||||
sample['image'] = sample['image'].copy().astype(np.uint8)
|
||||
sample['lanes'] = lane_to_linestrings(sample['lanes'])
|
||||
sample['lanes'] = LineStringsOnImage(
|
||||
sample['lanes'], shape=img_org.shape)
|
||||
sample['seg'] = np.zeros(img_org.shape)
|
||||
|
||||
return sample
|
||||
Reference in New Issue
Block a user