from __future__ import annotations
__all__ = ["Trajectory"]
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from typing import Dict, Optional, Tuple, Callable
from numpy.typing import NDArray
from ADCS.helpers.math_helpers import quat_diff, quat_to_vec3
from ADCS.helpers.simresults import SimulationResults, RunResults
from ADCS.satellite_hardware.satellite import Satellite
[docs]
class Trajectory:
r"""
Container for trajectory optimization results produced by the ALTRO planner.
This class stores discrete-time trajectories of states, controls, feedback
gains, and cost-to-go values, and provides interpolation utilities for
real-time tracking control.
The reference trajectory defines a nominal solution
:math:`(\mathbf{x}^\ast(t), \mathbf{u}^\ast(t))` to a nonlinear optimal control
problem. A time-varying LQR feedback law is applied as
.. math::
\mathbf{u}(t) =
\mathbf{u}^\ast(t) - \mathbf{K}(t)\,\delta\mathbf{x}(t)
where :math:`\delta\mathbf{x}` is a reduced-dimension error state that replaces
the quaternion with a minimal 3D attitude error.
The class supports both row-major and column-major storage layouts and
automatically detects dimensions.
:param t:
Time stamps associated with the trajectory samples.
:type t:
numpy.ndarray
:param x:
State trajectory array, either time-major or state-major.
:type x:
numpy.ndarray
:param u:
Control trajectory array, either time-major or control-major.
:type u:
numpy.ndarray
:param K:
Feedback gain matrices for TVLQR tracking.
:type K:
numpy.ndarray
:param S:
Cost-to-go values along the trajectory.
:type S:
numpy.ndarray
:param use_disturbance_estimation:
Flag enabling augmented error state with disturbance estimation.
:type use_disturbance_estimation:
bool
"""
# Class-level type annotations
times: NDArray[np.float64]
states: NDArray[np.float64]
controls: NDArray[np.float64]
gains: NDArray[np.float64]
costs: NDArray[np.float64]
start_time: float
end_time: float
n_steps: int
state_dim: int
ctrl_dim: int
_is_row_major: bool
def __init__(
self,
t: NDArray[np.float64],
x: NDArray[np.float64],
u: NDArray[np.float64],
K: NDArray[np.float64],
S: NDArray[np.float64],
use_disturbance_estimation: bool = False
) -> None:
"""
Initialize trajectory from planner output.
Args:
t: Time array of shape (n_steps,)
x: State array, either (n_steps, state_dim) or (state_dim, n_steps)
u: Control array, either (n_steps-1, ctrl_dim) or (ctrl_dim, n_steps-1)
K: Feedback gains array
S: Cost-to-go array
use_disturbance_estimation: If True, gains have 3 extra columns for
disturbance estimation (KwDist mode, tracking_LQR_formulation=2)
"""
self.times = t
self.states = x
self.controls = u
self.gains = K
self.costs = S
self.use_disturbance_estimation = use_disturbance_estimation
self.start_time = float(t[0])
self.end_time = float(t[-1])
self.n_steps = len(t)
# Robust Dimension Detection
# Check if Axis 0 matches time steps (Row-Major: N x nx)
if x.shape[0] == self.n_steps:
self.state_dim = x.shape[1]
self._is_row_major = True
else:
self.state_dim = x.shape[0]
self._is_row_major = False
# Same check for controls
if u.shape[0] == self.n_steps or u.shape[0] == self.n_steps - 1:
self.ctrl_dim = u.shape[1]
else:
self.ctrl_dim = u.shape[0]
# Disturbance estimate for KwDist mode
if use_disturbance_estimation:
self._dist_estimate = np.zeros(3)
[docs]
def is_valid_time(self, t: float) -> bool:
r"""
Check whether a time lies within the trajectory bounds.
:param t:
Query time.
:type t:
float
:return:
True if the time is within the trajectory interval.
:rtype:
bool
"""
return self.start_time <= t <= self.end_time
[docs]
def get_state_at(self, t: float) -> np.ndarray:
r"""
Interpolate the reference state at a given time.
Linear interpolation is applied between neighboring samples. Quaternion
components are renormalized to preserve unit norm.
:param t:
Query time.
:type t:
float
:return:
Interpolated reference state vector.
:rtype:
numpy.ndarray
"""
idx = self._get_idx(t)
dt = self.times[idx+1] - self.times[idx]
# Get raw states based on layout
if self._is_row_major:
x0 = self.states[idx, :]
x1 = self.states[idx+1, :]
else:
x0 = self.states[:, idx]
x1 = self.states[:, idx+1]
if dt == 0: return x0
alpha = (t - self.times[idx]) / dt
# Linear Interpolation
state_interp = (1 - alpha) * x0 + alpha * x1
# Normalize Quaternion (indices 3:7) if state is large enough
if self.state_dim >= 7:
# Handle standard ADCS state vector: [w(3), q(4), h(3)]
q0 = x0[3:7]
q1 = x1[3:7]
# Simple lerp then normalize is sufficient for small steps
q_interp = (1 - alpha) * q0 + alpha * q1
if np.linalg.norm(q_interp) > 1e-9:
state_interp[3:7] = q_interp / np.linalg.norm(q_interp)
return state_interp
[docs]
def get_control_at(self, t: float) -> np.ndarray:
r"""
Interpolate the reference control at a given time.
The control trajectory is interpolated linearly between neighboring
time samples.
:param t:
Query time.
:type t:
float
:return:
Interpolated reference control vector.
:rtype:
numpy.ndarray
"""
idx = self._get_idx(t)
# Helper to extract u at index i handling layout
def get_u(i):
# Clamp index for N-1 controls
limit = self.controls.shape[0] if self._is_row_major else self.controls.shape[1]
if i >= limit: i = limit - 1
if self._is_row_major:
return self.controls[i, :]
return self.controls[:, i]
u0 = get_u(idx)
u1 = get_u(idx+1)
dt = self.times[idx+1] - self.times[idx]
if dt == 0: return u0
alpha = (t - self.times[idx]) / dt
return (1 - alpha) * u0 + alpha * u1
[docs]
def get_gain_at(self, t: float) -> np.ndarray:
r"""
Retrieve the feedback gain matrix at a given time.
The gain matrix :math:`\mathbf{K}(t)` maps the reduced error state
to control corrections. Multiple storage conventions are supported,
including time-major, control-major, and flattened formats.
:param t:
Query time.
:type t:
float
:return:
Feedback gain matrix corresponding to the query time.
:rtype:
numpy.ndarray
"""
idx = (np.abs(self.times - t)).argmin()
# Handle Gain shape conventions
# Expected: (N, nu, nx) [Row Major Time] OR (nu, nx, N) [Col Major Time]
# Check dimensions
if self.gains.ndim == 3:
if self.gains.shape[0] >= self.n_steps - 1:
# Time is first axis
safe_idx = min(idx, self.gains.shape[0]-1)
return self.gains[safe_idx, :, :]
elif self.gains.shape[2] >= self.n_steps - 1:
# Time is last axis
safe_idx = min(idx, self.gains.shape[2]-1)
return self.gains[:, :, safe_idx]
# Fallback for flattened gain storage.
# Gain matrix K maps error state (reduced) to control:
# K is (n_ctrl, n_err) where n_err = state_dim - 1
# The -1 accounts for quaternion 4D → 3D reduction in error state.
#
# For KwDist mode (disturbance estimation), gains will be
# (n_ctrl, n_err + 3) but this is handled by checking actual shape.
k_flat = self.gains[:, idx]
if self.use_disturbance_estimation:
# KwDist gains: (ctrl_dim, state_dim - 1 + 3)
error_dim_with_dist = self.state_dim - 1 + 3
return k_flat.reshape(self.ctrl_dim, error_dim_with_dist)
else:
return k_flat.reshape(self.ctrl_dim, self.state_dim - 1)
[docs]
def compute_tracking_control(self, t: float, x_current: np.ndarray) -> np.ndarray:
r"""
Compute the tracking control input at a given time and state.
The control law is
.. math::
\mathbf{u} =
\mathbf{u}^\ast(t) - \mathbf{K}(t)\,\delta\mathbf{x}
where :math:`\delta\mathbf{x}` is the reduced error state computed from
the current and reference states.
:param t:
Query time.
:type t:
float
:param x_current:
Current full state vector.
:type x_current:
numpy.ndarray
:return:
Control input computed by the tracking controller.
:rtype:
numpy.ndarray
"""
if not self.is_valid_time(t):
raise ValueError(f"Time {t} is outside bounds")
x_ref = self.get_state_at(t)
u_ref = self.get_control_at(t)
K = self.get_gain_at(t)
dx = self._state_diff(x_current, x_ref)
if self.use_disturbance_estimation:
# KwDist mode: augment error state with disturbance estimate
dx_aug = np.concatenate([dx, self._dist_estimate])
return u_ref - K @ dx_aug
else:
return u_ref - K @ dx
def _state_diff(self, x_curr: np.ndarray, x_ref: np.ndarray) -> np.ndarray:
r"""
Compute the reduced error state for TVLQR feedback.
The error state is defined as
.. math::
\delta\mathbf{x} =
\begin{bmatrix}
\boldsymbol{\omega} - \boldsymbol{\omega}^\ast \\
2\,\mathrm{vec}\!\left(q_\text{ref}^{-1} \otimes q\right) \\
\mathbf{h} - \mathbf{h}^\ast
\end{bmatrix}
where the quaternion error is reduced from 4D to 3D.
:param x_curr:
Current state vector.
:type x_curr:
numpy.ndarray
:param x_ref:
Reference state vector.
:type x_ref:
numpy.ndarray
:return:
Reduced-dimension error state vector.
:rtype:
numpy.ndarray
"""
# Number of reaction wheels: full state = 7 + n_rw
n_rw = self.state_dim - 7
error_dim = 6 + n_rw
dx = np.zeros(error_dim)
# 1. Angular Velocity Error (indices 0:3)
dx[0:3] = x_curr[0:3] - x_ref[0:3]
# 2. Attitude Error (indices 3:6)
# quat_diff returns q_ref^(-1) * q_curr
q_err = quat_diff(x_ref[3:7], x_curr[3:7])
# Match the C++ planner's reduced attitude error convention: 2×MRP.
dx[3:6] = quat_to_vec3(q_err, 5)
# 3. RW Momentum Error (indices 6:6+n_rw, from full state 7:7+n_rw)
if n_rw > 0:
dx[6:6+n_rw] = x_curr[7:7+n_rw] - x_ref[7:7+n_rw]
return dx
[docs]
def update_disturbance_estimate(self, dist_torque: np.ndarray) -> None:
r"""
Update the internal disturbance torque estimate.
This method is used when disturbance estimation is enabled in the
tracking controller.
:param dist_torque:
Estimated disturbance torque in the body frame.
:type dist_torque:
numpy.ndarray
:return:
None.
:rtype:
None
"""
if self.use_disturbance_estimation:
self._dist_estimate = np.asarray(dist_torque).flatten()[:3]
[docs]
def get_plotting_data(self) -> Dict[str, NDArray[np.float64]]:
r"""
Return trajectory data packaged for plotting.
:return:
Dictionary containing time, state, control, and cost arrays.
:rtype:
dict
"""
return {
"time": self.times,
"state": self.states,
"control": self.controls,
"cost": self.costs
}
def _get_idx(self, t: float) -> int:
r"""
Determine the lower index for interpolation at a given time.
:param t:
Query time.
:type t:
float
:return:
Index corresponding to the interval containing the query time.
:rtype:
int
"""
if t >= self.end_time:
return self.n_steps - 2
idx = np.searchsorted(self.times, t, side='right') - 1
return max(0, min(idx, self.n_steps - 2))
# --- Visualization Methods ---
[docs]
def plot_eci_trajectory(self,
body_axis: np.ndarray = np.array([0, 0, 1]),
stride: int = 1,
show: bool = True):
r"""
Plot the trajectory of a body-fixed axis expressed in the ECI frame.
The body axis is rotated into the inertial frame using the quaternion
trajectory, producing a 3D curve on the unit sphere.
:param body_axis:
Body-frame axis to visualize.
:type body_axis:
numpy.ndarray
:param stride:
Subsampling stride for plotting.
:type stride:
int
:param show:
Flag indicating whether to display the plot immediately.
:type show:
bool
:return:
None.
:rtype:
None
"""
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection='3d')
# --- Robust State Extraction ---
# We need quaternions (indices 3:7) for every time step.
# Goal: quats shape (4, N_points)
# Check if Axis 0 is Time (N, nx)
if self.states.shape[0] == self.n_steps:
# Row-Major: Slice rows by stride, grab cols 3:7, Transpose to (4, N)
quats = self.states[::stride, 3:7].T
# Check if Axis 1 is Time (nx, N)
elif self.states.shape[1] == self.n_steps:
# Col-Major: Grab cols 3:7, slice cols by stride -> (4, N)
quats = self.states[3:7, ::stride]
else:
raise ValueError(f"State shape {self.states.shape} does not match n_steps={self.n_steps}")
times = self.times[::stride]
# Ensure lengths match exactly (handle any potential off-by-one from slicing)
# Usually standard slicing [::s] is consistent for both arrays of same len
v_body = body_axis / np.linalg.norm(body_axis)
# Rotate body vector to ECI
v_eci_list = []
for i in range(quats.shape[1]):
q = quats[:, i]
# Safety normalize
norm = np.linalg.norm(q)
if norm > 1e-6: q = q / norm
v_eci = self._rotate_vector(q, v_body)
v_eci_list.append(v_eci)
v_eci = np.array(v_eci_list).T # (3, N_plotted)
# --- DEBUG PRINT (Optional) ---
# print(f"DEBUG: Plotting {v_eci.shape[1]} points. Time len: {len(times)}")
# Plot Sphere
u, v = np.mgrid[0:2*np.pi:20j, 0:np.pi:10j]
x_sphere = np.cos(u)*np.sin(v)
y_sphere = np.sin(u)*np.sin(v)
z_sphere = np.cos(v)
ax.plot_wireframe(x_sphere, y_sphere, z_sphere, color="gray", alpha=0.15)
# Plot Trace
# c argument needs to match x/y size.
p = ax.scatter(v_eci[0, :], v_eci[1, :], v_eci[2, :], c=times, cmap='viridis', s=10, label='Trace')
# Markers
ax.scatter(v_eci[0, 0], v_eci[1, 0], v_eci[2, 0], color='green', s=100, marker='o', label='Start')
ax.scatter(v_eci[0, -1], v_eci[1, -1], v_eci[2, -1], color='red', s=100, marker='X', label='End')
ax.set_xlabel("ECI X")
ax.set_ylabel("ECI Y")
ax.set_zlabel("ECI Z")
ax.set_title(f"Trajectory of Body Axis {body_axis} in ECI")
ax.legend()
self._set_axes_equal(ax)
cbar = fig.colorbar(p, ax=ax, shrink=0.5, aspect=10)
cbar.set_label('Time (J2000)')
if show:
plt.show()
def _rotate_vector(self, q: np.ndarray, v: np.ndarray) -> np.ndarray:
q_scalar = q[0]
q_vec = q[1:]
t = 2 * np.cross(q_vec, v)
v_prime = v + q_scalar * t + np.cross(q_vec, t)
return v_prime
def _set_axes_equal(self, ax):
x_limits = ax.get_xlim3d()
y_limits = ax.get_ylim3d()
z_limits = ax.get_zlim3d()
x_range = abs(x_limits[1] - x_limits[0])
x_middle = np.mean(x_limits)
y_range = abs(y_limits[1] - y_limits[0])
y_middle = np.mean(y_limits)
z_range = abs(z_limits[1] - z_limits[0])
z_middle = np.mean(z_limits)
plot_radius = 0.5 * max([x_range, y_range, z_range])
ax.set_xlim3d([x_middle - plot_radius, x_middle + plot_radius])
ax.set_ylim3d([y_middle - plot_radius, y_middle + plot_radius])
ax.set_zlim3d([z_middle - plot_radius, z_middle + plot_radius])
[docs]
def to_simulation_results(
self,
satellite: Satellite,
target: Optional[np.ndarray] = None,
w_target: Optional[np.ndarray] = None,
boresight: Optional[np.ndarray] = None,
*,
include_cost: bool = False,
) -> SimulationResults:
time_J2000 = np.asarray(self.times)
time_s = (time_J2000 - time_J2000[0]) * 36525.0 * 86400.0
N = len(time_J2000)
if self._is_row_major:
state_hist = np.asarray(self.states)
control_hist = np.asarray(self.controls)
else:
state_hist = np.asarray(self.states.T)
control_hist = np.asarray(self.controls.T)
run = RunResults(
satellite=satellite,
time_J2000=time_J2000,
time_s=time_s,
state_hist=state_hist,
control_hist=control_hist,
)
if target is not None:
target = np.asarray(target)
if target.ndim == 1:
run.target_hist = [target.copy() for _ in range(N)]
else:
run.target_hist = [target[k].copy() for k in range(N)]
if w_target is not None:
w_target = np.asarray(w_target)
if w_target.ndim == 1:
run.w_target_hist = [w_target.copy() for _ in range(N)]
else:
run.w_target_hist = [w_target[k].copy() for k in range(N)]
if boresight is not None:
boresight = np.asarray(boresight)
if boresight.ndim == 1:
run.boresight_hist = [boresight.copy() for _ in range(N)]
else:
run.boresight_hist = [boresight[k].copy() for k in range(N)]
if include_cost and self.costs is not None:
costs = np.asarray(self.costs)
run.target_hist = [np.array([costs[k]]) for k in range(N)]
return SimulationResults(runs=[run])