Source code for ADCS.controller.plan_and_track.debug_planner

"""
Debug wrapper for the trajectory planner that provides step-by-step output.

This module provides a DebugPlanner class that wraps the tplaunch.Planner
and exposes the same interface, while providing detailed diagnostic output
about the backward/forward passes, constraint violations, and active sets.

Usage:
    # In plan_and_track_exact.py, replace:
    #   self.planner = tplaunch.Planner(...)
    # With:
    #   from ADCS.controller.plan_and_track.debug_planner import DebugPlanner
    #   self.planner = DebugPlanner(csat, ..., debug_level=2)
"""
from __future__ import annotations

__all__ = ["DebugPlanner"]

import numpy as np
from typing import Tuple, Optional, Dict, List, Any, TextIO
from numpy.typing import NDArray
from ADCS.controller.helpers.optional_dependencies import get_trajectory_planner_modules


[docs] class DebugPlanner: """ A wrapper around tplaunch.Planner that provides debugging output. debug_level: 0 = No debug output (passthrough to underlying planner) 1 = Summary output (cost, max violation, iteration counts) 2 = Detailed output (control patterns, constraint states) 3 = Very detailed (per-timestep values) """ def __init__( self, csat, systemSettings, alilqrSettings, alilqrSettings2, initialTrajSettings, costSettings, costSettings2, costSettings_tvlqr, debug_level: int = 1, log_file: Optional[str] = None ): tplaunch, _ = get_trajectory_planner_modules() # Store the underlying planner self._planner = tplaunch.Planner( csat, systemSettings, alilqrSettings, alilqrSettings2, initialTrajSettings, costSettings, costSettings2, costSettings_tvlqr ) self.debug_level: int = debug_level self.log_file: Optional[str] = log_file self._log_handle: Optional[TextIO] = None # Extract control limits from alilqr settings for analysis # The settings structure: (lineSearch, auglag, break, reg) self._auglag_settings = alilqrSettings[1] # (lam_init, lam_max, mu_init, mu_max, mu_scale) # Track iteration history self.iteration_history: List[Dict[str, Any]] = [] self.control_history: List[NDArray[np.float64]] = [] self.violation_history: List[float] = [] def _log(self, msg: str, level: int = 1) -> None: """Log a message if debug_level is sufficient.""" if self.debug_level >= level: print(msg) if self._log_handle: self._log_handle.write(msg + "\n") def _analyze_controls(self, Uset: NDArray[np.float64], u_limit: Optional[float] = None) -> Dict[str, List[Any]]: """Analyze control trajectory for oscillations and saturation.""" if Uset.ndim == 1: Uset = Uset.reshape(-1, 1).T ctrl_dim, N = Uset.shape analysis = { 'sign_changes': [], 'saturation_count': [], 'rapid_oscillations': [], 'range': [] } for ch in range(ctrl_dim): u_ch = Uset[ch, :] # Sign changes sign_changes = np.sum(np.diff(np.sign(u_ch)) != 0) analysis['sign_changes'].append(sign_changes) # Saturation if u_limit: at_sat = np.sum(np.abs(u_ch) > 0.95 * u_limit) else: at_sat = 0 analysis['saturation_count'].append(at_sat) # Rapid oscillations (+-+ or -+- patterns) rapid_osc = 0 if len(u_ch) > 2: signs = np.sign(u_ch) for i in range(len(signs) - 2): if signs[i] != signs[i+1] and signs[i+1] != signs[i+2]: rapid_osc += 1 analysis['rapid_oscillations'].append(rapid_osc) # Range analysis['range'].append((float(u_ch.min()), float(u_ch.max()))) return analysis def _print_control_analysis(self, Uset: NDArray[np.float64], prefix: str = "") -> None: """Print control analysis summary.""" analysis = self._analyze_controls(Uset) self._log(f"{prefix}Control Analysis:", 1) for ch in range(len(analysis['sign_changes'])): sc = analysis['sign_changes'][ch] ro = analysis['rapid_oscillations'][ch] r = analysis['range'][ch] warn = " [OSCILLATING!]" if ro > 5 else "" self._log(f" Ch{ch}: sign_changes={sc}, rapid_osc={ro}, range=[{r[0]:.4e}, {r[1]:.4e}]{warn}", 1) # ========================================================================= # Passthrough methods that match tplaunch.Planner interface # =========================================================================
[docs] def readParameters(self): return self._planner.readParameters()
[docs] def readDebug(self): return self._planner.readDebug()
[docs] def updateParameters(self, *args): return self._planner.updateParameters(*args)
[docs] def setVerbosity(self, verbosity: bool): return self._planner.setVerbosity(verbosity)
[docs] def setquaternionTo3VecMode(self, val: int): return self._planner.setquaternionTo3VecMode(val)
[docs] def getdt(self): return self._planner.getdt()
[docs] def echo_int(self, x: int): return self._planner.echo_int(x)
# ========================================================================= # Main trajectory optimization with debug output # =========================================================================
[docs] def trajOpt(self, vecsPy, N, time_start, time_end, x0, bdotOn): """ Main trajectory optimization with debug output. """ self._log("\n" + "=" * 70, 1) self._log("DEBUG PLANNER: trajOpt called", 1) self._log("=" * 70, 1) self._log(f" N={N}, time_start={time_start:.6f}, time_end={time_end:.6f}", 1) self._log(f" bdotOn={bdotOn}", 1) self._log(f" x0 shape: {x0.shape}, x0={x0[:7]}...", 2) # Open log file if specified if self.log_file and not self._log_handle: self._log_handle = open(self.log_file, 'w') # Call underlying trajOpt result = self._planner.trajOpt(vecsPy, N, time_start, time_end, x0, bdotOn) # Analyze result (success, cost, opt1, lqr_opt, traj_final) = result (Xset, Uset, Tset, Kset, Sset, lqr_times) = lqr_opt self._log(f"\n Result: success={success}, final_cost={cost:.6e}", 1) self._log(f" Xset shape: {Xset.shape}, Uset shape: {Uset.shape}", 2) # Analyze final controls self._print_control_analysis(Uset, prefix=" Final ") # Close log file if self._log_handle: self._log_handle.close() self._log_handle = None return result
# ========================================================================= # Step-by-step methods with debug output # =========================================================================
[docs] def prepareForAlilqr(self, vecsPy, dt_use, time_start, time_end, x0, bdotOn): """Prepare for ALILQR with debug output.""" self._log("\n" + "-" * 50, 2) self._log("DEBUG: prepareForAlilqr called", 2) result = self._planner.prepareForAlilqr(vecsPy, dt_use, time_start, time_end, x0, bdotOn) (traj, vecs_dt, costSettings) = result (Xset, Uset, Tset, _) = traj self._log(f" Initial trajectory: Xset={Xset.shape}, Uset={Uset.shape}", 2) self._print_control_analysis(Uset, prefix=" Initial ") return result
[docs] def backwardPass(self, dt, trajPy, vecsPy, auglag_valsPy, regs, costSettings, regSettings, useDist): """Backward pass with debug output.""" self._log("\n DEBUG: backwardPass", 3) result = self._planner.backwardPass(dt, trajPy, vecsPy, auglag_valsPy, regs, costSettings, regSettings, useDist) (bp_results, new_regs) = result (Kset, dset, Sset) = bp_results self._log(f" dset shape: {dset.shape}", 3) self._log(f" Regs: {regs} -> {new_regs}", 3) # Analyze feedforward term (d) for oscillations if dset.ndim == 2: d_analysis = self._analyze_controls(dset) for ch in range(len(d_analysis['sign_changes'])): sc = d_analysis['sign_changes'][ch] ro = d_analysis['rapid_oscillations'][ch] if ro > 5: self._log(f" WARNING: d[{ch}] has {ro} rapid oscillations!", 2) return result
[docs] def forwardPass(self, dt, trajPy, vecsPy, auglag_valsPy, BPresultsPy, regs, costSettings, regSettings, lineSearchSettings, useDist): """Forward pass with debug output.""" self._log("\n DEBUG: forwardPass", 3) result = self._planner.forwardPass(dt, trajPy, vecsPy, auglag_valsPy, BPresultsPy, regs, costSettings, regSettings, lineSearchSettings, useDist) (traj_new, newLA, new_regs) = result (Xset, Uset, _, _) = traj_new self._log(f" New cost: {newLA:.6e}", 3) self._print_control_analysis(Uset, prefix=" ") return result
[docs] def maxViol(self, trajPy, vecsPy, auglag_valsPy): """Max violation with debug output.""" result = self._planner.maxViol(trajPy, vecsPy, auglag_valsPy) (clist, cmax) = result self._log(f" Max violation: {cmax:.6e}", 2) # Analyze constraint activity if self.debug_level >= 3 and clist.ndim == 2: num_constraints, N = clist.shape active_counts = np.sum(clist > 0, axis=1) for i in range(num_constraints): self._log(f" Constraint {i}: active at {active_counts[i]}/{N} timesteps", 3) return result
[docs] def incrementAugLag(self, auglag_valsPy, clistPy, auglagSettings): """Increment augmented Lagrangian with debug output.""" (lam_old, mu_old, muk_old) = auglag_valsPy result = self._planner.incrementAugLag(auglag_valsPy, clistPy, auglagSettings) (lam_new, mu_new, muk_new) = result self._log(f" AugLag update: mu {mu_old:.2e} -> {mu_new:.2e}", 2) return result
[docs] def ilqrStep(self, dt, trajPy, vecsPy, auglag_valsPy, regs, costSettings, regSettings, lineSearchSettings, breakSettings, useDist): """Single iLQR step with debug output.""" self._log("\n DEBUG: ilqrStep", 3) result = self._planner.ilqrStep(dt, trajPy, vecsPy, auglag_valsPy, regs, costSettings, regSettings, lineSearchSettings, breakSettings, useDist) (newLA, cmax, clist, grad, new_regs, traj_new) = result self._log(f" newLA={newLA:.6e}, cmax={cmax:.6e}, grad={grad:.6e}", 3) return result
[docs] def alilqr(self, dt, trajPy, vecsPy, costSettings, alilqrSettings, isFirstSearch): """Full ALILQR with debug output.""" self._log("\n" + "-" * 50, 1) self._log(f"DEBUG: alilqr (isFirstSearch={isFirstSearch})", 1) result = self._planner.alilqr(dt, trajPy, vecsPy, costSettings, alilqrSettings, isFirstSearch) (opt_result, final_cost, cmax) = result (Xset, Uset, Tset, Kset, Sset, times) = opt_result self._log(f" Final cost: {final_cost:.6e}, max violation: {cmax:.6e}", 1) self._print_control_analysis(Uset, prefix=" ") return result
[docs] def cost2Func(self, trajPy, vecsPy, auglag_valsPy, costSettings): """Cost function with debug output.""" cost = self._planner.cost2Func(trajPy, vecsPy, auglag_valsPy, costSettings) self._log(f" Cost: {cost:.6e}", 3) return cost
[docs] def ilqrBreak(self, grad, LA, dLA, dlaZcount, cmax, iter, breakSettings, outer_iter, ilqr_iter, forOuter): """Break condition check with debug output.""" result = self._planner.ilqrBreak(grad, LA, dLA, dlaZcount, cmax, iter, breakSettings, outer_iter, ilqr_iter, forOuter) if result: self._log(f" iLQR break: outer={outer_iter}, inner={ilqr_iter}, grad={grad:.2e}, cmax={cmax:.2e}", 2) return result
[docs] def outerBreak(self, auglag_valsPy, cmax, breakSettings, auglagSettings, outer_iter): """Outer break condition check with debug output.""" result = self._planner.outerBreak(auglag_valsPy, cmax, breakSettings, auglagSettings, outer_iter) if result: self._log(f" Outer break at iter {outer_iter}, cmax={cmax:.2e}", 1) return result
[docs] def generateInitialTrajectory(self, dt, x0, Uset, vecsPy): """Generate initial trajectory with debug output.""" self._log("\n DEBUG: generateInitialTrajectory", 2) result = self._planner.generateInitialTrajectory(dt, x0, Uset, vecsPy) (Xset, Uset_out, _, _) = result self._log(f" Generated: Xset={Xset.shape}, Uset={Uset_out.shape}", 2) self._print_control_analysis(Uset_out, prefix=" ") return result
[docs] def addRandNoise(self, dt, trajPy, dlaZcount, stepsSinceRand, breakSettings, regSettings, costSettings, auglag_vals, vecs): """Add random noise with debug output.""" self._log(" DEBUG: addRandNoise", 3) return self._planner.addRandNoise(dt, trajPy, dlaZcount, stepsSinceRand, breakSettings, regSettings, costSettings, auglag_vals, vecs)
[docs] def dynamics(self, x, u, dynamics_info): """Dynamics evaluation.""" return self._planner.dynamics(x, u, dynamics_info)
[docs] def rk4z(self, dt, x, u, dynamics_info_k, dynamics_info_kp1): """RK4 integration.""" return self._planner.rk4z(dt, x, u, dynamics_info_k, dynamics_info_kp1)
[docs] def cleanUpAfterAlilqr(self, vecsPy, dt_prev, time_start, time_end, alilqrOut): """Clean up after ALILQR.""" return self._planner.cleanUpAfterAlilqr(vecsPy, dt_prev, time_start, time_end, alilqrOut)