更换文档检测模型
This commit is contained in:
21
paddle_detection/ppdet/modeling/mot/matching/__init__.py
Normal file
21
paddle_detection/ppdet/modeling/mot/matching/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from . import jde_matching
|
||||
from . import deepsort_matching
|
||||
from . import ocsort_matching
|
||||
|
||||
from .jde_matching import *
|
||||
from .deepsort_matching import *
|
||||
from .ocsort_matching import *
|
||||
@@ -0,0 +1,379 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This code is based on https://github.com/nwojke/deep_sort/tree/master/deep_sort
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from scipy.optimize import linear_sum_assignment
|
||||
from ..motion import kalman_filter
|
||||
|
||||
INFTY_COST = 1e+5
|
||||
|
||||
__all__ = [
|
||||
'iou_1toN',
|
||||
'iou_cost',
|
||||
'_nn_euclidean_distance',
|
||||
'_nn_cosine_distance',
|
||||
'NearestNeighborDistanceMetric',
|
||||
'min_cost_matching',
|
||||
'matching_cascade',
|
||||
'gate_cost_matrix',
|
||||
]
|
||||
|
||||
|
||||
def iou_1toN(bbox, candidates):
|
||||
"""
|
||||
Computer intersection over union (IoU) by one box to N candidates.
|
||||
|
||||
Args:
|
||||
bbox (ndarray): A bounding box in format `(top left x, top left y, width, height)`.
|
||||
candidates (ndarray): A matrix of candidate bounding boxes (one per row) in the
|
||||
same format as `bbox`.
|
||||
|
||||
Returns:
|
||||
ious (ndarray): The intersection over union in [0, 1] between the `bbox`
|
||||
and each candidate. A higher score means a larger fraction of the
|
||||
`bbox` is occluded by the candidate.
|
||||
"""
|
||||
bbox_tl = bbox[:2]
|
||||
bbox_br = bbox[:2] + bbox[2:]
|
||||
candidates_tl = candidates[:, :2]
|
||||
candidates_br = candidates[:, :2] + candidates[:, 2:]
|
||||
|
||||
tl = np.c_[np.maximum(bbox_tl[0], candidates_tl[:, 0])[:, np.newaxis],
|
||||
np.maximum(bbox_tl[1], candidates_tl[:, 1])[:, np.newaxis]]
|
||||
br = np.c_[np.minimum(bbox_br[0], candidates_br[:, 0])[:, np.newaxis],
|
||||
np.minimum(bbox_br[1], candidates_br[:, 1])[:, np.newaxis]]
|
||||
wh = np.maximum(0., br - tl)
|
||||
|
||||
area_intersection = wh.prod(axis=1)
|
||||
area_bbox = bbox[2:].prod()
|
||||
area_candidates = candidates[:, 2:].prod(axis=1)
|
||||
ious = area_intersection / (area_bbox + area_candidates - area_intersection)
|
||||
return ious
|
||||
|
||||
|
||||
def iou_cost(tracks, detections, track_indices=None, detection_indices=None):
|
||||
"""
|
||||
IoU distance metric.
|
||||
|
||||
Args:
|
||||
tracks (list[Track]): A list of tracks.
|
||||
detections (list[Detection]): A list of detections.
|
||||
track_indices (Optional[list[int]]): A list of indices to tracks that
|
||||
should be matched. Defaults to all `tracks`.
|
||||
detection_indices (Optional[list[int]]): A list of indices to detections
|
||||
that should be matched. Defaults to all `detections`.
|
||||
|
||||
Returns:
|
||||
cost_matrix (ndarray): A cost matrix of shape len(track_indices),
|
||||
len(detection_indices) where entry (i, j) is
|
||||
`1 - iou(tracks[track_indices[i]], detections[detection_indices[j]])`.
|
||||
"""
|
||||
if track_indices is None:
|
||||
track_indices = np.arange(len(tracks))
|
||||
if detection_indices is None:
|
||||
detection_indices = np.arange(len(detections))
|
||||
|
||||
cost_matrix = np.zeros((len(track_indices), len(detection_indices)))
|
||||
for row, track_idx in enumerate(track_indices):
|
||||
if tracks[track_idx].time_since_update > 1:
|
||||
cost_matrix[row, :] = 1e+5
|
||||
continue
|
||||
|
||||
bbox = tracks[track_idx].to_tlwh()
|
||||
candidates = np.asarray([detections[i].tlwh for i in detection_indices])
|
||||
cost_matrix[row, :] = 1. - iou_1toN(bbox, candidates)
|
||||
return cost_matrix
|
||||
|
||||
|
||||
def _nn_euclidean_distance(s, q):
|
||||
"""
|
||||
Compute pair-wise squared (Euclidean) distance between points in `s` and `q`.
|
||||
|
||||
Args:
|
||||
s (ndarray): Sample points: an NxM matrix of N samples of dimensionality M.
|
||||
q (ndarray): Query points: an LxM matrix of L samples of dimensionality M.
|
||||
|
||||
Returns:
|
||||
distances (ndarray): A vector of length M that contains for each entry in `q` the
|
||||
smallest Euclidean distance to a sample in `s`.
|
||||
"""
|
||||
s, q = np.asarray(s), np.asarray(q)
|
||||
if len(s) == 0 or len(q) == 0:
|
||||
return np.zeros((len(s), len(q)))
|
||||
s2, q2 = np.square(s).sum(axis=1), np.square(q).sum(axis=1)
|
||||
distances = -2. * np.dot(s, q.T) + s2[:, None] + q2[None, :]
|
||||
distances = np.clip(distances, 0., float(np.inf))
|
||||
|
||||
return np.maximum(0.0, distances.min(axis=0))
|
||||
|
||||
|
||||
def _nn_cosine_distance(s, q):
|
||||
"""
|
||||
Compute pair-wise cosine distance between points in `s` and `q`.
|
||||
|
||||
Args:
|
||||
s (ndarray): Sample points: an NxM matrix of N samples of dimensionality M.
|
||||
q (ndarray): Query points: an LxM matrix of L samples of dimensionality M.
|
||||
|
||||
Returns:
|
||||
distances (ndarray): A vector of length M that contains for each entry in `q` the
|
||||
smallest Euclidean distance to a sample in `s`.
|
||||
"""
|
||||
s = np.asarray(s) / np.linalg.norm(s, axis=1, keepdims=True)
|
||||
q = np.asarray(q) / np.linalg.norm(q, axis=1, keepdims=True)
|
||||
distances = 1. - np.dot(s, q.T)
|
||||
|
||||
return distances.min(axis=0)
|
||||
|
||||
|
||||
class NearestNeighborDistanceMetric(object):
|
||||
"""
|
||||
A nearest neighbor distance metric that, for each target, returns
|
||||
the closest distance to any sample that has been observed so far.
|
||||
|
||||
Args:
|
||||
metric (str): Either "euclidean" or "cosine".
|
||||
matching_threshold (float): The matching threshold. Samples with larger
|
||||
distance are considered an invalid match.
|
||||
budget (Optional[int]): If not None, fix samples per class to at most
|
||||
this number. Removes the oldest samples when the budget is reached.
|
||||
|
||||
Attributes:
|
||||
samples (Dict[int -> List[ndarray]]): A dictionary that maps from target
|
||||
identities to the list of samples that have been observed so far.
|
||||
"""
|
||||
|
||||
def __init__(self, metric, matching_threshold, budget=None):
|
||||
if metric == "euclidean":
|
||||
self._metric = _nn_euclidean_distance
|
||||
elif metric == "cosine":
|
||||
self._metric = _nn_cosine_distance
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid metric; must be either 'euclidean' or 'cosine'")
|
||||
self.matching_threshold = matching_threshold
|
||||
self.budget = budget
|
||||
self.samples = {}
|
||||
|
||||
def partial_fit(self, features, targets, active_targets):
|
||||
"""
|
||||
Update the distance metric with new data.
|
||||
|
||||
Args:
|
||||
features (ndarray): An NxM matrix of N features of dimensionality M.
|
||||
targets (ndarray): An integer array of associated target identities.
|
||||
active_targets (List[int]): A list of targets that are currently
|
||||
present in the scene.
|
||||
"""
|
||||
for feature, target in zip(features, targets):
|
||||
self.samples.setdefault(target, []).append(feature)
|
||||
if self.budget is not None:
|
||||
self.samples[target] = self.samples[target][-self.budget:]
|
||||
self.samples = {k: self.samples[k] for k in active_targets}
|
||||
|
||||
def distance(self, features, targets):
|
||||
"""
|
||||
Compute distance between features and targets.
|
||||
|
||||
Args:
|
||||
features (ndarray): An NxM matrix of N features of dimensionality M.
|
||||
targets (list[int]): A list of targets to match the given `features` against.
|
||||
|
||||
Returns:
|
||||
cost_matrix (ndarray): a cost matrix of shape len(targets), len(features),
|
||||
where element (i, j) contains the closest squared distance between
|
||||
`targets[i]` and `features[j]`.
|
||||
"""
|
||||
cost_matrix = np.zeros((len(targets), len(features)))
|
||||
for i, target in enumerate(targets):
|
||||
cost_matrix[i, :] = self._metric(self.samples[target], features)
|
||||
return cost_matrix
|
||||
|
||||
|
||||
def min_cost_matching(distance_metric,
|
||||
max_distance,
|
||||
tracks,
|
||||
detections,
|
||||
track_indices=None,
|
||||
detection_indices=None):
|
||||
"""
|
||||
Solve linear assignment problem.
|
||||
|
||||
Args:
|
||||
distance_metric :
|
||||
Callable[List[Track], List[Detection], List[int], List[int]) -> ndarray
|
||||
The distance metric is given a list of tracks and detections as
|
||||
well as a list of N track indices and M detection indices. The
|
||||
metric should return the NxM dimensional cost matrix, where element
|
||||
(i, j) is the association cost between the i-th track in the given
|
||||
track indices and the j-th detection in the given detection_indices.
|
||||
max_distance (float): Gating threshold. Associations with cost larger
|
||||
than this value are disregarded.
|
||||
tracks (list[Track]): A list of predicted tracks at the current time
|
||||
step.
|
||||
detections (list[Detection]): A list of detections at the current time
|
||||
step.
|
||||
track_indices (list[int]): List of track indices that maps rows in
|
||||
`cost_matrix` to tracks in `tracks`.
|
||||
detection_indices (List[int]): List of detection indices that maps
|
||||
columns in `cost_matrix` to detections in `detections`.
|
||||
|
||||
Returns:
|
||||
A tuple (List[(int, int)], List[int], List[int]) with the following
|
||||
three entries:
|
||||
* A list of matched track and detection indices.
|
||||
* A list of unmatched track indices.
|
||||
* A list of unmatched detection indices.
|
||||
"""
|
||||
if track_indices is None:
|
||||
track_indices = np.arange(len(tracks))
|
||||
if detection_indices is None:
|
||||
detection_indices = np.arange(len(detections))
|
||||
|
||||
if len(detection_indices) == 0 or len(track_indices) == 0:
|
||||
return [], track_indices, detection_indices # Nothing to match.
|
||||
|
||||
cost_matrix = distance_metric(tracks, detections, track_indices,
|
||||
detection_indices)
|
||||
|
||||
cost_matrix[cost_matrix > max_distance] = max_distance + 1e-5
|
||||
indices = linear_sum_assignment(cost_matrix)
|
||||
|
||||
matches, unmatched_tracks, unmatched_detections = [], [], []
|
||||
for col, detection_idx in enumerate(detection_indices):
|
||||
if col not in indices[1]:
|
||||
unmatched_detections.append(detection_idx)
|
||||
for row, track_idx in enumerate(track_indices):
|
||||
if row not in indices[0]:
|
||||
unmatched_tracks.append(track_idx)
|
||||
for row, col in zip(indices[0], indices[1]):
|
||||
track_idx = track_indices[row]
|
||||
detection_idx = detection_indices[col]
|
||||
if cost_matrix[row, col] > max_distance:
|
||||
unmatched_tracks.append(track_idx)
|
||||
unmatched_detections.append(detection_idx)
|
||||
else:
|
||||
matches.append((track_idx, detection_idx))
|
||||
return matches, unmatched_tracks, unmatched_detections
|
||||
|
||||
|
||||
def matching_cascade(distance_metric,
|
||||
max_distance,
|
||||
cascade_depth,
|
||||
tracks,
|
||||
detections,
|
||||
track_indices=None,
|
||||
detection_indices=None):
|
||||
"""
|
||||
Run matching cascade.
|
||||
|
||||
Args:
|
||||
distance_metric :
|
||||
Callable[List[Track], List[Detection], List[int], List[int]) -> ndarray
|
||||
The distance metric is given a list of tracks and detections as
|
||||
well as a list of N track indices and M detection indices. The
|
||||
metric should return the NxM dimensional cost matrix, where element
|
||||
(i, j) is the association cost between the i-th track in the given
|
||||
track indices and the j-th detection in the given detection_indices.
|
||||
max_distance (float): Gating threshold. Associations with cost larger
|
||||
than this value are disregarded.
|
||||
cascade_depth (int): The cascade depth, should be se to the maximum
|
||||
track age.
|
||||
tracks (list[Track]): A list of predicted tracks at the current time
|
||||
step.
|
||||
detections (list[Detection]): A list of detections at the current time
|
||||
step.
|
||||
track_indices (list[int]): List of track indices that maps rows in
|
||||
`cost_matrix` to tracks in `tracks`.
|
||||
detection_indices (List[int]): List of detection indices that maps
|
||||
columns in `cost_matrix` to detections in `detections`.
|
||||
|
||||
Returns:
|
||||
A tuple (List[(int, int)], List[int], List[int]) with the following
|
||||
three entries:
|
||||
* A list of matched track and detection indices.
|
||||
* A list of unmatched track indices.
|
||||
* A list of unmatched detection indices.
|
||||
"""
|
||||
if track_indices is None:
|
||||
track_indices = list(range(len(tracks)))
|
||||
if detection_indices is None:
|
||||
detection_indices = list(range(len(detections)))
|
||||
|
||||
unmatched_detections = detection_indices
|
||||
matches = []
|
||||
for level in range(cascade_depth):
|
||||
if len(unmatched_detections) == 0: # No detections left
|
||||
break
|
||||
|
||||
track_indices_l = [
|
||||
k for k in track_indices if tracks[k].time_since_update == 1 + level
|
||||
]
|
||||
if len(track_indices_l) == 0: # Nothing to match at this level
|
||||
continue
|
||||
|
||||
matches_l, _, unmatched_detections = \
|
||||
min_cost_matching(
|
||||
distance_metric, max_distance, tracks, detections,
|
||||
track_indices_l, unmatched_detections)
|
||||
matches += matches_l
|
||||
unmatched_tracks = list(set(track_indices) - set(k for k, _ in matches))
|
||||
return matches, unmatched_tracks, unmatched_detections
|
||||
|
||||
|
||||
def gate_cost_matrix(kf,
|
||||
cost_matrix,
|
||||
tracks,
|
||||
detections,
|
||||
track_indices,
|
||||
detection_indices,
|
||||
gated_cost=INFTY_COST,
|
||||
only_position=False):
|
||||
"""
|
||||
Invalidate infeasible entries in cost matrix based on the state
|
||||
distributions obtained by Kalman filtering.
|
||||
|
||||
Args:
|
||||
kf (object): The Kalman filter.
|
||||
cost_matrix (ndarray): The NxM dimensional cost matrix, where N is the
|
||||
number of track indices and M is the number of detection indices,
|
||||
such that entry (i, j) is the association cost between
|
||||
`tracks[track_indices[i]]` and `detections[detection_indices[j]]`.
|
||||
tracks (list[Track]): A list of predicted tracks at the current time
|
||||
step.
|
||||
detections (list[Detection]): A list of detections at the current time
|
||||
step.
|
||||
track_indices (List[int]): List of track indices that maps rows in
|
||||
`cost_matrix` to tracks in `tracks`.
|
||||
detection_indices (List[int]): List of detection indices that maps
|
||||
columns in `cost_matrix` to detections in `detections`.
|
||||
gated_cost (Optional[float]): Entries in the cost matrix corresponding
|
||||
to infeasible associations are set this value. Defaults to a very
|
||||
large value.
|
||||
only_position (Optional[bool]): If True, only the x, y position of the
|
||||
state distribution is considered during gating. Default False.
|
||||
"""
|
||||
gating_dim = 2 if only_position else 4
|
||||
gating_threshold = kalman_filter.chi2inv95[gating_dim]
|
||||
measurements = np.asarray(
|
||||
[detections[i].to_xyah() for i in detection_indices])
|
||||
for row, track_idx in enumerate(track_indices):
|
||||
track = tracks[track_idx]
|
||||
gating_distance = kf.gating_distance(track.mean, track.covariance,
|
||||
measurements, only_position)
|
||||
cost_matrix[row, gating_distance > gating_threshold] = gated_cost
|
||||
return cost_matrix
|
||||
163
paddle_detection/ppdet/modeling/mot/matching/jde_matching.py
Normal file
163
paddle_detection/ppdet/modeling/mot/matching/jde_matching.py
Normal file
@@ -0,0 +1,163 @@
|
||||
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This code is based on https://github.com/Zhongdao/Towards-Realtime-MOT/blob/master/tracker/matching.py
|
||||
"""
|
||||
|
||||
try:
|
||||
import lap
|
||||
except:
|
||||
print(
|
||||
'Warning: Unable to use JDE/FairMOT/ByteTrack, please install lap, for example: `pip install lap`, see https://github.com/gatagat/lap'
|
||||
)
|
||||
pass
|
||||
|
||||
import scipy
|
||||
import numpy as np
|
||||
from scipy.spatial.distance import cdist
|
||||
from ..motion import kalman_filter
|
||||
import warnings
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
__all__ = [
|
||||
'merge_matches',
|
||||
'linear_assignment',
|
||||
'bbox_ious',
|
||||
'iou_distance',
|
||||
'embedding_distance',
|
||||
'fuse_motion',
|
||||
]
|
||||
|
||||
|
||||
def merge_matches(m1, m2, shape):
|
||||
O, P, Q = shape
|
||||
m1 = np.asarray(m1)
|
||||
m2 = np.asarray(m2)
|
||||
|
||||
M1 = scipy.sparse.coo_matrix(
|
||||
(np.ones(len(m1)), (m1[:, 0], m1[:, 1])), shape=(O, P))
|
||||
M2 = scipy.sparse.coo_matrix(
|
||||
(np.ones(len(m2)), (m2[:, 0], m2[:, 1])), shape=(P, Q))
|
||||
|
||||
mask = M1 * M2
|
||||
match = mask.nonzero()
|
||||
match = list(zip(match[0], match[1]))
|
||||
unmatched_O = tuple(set(range(O)) - set([i for i, j in match]))
|
||||
unmatched_Q = tuple(set(range(Q)) - set([j for i, j in match]))
|
||||
|
||||
return match, unmatched_O, unmatched_Q
|
||||
|
||||
|
||||
def linear_assignment(cost_matrix, thresh):
|
||||
try:
|
||||
import lap
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
'Unable to use JDE/FairMOT/ByteTrack, please install lap, for example: `pip install lap`, see https://github.com/gatagat/lap'
|
||||
)
|
||||
if cost_matrix.size == 0:
|
||||
return np.empty(
|
||||
(0, 2), dtype=int), tuple(range(cost_matrix.shape[0])), tuple(
|
||||
range(cost_matrix.shape[1]))
|
||||
matches, unmatched_a, unmatched_b = [], [], []
|
||||
cost, x, y = lap.lapjv(cost_matrix, extend_cost=True, cost_limit=thresh)
|
||||
for ix, mx in enumerate(x):
|
||||
if mx >= 0:
|
||||
matches.append([ix, mx])
|
||||
unmatched_a = np.where(x < 0)[0]
|
||||
unmatched_b = np.where(y < 0)[0]
|
||||
matches = np.asarray(matches)
|
||||
return matches, unmatched_a, unmatched_b
|
||||
|
||||
|
||||
def bbox_ious(atlbrs, btlbrs):
|
||||
boxes = np.ascontiguousarray(atlbrs, dtype=np.float32)
|
||||
query_boxes = np.ascontiguousarray(btlbrs, dtype=np.float32)
|
||||
N = boxes.shape[0]
|
||||
K = query_boxes.shape[0]
|
||||
ious = np.zeros((N, K), dtype=boxes.dtype)
|
||||
if N * K == 0:
|
||||
return ious
|
||||
|
||||
for k in range(K):
|
||||
box_area = ((query_boxes[k, 2] - query_boxes[k, 0] + 1) *
|
||||
(query_boxes[k, 3] - query_boxes[k, 1] + 1))
|
||||
for n in range(N):
|
||||
iw = (min(boxes[n, 2], query_boxes[k, 2]) - max(
|
||||
boxes[n, 0], query_boxes[k, 0]) + 1)
|
||||
if iw > 0:
|
||||
ih = (min(boxes[n, 3], query_boxes[k, 3]) - max(
|
||||
boxes[n, 1], query_boxes[k, 1]) + 1)
|
||||
if ih > 0:
|
||||
ua = float((boxes[n, 2] - boxes[n, 0] + 1) * (boxes[
|
||||
n, 3] - boxes[n, 1] + 1) + box_area - iw * ih)
|
||||
ious[n, k] = iw * ih / ua
|
||||
return ious
|
||||
|
||||
|
||||
def iou_distance(atracks, btracks):
|
||||
"""
|
||||
Compute cost based on IoU between two list[STrack].
|
||||
"""
|
||||
if (len(atracks) > 0 and isinstance(atracks[0], np.ndarray)) or (
|
||||
len(btracks) > 0 and isinstance(btracks[0], np.ndarray)):
|
||||
atlbrs = atracks
|
||||
btlbrs = btracks
|
||||
else:
|
||||
atlbrs = [track.tlbr for track in atracks]
|
||||
btlbrs = [track.tlbr for track in btracks]
|
||||
_ious = bbox_ious(atlbrs, btlbrs)
|
||||
cost_matrix = 1 - _ious
|
||||
|
||||
return cost_matrix
|
||||
|
||||
|
||||
def embedding_distance(tracks, detections, metric='euclidean'):
|
||||
"""
|
||||
Compute cost based on features between two list[STrack].
|
||||
"""
|
||||
cost_matrix = np.zeros((len(tracks), len(detections)), dtype=np.float32)
|
||||
if cost_matrix.size == 0:
|
||||
return cost_matrix
|
||||
det_features = np.asarray(
|
||||
[track.curr_feat for track in detections], dtype=np.float32)
|
||||
track_features = np.asarray(
|
||||
[track.smooth_feat for track in tracks], dtype=np.float32)
|
||||
cost_matrix = np.maximum(0.0, cdist(track_features, det_features,
|
||||
metric)) # Nomalized features
|
||||
return cost_matrix
|
||||
|
||||
|
||||
def fuse_motion(kf,
|
||||
cost_matrix,
|
||||
tracks,
|
||||
detections,
|
||||
only_position=False,
|
||||
lambda_=0.98):
|
||||
if cost_matrix.size == 0:
|
||||
return cost_matrix
|
||||
gating_dim = 2 if only_position else 4
|
||||
gating_threshold = kalman_filter.chi2inv95[gating_dim]
|
||||
measurements = np.asarray([det.to_xyah() for det in detections])
|
||||
for row, track in enumerate(tracks):
|
||||
gating_distance = kf.gating_distance(
|
||||
track.mean,
|
||||
track.covariance,
|
||||
measurements,
|
||||
only_position,
|
||||
metric='maha')
|
||||
cost_matrix[row, gating_distance > gating_threshold] = np.inf
|
||||
cost_matrix[row] = lambda_ * cost_matrix[row] + (1 - lambda_
|
||||
) * gating_distance
|
||||
return cost_matrix
|
||||
165
paddle_detection/ppdet/modeling/mot/matching/ocsort_matching.py
Normal file
165
paddle_detection/ppdet/modeling/mot/matching/ocsort_matching.py
Normal file
@@ -0,0 +1,165 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This code is based on https://github.com/noahcao/OC_SORT/blob/master/trackers/ocsort_tracker/association.py
|
||||
"""
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
|
||||
def iou_batch(bboxes1, bboxes2):
|
||||
bboxes2 = np.expand_dims(bboxes2, 0)
|
||||
bboxes1 = np.expand_dims(bboxes1, 1)
|
||||
|
||||
xx1 = np.maximum(bboxes1[..., 0], bboxes2[..., 0])
|
||||
yy1 = np.maximum(bboxes1[..., 1], bboxes2[..., 1])
|
||||
xx2 = np.minimum(bboxes1[..., 2], bboxes2[..., 2])
|
||||
yy2 = np.minimum(bboxes1[..., 3], bboxes2[..., 3])
|
||||
w = np.maximum(0., xx2 - xx1)
|
||||
h = np.maximum(0., yy2 - yy1)
|
||||
area = w * h
|
||||
iou_matrix = area / ((bboxes1[..., 2] - bboxes1[..., 0]) *
|
||||
(bboxes1[..., 3] - bboxes1[..., 1]) +
|
||||
(bboxes2[..., 2] - bboxes2[..., 0]) *
|
||||
(bboxes2[..., 3] - bboxes2[..., 1]) - area)
|
||||
return iou_matrix
|
||||
|
||||
|
||||
def speed_direction_batch(dets, tracks):
|
||||
tracks = tracks[..., np.newaxis]
|
||||
CX1, CY1 = (dets[:, 0] + dets[:, 2]) / 2.0, (dets[:, 1] + dets[:, 3]) / 2.0
|
||||
CX2, CY2 = (tracks[:, 0] + tracks[:, 2]) / 2.0, (
|
||||
tracks[:, 1] + tracks[:, 3]) / 2.0
|
||||
dx = CX1 - CX2
|
||||
dy = CY1 - CY2
|
||||
norm = np.sqrt(dx**2 + dy**2) + 1e-6
|
||||
dx = dx / norm
|
||||
dy = dy / norm
|
||||
return dy, dx
|
||||
|
||||
|
||||
def linear_assignment(cost_matrix):
|
||||
try:
|
||||
import lap
|
||||
_, x, y = lap.lapjv(cost_matrix, extend_cost=True)
|
||||
return np.array([[y[i], i] for i in x if i >= 0])
|
||||
except ImportError:
|
||||
from scipy.optimize import linear_sum_assignment
|
||||
x, y = linear_sum_assignment(cost_matrix)
|
||||
return np.array(list(zip(x, y)))
|
||||
|
||||
|
||||
def associate(detections, trackers, iou_threshold, velocities, previous_obs,
|
||||
vdc_weight):
|
||||
if (len(trackers) == 0):
|
||||
return np.empty(
|
||||
(0, 2), dtype=int), np.arange(len(detections)), np.empty(
|
||||
(0, 5), dtype=int)
|
||||
|
||||
Y, X = speed_direction_batch(detections, previous_obs)
|
||||
inertia_Y, inertia_X = velocities[:, 0], velocities[:, 1]
|
||||
inertia_Y = np.repeat(inertia_Y[:, np.newaxis], Y.shape[1], axis=1)
|
||||
inertia_X = np.repeat(inertia_X[:, np.newaxis], X.shape[1], axis=1)
|
||||
diff_angle_cos = inertia_X * X + inertia_Y * Y
|
||||
diff_angle_cos = np.clip(diff_angle_cos, a_min=-1, a_max=1)
|
||||
diff_angle = np.arccos(diff_angle_cos)
|
||||
diff_angle = (np.pi / 2.0 - np.abs(diff_angle)) / np.pi
|
||||
|
||||
valid_mask = np.ones(previous_obs.shape[0])
|
||||
valid_mask[np.where(previous_obs[:, 4] < 0)] = 0
|
||||
|
||||
iou_matrix = iou_batch(detections, trackers)
|
||||
scores = np.repeat(
|
||||
detections[:, -1][:, np.newaxis], trackers.shape[0], axis=1)
|
||||
# iou_matrix = iou_matrix * scores # a trick sometiems works, we don't encourage this
|
||||
valid_mask = np.repeat(valid_mask[:, np.newaxis], X.shape[1], axis=1)
|
||||
|
||||
angle_diff_cost = (valid_mask * diff_angle) * vdc_weight
|
||||
angle_diff_cost = angle_diff_cost.T
|
||||
angle_diff_cost = angle_diff_cost * scores
|
||||
|
||||
if min(iou_matrix.shape) > 0:
|
||||
a = (iou_matrix > iou_threshold).astype(np.int32)
|
||||
if a.sum(1).max() == 1 and a.sum(0).max() == 1:
|
||||
matched_indices = np.stack(np.where(a), axis=1)
|
||||
else:
|
||||
matched_indices = linear_assignment(-(iou_matrix + angle_diff_cost))
|
||||
else:
|
||||
matched_indices = np.empty(shape=(0, 2))
|
||||
|
||||
unmatched_detections = []
|
||||
for d, det in enumerate(detections):
|
||||
if (d not in matched_indices[:, 0]):
|
||||
unmatched_detections.append(d)
|
||||
unmatched_trackers = []
|
||||
for t, trk in enumerate(trackers):
|
||||
if (t not in matched_indices[:, 1]):
|
||||
unmatched_trackers.append(t)
|
||||
|
||||
# filter out matched with low IOU
|
||||
matches = []
|
||||
for m in matched_indices:
|
||||
if (iou_matrix[m[0], m[1]] < iou_threshold):
|
||||
unmatched_detections.append(m[0])
|
||||
unmatched_trackers.append(m[1])
|
||||
else:
|
||||
matches.append(m.reshape(1, 2))
|
||||
if (len(matches) == 0):
|
||||
matches = np.empty((0, 2), dtype=int)
|
||||
else:
|
||||
matches = np.concatenate(matches, axis=0)
|
||||
|
||||
return matches, np.array(unmatched_detections), np.array(unmatched_trackers)
|
||||
|
||||
|
||||
def associate_only_iou(detections, trackers, iou_threshold):
|
||||
if (len(trackers) == 0):
|
||||
return np.empty(
|
||||
(0, 2), dtype=int), np.arange(len(detections)), np.empty(
|
||||
(0, 5), dtype=int)
|
||||
|
||||
iou_matrix = iou_batch(detections, trackers)
|
||||
|
||||
if min(iou_matrix.shape) > 0:
|
||||
a = (iou_matrix > iou_threshold).astype(np.int32)
|
||||
if a.sum(1).max() == 1 and a.sum(0).max() == 1:
|
||||
matched_indices = np.stack(np.where(a), axis=1)
|
||||
else:
|
||||
matched_indices = linear_assignment(-iou_matrix)
|
||||
else:
|
||||
matched_indices = np.empty(shape=(0, 2))
|
||||
|
||||
unmatched_detections = []
|
||||
for d, det in enumerate(detections):
|
||||
if (d not in matched_indices[:, 0]):
|
||||
unmatched_detections.append(d)
|
||||
unmatched_trackers = []
|
||||
for t, trk in enumerate(trackers):
|
||||
if (t not in matched_indices[:, 1]):
|
||||
unmatched_trackers.append(t)
|
||||
|
||||
# filter out matched with low IOU
|
||||
matches = []
|
||||
for m in matched_indices:
|
||||
if (iou_matrix[m[0], m[1]] < iou_threshold):
|
||||
unmatched_detections.append(m[0])
|
||||
unmatched_trackers.append(m[1])
|
||||
else:
|
||||
matches.append(m.reshape(1, 2))
|
||||
if (len(matches) == 0):
|
||||
matches = np.empty((0, 2), dtype=int)
|
||||
else:
|
||||
matches = np.concatenate(matches, axis=0)
|
||||
return matches, np.array(unmatched_detections), np.array(unmatched_trackers)
|
||||
Reference in New Issue
Block a user