Basis Pursuit#
This example demonstrates how to solve a basis pursuit problem, which seeks the sparsest solution to an underdetermined linear system.
Problem Formulation#
Given a matrix \(B \in \mathbb{R}^{m \times n}\) with \(m < n\) and a vector \(b \in \mathbb{R}^m\), the basis pursuit problem seeks a sparse solution \(z\) such that \(Bz = b\):
By introducing a change of variables \(z = u_1^2 - u_2^2\) where \(u_1, u_2 \geq 0\), we can reformulate the problem as follows [BP1]:
Let \(x = [u_1; u_2]\) and \(B_{\text{big}} = [B, -B]\). The problem becomes:
where \(\circ\) denotes element-wise multiplication.
Implementation#
import jax
import jax.numpy as jnp
import pbalm
import numpy as np
# Configure JAX
jax.config.update('jax_platform_name', 'cpu')
jax.config.update("jax_enable_x64", True)
# Generate basis pursuit data
def get_basis_pursuit_data(m, n, k, key=1234):
"""
Generate random data for basis pursuit problem.
Parameters:
m: Number of measurements
n: Signal dimension
k: Sparsity level (number of nonzero entries)
key: Random seed
Returns:
B: Measurement matrix (m x n)
b: Observation vector (m,)
z_star: Ground truth sparse signal
x_star: Ground truth sparse signal
B_big: Extended matrix [B, -B]
"""
rng = np.random.default_rng(key)
B = jnp.array(rng.standard_normal((m, n)))
# Generate sparse ground truth
z_star = jnp.zeros(n)
support = rng.choice(n, size=k, replace=False)
amplitudes = jnp.array(rng.standard_normal(k))
z_star = z_star.at[support].set(amplitudes)
# Compute observations
b = B @ z_star
# Extended matrix for squared formulation
B_big = jnp.concatenate([B, -B], axis=1)
# Compute optimal x from z_star
z_star_pos = jnp.maximum(z_star, 0.0)
z_star_neg = jnp.maximum(-z_star, 0.0)
u1_star = jnp.sqrt(z_star_pos)
u2_star = jnp.sqrt(z_star_neg)
x_star = jnp.concatenate([u1_star, u2_star], axis=0)
return B, b, z_star, x_star, B_big
# Problem dimensions
m, n, k = 200, 512, 10 # 200 measurements, 512 variables, 10 nonzeros
# Generate data
B, b, z_star, x_star, B_big = get_basis_pursuit_data(m, n, k)
# Define objective function: ||x||^2
def f1(x):
return jnp.sum(x**2)
# Define equality constraint: B_big @ (x^2) = b
def h(x):
return B_big @ (x**2) - b
# Check optimal value
f_star = f1(x_star)
print(f"Optimal objective value: {f_star}")
# Initial point
rng = np.random.default_rng(1234)
x0 = jnp.array(rng.standard_normal(2*n))
f0 = f1(x0)
print(f"Initial objective value: {f0}")
# Define problem
problem = pbalm.Problem(
f1=f1,
h=[h],
jittable=True # Enable JIT compilation
)
# Solve using PBALM
tol = 1e-9 # small tolerance for high accuracy
result = pbalm.solve(
problem,
x0,
use_proximal=True, # Use proximal ALM
tol=tol,
max_iter=300,
alpha=10,
delta=1.0
)
# Results
x_pbalm = result.x
print(f"Solver status: {result.solve_status}")
print(f"Final objective: {f1(x_pbalm)}")
print(f"Relative error: {(f1(x_pbalm) - f_star) / (f0 - f_star):.6e}")
Running this example produces output similar to:
Optimal objective value: 9.393346945133704
Initial objective value: 1045.3144493189802
Initial point is not feasible. Finding a feasible point...
Phase I optimization successful.
iter | f | p. term | total infeas | rho | nu | gamma
------------------------------------------------------------------------------------------
0 | 8.3323e+02 | nan | 2.1452e-09 | 1.0000e-03 | 0.0000e+00 | 1.0000e-01
19 | 9.3933e+00 | 5.1580e-21 | 8.2671e-10 | 6.1311e+09 | 0.0000e+00 | 6.1311e+11
------------------------------------------------------------------------------------------
...
Solver status: Converged
Final objective: 9.393346955103313
Relative error: 9.623908e-12
Key Observations#
Phase I: Since the random initial point is typically infeasible, PBALM first solves a Phase I problem to find a feasible starting point. User may skip this step by setting start_feas=False in pbalm.solve().
Nonconvex formulation: The squared reformulation introduces nonconvexity, which PBALM handles effectively.
Proximal ALM: Using
use_proximal=Trueimproves convergence for this nonconvex problem.Sparse recovery: The solution accurately recovers the sparse ground truth signal.
After solving, recover the original sparse signal:
# Split solution into u1 and u2
u1 = x_pbalm[:n]
u2 = x_pbalm[n:]
# Recover z = u1^2 - u2^2
z_recovered = u1**2 - u2**2
# Compute recovery error
z_error = jnp.linalg.norm(z_recovered - z_star) / jnp.linalg.norm(z_star)
print(f"Signal recovery error: {z_error:.6e}")
# Check sparsity
threshold = 1e-5
nnz = jnp.sum(jnp.abs(z_recovered) > threshold)
print(f"Number of nonzeros in recovered signal: {nnz}")
References#
Sahin, M. F., Alacaoglu, A., Latorre, F., & Cevher, V. (2019). An inexact augmented Lagrangian framework for nonconvex optimization with nonlinear constraints. Advances in Neural Information Processing Systems, 32.