更换文档检测模型

This commit is contained in:
2024-08-27 14:42:45 +08:00
parent aea6f19951
commit 1514e09c40
2072 changed files with 254336 additions and 4967 deletions

View File

@@ -0,0 +1,18 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import kalman_filter
from .kalman_filter import *
from .gmc import *

View File

@@ -0,0 +1,368 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is based on https://github.com/WWangYuHsiang/SMILEtrack/blob/main/BoT-SORT/tracker/gmc.py
"""
import cv2
import matplotlib.pyplot as plt
import numpy as np
import copy
import time
from ppdet.core.workspace import register, serializable
@register
@serializable
class GMC:
def __init__(self, method='sparseOptFlow', downscale=2, verbose=None):
super(GMC, self).__init__()
self.method = method
self.downscale = max(1, int(downscale))
if self.method == 'orb':
self.detector = cv2.FastFeatureDetector_create(20)
self.extractor = cv2.ORB_create()
self.matcher = cv2.BFMatcher(cv2.NORM_HAMMING)
elif self.method == 'sift':
self.detector = cv2.SIFT_create(
nOctaveLayers=3, contrastThreshold=0.02, edgeThreshold=20)
self.extractor = cv2.SIFT_create(
nOctaveLayers=3, contrastThreshold=0.02, edgeThreshold=20)
self.matcher = cv2.BFMatcher(cv2.NORM_L2)
elif self.method == 'ecc':
number_of_iterations = 5000
termination_eps = 1e-6
self.warp_mode = cv2.MOTION_EUCLIDEAN
self.criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT,
number_of_iterations, termination_eps)
elif self.method == 'sparseOptFlow':
self.feature_params = dict(
maxCorners=1000,
qualityLevel=0.01,
minDistance=1,
blockSize=3,
useHarrisDetector=False,
k=0.04)
# self.gmc_file = open('GMC_results.txt', 'w')
elif self.method == 'file' or self.method == 'files':
seqName = verbose[0]
ablation = verbose[1]
if ablation:
filePath = r'tracker/GMC_files/MOT17_ablation'
else:
filePath = r'tracker/GMC_files/MOTChallenge'
if '-FRCNN' in seqName:
seqName = seqName[:-6]
elif '-DPM' in seqName:
seqName = seqName[:-4]
elif '-SDP' in seqName:
seqName = seqName[:-4]
self.gmcFile = open(filePath + "/GMC-" + seqName + ".txt", 'r')
if self.gmcFile is None:
raise ValueError("Error: Unable to open GMC file in directory:"
+ filePath)
elif self.method == 'none' or self.method == 'None':
self.method = 'none'
else:
raise ValueError("Error: Unknown CMC method:" + method)
self.prevFrame = None
self.prevKeyPoints = None
self.prevDescriptors = None
self.initializedFirstFrame = False
def apply(self, raw_frame, detections=None):
if self.method == 'orb' or self.method == 'sift':
return self.applyFeaures(raw_frame, detections)
elif self.method == 'ecc':
return self.applyEcc(raw_frame, detections)
elif self.method == 'sparseOptFlow':
return self.applySparseOptFlow(raw_frame, detections)
elif self.method == 'file':
return self.applyFile(raw_frame, detections)
elif self.method == 'none':
return np.eye(2, 3)
else:
return np.eye(2, 3)
def applyEcc(self, raw_frame, detections=None):
# Initialize
height, width, _ = raw_frame.shape
frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY)
H = np.eye(2, 3, dtype=np.float32)
# Downscale image (TODO: consider using pyramids)
if self.downscale > 1.0:
frame = cv2.GaussianBlur(frame, (3, 3), 1.5)
frame = cv2.resize(frame, (width // self.downscale,
height // self.downscale))
width = width // self.downscale
height = height // self.downscale
# Handle first frame
if not self.initializedFirstFrame:
# Initialize data
self.prevFrame = frame.copy()
# Initialization done
self.initializedFirstFrame = True
return H
# Run the ECC algorithm. The results are stored in warp_matrix.
# (cc, H) = cv2.findTransformECC(self.prevFrame, frame, H, self.warp_mode, self.criteria)
try:
(cc,
H) = cv2.findTransformECC(self.prevFrame, frame, H, self.warp_mode,
self.criteria, None, 1)
except:
print('Warning: find transform failed. Set warp as identity')
return H
def applyFeaures(self, raw_frame, detections=None):
# Initialize
height, width, _ = raw_frame.shape
frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY)
H = np.eye(2, 3)
# Downscale image (TODO: consider using pyramids)
if self.downscale > 1.0:
# frame = cv2.GaussianBlur(frame, (3, 3), 1.5)
frame = cv2.resize(frame, (width // self.downscale,
height // self.downscale))
width = width // self.downscale
height = height // self.downscale
# find the keypoints
mask = np.zeros_like(frame)
# mask[int(0.05 * height): int(0.95 * height), int(0.05 * width): int(0.95 * width)] = 255
mask[int(0.02 * height):int(0.98 * height), int(0.02 * width):int(
0.98 * width)] = 255
if detections is not None:
for det in detections:
tlbr = (det[:4] / self.downscale).astype(np.int_)
mask[tlbr[1]:tlbr[3], tlbr[0]:tlbr[2]] = 0
keypoints = self.detector.detect(frame, mask)
# compute the descriptors
keypoints, descriptors = self.extractor.compute(frame, keypoints)
# Handle first frame
if not self.initializedFirstFrame:
# Initialize data
self.prevFrame = frame.copy()
self.prevKeyPoints = copy.copy(keypoints)
self.prevDescriptors = copy.copy(descriptors)
# Initialization done
self.initializedFirstFrame = True
return H
# Match descriptors.
knnMatches = self.matcher.knnMatch(self.prevDescriptors, descriptors, 2)
# Filtered matches based on smallest spatial distance
matches = []
spatialDistances = []
maxSpatialDistance = 0.25 * np.array([width, height])
# Handle empty matches case
if len(knnMatches) == 0:
# Store to next iteration
self.prevFrame = frame.copy()
self.prevKeyPoints = copy.copy(keypoints)
self.prevDescriptors = copy.copy(descriptors)
return H
for m, n in knnMatches:
if m.distance < 0.9 * n.distance:
prevKeyPointLocation = self.prevKeyPoints[m.queryIdx].pt
currKeyPointLocation = keypoints[m.trainIdx].pt
spatialDistance = (
prevKeyPointLocation[0] - currKeyPointLocation[0],
prevKeyPointLocation[1] - currKeyPointLocation[1])
if (np.abs(spatialDistance[0]) < maxSpatialDistance[0]) and \
(np.abs(spatialDistance[1]) < maxSpatialDistance[1]):
spatialDistances.append(spatialDistance)
matches.append(m)
meanSpatialDistances = np.mean(spatialDistances, 0)
stdSpatialDistances = np.std(spatialDistances, 0)
inliesrs = (spatialDistances - meanSpatialDistances
) < 2.5 * stdSpatialDistances
goodMatches = []
prevPoints = []
currPoints = []
for i in range(len(matches)):
if inliesrs[i, 0] and inliesrs[i, 1]:
goodMatches.append(matches[i])
prevPoints.append(self.prevKeyPoints[matches[i].queryIdx].pt)
currPoints.append(keypoints[matches[i].trainIdx].pt)
prevPoints = np.array(prevPoints)
currPoints = np.array(currPoints)
# Draw the keypoint matches on the output image
if 0:
matches_img = np.hstack((self.prevFrame, frame))
matches_img = cv2.cvtColor(matches_img, cv2.COLOR_GRAY2BGR)
W = np.size(self.prevFrame, 1)
for m in goodMatches:
prev_pt = np.array(
self.prevKeyPoints[m.queryIdx].pt, dtype=np.int_)
curr_pt = np.array(keypoints[m.trainIdx].pt, dtype=np.int_)
curr_pt[0] += W
color = np.random.randint(0, 255, (3, ))
color = (int(color[0]), int(color[1]), int(color[2]))
matches_img = cv2.line(matches_img, prev_pt, curr_pt,
tuple(color), 1, cv2.LINE_AA)
matches_img = cv2.circle(matches_img, prev_pt, 2,
tuple(color), -1)
matches_img = cv2.circle(matches_img, curr_pt, 2,
tuple(color), -1)
plt.figure()
plt.imshow(matches_img)
plt.show()
# Find rigid matrix
if (np.size(prevPoints, 0) > 4) and (
np.size(prevPoints, 0) == np.size(prevPoints, 0)):
H, inliesrs = cv2.estimateAffinePartial2D(prevPoints, currPoints,
cv2.RANSAC)
# Handle downscale
if self.downscale > 1.0:
H[0, 2] *= self.downscale
H[1, 2] *= self.downscale
else:
print('Warning: not enough matching points')
# Store to next iteration
self.prevFrame = frame.copy()
self.prevKeyPoints = copy.copy(keypoints)
self.prevDescriptors = copy.copy(descriptors)
return H
def applySparseOptFlow(self, raw_frame, detections=None):
t0 = time.time()
# Initialize
height, width, _ = raw_frame.shape
frame = cv2.cvtColor(raw_frame, cv2.COLOR_BGR2GRAY)
H = np.eye(2, 3)
# Downscale image
if self.downscale > 1.0:
# frame = cv2.GaussianBlur(frame, (3, 3), 1.5)
frame = cv2.resize(frame, (width // self.downscale,
height // self.downscale))
# find the keypoints
keypoints = cv2.goodFeaturesToTrack(
frame, mask=None, **self.feature_params)
# Handle first frame
if not self.initializedFirstFrame:
# Initialize data
self.prevFrame = frame.copy()
self.prevKeyPoints = copy.copy(keypoints)
# Initialization done
self.initializedFirstFrame = True
return H
if self.prevFrame.shape != frame.shape:
self.prevFrame = frame.copy()
self.prevKeyPoints = copy.copy(keypoints)
return H
# find correspondences
matchedKeypoints, status, err = cv2.calcOpticalFlowPyrLK(
self.prevFrame, frame, self.prevKeyPoints, None)
# leave good correspondences only
prevPoints = []
currPoints = []
for i in range(len(status)):
if status[i]:
prevPoints.append(self.prevKeyPoints[i])
currPoints.append(matchedKeypoints[i])
prevPoints = np.array(prevPoints)
currPoints = np.array(currPoints)
# Find rigid matrix
if (np.size(prevPoints, 0) > 4) and (
np.size(prevPoints, 0) == np.size(prevPoints, 0)):
H, inliesrs = cv2.estimateAffinePartial2D(prevPoints, currPoints,
cv2.RANSAC)
# Handle downscale
if self.downscale > 1.0:
H[0, 2] *= self.downscale
H[1, 2] *= self.downscale
else:
print('Warning: not enough matching points')
# Store to next iteration
self.prevFrame = frame.copy()
self.prevKeyPoints = copy.copy(keypoints)
t1 = time.time()
# gmc_line = str(1000 * (t1 - t0)) + "\t" + str(H[0, 0]) + "\t" + str(H[0, 1]) + "\t" + str(
# H[0, 2]) + "\t" + str(H[1, 0]) + "\t" + str(H[1, 1]) + "\t" + str(H[1, 2]) + "\n"
# self.gmc_file.write(gmc_line)
return H
def applyFile(self, raw_frame, detections=None):
line = self.gmcFile.readline()
tokens = line.split("\t")
H = np.eye(2, 3, dtype=np.float_)
H[0, 0] = float(tokens[1])
H[0, 1] = float(tokens[2])
H[0, 2] = float(tokens[3])
H[1, 0] = float(tokens[4])
H[1, 1] = float(tokens[5])
H[1, 2] = float(tokens[6])
return H

View File

@@ -0,0 +1,316 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is based on https://github.com/nwojke/deep_sort/blob/master/deep_sort/kalman_filter.py
"""
import numpy as np
import scipy.linalg
use_numba = True
try:
import numba as nb
@nb.njit(fastmath=True, cache=True)
def nb_project(mean, covariance, std, _update_mat):
innovation_cov = np.diag(np.square(std))
mean = np.dot(_update_mat, mean)
covariance = np.dot(np.dot(_update_mat, covariance), _update_mat.T)
return mean, covariance + innovation_cov
@nb.njit(fastmath=True, cache=True)
def nb_multi_predict(mean, covariance, motion_cov, motion_mat):
mean = np.dot(mean, motion_mat.T)
left = np.dot(motion_mat, covariance)
covariance = np.dot(left, motion_mat.T) + motion_cov
return mean, covariance
@nb.njit(fastmath=True, cache=True)
def nb_update(mean, covariance, proj_mean, proj_cov, measurement, meas_mat):
kalman_gain = np.linalg.solve(proj_cov, (covariance @meas_mat.T).T).T
innovation = measurement - proj_mean
mean = mean + innovation @kalman_gain.T
covariance = covariance - kalman_gain @proj_cov @kalman_gain.T
return mean, covariance
except:
use_numba = False
print(
'Warning: Unable to use numba in PP-Tracking, please install numba, for example(python3.7): `pip install numba==0.56.4`'
)
pass
__all__ = ['KalmanFilter']
"""
Table for the 0.95 quantile of the chi-square distribution with N degrees of
freedom (contains values for N=1, ..., 9). Taken from MATLAB/Octave's chi2inv
function and used as Mahalanobis gating threshold.
"""
chi2inv95 = {
1: 3.8415,
2: 5.9915,
3: 7.8147,
4: 9.4877,
5: 11.070,
6: 12.592,
7: 14.067,
8: 15.507,
9: 16.919
}
class KalmanFilter(object):
"""
A simple Kalman filter for tracking bounding boxes in image space.
The 8-dimensional state space
x, y, a, h, vx, vy, va, vh
contains the bounding box center position (x, y), aspect ratio a, height h,
and their respective velocities.
Object motion follows a constant velocity model. The bounding box location
(x, y, a, h) is taken as direct observation of the state space (linear
observation model).
"""
def __init__(self):
ndim, dt = 4, 1.
# Create Kalman filter model matrices.
self._motion_mat = np.eye(2 * ndim, 2 * ndim, dtype=np.float32)
for i in range(ndim):
self._motion_mat[i, ndim + i] = dt
self._update_mat = np.eye(ndim, 2 * ndim, dtype=np.float32)
# Motion and observation uncertainty are chosen relative to the current
# state estimate. These weights control the amount of uncertainty in
# the model. This is a bit hacky.
self._std_weight_position = 1. / 20
self._std_weight_velocity = 1. / 160
def initiate(self, measurement):
"""
Create track from unassociated measurement.
Args:
measurement (ndarray): Bounding box coordinates (x, y, a, h) with
center position (x, y), aspect ratio a, and height h.
Returns:
The mean vector (8 dimensional) and covariance matrix (8x8
dimensional) of the new track. Unobserved velocities are
initialized to 0 mean.
"""
mean_pos = measurement
mean_vel = np.zeros_like(mean_pos)
mean = np.r_[mean_pos, mean_vel]
std = [
2 * self._std_weight_position * measurement[3],
2 * self._std_weight_position * measurement[3], 1e-2,
2 * self._std_weight_position * measurement[3],
10 * self._std_weight_velocity * measurement[3],
10 * self._std_weight_velocity * measurement[3], 1e-5,
10 * self._std_weight_velocity * measurement[3]
]
covariance = np.diag(np.square(std))
return mean, np.float32(covariance)
def predict(self, mean, covariance):
"""
Run Kalman filter prediction step.
Args:
mean (ndarray): The 8 dimensional mean vector of the object state
at the previous time step.
covariance (ndarray): The 8x8 dimensional covariance matrix of the
object state at the previous time step.
Returns:
The mean vector and covariance matrix of the predicted state.
Unobserved velocities are initialized to 0 mean.
"""
std_pos = [
self._std_weight_position * mean[3], self._std_weight_position *
mean[3], 1e-2, self._std_weight_position * mean[3]
]
std_vel = [
self._std_weight_velocity * mean[3], self._std_weight_velocity *
mean[3], 1e-5, self._std_weight_velocity * mean[3]
]
motion_cov = np.diag(np.square(np.r_[std_pos, std_vel]))
#mean = np.dot(self._motion_mat, mean)
mean = np.dot(mean, self._motion_mat.T)
covariance = np.linalg.multi_dot(
(self._motion_mat, covariance, self._motion_mat.T)) + motion_cov
return mean, covariance
def project(self, mean, covariance):
"""
Project state distribution to measurement space.
Args
mean (ndarray): The state's mean vector (8 dimensional array).
covariance (ndarray): The state's covariance matrix (8x8 dimensional).
Returns:
The projected mean and covariance matrix of the given state estimate.
"""
std = np.array(
[
self._std_weight_position * mean[3], self._std_weight_position *
mean[3], 1e-1, self._std_weight_position * mean[3]
],
dtype=np.float32)
if use_numba:
return nb_project(mean, covariance, std, self._update_mat)
innovation_cov = np.diag(np.square(std))
mean = np.dot(self._update_mat, mean)
covariance = np.linalg.multi_dot((self._update_mat, covariance,
self._update_mat.T))
return mean, covariance + innovation_cov
def multi_predict(self, mean, covariance):
"""
Run Kalman filter prediction step (Vectorized version).
Args:
mean (ndarray): The Nx8 dimensional mean matrix of the object states
at the previous time step.
covariance (ndarray): The Nx8x8 dimensional covariance matrics of the
object states at the previous time step.
Returns:
The mean vector and covariance matrix of the predicted state.
Unobserved velocities are initialized to 0 mean.
"""
std_pos = np.array([
self._std_weight_position * mean[:, 3], self._std_weight_position *
mean[:, 3], 1e-2 * np.ones_like(mean[:, 3]),
self._std_weight_position * mean[:, 3]
])
std_vel = np.array([
self._std_weight_velocity * mean[:, 3], self._std_weight_velocity *
mean[:, 3], 1e-5 * np.ones_like(mean[:, 3]),
self._std_weight_velocity * mean[:, 3]
])
sqr = np.square(np.r_[std_pos, std_vel]).T
if use_numba:
means = []
covariances = []
for i in range(len(mean)):
a, b = nb_multi_predict(mean[i], covariance[i],
np.diag(sqr[i]), self._motion_mat)
means.append(a)
covariances.append(b)
return np.asarray(means), np.asarray(covariances)
motion_cov = []
for i in range(len(mean)):
motion_cov.append(np.diag(sqr[i]))
motion_cov = np.asarray(motion_cov)
mean = np.dot(mean, self._motion_mat.T)
left = np.dot(self._motion_mat, covariance).transpose((1, 0, 2))
covariance = np.dot(left, self._motion_mat.T) + motion_cov
return mean, covariance
def update(self, mean, covariance, measurement):
"""
Run Kalman filter correction step.
Args:
mean (ndarray): The predicted state's mean vector (8 dimensional).
covariance (ndarray): The state's covariance matrix (8x8 dimensional).
measurement (ndarray): The 4 dimensional measurement vector
(x, y, a, h), where (x, y) is the center position, a the aspect
ratio, and h the height of the bounding box.
Returns:
The measurement-corrected state distribution.
"""
projected_mean, projected_cov = self.project(mean, covariance)
if use_numba:
return nb_update(mean, covariance, projected_mean, projected_cov,
measurement, self._update_mat)
kalman_gain = np.linalg.solve(projected_cov,
(covariance @self._update_mat.T).T).T
innovation = measurement - projected_mean
mean = mean + innovation @kalman_gain.T
covariance = covariance - kalman_gain @projected_cov @kalman_gain.T
return mean, covariance
def gating_distance(self,
mean,
covariance,
measurements,
only_position=False,
metric='maha'):
"""
Compute gating distance between state distribution and measurements.
A suitable distance threshold can be obtained from `chi2inv95`. If
`only_position` is False, the chi-square distribution has 4 degrees of
freedom, otherwise 2.
Args:
mean (ndarray): Mean vector over the state distribution (8
dimensional).
covariance (ndarray): Covariance of the state distribution (8x8
dimensional).
measurements (ndarray): An Nx4 dimensional matrix of N measurements,
each in format (x, y, a, h) where (x, y) is the bounding box center
position, a the aspect ratio, and h the height.
only_position (Optional[bool]): If True, distance computation is
done with respect to the bounding box center position only.
metric (str): Metric type, 'gaussian' or 'maha'.
Returns
An array of length N, where the i-th element contains the squared
Mahalanobis distance between (mean, covariance) and `measurements[i]`.
"""
mean, covariance = self.project(mean, covariance)
if only_position:
mean, covariance = mean[:2], covariance[:2, :2]
measurements = measurements[:, :2]
d = measurements - mean
if metric == 'gaussian':
return np.sum(d * d, axis=1)
elif metric == 'maha':
cholesky_factor = np.linalg.cholesky(covariance)
z = scipy.linalg.solve_triangular(
cholesky_factor,
d.T,
lower=True,
check_finite=False,
overwrite_b=True)
squared_maha = np.sum(z * z, axis=0)
return squared_maha
else:
raise ValueError('invalid distance metric')

