Source code for mekf

import numpy as np
import matplotlib.pyplot as plt
import utils as u


[docs] class MEKF: """ Class wrapper around MEKF implementation. """
[docs] def __init__(self, inputs: dict | None = None) -> None: if inputs is None: inputs = {} # Constants self.I3 = np.eye(3) self.O3 = np.zeros((3, 3)) self.pi = np.pi self.arcsec_to_rad = self.pi / (180 * 3600) self.degh_to_rads = self.pi / (180 * 3600) # Simulation and sensor parameters (defaults taken from ekf.py) self.t_max = inputs.get("t_max", 1000) self.dt = inputs.get("dt", 0.01) self.sigma_startracker = inputs.get("sigma_startracker", 6) self.sigma_v = inputs.get("sigma_v", (10 ** 0.5) * 1e-6) self.sigma_u = inputs.get("sigma_u", (10 ** 0.5) * 1e-9) self.freq_startracker = inputs.get("freq_startracker", 1) self.freq_gyro = inputs.get("freq_gyro", 20) self.init_inaccuracy = inputs.get("init_inaccuracy", 30) self.rng_seed = inputs.get("rng_seed", 1) # Algorithm switches self.Joseph = inputs.get("Joseph", True) self.simple_Phi = inputs.get("simple_Phi", False) # Initial estimate covariances and bias self.B_h_0 = np.asarray(inputs.get("B_h_0", np.array([0, 0, 0])), dtype=float) self.Pq = np.asarray( inputs.get( "Pq", (self.init_inaccuracy * self.arcsec_to_rad) ** 2 * self.I3, ), dtype=float, ) self.Pb = np.asarray( inputs.get("Pb", (0.2 * self.degh_to_rads) ** 2 * self.I3), dtype=float ) # Ground-truth initial conditions self.B_t_0 = np.asarray( inputs.get( "B_t_0", np.array([0.1, 0.1, 0.1]) * self.degh_to_rads ), dtype=float, ) self.q_t_0 = np.asarray( inputs.get("q_t_0", np.array([1, 0, 0, 1]) / 2 ** 0.5), dtype=float ) # Optional custom truth-rate function self.w_t_fun = inputs.get("w_t_fun", self._default_w_t_fun) # Results container (populated by calculate) self.results: dict | None = None
# --- Defaults --------------------------------------------------------- def _default_w_t_fun(self, t: np.ndarray) -> np.ndarray: w1 = 0.1 * np.sin(0.01 * t) * self.pi / 180 w2 = 0.1 * np.sin(0.0085 * t) * self.pi / 180 w3 = 0.1 * np.cos(0.0085 * t) * self.pi / 180 return np.vstack((w1, w2, w3)) # --- Core computation --------------------------------------------------
[docs] def calculate(self) -> dict: """ Run the MEKF simulation and return a dictionary with logs and results. The same information is also stored on self.results. """ I3 = self.I3 O3 = self.O3 arcsec_to_rad = self.arcsec_to_rad # Measurement model matrices H = np.hstack((I3, O3)) R = I3 * (self.sigma_startracker * arcsec_to_rad) ** 2 # Truth time grid and measurement indices times = np.arange(0, self.t_max, self.dt) idx_gyro = u.measurement_indices(self.t_max, self.dt, self.freq_gyro) idx_star = u.measurement_indices(self.t_max, self.dt, self.freq_startracker) idx_all = idx_gyro | idx_star timesteps = len(idx_all) n = 3 w_t_l = self.w_t_fun(times) np.random.seed(self.rng_seed) # Initial attitude estimate from a noisy measurement Z_n = np.random.normal( 0, self.sigma_startracker * arcsec_to_rad * self.init_inaccuracy, n ).reshape(-1, 1) q_m_0 = self.q_t_0.reshape(-1, 1) + 0.5 * u.Xi(self.q_t_0) @ Z_n q_m_0 = q_m_0.flatten() / np.linalg.norm(q_m_0) q_h_0 = q_m_0 # Allocate logs s_l = np.empty((6, timesteps)) q_h_l = np.empty((4, timesteps)) B_h_l = np.empty((3, timesteps)) q_t_l = np.empty((4, timesteps)) Z_d_l = np.empty((3, timesteps)) B_t_l = np.empty((3, timesteps)) t_l = np.empty(timesteps) G_l = np.empty(timesteps) # Initial states q_t = self.q_t_0.copy() B_t = self.B_t_0.copy() B_h = self.B_h_0.copy() q_h = q_h_0.copy() q_d = u.quat_mul(q_t, u.quat_inv(q_h)) Z_d = u.quat_to_rotvec(q_d) G = np.linalg.norm(Z_d) P = np.block([[self.Pq, O3], [O3, self.Pb]]) # Optional initial log at t=0 k = 0 if 0 in idx_all and k < timesteps: s = np.sqrt(np.diag(P)) s_l[:, k] = s t_l[k] = times[0] q_h_l[:, k] = q_h q_t_l[:, k] = q_t Z_d_l[:, k] = Z_d B_h_l[:, k] = B_h.flatten() B_t_l[:, k] = B_t k += 1 last_gyro_i = 0 for i in range(1, len(times)): # Propagate ground truth w_t = w_t_l[:, i - 1] q_t = u.quat_propagate(q_t, w_t, self.dt) B_t = B_t + np.random.normal(0, self.sigma_u * self.dt**0.5, n) # Prediction on gyro event if i in idx_gyro: dt_g = times[i] - times[last_gyro_i] if dt_g <= 0: dt_g = self.dt w_t_meas = w_t_l[:, i] if i < w_t_l.shape[1] else w_t_l[:, -1] w_m = w_t_meas + B_t + np.random.standard_normal(n) * ( self.sigma_v / np.sqrt(dt_g) ) w_h = w_m - B_h Phi = u.Phi(dt_g, w_h, I3, self.simple_Phi) Qk = u.Q(self.sigma_v, self.sigma_u, dt_g, I3) P = u.P_prop(P, Phi, Qk) q_h = u.quat_propagate(q_h, w_h, dt_g) last_gyro_i = i # Update on star tracker event if i in idx_star: dZ_m = u.startracker_meas( q_t, q_h, self.sigma_startracker * arcsec_to_rad, n ) K, K_Z, K_B = u.K(P, H, R) P = u.P_meas(K, H, P, R, self.Joseph) dB_h = K_B @ dZ_m dZ_h = K_Z @ dZ_m B_h = B_h + dB_h theta = np.linalg.norm(dZ_h) if theta > 0: axis = dZ_h / theta dq_err = np.hstack((axis * np.sin(0.5 * theta), np.cos(0.5 * theta))) else: dq_err = np.array([0.0, 0.0, 0.0, 1.0]) q_h = u.quat_mul(dq_err, q_h) q_h = q_h / np.linalg.norm(q_h) q_d = u.quat_mul(q_t, u.quat_inv(q_h)) Z_d = u.quat_to_rotvec(q_d) G = np.linalg.norm(Z_d) # Log at measurement events if i in idx_all and k < timesteps: s = np.sqrt(np.diag(P)) s_l[:, k] = s t_l[k] = times[i] G_l[k] = G q_h_l[:, k] = q_h q_t_l[:, k] = q_t Z_d_l[:, k] = Z_d B_h_l[:, k] = B_h.flatten() B_t_l[:, k] = B_t k += 1 B_d = B_t_l - B_h_l self.results = { "t": t_l, "G": G_l, "q_h": q_h_l, "q_t": q_t_l, "Z_d": Z_d_l, "B_h": B_h_l, "B_t": B_t_l, "B_d": B_d, "s": s_l, } return self.results
# --- Plotting helpers -------------------------------------------------- def _require_results(self) -> dict: if self.results is None: raise RuntimeError("No results found. Call calculate() first.") return self.results
[docs] def plot_pointing_error(self) -> None: r = self._require_results() plt.figure(figsize=(10, 6)) plt.plot(r["t"], r["G"]) plt.title("Total pointing error") plt.ylabel("Error (rad)") plt.xlabel("Time (s)") plt.grid(True)
[docs] def plot_bias(self) -> None: r = self._require_results() plt.figure(figsize=(10, 6)) plt.plot(r["t"], r["B_t"][0, :]) plt.plot(r["t"], r["B_h"][0, :]) plt.title("Bias and Estimated Bias (X component)") plt.ylabel("Bias (rad/s)") plt.xlabel("Time (s)") plt.grid(True)
[docs] def plot_attitude(self) -> None: r = self._require_results() fig, axs = plt.subplots(2, 2, figsize=(18, 10), sharex=True) axs[0, 0].plot(r["t"], r["q_t"][0, :]) axs[0, 0].plot(r["t"], r["q_h"][0, :]) axs[0, 1].plot(r["t"], r["q_t"][1, :]) axs[0, 1].plot(r["t"], r["q_h"][1, :]) axs[1, 0].plot(r["t"], r["q_t"][2, :]) axs[1, 0].plot(r["t"], r["q_h"][2, :]) axs[1, 1].plot(r["t"], r["q_t"][3, :]) axs[1, 1].plot(r["t"], r["q_h"][3, :]) axs[0, 0].set_title("X-component") axs[0, 1].set_title("Y-component") axs[1, 0].set_title("Z-component") axs[1, 1].set_title("W-component") for i in range(2): for j in range(2): axs[i, j].set_ylabel("Attitude (rad)") axs[i, j].legend(loc="upper right") axs[i, j].set_xlabel("Time (s)") axs[i, j].grid(True) plt.tight_layout()
[docs] def plot_errors_with_bounds(self) -> None: r = self._require_results() fig, axs = plt.subplots(2, 3, figsize=(18, 10), sharex=True) fig.suptitle( "Attitude and gyro bias estimation errors with 3-sigma bounds", fontsize=16, ) component = ["X", "Y", "Z"] for i in range(3): ax = axs[0, i] ax.plot(r["t"], r["Z_d"][i, :], "b", label=f"Error Axis {i+1}") ax.plot(r["t"], 3 * r["s"][i, :], "r--", label="3-sigma") ax.plot(r["t"], -3 * r["s"][i, :], "r--") ax.set_title(f"Attitude error - Component {component[i]}") ax.set_ylabel("Error (rad)") ax.grid(True) ax.legend() for i in range(3): ax = axs[1, i] ax.plot(r["t"], r["B_d"][i, :], "b", label=f"Error Axis {i+1}") ax.plot(r["t"], 3 * r["s"][i + 3, :], "r--", label="3-sigma") ax.plot(r["t"], -3 * r["s"][i + 3, :], "r--") ax.set_title(f"Gyro bias error - Component {component[i]}") ax.set_ylabel("Error (rad/s)") ax.grid(True) ax.legend() for ax in axs.flat: ax.set_xlabel("Time (s)") plt.tight_layout(rect=[0, 0.03, 1, 0.95])
[docs] def plot_all(self) -> None: self.plot_pointing_error() self.plot_bias() self.plot_attitude() self.plot_errors_with_bounds() plt.show()
if __name__ == "__main__": mekf = MEKF() mekf.calculate() mekf.plot_all()