Utilities#
The utilities module provides helper functions for working with structured variables and tracking evaluations.
Structured Variable Utilities#
These functions help manage decision variables that have structure (e.g., matrices, multiple blocks).
params_flatten#
- pbalm.utils.utils.params_flatten(params)#
Flatten a list of arrays into a single 1D vector.
- Parameters:
params (list) – List of arrays (can be scalars, vectors, or matrices)
- Returns:
Flattened 1D array
- Return type:
jax.numpy.ndarray
Example:
import jax.numpy as jnp from pbalm.utils.utils import params_flatten # Multiple variable blocks A = jnp.zeros((3, 3)) # 9 elements b = jnp.zeros(3) # 3 elements c = 1.0 # 1 element params = [A, b, c] x = params_flatten(params) print(x.shape) # (13,)
params_shape#
- pbalm.utils.utils.params_shape(params)#
Get shapes and cumulative sizes for a list of parameters.
- Parameters:
params (list) – List of arrays
- Returns:
Tuple of (shapes, cumsizes) where shapes is a list of shapes and cumsizes is a numpy array of cumulative sizes
- Return type:
Example:
from pbalm.utils.utils import params_shape params = [A, b, c] shapes, cumsizes = params_shape(params) print(shapes) # [(3, 3), (3,), ()] print(cumsizes) # [ 9 12 13]
params_unflatten#
- pbalm.utils.utils.params_unflatten(params_flattened, shapes, cumsizes)#
Unflatten a 1D vector back into structured parameters.
- Parameters:
params_flattened (jax.numpy.ndarray) – Flattened parameter vector
shapes (list) – List of shapes (from
params_shape)cumsizes (numpy.ndarray) – Cumulative sizes (from
params_shape)
- Returns:
List of arrays with original shapes
- Return type:
Example:
from pbalm.utils.utils import params_unflatten # After optimization x_opt = result.x # Recover structured parameters params_opt = params_unflatten(x_opt, shapes, cumsizes) A_opt, b_opt, c_opt = params_opt
Complete Workflow Example#
import jax.numpy as jnp
import pbalm
from pbalm.utils.utils import params_flatten, params_shape, params_unflatten
# Define structured variables
W = jnp.zeros((5, 3)) # Weight matrix
b = jnp.zeros(5) # Bias vector
# Get structure info
params = [W, b]
shapes, cumsizes = params_shape(params)
# Flatten for the solver
x0 = params_flatten(params)
# Define objective using flattened variables
def f(x):
# Unflatten inside objective
W, b = params_unflatten(x, shapes, cumsizes)
return jnp.sum(W**2) + jnp.sum(b**2)
# Solve
problem = pbalm.Problem(f=f, jittable=True)
result = pbalm.solve(
problem,
x0,
x_shapes=shapes,
x_cumsizes=cumsizes
)
# Recover structured solution
W_opt, b_opt = params_unflatten(result.x, shapes, cumsizes)
Gradient Evaluation Counter#
GradEvalCounter#
- class pbalm.utils.utils.GradEvalCounter(fn)#
Wrapper class to count function evaluations.
This is used internally to track the number of gradient evaluations during optimization.
- Parameters:
fn (callable) – Function to wrap
Attributes:
- fn: callable#
The wrapped function.
Methods:
- __call__(*args, **kwargs)#
Call the wrapped function and increment counter.
- reset()#
Reset the evaluation counter to zero.
Example:
import jax from pbalm.utils.utils import GradEvalCounter def my_function(x): return x**2 # Wrap with counter counted_fn = GradEvalCounter(my_function) # Use the function for i in range(10): y = counted_fn(i) print(f"Function was called {counted_fn.count} times") # 10 # Reset counter counted_fn.reset() print(f"After reset: {counted_fn.count}") # 0
Penalty Update#
update_penalties#
- pbalm.utils.utils.update_penalties(lbda_sizes, mu_sizes, rho, nu, rho0, nu0, E_x, prev_E, h_x, prev_h, beta, xi1, xi2, phi_i)#
Update penalty parameters based on constraint satisfaction progress.
This function implements the adaptive penalty update rule from the PBALM algorithm. Penalties are increased when constraints are not being satisfied sufficiently.
- Parameters:
lbda_sizes (list) – Sizes of equality constraint multipliers
mu_sizes (list) – Sizes of inequality constraint multipliers
rho (jax.numpy.ndarray) – Current equality penalty parameters
nu (jax.numpy.ndarray) – Current inequality penalty parameters
rho0 (float) – Initial equality penalty
nu0 (float) – Initial inequality penalty
E_x (jax.numpy.ndarray) – Current inequality constraint term
prev_E (jax.numpy.ndarray) – Previous inequality constraint term
h_x (jax.numpy.ndarray) – Current equality constraint values
prev_h (jax.numpy.ndarray) – Previous equality constraint values
beta (float) – Satisfaction threshold (0 < beta < 1)
xi1 (float) – Equality penalty scaling factor
xi2 (float) – Inequality penalty scaling factor
phi_i (float) – Minimum penalty floor value
- Returns:
Updated penalty parameters (rho_new, nu_new)
- Return type:
The update rule is:
\[\begin{split}\rho_i^{k+1} = \begin{cases} \max(\xi_1 \rho_i^k, \rho_0 \cdot \phi(k)) & \text{if } \|h_i(x^{k+1})\| > \beta \|h_i(x^k)\| \\ \rho_i^k & \text{otherwise} \end{cases}\end{split}\]