Constrained Optimization Examples#

This page demonstrates various constrained optimization problems that can be solved with pbalm.

Example 1: Quadratic Program with Linear Constraints#

Consider minimizing a quadratic objective subject to linear equality and inequality constraints:

\[\begin{split}\min_{x} \quad & \frac{1}{2} x^T Q x + c^T x \\ \text{s.t.} \quad & Ax = b \\ & Gx \leq h\end{split}\]

Implementation#

import jax.numpy as jnp
import pbalm
import numpy as np

# Configure JAX
import jax
jax.config.update('jax_platform_name', 'cpu')
jax.config.update("jax_enable_x64", True)

# Problem data
n = 10
rng = np.random.default_rng(42)

# Positive definite Q matrix
M = rng.standard_normal((n, n))
Q = jnp.array(M.T @ M + 0.1 * np.eye(n))
c = jnp.array(rng.standard_normal(n))

# Equality constraint: sum(x) = 1
A = jnp.ones((1, n))
b_eq = jnp.array([1.0])

# Inequality constraint: x >= 0 (i.e., -x <= 0)
G = -jnp.eye(n)
h_ineq = jnp.zeros(n)

# Define functions
def f1(x):
    return 0.5 * x @ Q @ x + c @ x

def h(x):
    return A @ x - b_eq

def g(x):
    return G @ x - h_ineq

# Create and solve problem
problem = pbalm.Problem(f1=f1, h=[h], g=[g], jittable=True)
x0 = jnp.ones(n) / n  # Start on simplex

result = pbalm.solve(problem, x0, use_proximal=True, tol=1e-6)

print(f"Optimal x: {result.x}")
eq_con = h(result.x)
ineq_con = g(result.x)
print(f"Equality constraint: {eq_con}")
print(f"Inequality constraint: {ineq_con}")

Example 2: Nonlinear Least Squares with Constraints#

Fit a nonlinear model with parameter constraints:

\[\begin{split}\min_{\theta} \quad & \sum_{i=1}^{m} (y_i - f(x_i; \theta))^2 \\ \text{s.t.} \quad & \theta_1 + \theta_2 \leq 1 \\ & \theta \geq 0\end{split}\]

Implementation#

import jax.numpy as jnp
import pbalm
import numpy as np

# Configure JAX
import jax
jax.config.update('jax_platform_name', 'cpu')
jax.config.update("jax_enable_x64", True)

# Generate synthetic data
rng = np.random.default_rng(123)
n_samples = 100
x_data = jnp.linspace(0, 1, n_samples)

# True model: y = theta1 * exp(-theta2 * x) + noise
theta_true = jnp.array([0.5, 2.0])
y_data = theta_true[0] * jnp.exp(-theta_true[1] * x_data) + 0.05 * rng.standard_normal(n_samples)

# Model prediction
def model(x, theta):
    return theta[0] * jnp.exp(-theta[1] * x)

# Objective: sum of squared residuals
def f1(theta):
    predictions = model(x_data, theta)
    residuals = y_data - predictions
    return jnp.sum(residuals**2)

# Constraints
def g1(theta):
    return theta[0] + theta[1] - 5.0  # theta1 + theta2 <= 5

def g2(theta):
    return -theta  # theta >= 0 (element-wise)

# Create problem
problem = pbalm.Problem(
    f1=f1,
    g=[g1, g2],
    jittable=True
)

# Initial guess
theta0 = jnp.array([0.3, 1.0])

# Solve
result = pbalm.solve(
    problem,
    theta0,
    use_proximal=True,
    tol=1e-6,
    max_iter=200
)

print(f"True parameters: {theta_true}")
print(f"Estimated parameters: {result.x}")
print(f"Final objective: {f1(result.x):.6f}")

Example 3: Optimization on a Sphere#

Minimize a function subject to a spherical constraint \(\|x\|^2 = 1\):

\[\begin{split}\min_{x} \quad & c^T x \\ \text{s.t.} \quad & \|x\|^2 = 1\end{split}\]

This finds the point on the unit sphere that minimizes the linear objective.

Implementation#

import jax.numpy as jnp
import pbalm
import numpy as np

