Source code for sparsekit.linalg

# Copyright (c) 2026 - Ayoub Ghriss & Contributors
# Licensed under CC BY-NC 4.0
# (see LICENSE or https://creativecommons.org/licenses/by-nc/4.0/)
# Non-commercial use only; contact us for commercial licensing.
"""Linear algebra utilities for sparse optimization."""

import torch


def _norm(x, **kwargs):
    """Wrapper around torch.linalg.norm."""
    # pylint: disable=not-callable
    return torch.linalg.norm(x, **kwargs)


[docs] def lsqr_gkl( A: torch.Tensor, b: torch.Tensor, max_iter: int = 1000, tol: float = 1e-6, x_0: torch.Tensor | None = None, device: torch.device | None = None, ): """ Solves min ||Ax - b||_2 using the LSQR algorithm, leveraging PyTorch for GPU and Golub-Kahan-Lanczos bidiagonalization. Args: A: The matrix A (torch.Tensor). b: The right-hand side vector b (torch.Tensor). max_iter: Maximum number of iterations (int). tol: Tolerance for convergence (float). x_0: Initial guess for the solution (torch.Tensor, optional). device: Device to perform computations on (torch.device, optional). Returns: x: The solution vector (torch.Tensor). info: A dictionary containing information about the convergence. """ assert max_iter > 2, ( "max_iter must be greater than 2 for LSQR to work properly." ) if device is None: device = A.device n = A.shape[1] A = A.to(device) b = b.to(device) # initial guess if x_0 is None: x = torch.zeros(n, dtype=b.dtype, device=device) else: x = x_0.to(device) # initial residual norm_b = _norm(b) r = b - A @ x beta = _norm(r) if beta == 0: return x, { "converged": True, "iterations": 0, "residual": 0.0, "status": "rhs is zero", } u = r / beta # Initial Lanczos vector v1 v = A.t() @ u alpha = _norm(v) if alpha == 0: return x, { "converged": True, "iterations": 0, "rel_res": 0.0, "status": "A.T @ u is zero", } v = v / alpha # working variables phi_bar = beta rho_bar = alpha w = v # history rel_res = (alpha * phi_bar).abs().item() new_rel_res = 0.0 new_tol = 0.0 # LSQR iteration for k in range(max_iter): # bidiagonalization step (Golub-Kahan) u_prev = u u = A @ v - alpha * u_prev beta = _norm(u) if beta == 0: break u = u / beta v_prev = v v = A.t() @ u - beta * v_prev alpha = _norm(v) if alpha == 0: break v = v / alpha # apply Givens rotations rho = torch.sqrt(rho_bar**2 + beta**2) c = rho_bar / rho s = beta / rho theta = s * alpha rho_bar = c * alpha phi = c * phi_bar phi_bar = -s * phi_bar # update solution and auxiliary w x = x + (phi / rho) * w w = v - (theta / rho) * w # Check for convergence new_rel_res = (torch.abs(phi_bar) / norm_b).item() new_tol = abs(new_rel_res - rel_res) if new_tol < tol: return x, { "converged": True, "iterations": k + 1, "tol": new_tol, "rel_res": new_rel_res, "status": "tolerance met", } rel_res = new_rel_res return x, { "converged": False, "iterations": max_iter, "tol": new_tol, "rel_res": rel_res, "status": "max iters reached", }
[docs] def hard_threshold( vec: torch.Tensor, alpha: torch.Tensor, k: int ) -> torch.Tensor: """Keep the k largest elements of vec, selected by magnitude of alpha. Args: vec: Values tensor. alpha: Scores tensor (same shape); top-k selected by these magnitudes. k: Number of elements to keep. Returns: Tensor with only the k largest-alpha entries of vec; rest zeroed. """ if k >= vec.numel(): return vec _, indices = torch.topk(alpha, k) result = torch.zeros_like(vec) result[indices] = vec[indices] return result
[docs] def soft_threshold( vec: torch.Tensor, threshold: torch.Tensor | float ) -> torch.Tensor: """Element-wise soft-thresholding: ``sign(x) * max(abs(x) - threshold, 0)``. Args: vec: Input tensor. threshold: Threshold value (scalar or broadcastable tensor). Returns: Soft-thresholded tensor. """ return torch.sign(vec) * torch.nn.functional.relu( torch.abs(vec) - threshold )
[docs] @torch.no_grad() def solve_proximal_adam( v_elements: torch.Tensor, hessian_elements: torch.Tensor, thresholds: torch.Tensor, eps: float = 1e-6, max_iter: int = 10, ) -> torch.Tensor: """Solve for mu in the Proximal Adam equation. Uses bisection search. Args: v_elements: Shape (s1,s2,...,sm). The dense weights (or updates). hessian_elements: Shape (s1,s2,...,sm). The Adam preconditioner (sqrt(v) + eps). thresholds: Shape (num_blocks). The target value (eta * lambda). Returns: mu: Shape (num_blocks, 1). The scaling factor. """ # 1. Compute Norms ||H * v||_2 hess_weighted = hessian_elements * v_elements hess_weighted_norms = _norm(hess_weighted) # 2. Identify Survivors # If ||Hv|| <= threshold, the optimal weight is 0. # We only solve for the rest. We add a small epsilon # to threshold to avoid division by zero. is_survivor = hess_weighted_norms > thresholds # Prepare output mu_solutions = torch.zeros_like(hess_weighted_norms) # Filter to active blocks to save compute indices = torch.nonzero( is_survivor.squeeze() ).squeeze() if indices.numel() == 0: return mu_solutions # All zero v_active = v_elements[indices] mask_active = hessian_elements[indices] thresh_active = thresholds[indices] norm_active = hess_weighted_norms[indices] # 3. Compute Bounds (from the derivation) # S = ||Hv||^2, so sqrt(S) = norm_active # mu_low = (lambda * h_min) / (sqrt(S) - lambda) denom = norm_active - thresh_active + eps h_min = mask_active.min( dim=1, keepdim=True ).values h_max = mask_active.max( dim=1, keepdim=True ).values mu_low = (thresh_active * h_min) / denom mu_high = (thresh_active * h_max) / denom # We search for mu such that Zeta(mu) = threshold # Zeta(mu) = mu * || (H / (H + mu)) * v || low = mu_low high = mu_high for _ in range(max_iter): mu = (low + high) / 2 # Compute Zeta(mu) # scaling vector = M / (M + mu) scaling = mask_active / (mask_active + mu) # weighted_v = scaling * v # zeta = mu * ||weighted_v|| weighted_norm = _norm( scaling * v_active, dim=1, keepdim=True ) zeta = mu * weighted_norm # Update bounds # Zeta is strictly increasing with mu. # If zeta < threshold, mu is too small -> low = mu # If zeta > threshold, mu is too big -> high = mu mask_low = zeta < thresh_active low = torch.where(mask_low, mu, low) high = torch.where(~mask_low, mu, high) # Final estimate mu_solutions[indices] = (low + high) / 2 return mu_solutions