Source code for ADCS.controller.saltro.SALTRO_pass_settings
"""Pass-level SALTRO optimization settings.
This module defines Python configuration dataclasses for solver costs,
regularization, line search, and loop limits, with helpers to convert each
configuration block to the corresponding SALTRO C++ struct.
"""
from __future__ import annotations
__all__ = ["PassConfig"]
import numpy as np
from typing import Tuple, Optional, List
from numpy.typing import NDArray
from dataclasses import dataclass, field
def _get_saltro_py():
"""Import and return the ``saltro_py`` binding module.
The loader appends ``SALTRO/build`` to ``sys.path`` if needed.
:return: Imported ``saltro_py`` module.
:rtype: module
:raises ImportError: If the SALTRO Python extension is not available.
"""
import os
import sys
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
build_dir = os.path.join(parent_dir, "SALTRO", "build")
if build_dir not in sys.path:
sys.path.append(build_dir)
try:
import saltro_py
except ImportError as exc:
raise ImportError(f"saltro_py not available (expected in {build_dir})") from exc
return saltro_py
@dataclass
class CostConfig:
"""Running and terminal cost weights for one optimization pass.
:param angle: Running attitude error cost weight.
:type angle: float
:param ang_vel: Running angular velocity error cost weight.
:type ang_vel: float
:param control_mult: Global multiplier for actuator effort costs.
:type control_mult: float
:param angle_N: Terminal attitude error cost weight.
:type angle_N: float
:param ang_vel_N: Terminal angular velocity error cost weight.
:type ang_vel_N: float
"""
# Running costs
angle: float = 1e2
ang_vel: float = 1e5
ang_vel_mag: float = 0.0
ang_vel_err_dir: float = 0.0
control_mult: float = 1.0
# Actuator Weights
mtq_control_weight: float = 1.0
rw_control_weight: float = 1.0
magic_control_weight: float = 0.0
rw_AM_weight: float = 0.0
rw_stic_weight: float = 0.0
RWh_max_mult: float = 0.0
RWh_stiction_mult: float = 0.0
RWh_ok_mult: float = 0.0
# Terminal costs
angle_N: float = 0.0
ang_vel_N: float = 0.0
ang_vel_mag_N: float = 0.0
ang_vel_err_dir_N: float = 0.0
# Flags
ang_cost_func_type: int = 2
use_cost_hess: int = 1
def to_cpp(self):
"""Convert Python costs to SALTRO C++ ``CostConfig``.
:return: C++ cost config object.
:rtype: Any
"""
saltro_py = _get_saltro_py()
cpp_cost = saltro_py.CostConfig()
cpp_cost.angle = self.angle
cpp_cost.ang_vel = self.ang_vel
cpp_cost.ang_vel_mag = self.ang_vel_mag
cpp_cost.ang_vel_err_dir = self.ang_vel_err_dir
cpp_cost.control_mult = self.control_mult
cpp_cost.mtq_control_weight = self.mtq_control_weight
cpp_cost.rw_control_weight = self.rw_control_weight
cpp_cost.magic_control_weight = self.magic_control_weight
cpp_cost.rw_AM_weight = self.rw_AM_weight
cpp_cost.rw_stic_weight = self.rw_stic_weight
cpp_cost.RWh_max_mult = self.RWh_max_mult
cpp_cost.RWh_stiction_mult = self.RWh_stiction_mult
cpp_cost.RWh_ok_mult = self.RWh_ok_mult
cpp_cost.angle_N = self.angle_N
cpp_cost.ang_vel_N = self.ang_vel_N
cpp_cost.ang_vel_mag_N = self.ang_vel_mag_N
cpp_cost.ang_vel_err_dir_N = self.ang_vel_err_dir_N
cpp_cost.ang_cost_func_type = self.ang_cost_func_type
cpp_cost.use_cost_hess = bool(self.use_cost_hess)
return cpp_cost
@dataclass
class AugLagConfig:
"""Augmented-Lagrangian outer-loop settings.
:param max_outer_iters: Maximum number of outer iterations.
:type max_outer_iters: int
:param penalty_init: Initial constraint penalty.
:type penalty_init: float
:param penalty_max: Maximum constraint penalty.
:type penalty_max: float
:param constraint_tol: Constraint satisfaction tolerance.
:type constraint_tol: float
"""
max_outer_iters: int = 30
lag_mult_init: float = 0.0
lag_mult_max: float = 1e20
penalty_init: float = 1e-1
penalty_max: float = 1e16
penalty_scale: float = 10.0
constraint_tol: float = 1e-3
total_cost_tol: float = 1e-2
def to_cpp(self):
"""Convert Python settings to SALTRO C++ ``AugLagConfig``.
:return: C++ augmented-Lagrangian config object.
:rtype: Any
"""
saltro_py = _get_saltro_py()
cpp_auglag = saltro_py.AugLagConfig()
cpp_auglag.max_outer_iters = self.max_outer_iters
cpp_auglag.lag_mult_init = self.lag_mult_init
cpp_auglag.lag_mult_max = self.lag_mult_max
cpp_auglag.penalty_init = self.penalty_init
cpp_auglag.penalty_max = self.penalty_max
cpp_auglag.penalty_scale = self.penalty_scale
cpp_auglag.constraint_tol = self.constraint_tol
cpp_auglag.total_cost_tol = self.total_cost_tol
return cpp_auglag
@dataclass
class ILQRConfig:
"""iLQR middle-loop settings.
:param max_iters: Maximum iLQR iterations.
:type max_iters: int
:param grad_tol: Gradient norm tolerance.
:type grad_tol: float
:param cost_tol: Relative/absolute cost change tolerance.
:type cost_tol: float
"""
max_iters: int = 20
grad_tol: float = 1e-3
cost_tol: float = 1e-5
z_count_lim: int = 10
max_cost: float = 1e40
state_bound: float = 10.0
def to_cpp(self):
"""Convert Python settings to SALTRO C++ ``ILQRConfig``.
:return: C++ iLQR config object.
:rtype: Any
"""
saltro_py = _get_saltro_py()
cpp_ilqr = saltro_py.ILQRConfig()
cpp_ilqr.max_iters = self.max_iters
cpp_ilqr.grad_tol = self.grad_tol
cpp_ilqr.cost_tol = self.cost_tol
cpp_ilqr.z_count_lim = self.z_count_lim
cpp_ilqr.max_cost = self.max_cost
cpp_ilqr.state_bound = self.state_bound
return cpp_ilqr
@dataclass
class RegularizationConfig:
"""Regularization settings for iLQR backward passes.
:param reg_init: Initial regularization value.
:type reg_init: float
:param reg_min: Lower bound on regularization.
:type reg_min: float
:param reg_max: Upper bound on regularization.
:type reg_max: float
"""
reg_init: float = 1e-6
reg_min: float = 1e-8
reg_max: float = 1e10
reg_scale: float = 10.0
reg_bump: float = 10.0
reg_min_cond: int = 2
rand_add_ratio: float = 0.0
use_dynamics_hess: int = 0
use_constraint_hess: int = 0
def to_cpp(self):
"""Convert Python settings to SALTRO C++ ``RegularizationConfig``.
:return: C++ regularization config object.
:rtype: Any
"""
saltro_py = _get_saltro_py()
cpp_reg = saltro_py.RegularizationConfig()
cpp_reg.reg_init = self.reg_init
cpp_reg.reg_min = self.reg_min
cpp_reg.reg_max = self.reg_max
cpp_reg.reg_scale = self.reg_scale
cpp_reg.reg_bump = self.reg_bump
cpp_reg.reg_min_cond = self.reg_min_cond
cpp_reg.rand_add_ratio = self.rand_add_ratio
cpp_reg.use_dynamics_hess = bool(self.use_dynamics_hess)
cpp_reg.use_constraint_hess = bool(self.use_constraint_hess)
return cpp_reg
@dataclass
class LineSearchConfig:
"""Line-search settings for iLQR forward rollout.
:param max_iters: Maximum line-search iterations.
:type max_iters: int
:param beta1: Minimum step scaling factor.
:type beta1: float
:param beta2: Maximum step scaling factor.
:type beta2: float
"""
max_iters: int = 24
beta1: float = 1e-10
beta2: float = 5000.0
def to_cpp(self):
"""Convert Python settings to SALTRO C++ ``LineSearchConfig``.
:return: C++ line-search config object.
:rtype: Any
"""
saltro_py = _get_saltro_py()
cpp_ls = saltro_py.LineSearchConfig()
cpp_ls.max_iters = self.max_iters
cpp_ls.beta1 = self.beta1
cpp_ls.beta2 = self.beta2
return cpp_ls
[docs]
@dataclass
class PassConfig:
"""Composite settings for one SALTRO optimization pass.
A pass bundles cost, augmented-Lagrangian, iLQR, regularization, and
line-search configuration with a fixed planning timestep ``dt``.
:param cost: Cost-function configuration.
:type cost: :class:`CostConfig`
:param aug_lag: Augmented-Lagrangian outer-loop configuration.
:type aug_lag: :class:`AugLagConfig`
:param ilqr: iLQR middle-loop configuration.
:type ilqr: :class:`ILQRConfig`
:param reg: Regularization configuration.
:type reg: :class:`RegularizationConfig`
:param linesearch: Line-search configuration.
:type linesearch: :class:`LineSearchConfig`
:param dt: Planner timestep in seconds.
:type dt: float
"""
# Cost Function
cost: CostConfig = field(default_factory=CostConfig)
# Outer Loop
aug_lag: AugLagConfig = field(default_factory=AugLagConfig)
# Middle Loop
ilqr: ILQRConfig = field(default_factory=ILQRConfig)
reg: RegularizationConfig = field(default_factory=RegularizationConfig)
# Inner Loop
linesearch: LineSearchConfig = field(default_factory=LineSearchConfig)
# Timestep
dt: float = 5.0
[docs]
def to_cpp(self):
"""Convert Python settings to SALTRO C++ ``PassConfig``.
:return: C++ pass config object.
:rtype: Any
"""
saltro_py = _get_saltro_py()
cpp_pass = saltro_py.PassConfig()
cpp_pass.cost = self.cost.to_cpp()
cpp_pass.auglag = self.aug_lag.to_cpp()
cpp_pass.ilqr = self.ilqr.to_cpp()
cpp_pass.reg = self.reg.to_cpp()
cpp_pass.linesearch = self.linesearch.to_cpp()
cpp_pass.dt = self.dt
return cpp_pass