更换文档检测模型
This commit is contained in:
1
paddle_detection/configs/mot/pedestrian/README.md
Symbolic link
1
paddle_detection/configs/mot/pedestrian/README.md
Symbolic link
@@ -0,0 +1 @@
|
||||
README_cn.md
|
||||
135
paddle_detection/configs/mot/pedestrian/README_cn.md
Normal file
135
paddle_detection/configs/mot/pedestrian/README_cn.md
Normal file
@@ -0,0 +1,135 @@
|
||||
[English](README.md) | 简体中文
|
||||
# 特色垂类跟踪模型
|
||||
|
||||
## 大规模行人跟踪 (Pedestrian Tracking)
|
||||
|
||||
行人跟踪的主要应用之一是交通监控。
|
||||
|
||||
[PathTrack](https://www.trace.ethz.ch/publications/2017/pathtrack/index.html)包含720个视频序列,有着超过15000个行人的轨迹。包含了街景、舞蹈、体育运动、采访等各种场景的,大部分是移动摄像头拍摄场景。该数据集只有Pedestrian一类标注作为跟踪任务。
|
||||
|
||||
[VisDrone](http://aiskyeye.com)是无人机视角拍摄的数据集,是以俯视视角为主。该数据集涵盖不同位置(取自中国数千个相距数千公里的14个不同城市)、不同环境(城市和乡村)、不同物体(行人、车辆、自行车等)和不同密度(稀疏和拥挤的场景)。[VisDrone2019-MOT](https://github.com/VisDrone/VisDrone-Dataset)包含56个视频序列用于训练,7个视频序列用于验证。此处针对VisDrone2019-MOT多目标跟踪数据集进行提取,抽取出类别为pedestrian和people的数据组合成一个大的Pedestrian类别。
|
||||
|
||||
|
||||
## 模型库
|
||||
|
||||
### FairMOT在各个数据集val-set上Pedestrian类别的结果
|
||||
|
||||
| 数据集 | 骨干网络 | 输入尺寸 | MOTA | IDF1 | FPS | 下载链接 | 配置文件 |
|
||||
| :-------------| :-------- | :------- | :----: | :----: | :----: | :-----: |:------: |
|
||||
| PathTrack | DLA-34 | 1088x608 | 44.9 | 59.3 | - |[下载链接](https://paddledet.bj.bcebos.com/models/mot/fairmot_dla34_30e_1088x608_pathtrack.pdparams) | [配置文件](./fairmot_dla34_30e_1088x608_pathtrack.yml) |
|
||||
| VisDrone | DLA-34 | 1088x608 | 49.2 | 63.1 | - | [下载链接](https://paddledet.bj.bcebos.com/models/mot/fairmot_dla34_30e_1088x608_visdrone_pedestrian.pdparams) | [配置文件](./fairmot_dla34_30e_1088x608_visdrone_pedestrian.yml) |
|
||||
| VisDrone | HRNetv2-W18| 1088x608 | 40.5 | 54.7 | - | [下载链接](https://paddledet.bj.bcebos.com/models/mot/fairmot_hrnetv2_w18_dlafpn_30e_1088x608_visdrone_pedestrian.pdparams) | [配置文件](./fairmot_hrnetv2_w18_dlafpn_30e_1088x608_visdrone_pedestrian.yml) |
|
||||
| VisDrone | HRNetv2-W18| 864x480 | 38.6 | 50.9 | - | [下载链接](https://paddledet.bj.bcebos.com/models/mot/fairmot_hrnetv2_w18_dlafpn_30e_864x480_visdrone_pedestrian.pdparams) | [配置文件](./fairmot_hrnetv2_w18_dlafpn_30e_864x480_visdrone_pedestrian.yml) |
|
||||
| VisDrone | HRNetv2-W18| 576x320 | 30.6 | 47.2 | - | [下载链接](https://paddledet.bj.bcebos.com/models/mot/fairmot_hrnetv2_w18_dlafpn_30e_576x320_visdrone_pedestrian.pdparams) | [配置文件](./fairmot_hrnetv2_w18_dlafpn_30e_576x320_visdrone_pedestrian.yml) |
|
||||
|
||||
**注意:**
|
||||
- FairMOT均使用DLA-34为骨干网络,4个GPU进行训练,每个GPU上batch size为6,训练30个epoch。
|
||||
|
||||
|
||||
## 数据集准备和处理
|
||||
|
||||
### 1、数据集处理代码说明
|
||||
代码统一都在tools目录下
|
||||
```
|
||||
# visdrone
|
||||
tools/visdrone/visdrone2mot.py: 生成visdrone_pedestrian据集
|
||||
```
|
||||
|
||||
### 2、visdrone_pedestrian数据集处理
|
||||
```
|
||||
# 复制tool/visdrone/visdrone2mot.py到数据集目录下
|
||||
# 生成visdrone_pedestrian MOT格式的数据,抽取类别classes=1,2 (pedestrian, people)
|
||||
<<--生成前目录-->>
|
||||
├── VisDrone2019-MOT-val
|
||||
│ ├── annotations
|
||||
│ ├── sequences
|
||||
│ ├── visdrone2mot.py
|
||||
<<--生成后目录-->>
|
||||
├── VisDrone2019-MOT-val
|
||||
│ ├── annotations
|
||||
│ ├── sequences
|
||||
│ ├── visdrone2mot.py
|
||||
│ ├── visdrone_pedestrian
|
||||
│ │ ├── images
|
||||
│ │ │ ├── train
|
||||
│ │ │ ├── val
|
||||
│ │ ├── labels_with_ids
|
||||
│ │ │ ├── train
|
||||
│ │ │ ├── val
|
||||
# 执行
|
||||
python visdrone2mot.py --transMot=True --data_name=visdrone_pedestrian --phase=val
|
||||
python visdrone2mot.py --transMot=True --data_name=visdrone_pedestrian --phase=train
|
||||
```
|
||||
|
||||
## 快速开始
|
||||
|
||||
### 1. 训练
|
||||
使用2个GPU通过如下命令一键式启动训练
|
||||
```bash
|
||||
python -m paddle.distributed.launch --log_dir=./fairmot_dla34_30e_1088x608_visdrone_pedestrian/ --gpus 0,1 tools/train.py -c configs/mot/pedestrian/fairmot_dla34_30e_1088x608_visdrone_pedestrian.yml
|
||||
```
|
||||
|
||||
### 2. 评估
|
||||
使用单张GPU通过如下命令一键式启动评估
|
||||
```bash
|
||||
# 使用PaddleDetection发布的权重
|
||||
CUDA_VISIBLE_DEVICES=0 python tools/eval_mot.py -c configs/mot/pedestrian/fairmot_dla34_30e_1088x608_visdrone_pedestrian.yml -o weights=https://paddledet.bj.bcebos.com/models/mot/fairmot_dla34_30e_1088x608_visdrone_pedestrian.pdparams
|
||||
|
||||
# 使用训练保存的checkpoint
|
||||
CUDA_VISIBLE_DEVICES=0 python tools/eval_mot.py -c configs/mot/pedestrian/fairmot_dla34_30e_1088x608_visdrone_pedestrian.yml -o weights=output/fairmot_dla34_30e_1088x608_visdrone_pedestrian/model_final.pdparams
|
||||
```
|
||||
|
||||
### 3. 预测
|
||||
使用单个GPU通过如下命令预测一个视频,并保存为视频
|
||||
```bash
|
||||
# 预测一个视频
|
||||
CUDA_VISIBLE_DEVICES=0 python tools/infer_mot.py -c configs/mot/pedestrian/fairmot_dla34_30e_1088x608_visdrone_pedestrian.yml -o weights=https://paddledet.bj.bcebos.com/models/mot/fairmot_dla34_30e_1088x608_visdrone_pedestrian.pdparams --video_file={your video name}.mp4 --save_videos
|
||||
```
|
||||
**注意:**
|
||||
- 请先确保已经安装了[ffmpeg](https://ffmpeg.org/ffmpeg.html), Linux(Ubuntu)平台可以直接用以下命令安装:`apt-get update && apt-get install -y ffmpeg`。
|
||||
|
||||
### 4. 导出预测模型
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 python tools/export_model.py -c configs/mot/pedestrian/fairmot_dla34_30e_1088x608_visdrone_pedestrian.yml -o weights=https://paddledet.bj.bcebos.com/models/mot/fairmot_dla34_30e_1088x608_visdrone_pedestrian.pdparams
|
||||
```
|
||||
|
||||
### 5. 用导出的模型基于Python去预测
|
||||
```bash
|
||||
python deploy/pptracking/python/mot_jde_infer.py --model_dir=output_inference/fairmot_dla34_30e_1088x608_visdrone_pedestrian --video_file={your video name}.mp4 --device=GPU --save_mot_txts
|
||||
```
|
||||
**注意:**
|
||||
- 跟踪模型是对视频进行预测,不支持单张图的预测,默认保存跟踪结果可视化后的视频,可添加`--save_mot_txts`表示保存跟踪结果的txt文件,或`--save_images`表示保存跟踪结果可视化图片。
|
||||
- 跟踪结果txt文件每行信息是`frame,id,x1,y1,w,h,score,-1,-1,-1`。
|
||||
|
||||
## 引用
|
||||
```
|
||||
@article{zhang2020fair,
|
||||
title={FairMOT: On the Fairness of Detection and Re-Identification in Multiple Object Tracking},
|
||||
author={Zhang, Yifu and Wang, Chunyu and Wang, Xinggang and Zeng, Wenjun and Liu, Wenyu},
|
||||
journal={arXiv preprint arXiv:2004.01888},
|
||||
year={2020}
|
||||
}
|
||||
|
||||
@INPROCEEDINGS{8237302,
|
||||
author={S. {Manen} and M. {Gygli} and D. {Dai} and L. V. {Gool}},
|
||||
booktitle={2017 IEEE International Conference on Computer Vision (ICCV)},
|
||||
title={PathTrack: Fast Trajectory Annotation with Path Supervision},
|
||||
year={2017},
|
||||
volume={},
|
||||
number={},
|
||||
pages={290-299},
|
||||
doi={10.1109/ICCV.2017.40},
|
||||
ISSN={2380-7504},
|
||||
month={Oct},}
|
||||
|
||||
@ARTICLE{9573394,
|
||||
author={Zhu, Pengfei and Wen, Longyin and Du, Dawei and Bian, Xiao and Fan, Heng and Hu, Qinghua and Ling, Haibin},
|
||||
journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
|
||||
title={Detection and Tracking Meet Drones Challenge},
|
||||
year={2021},
|
||||
volume={},
|
||||
number={},
|
||||
pages={1-1},
|
||||
doi={10.1109/TPAMI.2021.3119563}
|
||||
}
|
||||
```
|
||||
@@ -0,0 +1,26 @@
|
||||
_BASE_: [
|
||||
'../fairmot/fairmot_dla34_30e_1088x608.yml'
|
||||
]
|
||||
|
||||
weights: output/fairmot_dla34_30e_1088x608_pathtrack/model_final
|
||||
|
||||
# for MOT training
|
||||
TrainDataset:
|
||||
!MOTDataSet
|
||||
dataset_dir: dataset/mot
|
||||
image_lists: ['pathtrack.train']
|
||||
data_fields: ['image', 'gt_bbox', 'gt_class', 'gt_ide']
|
||||
|
||||
# for MOT evaluation
|
||||
# If you want to change the MOT evaluation dataset, please modify 'data_root'
|
||||
EvalMOTDataset:
|
||||
!MOTImageFolder
|
||||
dataset_dir: dataset/mot
|
||||
data_root: pathtrack/images/test
|
||||
keep_ori_im: False # set True if save visualization images or video, or used in DeepSORT
|
||||
|
||||
# for MOT video inference
|
||||
TestMOTDataset:
|
||||
!MOTImageFolder
|
||||
dataset_dir: dataset/mot
|
||||
keep_ori_im: True # set True if save visualization images or video
|
||||
@@ -0,0 +1,26 @@
|
||||
_BASE_: [
|
||||
'../fairmot/fairmot_dla34_30e_1088x608.yml'
|
||||
]
|
||||
|
||||
weights: output/fairmot_dla34_30e_1088x608_visdrone_pedestrian/model_final
|
||||
|
||||
# for MOT training
|
||||
TrainDataset:
|
||||
!MOTDataSet
|
||||
dataset_dir: dataset/mot
|
||||
image_lists: ['visdrone_pedestrian.train']
|
||||
data_fields: ['image', 'gt_bbox', 'gt_class', 'gt_ide']
|
||||
|
||||
# for MOT evaluation
|
||||
# If you want to change the MOT evaluation dataset, please modify 'data_root'
|
||||
EvalMOTDataset:
|
||||
!MOTImageFolder
|
||||
dataset_dir: dataset/mot
|
||||
data_root: visdrone_pedestrian/images/val
|
||||
keep_ori_im: False # set True if save visualization images or video, or used in DeepSORT
|
||||
|
||||
# for MOT video inference
|
||||
TestMOTDataset:
|
||||
!MOTImageFolder
|
||||
dataset_dir: dataset/mot
|
||||
keep_ori_im: True # set True if save visualization images or video
|
||||
@@ -0,0 +1,26 @@
|
||||
_BASE_: [
|
||||
'../fairmot/fairmot_hrnetv2_w18_dlafpn_30e_1088x608.yml'
|
||||
]
|
||||
|
||||
weights: output/fairmot_hrnetv2_w18_dlafpn_30e_1088x608_visdrone_pedestrian/model_final
|
||||
|
||||
# for MOT training
|
||||
TrainDataset:
|
||||
!MOTDataSet
|
||||
dataset_dir: dataset/mot
|
||||
image_lists: ['visdrone_pedestrian.train']
|
||||
data_fields: ['image', 'gt_bbox', 'gt_class', 'gt_ide']
|
||||
|
||||
# for MOT evaluation
|
||||
# If you want to change the MOT evaluation dataset, please modify 'data_root'
|
||||
EvalMOTDataset:
|
||||
!MOTImageFolder
|
||||
dataset_dir: dataset/mot
|
||||
data_root: visdrone_pedestrian/images/val
|
||||
keep_ori_im: False # set True if save visualization images or video, or used in DeepSORT
|
||||
|
||||
# for MOT video inference
|
||||
TestMOTDataset:
|
||||
!MOTImageFolder
|
||||
dataset_dir: dataset/mot
|
||||
keep_ori_im: True # set True if save visualization images or video
|
||||
@@ -0,0 +1,26 @@
|
||||
_BASE_: [
|
||||
'../fairmot/fairmot_hrnetv2_w18_dlafpn_30e_576x320.yml'
|
||||
]
|
||||
|
||||
weights: output/fairmot_hrnetv2_w18_dlafpn_30e_576x320_visdrone_pedestrian/model_final
|
||||
|
||||
# for MOT training
|
||||
TrainDataset:
|
||||
!MOTDataSet
|
||||
dataset_dir: dataset/mot
|
||||
image_lists: ['visdrone_pedestrian.train']
|
||||
data_fields: ['image', 'gt_bbox', 'gt_class', 'gt_ide']
|
||||
|
||||
# for MOT evaluation
|
||||
# If you want to change the MOT evaluation dataset, please modify 'data_root'
|
||||
EvalMOTDataset:
|
||||
!MOTImageFolder
|
||||
dataset_dir: dataset/mot
|
||||
data_root: visdrone_pedestrian/images/val
|
||||
keep_ori_im: False # set True if save visualization images or video, or used in DeepSORT
|
||||
|
||||
# for MOT video inference
|
||||
TestMOTDataset:
|
||||
!MOTImageFolder
|
||||
dataset_dir: dataset/mot
|
||||
keep_ori_im: True # set True if save visualization images or video
|
||||
@@ -0,0 +1,26 @@
|
||||
_BASE_: [
|
||||
'../fairmot/fairmot_hrnetv2_w18_dlafpn_30e_864x480.yml'
|
||||
]
|
||||
|
||||
weights: output/fairmot_hrnetv2_w18_dlafpn_30e_864x480_visdrone_pedestrian/model_final
|
||||
|
||||
# for MOT training
|
||||
TrainDataset:
|
||||
!MOTDataSet
|
||||
dataset_dir: dataset/mot
|
||||
image_lists: ['visdrone_pedestrian.train']
|
||||
data_fields: ['image', 'gt_bbox', 'gt_class', 'gt_ide']
|
||||
|
||||
# for MOT evaluation
|
||||
# If you want to change the MOT evaluation dataset, please modify 'data_root'
|
||||
EvalMOTDataset:
|
||||
!MOTImageFolder
|
||||
dataset_dir: dataset/mot
|
||||
data_root: visdrone_pedestrian/images/val
|
||||
keep_ori_im: False # set True if save visualization images or video, or used in DeepSORT
|
||||
|
||||
# for MOT video inference
|
||||
TestMOTDataset:
|
||||
!MOTImageFolder
|
||||
dataset_dir: dataset/mot
|
||||
keep_ori_im: True # set True if save visualization images or video
|
||||
@@ -0,0 +1,299 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import glob
|
||||
import os
|
||||
import os.path as osp
|
||||
import cv2
|
||||
import argparse
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
# The object category indicates the type of annotated object,
|
||||
# (i.e., ignored regions(0), pedestrian(1), people(2), bicycle(3), car(4), van(5), truck(6), tricycle(7), awning-tricycle(8), bus(9), motor(10),others(11))
|
||||
|
||||
# Extract single class or multi class
|
||||
isExtractMultiClass = False
|
||||
# These sequences are excluded because there are too few pedestrians
|
||||
exclude_seq = [
|
||||
"uav0000117_02622_v", "uav0000182_00000_v", "uav0000268_05773_v",
|
||||
"uav0000305_00000_v"
|
||||
]
|
||||
|
||||
|
||||
def mkdir_if_missing(d):
|
||||
if not osp.exists(d):
|
||||
os.makedirs(d)
|
||||
|
||||
|
||||
def genGtFile(seqPath, outPath, classes=[]):
|
||||
id_idx = 0
|
||||
old_idx = -1
|
||||
with open(seqPath, 'r') as singleSeqFile:
|
||||
motLine = []
|
||||
allLines = singleSeqFile.readlines()
|
||||
for line in allLines:
|
||||
line = line.replace('\n', '')
|
||||
line = line.split(',')
|
||||
# exclude occlusion!='2'
|
||||
if line[-1] != '2' and line[7] in classes:
|
||||
if old_idx != int(line[1]):
|
||||
id_idx += 1
|
||||
old_idx = int(line[1])
|
||||
newLine = line[0:6]
|
||||
newLine[1] = str(id_idx)
|
||||
newLine.append('1')
|
||||
if (len(classes) > 1 and isExtractMultiClass):
|
||||
class_index = str(classes.index(line[7]) + 1)
|
||||
newLine.append(class_index)
|
||||
else:
|
||||
newLine.append('1') # use permanent class '1'
|
||||
newLine.append('1')
|
||||
motLine.append(newLine)
|
||||
mkdir_if_missing(outPath)
|
||||
gtFilePath = osp.join(outPath, 'gt.txt')
|
||||
with open(gtFilePath, 'w') as gtFile:
|
||||
motLine = list(map(lambda x: str.join(',', x), motLine))
|
||||
motLineStr = str.join('\n', motLine)
|
||||
gtFile.write(motLineStr)
|
||||
|
||||
|
||||
def genSeqInfo(img1Path, seqName):
|
||||
imgPaths = glob.glob(img1Path + '/*.jpg')
|
||||
seqLength = len(imgPaths)
|
||||
if seqLength > 0:
|
||||
image1 = cv2.imread(imgPaths[0])
|
||||
imgHeight = image1.shape[0]
|
||||
imgWidth = image1.shape[1]
|
||||
else:
|
||||
imgHeight = 0
|
||||
imgWidth = 0
|
||||
seqInfoStr = f'''[Sequence]\nname={seqName}\nimDir=img1\nframeRate=30\nseqLength={seqLength}\nimWidth={imgWidth}\nimHeight={imgHeight}\nimExt=.jpg'''
|
||||
seqInfoPath = img1Path.replace('/img1', '')
|
||||
with open(seqInfoPath + '/seqinfo.ini', 'w') as seqFile:
|
||||
seqFile.write(seqInfoStr)
|
||||
|
||||
|
||||
def copyImg(img1Path, gtTxtPath, outputFileName):
|
||||
with open(gtTxtPath, 'r') as gtFile:
|
||||
allLines = gtFile.readlines()
|
||||
imgList = []
|
||||
for line in allLines:
|
||||
imgIdx = int(line.split(',')[0])
|
||||
if imgIdx not in imgList:
|
||||
imgList.append(imgIdx)
|
||||
seqName = gtTxtPath.replace('./{}/'.format(outputFileName),
|
||||
'').replace('/gt/gt.txt', '')
|
||||
sourceImgPath = osp.join('./sequences', seqName,
|
||||
'{:07d}.jpg'.format(imgIdx))
|
||||
os.system(f'cp {sourceImgPath} {img1Path}')
|
||||
|
||||
|
||||
def genMotLabels(datasetPath, outputFileName, classes=['2']):
|
||||
mkdir_if_missing(osp.join(datasetPath, outputFileName))
|
||||
annotationsPath = osp.join(datasetPath, 'annotations')
|
||||
annotationsList = glob.glob(osp.join(annotationsPath, '*.txt'))
|
||||
for annotationPath in annotationsList:
|
||||
seqName = annotationPath.split('/')[-1].replace('.txt', '')
|
||||
if seqName in exclude_seq:
|
||||
continue
|
||||
mkdir_if_missing(osp.join(datasetPath, outputFileName, seqName, 'gt'))
|
||||
mkdir_if_missing(osp.join(datasetPath, outputFileName, seqName, 'img1'))
|
||||
genGtFile(annotationPath,
|
||||
osp.join(datasetPath, outputFileName, seqName, 'gt'), classes)
|
||||
img1Path = osp.join(datasetPath, outputFileName, seqName, 'img1')
|
||||
gtTxtPath = osp.join(datasetPath, outputFileName, seqName, 'gt/gt.txt')
|
||||
copyImg(img1Path, gtTxtPath, outputFileName)
|
||||
genSeqInfo(img1Path, seqName)
|
||||
|
||||
|
||||
def deleteFileWhichImg1IsEmpty(mot16Path, dataType='train'):
|
||||
path = mot16Path
|
||||
data_images_train = osp.join(path, 'images', f'{dataType}')
|
||||
data_images_train_seqs = glob.glob(data_images_train + '/*')
|
||||
if (len(data_images_train_seqs) == 0):
|
||||
print('dataset is empty!')
|
||||
for data_images_train_seq in data_images_train_seqs:
|
||||
data_images_train_seq_img1 = osp.join(data_images_train_seq, 'img1')
|
||||
if len(glob.glob(data_images_train_seq_img1 + '/*.jpg')) == 0:
|
||||
print(f"os.system(rm -rf {data_images_train_seq})")
|
||||
os.system(f'rm -rf {data_images_train_seq}')
|
||||
|
||||
|
||||
def formatMot16Path(dataPath, pathType='train'):
|
||||
train_path = osp.join(dataPath, 'images', pathType)
|
||||
mkdir_if_missing(train_path)
|
||||
os.system(f'mv {dataPath}/* {train_path}')
|
||||
|
||||
|
||||
def VisualGt(dataPath, phase='train'):
|
||||
seqList = sorted(glob.glob(osp.join(dataPath, 'images', phase) + '/*'))
|
||||
seqIndex = random.randint(0, len(seqList) - 1)
|
||||
seqPath = seqList[seqIndex]
|
||||
gt_path = osp.join(seqPath, 'gt', 'gt.txt')
|
||||
img_list_path = sorted(glob.glob(osp.join(seqPath, 'img1', '*.jpg')))
|
||||
imgIndex = random.randint(0, len(img_list_path))
|
||||
img_Path = img_list_path[imgIndex]
|
||||
frame_value = int(img_Path.split('/')[-1].replace('.jpg', ''))
|
||||
gt_value = np.loadtxt(gt_path, dtype=int, delimiter=',')
|
||||
gt_value = gt_value[gt_value[:, 0] == frame_value]
|
||||
get_list = gt_value.tolist()
|
||||
img = cv2.imread(img_Path)
|
||||
colors = [[255, 0, 0], [255, 255, 0], [255, 0, 255], [0, 255, 0],
|
||||
[0, 255, 255], [0, 0, 255]]
|
||||
for seq, _id, pl, pt, w, h, _, bbox_class, _ in get_list:
|
||||
cv2.putText(img,
|
||||
str(bbox_class), (pl, pt), cv2.FONT_HERSHEY_PLAIN, 2,
|
||||
colors[bbox_class - 1])
|
||||
cv2.rectangle(
|
||||
img, (pl, pt), (pl + w, pt + h),
|
||||
colors[bbox_class - 1],
|
||||
thickness=2)
|
||||
cv2.imwrite('testGt.jpg', img)
|
||||
|
||||
|
||||
def VisualDataset(datasetPath, phase='train', seqName='', frameId=1):
|
||||
trainPath = osp.join(datasetPath, 'labels_with_ids', phase)
|
||||
seq1Paths = osp.join(trainPath, seqName)
|
||||
seq_img1_path = osp.join(seq1Paths, 'img1')
|
||||
label_with_idPath = osp.join(seq_img1_path, '%07d' % frameId) + '.txt'
|
||||
image_path = label_with_idPath.replace('labels_with_ids', 'images').replace(
|
||||
'.txt', '.jpg')
|
||||
seqInfoPath = str.join('/', image_path.split('/')[:-2])
|
||||
seqInfoPath = seqInfoPath + '/seqinfo.ini'
|
||||
seq_info = open(seqInfoPath).read()
|
||||
width = int(seq_info[seq_info.find('imWidth=') + 8:seq_info.find(
|
||||
'\nimHeight')])
|
||||
height = int(seq_info[seq_info.find('imHeight=') + 9:seq_info.find(
|
||||
'\nimExt')])
|
||||
|
||||
with open(label_with_idPath, 'r') as label:
|
||||
allLines = label.readlines()
|
||||
images = cv2.imread(image_path)
|
||||
for line in allLines:
|
||||
line = line.split(' ')
|
||||
line = list(map(lambda x: float(x), line))
|
||||
c1, c2, w, h = line[2:6]
|
||||
x1 = c1 - w / 2
|
||||
x2 = c2 - h / 2
|
||||
x3 = c1 + w / 2
|
||||
x4 = c2 + h / 2
|
||||
cv2.rectangle(
|
||||
images, (int(x1 * width), int(x2 * height)),
|
||||
(int(x3 * width), int(x4 * height)), (255, 0, 0),
|
||||
thickness=2)
|
||||
cv2.imwrite('test.jpg', images)
|
||||
|
||||
|
||||
def gen_image_list(dataPath, datType):
|
||||
inputPath = f'{dataPath}/images/{datType}'
|
||||
pathList = glob.glob(inputPath + '/*')
|
||||
pathList = sorted(pathList)
|
||||
allImageList = []
|
||||
for pathSingle in pathList:
|
||||
imgList = sorted(glob.glob(osp.join(pathSingle, 'img1', '*.jpg')))
|
||||
for imgPath in imgList:
|
||||
allImageList.append(imgPath)
|
||||
with open(f'{dataPath}.{datType}', 'w') as image_list_file:
|
||||
allImageListStr = str.join('\n', allImageList)
|
||||
image_list_file.write(allImageListStr)
|
||||
|
||||
|
||||
def gen_labels_mot(MOT_data, phase='train'):
|
||||
seq_root = './{}/images/{}'.format(MOT_data, phase)
|
||||
label_root = './{}/labels_with_ids/{}'.format(MOT_data, phase)
|
||||
mkdir_if_missing(label_root)
|
||||
seqs = [s for s in os.listdir(seq_root)]
|
||||
print('seqs => ', seqs)
|
||||
tid_curr = 0
|
||||
tid_last = -1
|
||||
for seq in seqs:
|
||||
seq_info = open(osp.join(seq_root, seq, 'seqinfo.ini')).read()
|
||||
seq_width = int(seq_info[seq_info.find('imWidth=') + 8:seq_info.find(
|
||||
'\nimHeight')])
|
||||
seq_height = int(seq_info[seq_info.find('imHeight=') + 9:seq_info.find(
|
||||
'\nimExt')])
|
||||
|
||||
gt_txt = osp.join(seq_root, seq, 'gt', 'gt.txt')
|
||||
gt = np.loadtxt(gt_txt, dtype=np.float64, delimiter=',')
|
||||
|
||||
seq_label_root = osp.join(label_root, seq, 'img1')
|
||||
mkdir_if_missing(seq_label_root)
|
||||
|
||||
for fid, tid, x, y, w, h, mark, label, _ in gt:
|
||||
# if mark == 0 or not label == 1:
|
||||
# continue
|
||||
fid = int(fid)
|
||||
tid = int(tid)
|
||||
if not tid == tid_last:
|
||||
tid_curr += 1
|
||||
tid_last = tid
|
||||
x += w / 2
|
||||
y += h / 2
|
||||
label_fpath = osp.join(seq_label_root, '{:07d}.txt'.format(fid))
|
||||
label_str = '0 {:d} {:.6f} {:.6f} {:.6f} {:.6f}\n'.format(
|
||||
tid_curr, x / seq_width, y / seq_height, w / seq_width,
|
||||
h / seq_height)
|
||||
with open(label_fpath, 'a') as f:
|
||||
f.write(label_str)
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
parser = argparse.ArgumentParser(description='input method')
|
||||
parser.add_argument("--transMot", type=bool, default=False)
|
||||
parser.add_argument("--genMot", type=bool, default=False)
|
||||
parser.add_argument("--formatMotPath", type=bool, default=False)
|
||||
parser.add_argument("--deleteEmpty", type=bool, default=False)
|
||||
parser.add_argument("--genLabelsMot", type=bool, default=False)
|
||||
parser.add_argument("--genImageList", type=bool, default=False)
|
||||
parser.add_argument("--visualImg", type=bool, default=False)
|
||||
parser.add_argument("--visualGt", type=bool, default=False)
|
||||
parser.add_argument("--data_name", type=str, default='visdrone_pedestrian')
|
||||
parser.add_argument("--phase", type=str, default='train')
|
||||
parser.add_argument(
|
||||
"--classes", type=str, default='1,2') # pedestrian and people
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_arguments()
|
||||
classes = args.classes.split(',')
|
||||
datasetPath = './'
|
||||
dataName = args.data_name
|
||||
phase = args.phase
|
||||
if args.transMot:
|
||||
genMotLabels(datasetPath, dataName, classes)
|
||||
formatMot16Path(dataName, pathType=phase)
|
||||
mot16Path = f'./{dataName}'
|
||||
deleteFileWhichImg1IsEmpty(mot16Path, dataType=phase)
|
||||
gen_labels_mot(dataName, phase=phase)
|
||||
gen_image_list(dataName, phase)
|
||||
if args.genMot:
|
||||
genMotLabels(datasetPath, dataName, classes)
|
||||
if args.formatMotPath:
|
||||
formatMot16Path(dataName, pathType=phase)
|
||||
if args.deleteEmpty:
|
||||
mot16Path = f'./{dataName}'
|
||||
deleteFileWhichImg1IsEmpty(mot16Path, dataType=phase)
|
||||
if args.genLabelsMot:
|
||||
gen_labels_mot(dataName, phase=phase)
|
||||
if args.genImageList:
|
||||
gen_image_list(dataName, phase)
|
||||
if args.visualGt:
|
||||
VisualGt(f'./{dataName}', phase)
|
||||
if args.visualImg:
|
||||
seqName = 'uav0000137_00458_v'
|
||||
frameId = 43
|
||||
VisualDataset(
|
||||
f'./{dataName}', phase=phase, seqName=seqName, frameId=frameId)
|
||||
Reference in New Issue
Block a user