Problem Class#

The Problem class defines the optimization problem to be solved by PBALM.

Class Definition#

class pbalm.problem.Problem(f1, f2=None, f2_lbda=0.0, h=None, g=None, f1_grad=None, jittable=True, callback=None)#

Defines a general optimization problem for PBALM.

The problem has the form:

\[\begin{split}\min_{x} \quad & f_1(x) + \text{f2\_lbda} \cdot f_2(x) \\ \text{s.t.} \quad & h_i(x) = 0, \quad i = 1, \ldots, p \\ & g_j(x) \leq 0, \quad j = 1, \ldots, m\end{split}\]
Parameters:
  • f1 (callable) – The smooth objective function \(f_1: \mathbb{R}^n \to \mathbb{R}\).

  • f2 (alpaqa regularizer or None) – Nonsmooth regularization function with a proximal operator. Uses regularizers from alpaqa, accessible via pbalm. Examples include pbalm.L1Norm(), pbalm.NuclearNorm(), or pbalm.Box(lower, upper).

  • f2_lbda (float or list) – Regularization parameter \(\lambda\) multiplying the regularizer. Default is None. When set alongside pbalm.L1Norm(), this value is prioritized for inner iterations. Can be a float or a list (same length as optimization variable) for element-wise weights.

  • h (list of callables or None) – List of equality constraint functions. Each function \(h_i\) should return a scalar or array, and the constraint is \(h_i(x) = 0\).

  • g (list of callables or None) – List of inequality constraint functions. Each function \(g_j\) should return a scalar or array, and the constraint is \(g_j(x) \leq 0\).

  • f1_grad (callable or None) – Custom gradient function for the smooth objective. If not provided, automatic differentiation via JAX is used.

  • jittable (bool) – If True, enable JAX JIT compilation for all functions. All provided functions must be JAX-compatible.

  • callback (callable or None) – Optional callback function called at each outer iteration.

Attributes#

pbalm.problem.f1: callable#

The (JIT-compiled if jittable=True) smooth objective function.

pbalm.problem.f2: alpaqa regularizer or None#

The nonsmooth regularizer.

pbalm.problem.f2_lbda: float or list#

The regularization parameter.

pbalm.problem.h: callable or None#

The combined equality constraint function (set by solve()).

pbalm.problem.g: callable or None#

The combined inequality constraint function (set by solve()).

pbalm.problem.f1_grad: GradEvalCounter#

Wrapped gradient function with evaluation counting.

pbalm.problem.h_grad: GradEvalCounter or None#

Jacobian of equality constraints (set by solve()).

pbalm.problem.g_grad: GradEvalCounter or None#

Jacobian of inequality constraints (set by solve()).

pbalm.problem.jittable: bool#

Whether JIT compilation is enabled.

pbalm.problem.callback: callable or None#

The callback function.

pbalm.problem.lbda_sizes: list#

Sizes of equality constraint multipliers (set by solve()).

pbalm.problem.mu_sizes: list#

Sizes of inequality constraint multipliers (set by solve()).

Methods#

pbalm.problem.reset_counters()#

Reset gradient evaluation counters for f_grad, h_grad, and g_grad.

Example Usage#

Basic problem with equality constraint:

import jax.numpy as jnp
import pbalm

def f1(x):
    return jnp.sum(x**2)

def h(x):
    return jnp.sum(x) - 1.0

problem = pbalm.Problem(f1=f1, h=[h], jittable=True)

Problem with L1 regularization:

def f1(x):
    return jnp.sum((x - 1)**2)

# L1 regularization
f2 = pbalm.L1Norm()

problem = pbalm.Problem(
    f1=f1,
    f2=f2,
    f2_lbda=0.1,
    jittable=True
)

Problem with callback:

def my_callback(iter, x, x_prev, lbda, mu, rho, nu, gamma_k, x0):
    print(f"Iteration {iter}: f1(x) = {f1(x):.6f}")

problem = pbalm.Problem(
    f1=f1,
    h=[h],
    callback=my_callback,
    jittable=True
)

Callback Signature#

The callback function is called at each outer iteration with the following arguments:

def callback(iter, x, x_prev, lbda, mu, rho, nu, gamma_k, x0):
    """
    Callback function signature.

    Parameters:
        iter: Current iteration number (int)
        x: Current solution (jax.numpy.ndarray)
        x_prev: Previous solution (jax.numpy.ndarray)
        lbda: Current equality constraint multipliers (jax.numpy.ndarray or None)
        mu: Current inequality constraint multipliers (jax.numpy.ndarray or None)
        rho: Current equality penalty parameters (jax.numpy.ndarray or None)
        nu: Current inequality penalty parameters (jax.numpy.ndarray or None)
        gamma_k: Current proximal parameter (float)
        x0: Initial point (jax.numpy.ndarray)
    """
    pass