Files
fcb_photo_review/paddle_detection/ppdet/data/source/pose3d_cmb.py
2024-08-27 14:42:45 +08:00

381 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import cv2
import numpy as np
import json
import copy
import pycocotools
from pycocotools.coco import COCO
from .dataset import DetDataset
from ppdet.core.workspace import register, serializable
from paddle.io import Dataset
@serializable
class Pose3DDataset(DetDataset):
"""Pose3D Dataset class.
Args:
dataset_dir (str): Root path to the dataset.
anno_list (list of str): each of the element is a relative path to the annotation file.
image_dirs (list of str): each of path is a relative path where images are held.
transform (composed(operators)): A sequence of data transforms.
test_mode (bool): Store True when building test or
validation dataset. Default: False.
24 joints order:
0-2: 'R_Ankle', 'R_Knee', 'R_Hip',
3-5:'L_Hip', 'L_Knee', 'L_Ankle',
6-8:'R_Wrist', 'R_Elbow', 'R_Shoulder',
9-11:'L_Shoulder','L_Elbow','L_Wrist',
12-14:'Neck','Top_of_Head','Pelvis',
15-18:'Thorax','Spine','Jaw','Head',
19-23:'Nose','L_Eye','R_Eye','L_Ear','R_Ear'
"""
def __init__(self,
dataset_dir,
image_dirs,
anno_list,
transform=[],
num_joints=24,
test_mode=False):
super().__init__(dataset_dir, image_dirs, anno_list)
self.image_info = {}
self.ann_info = {}
self.num_joints = num_joints
self.transform = transform
self.test_mode = test_mode
self.img_ids = []
self.dataset_dir = dataset_dir
self.image_dirs = image_dirs
self.anno_list = anno_list
def get_mask(self, mvm_percent=0.3):
num_joints = self.num_joints
mjm_mask = np.ones((num_joints, 1)).astype(np.float32)
if self.test_mode == False:
pb = np.random.random_sample()
masked_num = int(
pb * mvm_percent *
num_joints) # at most x% of the joints could be masked
indices = np.random.choice(
np.arange(num_joints), replace=False, size=masked_num)
mjm_mask[indices, :] = 0.0
# return mjm_mask
num_joints = 10
mvm_mask = np.ones((num_joints, 1)).astype(np.float)
if self.test_mode == False:
num_vertices = num_joints
pb = np.random.random_sample()
masked_num = int(
pb * mvm_percent *
num_vertices) # at most x% of the vertices could be masked
indices = np.random.choice(
np.arange(num_vertices), replace=False, size=masked_num)
mvm_mask[indices, :] = 0.0
mjm_mask = np.concatenate([mjm_mask, mvm_mask], axis=0)
return mjm_mask
def filterjoints(self, x):
if self.num_joints == 24:
return x
elif self.num_joints == 14:
return x[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 18], :]
elif self.num_joints == 17:
return x[
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 15, 18, 19], :]
else:
raise ValueError(
"unsupported joint numbers, only [24 or 17 or 14] is supported!")
def parse_dataset(self):
print("Loading annotations..., please wait")
self.annos = []
im_id = 0
self.human36m_num = 0
for idx, annof in enumerate(self.anno_list):
img_prefix = os.path.join(self.dataset_dir, self.image_dirs[idx])
dataf = os.path.join(self.dataset_dir, annof)
with open(dataf, 'r') as rf:
anno_data = json.load(rf)
annos = anno_data['data']
new_annos = []
print("{} has annos numbers: {}".format(dataf, len(annos)))
for anno in annos:
new_anno = {}
new_anno['im_id'] = im_id
im_id += 1
imagename = anno['imageName']
if imagename.startswith("COCO_train2014_"):
imagename = imagename[len("COCO_train2014_"):]
elif imagename.startswith("COCO_val2014_"):
imagename = imagename[len("COCO_val2014_"):]
imagename = os.path.join(img_prefix, imagename)
if not os.path.exists(imagename):
if "train2017" in imagename:
imagename = imagename.replace("train2017",
"val2017")
if not os.path.exists(imagename):
print("cannot find imagepath:{}".format(
imagename))
continue
else:
print("cannot find imagepath:{}".format(imagename))
continue
new_anno['imageName'] = imagename
if 'human3.6m' in imagename:
self.human36m_num += 1
new_anno['bbox_center'] = anno['bbox_center']
new_anno['bbox_scale'] = anno['bbox_scale']
new_anno['joints_2d'] = np.array(anno[
'gt_keypoint_2d']).astype(np.float32)
if new_anno['joints_2d'].shape[0] == 49:
#if the joints_2d is in SPIN format(which generated by eft), choose the last 24 public joints
#for detail please refer: https://github.com/nkolot/SPIN/blob/master/constants.py
new_anno['joints_2d'] = new_anno['joints_2d'][25:]
new_anno['joints_3d'] = np.array(anno[
'pose3d'])[:, :3].astype(np.float32)
new_anno['mjm_mask'] = self.get_mask()
if not 'has_3d_joints' in anno:
new_anno['has_3d_joints'] = int(1)
new_anno['has_2d_joints'] = int(1)
else:
new_anno['has_3d_joints'] = int(anno['has_3d_joints'])
new_anno['has_2d_joints'] = int(anno['has_2d_joints'])
new_anno['joints_2d'] = self.filterjoints(new_anno[
'joints_2d'])
self.annos.append(new_anno)
del annos
def get_temp_num(self):
"""get temporal data number, like human3.6m"""
return self.human36m_num
def __len__(self):
"""Get dataset length."""
return len(self.annos)
def _get_imganno(self, idx):
"""Get anno for a single image."""
return self.annos[idx]
def __getitem__(self, idx):
"""Prepare image for training given the index."""
records = copy.deepcopy(self._get_imganno(idx))
imgpath = records['imageName']
assert os.path.exists(imgpath), "cannot find image {}".format(imgpath)
records['image'] = cv2.imread(imgpath)
records['image'] = cv2.cvtColor(records['image'], cv2.COLOR_BGR2RGB)
records = self.transform(records)
return records
def check_or_download_dataset(self):
alldatafind = True
for image_dir in self.image_dirs:
image_dir = os.path.join(self.dataset_dir, image_dir)
if not os.path.isdir(image_dir):
print("dataset [{}] is not found".format(image_dir))
alldatafind = False
if not alldatafind:
raise ValueError(
"Some dataset is not valid and cannot download automatically now, please prepare the dataset first"
)
@register
@serializable
class Keypoint3DMultiFramesDataset(Dataset):
"""24 keypoints 3D dataset for pose estimation.
each item is a list of images
The dataset loads raw features and apply specified transforms
to return a dict containing the image tensors and other information.
Args:
dataset_dir (str): Root path to the dataset.
image_dir (str): Path to a directory where images are held.
"""
def __init__(
self,
dataset_dir, # 数据集根目录
image_dir, # 图像文件夹
p3d_dir, # 3D关键点文件夹
json_path,
img_size, #图像resize大小
num_frames, # 帧序列长度
anno_path=None, ):
self.dataset_dir = dataset_dir
self.image_dir = image_dir
self.p3d_dir = p3d_dir
self.json_path = json_path
self.img_size = img_size
self.num_frames = num_frames
self.anno_path = anno_path
self.data_labels, self.mf_inds = self._generate_multi_frames_list()
def _generate_multi_frames_list(self):
act_list = os.listdir(self.dataset_dir) # 动作列表
count = 0
mf_list = []
annos_dict = {'images': [], 'annotations': [], 'act_inds': []}
for act in act_list: #对每个动作,生成帧序列
if '.' in act:
continue
json_path = os.path.join(self.dataset_dir, act, self.json_path)
with open(json_path, 'r') as j:
annos = json.load(j)
length = len(annos['images'])
for k, v in annos.items():
if k in annos_dict:
annos_dict[k].extend(v)
annos_dict['act_inds'].extend([act] * length)
mf = [[i + j + count for j in range(self.num_frames)]
for i in range(0, length - self.num_frames + 1)]
mf_list.extend(mf)
count += length
print("total data number:", len(mf_list))
return annos_dict, mf_list
def __call__(self, *args, **kwargs):
return self
def __getitem__(self, index): # 拿一个连续的序列
inds = self.mf_inds[
index] # 如[568, 569, 570, 571, 572, 573]长度为num_frames
images = self.data_labels['images'] # all images
annots = self.data_labels['annotations'] # all annots
act = self.data_labels['act_inds'][inds[0]] # 动作名(文件夹名)
kps3d_list = []
kps3d_vis_list = []
names = []
h, w = 0, 0
for ind in inds: # one image
height = float(images[ind]['height'])
width = float(images[ind]['width'])
name = images[ind]['file_name'] # 图像名称,带有后缀
kps3d_name = name.split('.')[0] + '.obj'
kps3d_path = os.path.join(self.dataset_dir, act, self.p3d_dir,
kps3d_name)
joints, joints_vis = self.kps3d_process(kps3d_path)
joints_vis = np.array(joints_vis, dtype=np.float32)
kps3d_list.append(joints)
kps3d_vis_list.append(joints_vis)
names.append(name)
kps3d = np.array(kps3d_list) # (6, 24, 3),(num_frames, joints_num, 3)
kps3d_vis = np.array(kps3d_vis_list)
# read image
imgs = []
for name in names:
img_path = os.path.join(self.dataset_dir, act, self.image_dir, name)
image = cv2.imread(img_path, cv2.IMREAD_COLOR |
cv2.IMREAD_IGNORE_ORIENTATION)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
imgs.append(np.expand_dims(image, axis=0))
imgs = np.concatenate(imgs, axis=0)
imgs = imgs.astype(
np.float32) # (6, 1080, 1920, 3),(num_frames, h, w, c)
# attention: 此时图像和标注是镜像的
records = {
'kps3d': kps3d,
'kps3d_vis': kps3d_vis,
"image": imgs,
'act': act,
'names': names,
'im_id': index
}
return self.transform(records)
def kps3d_process(self, kps3d_path):
count = 0
kps = []
kps_vis = []
with open(kps3d_path, 'r') as f:
lines = f.readlines()
for line in lines:
if line[0] == 'v':
kps.append([])
line = line.strip('\n').split(' ')[1:]
for kp in line:
kps[-1].append(float(kp))
count += 1
kps_vis.append([1, 1, 1])
kps = np.array(kps) # 523
kps_vis = np.array(kps_vis)
kps *= 10 # scale points
kps -= kps[[0], :] # set root point to zero
kps = np.concatenate((kps[0:23], kps[[37]]), axis=0) # 24,3
kps *= 10
kps_vis = np.concatenate((kps_vis[0:23], kps_vis[[37]]), axis=0) # 24,3
return kps, kps_vis
def __len__(self):
return len(self.mf_inds)
def get_anno(self):
if self.anno_path is None:
return
return os.path.join(self.dataset_dir, self.anno_path)
def check_or_download_dataset(self):
return
def parse_dataset(self, ):
return
def set_transform(self, transform):
self.transform = transform
def set_epoch(self, epoch_id):
self._epoch = epoch_id
def set_kwargs(self, **kwargs):
self.mixup_epoch = kwargs.get('mixup_epoch', -1)
self.cutmix_epoch = kwargs.get('cutmix_epoch', -1)
self.mosaic_epoch = kwargs.get('mosaic_epoch', -1)