# Configure JAX
import jax
jax.config.update('jax_platform_name', 'cpu')
jax.config.update("jax_enable_x64", True)

# Problem data
n = 5
rng = np.random.default_rng(456)
c = jnp.array(rng.standard_normal(n))

# Objective
def f1(x):
    return c @ x

# Sphere constraint: ||x||^2 = 1
def h(x):
    return jnp.sum(x**2) - 1.0

# Create problem
problem = pbalm.Problem(f1=f1, h=[h], jittable=True)

# Initial point (will be projected to feasible)
x0 = jnp.ones(n) / jnp.sqrt(n)

# Solve
result = pbalm.solve(
    problem,
    x0,
    use_proximal=True,
    tol=1e-8,
    start_feas=False  # Start from normalized point
)

# Analytical solution: x* = -c / ||c||
x_analytical = -c / jnp.linalg.norm(c)

print(f"PBALM solution: {result.x}")
print(f"Analytical solution: {x_analytical}")
print(f"Solution norm: {jnp.linalg.norm(result.x)}")
print(f"Error: {jnp.linalg.norm(result.x - x_analytical):.2e}")

Example 4: Multiple Equality Constraints#

A problem with multiple nonlinear equality constraints:

\[\begin{split}\min_{x} \quad & x_1^2 + x_2^2 + x_3^2 \\ \text{s.t.} \quad & x_1 x_2 = 1 \\ & x_2 x_3 = 2\end{split}\]

Implementation#

import jax.numpy as jnp
import pbalm

def f1(x):
    return jnp.sum(x**2)

def h1(x):
    return x[0] * x[1] - 1.0

def h2(x):
    return x[1] * x[2] - 2.0

# Create problem with multiple equality constraints
problem = pbalm.Problem(
    f1=f1,
    h=[h1, h2],
    jittable=True
)

x0 = jnp.array([1.0, 1.0, 2.0])

result = pbalm.solve(
    problem,
    x0,
    use_proximal=True,
    tol=1e-9
)

print(f"Solution: {result.x}")
print(f"h1(x) = x1*x2 - 1 = {result.x[0] * result.x[1] - 1:.2e}")
print(f"h2(x) = x2*x3 - 2 = {result.x[1] * result.x[2] - 2:.2e}")
print(f"Objective: {f1(result.x):.6f}"))

Example 5: Mixed Constraints with Regularization#

Combining equality constraints, inequality constraints, and L1 regularization:

\[\begin{split}\min_{x} \quad & \frac{1}{2}\|Ax - b\|^2 + \lambda \|x\|_1 \\ \text{s.t.} \quad & \mathbf{1}^T x = 1 \\ & x \geq 0\end{split}\]

Implementation#

import jax.numpy as jnp
import pbalm
import numpy as np

# Configure JAX
import jax
jax.config.update('jax_platform_name', 'cpu')
jax.config.update("jax_enable_x64", True)

# Problem data
m, n = 50, 100
rng = np.random.default_rng(789)

A = jnp.array(rng.standard_normal((m, n)))
b = jnp.array(rng.standard_normal(m))

# Smooth part of objective
def f1(x):
    residual = A @ x - b
    return 0.5 * jnp.sum(residual**2)

# Equality: sum(x) = 1
def h(x):
    return jnp.sum(x) - 1.0

# Inequality: x >= 0
def g(x):
    return -x

# L1 regularization
f2_lbda = 0.1
f2 = pbalm.L1Norm(f2_lbda)

# Create problem
problem = pbalm.Problem(
    f1=f1,
    h=[h],
    g=[g],
    f2=f2,
    jittable=True
)

x0 = jnp.ones(n) / n

result = pbalm.solve(
    problem,
    x0,
    use_proximal=True,
    tol=1e-5,
    max_iter=500
)

print(f"Sum of x: {jnp.sum(result.x):.6f}")
print(f"Min of x: {jnp.min(result.x):.6f}")
print(f"Number of zeros: {jnp.sum(jnp.abs(result.x) < 1e-4)}")
print(f"Objective: {f1(result.x) + f2_lbda * jnp.sum(jnp.abs(result.x)):.6f}")