Inner Solvers#

The inner solvers module provides the optimization routines used to solve the augmented Lagrangian subproblems.

Overview#

At each outer iteration, PBALM needs to (approximately) minimize the augmented Lagrangian:

\[x^{k+1} \approx \arg\min_x \mathcal{L}_{\rho}(x, \lambda^k, \mu^k)\]

This is handled by inner solvers that support composite optimization (smooth + nonsmooth terms).

PALMInnerTrainer#

class pbalm.inner_solvers.inner_solvers.PALMInnerTrainer(train_fun)#

Class to run PALM inner solver.

Parameters:

train_fun (callable) – Function to run the inner optimization.

The train_fun should have the signature:

def train_fun(palm_obj_fun, x, max_iter, tol):
    """
    Parameters:
        palm_obj_fun: Augmented Lagrangian objective function
        x: Initial point
        max_iter: Maximum iterations
        tol: Convergence tolerance

    Returns:
        x_new: Optimized point
        state: Solver state/statistics (a dict with keys: "obj_grad_evals", "fp_res", "obj_val", "reg_val", "status")
    """
    ...
    return x_new, state

PaProblem#

class pbalm.inner_solvers.inner_solvers.PaProblem(f, x0, reg=None, lbda=None, solver_opts=None, tol=1e-9, max_iter=2000, direction=None, jittable=True)#

Internal problem class for the PANOC solver from Alpaqa.

Parameters:
  • f (callable) – Objective function

  • x0 (jax.numpy.ndarray) – Initial point (used for problem dimensions)

  • reg (alpaqa regularizer or None) – Regularizer (supports pbalm.L1Norm(), pbalm.NuclearNorm(), pbalm.Box(lower, upper))

  • lbda (float, list, or None) – L1 regularization weights

  • solver_opts (dict or None) – Options for PANOC solver

  • tol (float) – Convergence tolerance

  • max_iter (int) – Maximum iterations

  • direction (alpaqa direction or None) – Direction method for PANOC

  • jittable (bool) – Enable JIT compilation

Methods:

eval_objective(x)#

Evaluate the objective function at x.

eval_objective_gradient(x, grad_f)#

Evaluate the gradient and store in grad_f.

get_solver_run#

pbalm.inner_solvers.inner_solvers.get_solver_run(reg=None, lbda=None, solver_opts=None, direction=None, jittable=True)#

Factory function to create the default inner solver runner.

Parameters:
  • reg (alpaqa regularizer or None) – Regularizer for the problem

  • lbda (float, list, or None) – L1 regularization weights

  • solver_opts (dict or None) – Options for PANOC solver

  • direction (alpaqa direction or None) – Direction method (default: L-BFGS with memory 20)

  • jittable (bool) – Enable JIT compilation

Returns:

A PALMInnerTrainer instance

Return type:

PALMInnerTrainer

phase_I_optim#

pbalm.inner_solvers.inner_solvers.phase_I_optim(x0, h, g, reg, lbda0, mu0, alpha=20, gamma0=1e-8, tol=1e-7, max_iter=500, inner_solver='PANOC')#

Solve the Phase I feasibility problem to find an initial feasible point.

When starting from an infeasible point, this function finds a point that satisfies (or nearly satisfies) the constraints.

Parameters:
  • x0 (jax.numpy.ndarray) – Initial (infeasible) point

  • h (list of callables or None) – Equality constraint functions (list)

  • g (list of callables or None) – Inequality constraint functions (list)

  • reg (alpaqa regularizer or None) – Regularizer

  • lbda0 (jax.numpy.ndarray or None) – Initial equality multipliers

  • mu0 (jax.numpy.ndarray or None) – Initial inequality multipliers

  • alpha (float) – Penalty growth parameter

  • gamma0 (float) – Initial proximal parameter

  • tol (float) – Feasibility tolerance

  • max_iter (int) – Maximum iterations

  • inner_solver (str) – Inner solver name

Returns:

Feasible point

Return type:

jax.numpy.ndarray

Raises:

RuntimeError – If feasibility cannot be achieved

Custom Inner Solver Example#

To use a custom inner solver:

from pbalm.inner_solvers.inner_solvers import PALMInnerTrainer
import jax
import jax.numpy as jnp

def custom_gradient_descent(palm_obj_fun, x0, max_iter, tol):
    """Simple gradient descent inner solver."""
    grad_fn = jax.grad(palm_obj_fun)
    x = x0.copy()
    step_size = 0.01

    for i in range(max_iter):
        grad = grad_fn(x)
        x_new = x - step_size * grad

        if jnp.linalg.norm(grad) < tol:
            break
        x = x_new

    return x, {
               "obj_grad_evals": i + 1,
               "fp_res": jnp.linalg.norm(grad),
               "obj_val": palm_obj_fun(x),
               "reg_val": 0.0,  # assuming no regularization in this simple example
               "status": "Converged" if jnp.linalg.norm(grad) < tol else "MaxIterReached"
               }

# Create custom trainer
custom_runner = PALMInnerTrainer(custom_gradient_descent)

# Use with solve
result = pbalm.solve(
    problem,
    x0,
    inner_solve_runner=custom_runner
)

PANOC Solver Options#

The default inner solver uses PANOC from Alpaqa. Common options:

import alpaqa as pa

# Custom direction (L-BFGS with more memory)
direction = pa.LBFGSDirection({"memory": 50})

# Custom solver options
solver_opts = {
    "max_iter": 5000,
    "stop_crit": pa.ProjGradUnitNorm,
}

result = pbalm.solve(
    problem,
    x0,
    pa_direction=direction,
    pa_solver_opts=solver_opts
)

For more details on PANOC options, see the Alpaqa documentation.