Inner Solvers#
The inner solvers module provides the optimization routines used to solve the augmented Lagrangian subproblems.
Overview#
At each outer iteration, PBALM needs to (approximately) minimize the augmented Lagrangian:
This is handled by inner solvers that support composite optimization (smooth + nonsmooth terms).
PALMInnerTrainer#
- class pbalm.inner_solvers.inner_solvers.PALMInnerTrainer(train_fun)#
Class to run PALM inner solver.
- Parameters:
train_fun (callable) – Function to run the inner optimization.
The
train_funshould have the signature:def train_fun(palm_obj_fun, x, max_iter, tol): """ Parameters: palm_obj_fun: Augmented Lagrangian objective function x: Initial point max_iter: Maximum iterations tol: Convergence tolerance Returns: x_new: Optimized point state: Solver state/statistics (a dict with keys: "obj_grad_evals", "fp_res", "obj_val", "reg_val", "status") """ ... return x_new, state
PaProblem#
- class pbalm.inner_solvers.inner_solvers.PaProblem(f, x0, reg=None, lbda=None, solver_opts=None, tol=1e-9, max_iter=2000, direction=None, jittable=True)#
Internal problem class for the PANOC solver from Alpaqa.
- Parameters:
f (callable) – Objective function
x0 (jax.numpy.ndarray) – Initial point (used for problem dimensions)
reg (alpaqa regularizer or None) – Regularizer (supports
pbalm.L1Norm(),pbalm.NuclearNorm(),pbalm.Box(lower, upper))solver_opts (dict or None) – Options for PANOC solver
tol (float) – Convergence tolerance
max_iter (int) – Maximum iterations
direction (alpaqa direction or None) – Direction method for PANOC
jittable (bool) – Enable JIT compilation
Methods:
- eval_objective(x)#
Evaluate the objective function at
x.
- eval_objective_gradient(x, grad_f)#
Evaluate the gradient and store in
grad_f.
get_solver_run#
- pbalm.inner_solvers.inner_solvers.get_solver_run(reg=None, lbda=None, solver_opts=None, direction=None, jittable=True)#
Factory function to create the default inner solver runner.
- Parameters:
- Returns:
A
PALMInnerTrainerinstance- Return type:
phase_I_optim#
- pbalm.inner_solvers.inner_solvers.phase_I_optim(x0, h, g, reg, lbda0, mu0, alpha=20, gamma0=1e-8, tol=1e-7, max_iter=500, inner_solver='PANOC')#
Solve the Phase I feasibility problem to find an initial feasible point.
When starting from an infeasible point, this function finds a point that satisfies (or nearly satisfies) the constraints.
- Parameters:
x0 (jax.numpy.ndarray) – Initial (infeasible) point
h (list of callables or None) – Equality constraint functions (list)
g (list of callables or None) – Inequality constraint functions (list)
reg (alpaqa regularizer or None) – Regularizer
lbda0 (jax.numpy.ndarray or None) – Initial equality multipliers
mu0 (jax.numpy.ndarray or None) – Initial inequality multipliers
alpha (float) – Penalty growth parameter
gamma0 (float) – Initial proximal parameter
tol (float) – Feasibility tolerance
max_iter (int) – Maximum iterations
inner_solver (str) – Inner solver name
- Returns:
Feasible point
- Return type:
jax.numpy.ndarray
- Raises:
RuntimeError – If feasibility cannot be achieved
Custom Inner Solver Example#
To use a custom inner solver:
from pbalm.inner_solvers.inner_solvers import PALMInnerTrainer
import jax
import jax.numpy as jnp
def custom_gradient_descent(palm_obj_fun, x0, max_iter, tol):
"""Simple gradient descent inner solver."""
grad_fn = jax.grad(palm_obj_fun)
x = x0.copy()
step_size = 0.01
for i in range(max_iter):
grad = grad_fn(x)
x_new = x - step_size * grad
if jnp.linalg.norm(grad) < tol:
break
x = x_new
return x, {
"obj_grad_evals": i + 1,
"fp_res": jnp.linalg.norm(grad),
"obj_val": palm_obj_fun(x),
"reg_val": 0.0, # assuming no regularization in this simple example
"status": "Converged" if jnp.linalg.norm(grad) < tol else "MaxIterReached"
}
# Create custom trainer
custom_runner = PALMInnerTrainer(custom_gradient_descent)
# Use with solve
result = pbalm.solve(
problem,
x0,
inner_solve_runner=custom_runner
)
PANOC Solver Options#
The default inner solver uses PANOC from Alpaqa. Common options:
import alpaqa as pa
# Custom direction (L-BFGS with more memory)
direction = pa.LBFGSDirection({"memory": 50})
# Custom solver options
solver_opts = {
"max_iter": 5000,
"stop_crit": pa.ProjGradUnitNorm,
}
result = pbalm.solve(
problem,
x0,
pa_direction=direction,
pa_solver_opts=solver_opts
)
For more details on PANOC options, see the Alpaqa documentation.