Problem Class#
The Problem class defines the optimization problem to be solved by PBALM.
Class Definition#
- class pbalm.problem.Problem(f1, f2=None, f2_lbda=0.0, h=None, g=None, f1_grad=None, jittable=True, callback=None)#
Defines a general optimization problem for PBALM.
The problem has the form:
\[\begin{split}\min_{x} \quad & f_1(x) + \text{f2\_lbda} \cdot f_2(x) \\ \text{s.t.} \quad & h_i(x) = 0, \quad i = 1, \ldots, p \\ & g_j(x) \leq 0, \quad j = 1, \ldots, m\end{split}\]- Parameters:
f1 (callable) – The smooth objective function \(f_1: \mathbb{R}^n \to \mathbb{R}\).
f2 (alpaqa regularizer or None) – Nonsmooth regularization function with a proximal operator. Uses regularizers from alpaqa, accessible via
pbalm. Examples includepbalm.L1Norm(),pbalm.NuclearNorm(), orpbalm.Box(lower, upper).f2_lbda (float or list) – Regularization parameter \(\lambda\) multiplying the regularizer. Default is
None. When set alongsidepbalm.L1Norm(), this value is prioritized for inner iterations. Can be a float or a list (same length as optimization variable) for element-wise weights.h (list of callables or None) – List of equality constraint functions. Each function \(h_i\) should return a scalar or array, and the constraint is \(h_i(x) = 0\).
g (list of callables or None) – List of inequality constraint functions. Each function \(g_j\) should return a scalar or array, and the constraint is \(g_j(x) \leq 0\).
f1_grad (callable or None) – Custom gradient function for the smooth objective. If not provided, automatic differentiation via JAX is used.
jittable (bool) – If
True, enable JAX JIT compilation for all functions. All provided functions must be JAX-compatible.callback (callable or None) – Optional callback function called at each outer iteration.
Attributes#
- pbalm.problem.f1: callable#
The (JIT-compiled if
jittable=True) smooth objective function.
- pbalm.problem.f2: alpaqa regularizer or None#
The nonsmooth regularizer.
- pbalm.problem.f2_lbda: float or list#
The regularization parameter.
- pbalm.problem.h: callable or None#
The combined equality constraint function (set by
solve()).
- pbalm.problem.g: callable or None#
The combined inequality constraint function (set by
solve()).
- pbalm.problem.f1_grad: GradEvalCounter#
Wrapped gradient function with evaluation counting.
- pbalm.problem.h_grad: GradEvalCounter or None#
Jacobian of equality constraints (set by
solve()).
- pbalm.problem.g_grad: GradEvalCounter or None#
Jacobian of inequality constraints (set by
solve()).
- pbalm.problem.callback: callable or None#
The callback function.
Methods#
- pbalm.problem.reset_counters()#
Reset gradient evaluation counters for
f_grad,h_grad, andg_grad.
Example Usage#
Basic problem with equality constraint:
import jax.numpy as jnp
import pbalm
def f1(x):
return jnp.sum(x**2)
def h(x):
return jnp.sum(x) - 1.0
problem = pbalm.Problem(f1=f1, h=[h], jittable=True)
Problem with L1 regularization:
def f1(x):
return jnp.sum((x - 1)**2)
# L1 regularization
f2 = pbalm.L1Norm()
problem = pbalm.Problem(
f1=f1,
f2=f2,
f2_lbda=0.1,
jittable=True
)
Problem with callback:
def my_callback(iter, x, x_prev, lbda, mu, rho, nu, gamma_k, x0):
print(f"Iteration {iter}: f1(x) = {f1(x):.6f}")
problem = pbalm.Problem(
f1=f1,
h=[h],
callback=my_callback,
jittable=True
)
Callback Signature#
The callback function is called at each outer iteration with the following arguments:
def callback(iter, x, x_prev, lbda, mu, rho, nu, gamma_k, x0):
"""
Callback function signature.
Parameters:
iter: Current iteration number (int)
x: Current solution (jax.numpy.ndarray)
x_prev: Previous solution (jax.numpy.ndarray)
lbda: Current equality constraint multipliers (jax.numpy.ndarray or None)
mu: Current inequality constraint multipliers (jax.numpy.ndarray or None)
rho: Current equality penalty parameters (jax.numpy.ndarray or None)
nu: Current inequality penalty parameters (jax.numpy.ndarray or None)
gamma_k: Current proximal parameter (float)
x0: Initial point (jax.numpy.ndarray)
"""
pass