Linear Algebra

Linear algebra utilities for sparse optimization.

sparsekit.linalg.lsqr_gkl(A, b, max_iter=1000, tol=1e-06, x_0=None, device=None)[source]

Solves min ||Ax - b||_2 using the LSQR algorithm, leveraging PyTorch for GPU and Golub-Kahan-Lanczos bidiagonalization.

Parameters:
  • A (Tensor) – The matrix A (torch.Tensor).

  • b (Tensor) – The right-hand side vector b (torch.Tensor).

  • max_iter (int) – Maximum number of iterations (int).

  • tol (float) – Tolerance for convergence (float).

  • x_0 (Tensor | None) – Initial guess for the solution (torch.Tensor, optional).

  • device (device | None) – Device to perform computations on (torch.device, optional).

Returns:

The solution vector (torch.Tensor). info: A dictionary containing information about the convergence.

Return type:

x

sparsekit.linalg.hard_threshold(vec, alpha, k)[source]

Keep the k largest elements of vec, selected by magnitude of alpha.

Parameters:
  • vec (Tensor) – Values tensor.

  • alpha (Tensor) – Scores tensor (same shape); top-k selected by these magnitudes.

  • k (int) – Number of elements to keep.

Returns:

Tensor with only the k largest-alpha entries of vec; rest zeroed.

Return type:

Tensor

sparsekit.linalg.soft_threshold(vec, threshold)[source]

Element-wise soft-thresholding: sign(x) * max(abs(x) - threshold, 0).

Parameters:
  • vec (Tensor) – Input tensor.

  • threshold (Tensor | float) – Threshold value (scalar or broadcastable tensor).

Returns:

Soft-thresholded tensor.

Return type:

Tensor

sparsekit.linalg.solve_proximal_adam(v_elements, hessian_elements, thresholds, eps=1e-06, max_iter=10)[source]

Solve for mu in the Proximal Adam equation.

Uses bisection search.

Parameters:
  • v_elements (Tensor) – Shape (s1,s2,…,sm). The dense weights (or updates).

  • hessian_elements (Tensor) – Shape (s1,s2,…,sm). The Adam preconditioner (sqrt(v) + eps).

  • thresholds (Tensor) – Shape (num_blocks). The target value (eta * lambda).

  • eps (float)

  • max_iter (int)

Returns:

Shape (num_blocks, 1). The scaling factor.

Return type:

mu