View File

@@ -0,0 +1,93 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code is based on https://github.com/danbochman/SORT/blob/danny_opencv/kalman_filter.py
"""
import numpy as np
from numpy import dot, zeros, eye
from numpy.linalg import inv
use_numba = True
try:
import numba as nb
@nb.njit(fastmath=True, cache=True)
def nb_predict(x, F, P, Q):
x = dot(F, x)
P = dot(dot(F, P), F.T) + Q
return x, P
@nb.njit(fastmath=True, cache=True)
def nb_update(x, z, H, P, R, _I):
y = z - np.dot(H, x)
PHT = dot(P, H.T)
S = dot(H, PHT) + R
K = dot(PHT, inv(S))
x = x + dot(K, y)
I_KH = _I - dot(K, H)
P = dot(dot(I_KH, P), I_KH.T) + dot(dot(K, R), K.T)
return x, P
except:
use_numba = False
print(
'Warning: Unable to use numba in PP-Tracking, please install numba, for example(python3.7): `pip install numba==0.56.4`'
)
pass
class OCSORTKalmanFilter:
def __init__(self, dim_x, dim_z):
self.dim_x = dim_x
self.dim_z = dim_z
self.x = zeros((dim_x, 1))
self.P = eye(dim_x)
self.Q = eye(dim_x)
self.F = eye(dim_x)
self.H = zeros((dim_z, dim_x))
self.R = eye(dim_z)
self.M = zeros((dim_z, dim_z))
self._I = eye(dim_x)
def predict(self):
if use_numba:
self.x, self.P = nb_predict(self.x, self.F, self.P, self.Q)
else:
self.x = dot(self.F, self.x)
self.P = dot(dot(self.F, self.P), self.F.T) + self.Q
def update(self, z):
if z is None:
return
if use_numba:
self.x, self.P = nb_update(self.x, z, self.H, self.P, self.R,
self._I)
else:
y = z - np.dot(self.H, self.x)
PHT = dot(self.P, self.H.T)
S = dot(self.H, PHT) + self.R
K = dot(PHT, inv(S))
self.x = self.x + dot(K, y)
I_KH = self._I - dot(K, self.H)
self.P = dot(dot(I_KH, self.P), I_KH.T) + dot(dot(K, self.R), K.T)