Result Class#

The Result class stores the output of the PBALM solver.

Class Definition#

class pbalm.result.Result(x, fp_res, kkt_res, total_infeas, f_hist, rho_hist, nu_hist, gamma_hist, prox_hist, solve_status, total_runtime, solve_runtime, grad_evals=None)#

Class to store the results of the PBALM solver.

This object is returned by pbalm.solve() and contains the solution, convergence history, and solver diagnostics.

Attributes#

Solution:

pbalm.result.x: jax.numpy.ndarray#

The solution vector found by the solver. This is the optimal (or final) value of the decision variables.

Solver Status:

pbalm.result.solve_status: str or None#

Status indicating how the solver terminated. Possible values:

  • "Converged": Solver converged to tolerance

  • "Stopped": Solver stopped (user-defined condition)

  • "MaxRuntimeExceeded": Maximum runtime limit reached

  • "NaNOrInf": Numerical issues encountered (NaN or Inf values)

  • None: Solver still running or status not set

Convergence History:

pbalm.result.f_hist: list of float#

History of objective function values at each outer iteration.

pbalm.result.fp_res: list of float#

History of fixed-point residuals. This measures the optimality condition for the composite problem.

pbalm.result.kkt_res: list of float#

History of KKT residuals, measuring overall optimality.

pbalm.result.total_infeas: list of float#

History of total constraint infeasibility at each iteration. Includes both equality and inequality constraint violations.

Parameter History:

pbalm.result.rho_hist: list of float or None#

History of penalty parameters \(\rho\) for equality constraints.

pbalm.result.nu_hist: list of float or None#

History of penalty parameters \(\nu\) for inequality constraints.

pbalm.result.gamma_hist: list of float#

History of proximal parameter \(\gamma\) values.

pbalm.result.prox_hist: list of float#

History of proximal term values.

Timing:

pbalm.result.total_runtime: float#

Total runtime of the solver in seconds, including Phase I (if applicable).

pbalm.result.solve_runtime: float#

Runtime of the main solving phase in seconds (excluding Phase I).

Evaluation Counts:

pbalm.result.grad_evals: list of int or None#

Number of gradient evaluations at each iteration (if tracked).

Example Usage#

Accessing basic solution information:

import pbalm

result = pbalm.solve(problem, x0)

# Get the solution
x_opt = result.x
print(f"Optimal x: {x_opt}")

# Check solver status
if result.solve_status == "Converged":
    print("Solver converged successfully!")
else:
    print(f"Solver status: {result.solve_status}")

Analyzing convergence:

# Objective value convergence
print(f"Initial objective: {result.f_hist[0]:.6f}")
print(f"Final objective: {result.f_hist[-1]:.6f}")

# Constraint satisfaction
print(f"Initial infeasibility: {result.total_infeas[0]:.2e}")
print(f"Final infeasibility: {result.total_infeas[-1]:.2e}")

# Number of iterations
print(f"Iterations: {len(result.f_hist)}")

Plotting convergence:

import matplotlib.pyplot as plt

fig, axes = plt.subplots(2, 2, figsize=(12, 8))

# Objective value
axes[0, 0].semilogy(result.f_hist)
axes[0, 0].set_xlabel('Iteration')
axes[0, 0].set_ylabel('Objective')
axes[0, 0].set_title('Objective Convergence')
axes[0, 0].grid(True)

# Infeasibility
axes[0, 1].semilogy(result.total_infeas)
axes[0, 1].set_xlabel('Iteration')
axes[0, 1].set_ylabel('Infeasibility')
axes[0, 1].set_title('Constraint Satisfaction')
axes[0, 1].grid(True)

# Fixed-point residual
axes[1, 0].semilogy(result.fp_res)
axes[1, 0].set_xlabel('Iteration')
axes[1, 0].set_ylabel('Residual')
axes[1, 0].set_title('Fixed-Point Residual')
axes[1, 0].grid(True)

# Penalty parameters
if result.rho_hist[0] is not None:
    axes[1, 1].semilogy(result.rho_hist, label='rho (equality)')
if result.nu_hist[0] is not None:
    axes[1, 1].semilogy(result.nu_hist, label='nu (inequality)')
axes[1, 1].set_xlabel('Iteration')
axes[1, 1].set_ylabel('Penalty')
axes[1, 1].set_title('Penalty Parameters')
axes[1, 1].legend()
axes[1, 1].grid(True)

plt.tight_layout()
plt.savefig('convergence.png', dpi=150)
plt.show()

Timing analysis:

print(f"Total runtime: {result.total_runtime:.4f} seconds")
print(f"Solve runtime: {result.solve_runtime:.4f} seconds")

if result.total_runtime > result.solve_runtime:
    phase_I_time = result.total_runtime - result.solve_runtime
    print(f"Phase I time: {phase_I_time:.4f} seconds")