Utilities#

The utilities module provides helper functions for working with structured variables and tracking evaluations.

Structured Variable Utilities#

These functions help manage decision variables that have structure (e.g., matrices, multiple blocks).

params_flatten#

pbalm.utils.utils.params_flatten(params)#

Flatten a list of arrays into a single 1D vector.

Parameters:

params (list) – List of arrays (can be scalars, vectors, or matrices)

Returns:

Flattened 1D array

Return type:

jax.numpy.ndarray

Example:

import jax.numpy as jnp
from pbalm.utils.utils import params_flatten

# Multiple variable blocks
A = jnp.zeros((3, 3))  # 9 elements
b = jnp.zeros(3)       # 3 elements
c = 1.0                # 1 element

params = [A, b, c]
x = params_flatten(params)
print(x.shape)  # (13,)

params_shape#

pbalm.utils.utils.params_shape(params)#

Get shapes and cumulative sizes for a list of parameters.

Parameters:

params (list) – List of arrays

Returns:

Tuple of (shapes, cumsizes) where shapes is a list of shapes and cumsizes is a numpy array of cumulative sizes

Return type:

tuple

Example:

from pbalm.utils.utils import params_shape

params = [A, b, c]
shapes, cumsizes = params_shape(params)

print(shapes)    # [(3, 3), (3,), ()]
print(cumsizes)  # [ 9 12 13]

params_unflatten#

pbalm.utils.utils.params_unflatten(params_flattened, shapes, cumsizes)#

Unflatten a 1D vector back into structured parameters.

Parameters:
  • params_flattened (jax.numpy.ndarray) – Flattened parameter vector

  • shapes (list) – List of shapes (from params_shape)

  • cumsizes (numpy.ndarray) – Cumulative sizes (from params_shape)

Returns:

List of arrays with original shapes

Return type:

list

Example:

from pbalm.utils.utils import params_unflatten

# After optimization
x_opt = result.x

# Recover structured parameters
params_opt = params_unflatten(x_opt, shapes, cumsizes)
A_opt, b_opt, c_opt = params_opt

Complete Workflow Example#

import jax.numpy as jnp
import pbalm
from pbalm.utils.utils import params_flatten, params_shape, params_unflatten

# Define structured variables
W = jnp.zeros((5, 3))  # Weight matrix
b = jnp.zeros(5)       # Bias vector

# Get structure info
params = [W, b]
shapes, cumsizes = params_shape(params)

# Flatten for the solver
x0 = params_flatten(params)

# Define objective using flattened variables
def f(x):
    # Unflatten inside objective
    W, b = params_unflatten(x, shapes, cumsizes)
    return jnp.sum(W**2) + jnp.sum(b**2)

# Solve
problem = pbalm.Problem(f=f, jittable=True)
result = pbalm.solve(
    problem,
    x0,
    x_shapes=shapes,
    x_cumsizes=cumsizes
)

# Recover structured solution
W_opt, b_opt = params_unflatten(result.x, shapes, cumsizes)

Gradient Evaluation Counter#

GradEvalCounter#

class pbalm.utils.utils.GradEvalCounter(fn)#

Wrapper class to count function evaluations.

This is used internally to track the number of gradient evaluations during optimization.

Parameters:

fn (callable) – Function to wrap

Attributes:

fn: callable#

The wrapped function.

count: int#

Number of times the function has been called.

Methods:

__call__(*args, **kwargs)#

Call the wrapped function and increment counter.

reset()#

Reset the evaluation counter to zero.

Example:

import jax
from pbalm.utils.utils import GradEvalCounter

def my_function(x):
    return x**2

# Wrap with counter
counted_fn = GradEvalCounter(my_function)

# Use the function
for i in range(10):
    y = counted_fn(i)

print(f"Function was called {counted_fn.count} times")  # 10

# Reset counter
counted_fn.reset()
print(f"After reset: {counted_fn.count}")  # 0

Penalty Update#

update_penalties#

pbalm.utils.utils.update_penalties(lbda_sizes, mu_sizes, rho, nu, rho0, nu0, E_x, prev_E, h_x, prev_h, beta, xi1, xi2, phi_i)#

Update penalty parameters based on constraint satisfaction progress.

This function implements the adaptive penalty update rule from the PBALM algorithm. Penalties are increased when constraints are not being satisfied sufficiently.

Parameters:
  • lbda_sizes (list) – Sizes of equality constraint multipliers

  • mu_sizes (list) – Sizes of inequality constraint multipliers

  • rho (jax.numpy.ndarray) – Current equality penalty parameters

  • nu (jax.numpy.ndarray) – Current inequality penalty parameters

  • rho0 (float) – Initial equality penalty

  • nu0 (float) – Initial inequality penalty

  • E_x (jax.numpy.ndarray) – Current inequality constraint term

  • prev_E (jax.numpy.ndarray) – Previous inequality constraint term

  • h_x (jax.numpy.ndarray) – Current equality constraint values

  • prev_h (jax.numpy.ndarray) – Previous equality constraint values

  • beta (float) – Satisfaction threshold (0 < beta < 1)

  • xi1 (float) – Equality penalty scaling factor

  • xi2 (float) – Inequality penalty scaling factor

  • phi_i (float) – Minimum penalty floor value

Returns:

Updated penalty parameters (rho_new, nu_new)

Return type:

tuple

The update rule is:

\[\begin{split}\rho_i^{k+1} = \begin{cases} \max(\xi_1 \rho_i^k, \rho_0 \cdot \phi(k)) & \text{if } \|h_i(x^{k+1})\| > \beta \|h_i(x^k)\| \\ \rho_i^k & \text{otherwise} \end{cases}\end{split}\]