Quick Start Guide#
This guide will help you get started with pbalm by walking through a simple example.
Basic Workflow#
The typical workflow for using pbalm consists of three steps:
Define the problem: Create objective and constraint functions
Create a Problem instance: Use
pbalm.Problemto define the optimization problemSolve: Call
pbalm.solveto find the solution
Minimal Example#
Here’s a simple example of solving a constrained optimization problem:
import jax.numpy as jnp
import pbalm
# Step 1: Define the objective function
def f1(x):
return (x[0] - 1)**2 + (x[1] - 2)**2
# Step 2: Define constraints
# Equality constraint: x[0] + x[1] = 2
def h(x):
return x[0] + x[1] - 2
# Step 3: Create the problem
problem = pbalm.Problem(f1=f1, h=[h], jittable=True)
# Step 4: Set initial point and solve
x0 = jnp.array([0.0, 0.0])
result = pbalm.solve(problem, x0, tol=1e-6)
# Access the solution
print(f"Solution: {result.x}")
print(f"Objective value: {f1(result.x)}")
Understanding the Output#
The solve function returns a Result object containing:
x: The solution vectorf_hist: History of objective valuestotal_infeas: Total constraint infeasibility at each iterationsolve_status: Status of the solver (e.g., ‘Converged’, ‘MaxRuntimeExceeded’)solve_runtime: Time taken to solve
# Check solver status
print(f"Status: {result.solve_status}")
print(f"Runtime: {result.solve_runtime:.4f} seconds")
print(f"Final infeasibility: {result.total_infeas[-1]:.2e}")
Adding Inequality Constraints#
To add inequality constraints of the form \(g(x) \leq 0\):
import jax.numpy as jnp
import pbalm
def f1(x):
return x[0]**2 + x[1]**2
# Inequality constraint: x[0] + x[1] >= 1 --> -(x[0] + x[1] - 1) <= 0
def g(x):
return -(x[0] + x[1] - 1)
problem = pbalm.Problem(f1=f1, g=[g], jittable=True)
x0 = jnp.array([0.0, 0.0])
result = pbalm.solve(problem, x0, tol=1e-6)
print(f"Solution: {result.x}")
Using Proximal ALM#
For better convergence on some problems, enable the proximal ALM variant:
result = pbalm.solve(
problem,
x0,
use_proximal=True, # Enable proximal ALM
gamma0=0.1, # Initial proximal parameter
tol=1e-6
)
Adding Regularization#
You can add nonsmooth regularization terms using regularziers from alpaqa, accessible via pbalm:
import jax.numpy as jnp
import pbalm
def f1(x):
return jnp.sum(x**2)
def h(x):
return jnp.sum(x) - 1
# L1 regularization with lambda = 0.1
f2 = pbalm.L1Norm()
f2_lbda = 0.1 # can be a list for element-wise weights (same length as x)
problem = pbalm.Problem(
f1=f1,
h=[h],
f2=f2,
f2_lbda=f2_lbda,
jittable=True
)
x0 = jnp.ones(10)
result = pbalm.solve(problem, x0, tol=1e-6)
Box Constraints#
Box constraints can be handled via pbalm.Box. Note that pbalm.Box requires numpy arrays:
import jax.numpy as jnp
import numpy as np
import pbalm
def f1(x):
return jnp.sum((x - 1)**2)
# Box constraints: 0 <= x <= 2 (must use numpy arrays)
n = 5
f2 = pbalm.Box(
lower=np.zeros(n),
upper=2*np.ones(n)
)
problem = pbalm.Problem(f1=f1, f2=f2, jittable=True)
x0 = jnp.zeros(n)
result = pbalm.solve(problem, x0, tol=1e-6)
Controlling Verbosity#
Control the amount of output with the verbosity parameter:
# Silent mode
result = pbalm.solve(problem, x0, verbosity=0)
# Normal output (default)
result = pbalm.solve(problem, x0, verbosity=1)
Next Steps#
See User Guide for more detailed usage information
Check out Examples for more comprehensive examples
Refer to API Reference for the complete API reference