Result Class#
The Result class stores the output of the PBALM solver.
Class Definition#
- class pbalm.result.Result(x, fp_res, kkt_res, total_infeas, f_hist, rho_hist, nu_hist, gamma_hist, prox_hist, solve_status, total_runtime, solve_runtime, grad_evals=None)#
Class to store the results of the PBALM solver.
This object is returned by
pbalm.solve()and contains the solution, convergence history, and solver diagnostics.
Attributes#
Solution:
- pbalm.result.x: jax.numpy.ndarray#
The solution vector found by the solver. This is the optimal (or final) value of the decision variables.
Solver Status:
- pbalm.result.solve_status: str or None#
Status indicating how the solver terminated. Possible values:
"Converged": Solver converged to tolerance"Stopped": Solver stopped (user-defined condition)"MaxRuntimeExceeded": Maximum runtime limit reached"NaNOrInf": Numerical issues encountered (NaN or Inf values)None: Solver still running or status not set
Convergence History:
- pbalm.result.f_hist: list of float#
History of objective function values at each outer iteration.
- pbalm.result.fp_res: list of float#
History of fixed-point residuals. This measures the optimality condition for the composite problem.
- pbalm.result.kkt_res: list of float#
History of KKT residuals, measuring overall optimality.
- pbalm.result.total_infeas: list of float#
History of total constraint infeasibility at each iteration. Includes both equality and inequality constraint violations.
Parameter History:
- pbalm.result.rho_hist: list of float or None#
History of penalty parameters \(\rho\) for equality constraints.
- pbalm.result.nu_hist: list of float or None#
History of penalty parameters \(\nu\) for inequality constraints.
- pbalm.result.gamma_hist: list of float#
History of proximal parameter \(\gamma\) values.
- pbalm.result.prox_hist: list of float#
History of proximal term values.
Timing:
- pbalm.result.total_runtime: float#
Total runtime of the solver in seconds, including Phase I (if applicable).
- pbalm.result.solve_runtime: float#
Runtime of the main solving phase in seconds (excluding Phase I).
Evaluation Counts:
- pbalm.result.grad_evals: list of int or None#
Number of gradient evaluations at each iteration (if tracked).
Example Usage#
Accessing basic solution information:
import pbalm
result = pbalm.solve(problem, x0)
# Get the solution
x_opt = result.x
print(f"Optimal x: {x_opt}")
# Check solver status
if result.solve_status == "Converged":
print("Solver converged successfully!")
else:
print(f"Solver status: {result.solve_status}")
Analyzing convergence:
# Objective value convergence
print(f"Initial objective: {result.f_hist[0]:.6f}")
print(f"Final objective: {result.f_hist[-1]:.6f}")
# Constraint satisfaction
print(f"Initial infeasibility: {result.total_infeas[0]:.2e}")
print(f"Final infeasibility: {result.total_infeas[-1]:.2e}")
# Number of iterations
print(f"Iterations: {len(result.f_hist)}")
Plotting convergence:
import matplotlib.pyplot as plt
fig, axes = plt.subplots(2, 2, figsize=(12, 8))
# Objective value
axes[0, 0].semilogy(result.f_hist)
axes[0, 0].set_xlabel('Iteration')
axes[0, 0].set_ylabel('Objective')
axes[0, 0].set_title('Objective Convergence')
axes[0, 0].grid(True)
# Infeasibility
axes[0, 1].semilogy(result.total_infeas)
axes[0, 1].set_xlabel('Iteration')
axes[0, 1].set_ylabel('Infeasibility')
axes[0, 1].set_title('Constraint Satisfaction')
axes[0, 1].grid(True)
# Fixed-point residual
axes[1, 0].semilogy(result.fp_res)
axes[1, 0].set_xlabel('Iteration')
axes[1, 0].set_ylabel('Residual')
axes[1, 0].set_title('Fixed-Point Residual')
axes[1, 0].grid(True)
# Penalty parameters
if result.rho_hist[0] is not None:
axes[1, 1].semilogy(result.rho_hist, label='rho (equality)')
if result.nu_hist[0] is not None:
axes[1, 1].semilogy(result.nu_hist, label='nu (inequality)')
axes[1, 1].set_xlabel('Iteration')
axes[1, 1].set_ylabel('Penalty')
axes[1, 1].set_title('Penalty Parameters')
axes[1, 1].legend()
axes[1, 1].grid(True)
plt.tight_layout()
plt.savefig('convergence.png', dpi=150)
plt.show()
Timing analysis:
print(f"Total runtime: {result.total_runtime:.4f} seconds")
print(f"Solve runtime: {result.solve_runtime:.4f} seconds")
if result.total_runtime > result.solve_runtime:
phase_I_time = result.total_runtime - result.solve_runtime
print(f"Phase I time: {phase_I_time:.4f} seconds")