# Copyright (c) 2026 - Ayoub Ghriss & Contributors
# Licensed under CC BY-NC 4.0
# (see LICENSE or https://creativecommons.org/licenses/by-nc/4.0/)
# Non-commercial use only; contact us for commercial licensing.
"""Block-level sparsity specification."""
from dataclasses import dataclass, field
from abc import ABC, abstractmethod
from functools import cached_property, lru_cache
from typing import (
Any,
Optional,
Tuple,
List,
Mapping,
Iterable,
Callable,
Union,
)
import math
import torch
from torch import Tensor
from torch.nn import Parameter
from .tensor_ops import interleave_unsqueeze
from .tensor_ops import inverse_permutation, normalize_order
from .tensor_ops import get_dtype_epsilon
from .view import View
Values = Union[Tensor, Mapping[Any, Tensor], None]
class SparseKitError(Exception):
"""Base exception for sparsekit errors."""
class ShapeMismatchError(SparseKitError):
"""Raised when tensor shapes do not match."""
def __init__(
self,
expected: Tuple[int, ...],
got: Tuple[int, ...],
context: str = "",
):
msg = f"Shape mismatch: expected {expected}, got {got}"
super().__init__(f"{context}: {msg}" if context else msg)
class CouplingError(SparseKitError):
"""Raised when coupling constraints are violated."""
[docs]
@dataclass
class SparseBlock(ABC):
"""Abstract base class for block-structured sparse tensors.
Provides interface for viewing tensors as block grids, computing block
statistics, and applying soft/hard thresholding operations.
"""
def _lp_norm_fn(self, t: Tensor, p, keepdim=False) -> Tensor:
"""Compute Lp norm over reduction dimensions."""
return torch.linalg.vector_norm(
t, ord=p, dim=self._reduction_dim, keepdim=keepdim
)
def _min_fn(self, t: Tensor, keepdim=False) -> Tensor:
"""Compute minimum over reduction dimensions."""
return torch.amin(t, dim=self._reduction_dim, keepdim=keepdim)
def _max_fn(self, t: Tensor, keepdim=False) -> Tensor:
"""Compute maximum over reduction dimensions."""
return torch.amax(t, dim=self._reduction_dim, keepdim=keepdim)
[docs]
def norms(self, values: Values = None, p: int = 2) -> Tensor:
"""Compute Lp norm for each block."""
return self.reduce(values, lambda t: self._lp_norm_fn(t, p=p))
[docs]
def min(self, values: Values = None) -> Tensor:
"""Compute minimum value for each block."""
return self.reduce(values, self._min_fn)
[docs]
def max(self, values: Values = None) -> Tensor:
"""Compute maximum value for each block."""
return self.reduce(values, self._max_fn)
@property
def grid_ndim(self) -> int:
"""Number of dimensions in the block grid."""
return len(self.grid_shape)
@cached_property
@abstractmethod
def grid_shape(self) -> Tuple[int, ...]:
"""Shape of the block grid (number of blocks per dimension)."""
pass
[docs]
@abstractmethod
def parameters(self) -> Iterable[View]:
"""Iterable of Parameter objects managed by this node."""
pass
@property
@abstractmethod
def data(self) -> Mapping["BlockSpec", Tensor] | Tensor:
"""Raw tensor data of the underlying parameter(s)."""
pass
[docs]
@abstractmethod
def nnz(self, eps=1e-8) -> int:
"""Count non-zero elements with absolute value > eps."""
pass
[docs]
def numblk(self) -> int:
"""Total number of blocks in the grid."""
return math.prod(self.grid_shape)
[docs]
def numel(self) -> int:
"""Total number of elements across all blocks."""
return self.numblk() * self.numel_per_block()
[docs]
@lru_cache()
@abstractmethod
def numel_per_block(self) -> int:
"""Number of elements per block."""
pass
@cached_property
@abstractmethod
def _reduction_dim(self) -> int | Tuple[int, ...]:
"""Dimension(s) to reduce over when computing block statistics."""
pass
[docs]
@abstractmethod
def apply_mask(self, mask: Tensor) -> None:
"""Zero out blocks where mask is True."""
pass
[docs]
@abstractmethod
def apply_multiplier(self, multiplier: Tensor):
"""Multiply each block by corresponding scalar in multiplier."""
pass
[docs]
@abstractmethod
def block_view(self, values: Values, reorder=True, merge=False) -> Tensor:
"""Return a block-structured view of values."""
pass
[docs]
@abstractmethod
def reduce(
self, values: Values, reduce_fn: Callable[[Tensor], Tensor]
) -> Tensor:
"""Apply reduce_fn over each block and return grid-shaped result."""
pass
def _soft_threshold_euclid(self, block_thresholds: Tensor, eps: float):
"""In-place Euclidean (L2) proximal step."""
assert tuple(block_thresholds.shape) == self.grid_shape
block_scale = 1 - block_thresholds / (self.norms(self.data) + eps)
block_scale.clamp_(min=0.0)
self.apply_multiplier(block_scale)
@abstractmethod
def _soft_threshold_diag_cond(
self,
block_thresholds: Tensor,
conditioners: Values,
max_iter: int,
atol: float,
eps: float,
) -> None:
"""In-place Adam-conditioned proximal step."""
pass
[docs]
@torch.no_grad()
def soft_threshold(
self,
block_thresholds: Tensor,
conditioners: Values = None,
max_iter: int = 20,
atol: float = 1e-8,
eps: float | None = None,
) -> None:
"""Apply soft thresholding to shrink block norms.
Args:
block_thresholds: Per-block threshold values.
conditioners: Optional diagonal conditioner for Adam variant.
max_iter: Maximum iterations for Adam variant.
eps: Small constant for numerical stability.
atol: Absolute tolerance for convergence.
"""
assert tuple(block_thresholds.shape) == self.grid_shape
eps = get_dtype_epsilon(block_thresholds.dtype, eps)
if conditioners is None:
self._soft_threshold_euclid(block_thresholds, eps=eps)
else:
self._soft_threshold_diag_cond(
block_thresholds=block_thresholds,
conditioners=conditioners,
max_iter=max_iter,
atol=atol,
eps=eps,
)
[docs]
@torch.no_grad()
def hard_threshold(self, thresholds: Tensor, values: Values):
"""Zero out blocks with values-based norm below threshold.
Args:
thresholds: Per-block threshold values.
values: Values to compare to; use None to use self norm.
"""
if tuple(thresholds.shape) != self.grid_shape:
raise ValueError(
f"thresholds shape {thresholds.shape} must match "
f"block_grid_size {self.grid_shape}"
)
blocks_to_mask = self.norms(values) < thresholds
self.apply_mask(blocks_to_mask)
[docs]
@abstractmethod
def get_masks(self, block_masks: Tensor) -> Mapping["SparseBlock", Tensor]:
"""Convert block-level mask to element-level masks per BlockSpec."""
pass
@abstractmethod
def __repr__(self) -> str:
pass
def __str__(self) -> str:
return repr(self)
@abstractmethod
def __hash__(self) -> int:
pass
[docs]
@dataclass
class BlockSpec(SparseBlock):
"""Treats the entire tensor as a grid of blocks.
Attributes:
param: Parameter/View providing shape, stride, and write-through data access.
block_shape: Shape of each block in the grid.
name: Optional name for identification.
"""
view: View
shape: Tuple[int, ...]
name: Optional[str] = None
def __post_init__(self):
"""Validate and normalize block shape after initialization."""
self.view = View.from_existing(self.view)
if len(self.shape) == 0: # if block size empty, default to 1
self.shape = tuple([1 for _ in range(self.view.ndim)])
if len(self.shape) != self.view.ndim:
raise ValueError(
f"{self.name} block has len {len(self.shape)}:{self.shape} "
f"but tensor is {self.view.ndim}D:{self.view.shape}"
)
self.shape = tuple(
[
(bi if bi > 0 else self.view.shape[i]) # -1 ~ entire dim
for i, bi in enumerate(self.shape)
]
)
for i, (si, bi) in enumerate(zip(self.view.shape, self.shape)):
if si % bi != 0:
raise ValueError(
f"dim {i}: size {si} not divisible by block_shape[{i}]={bi}"
)
@property
def ndim(self) -> int:
"""Number of dimensions in the underlying tensor."""
return self.view.ndim
@property
def data(self) -> Tensor:
"""Raw tensor data of the underlying Parameter."""
return self.view.data
@cached_property
def grid_shape(self) -> Tuple[int, ...]:
"""Full grid shape including singleton dimensions."""
return tuple(si // bi for si, bi in zip(self.view.shape, self.shape))
[docs]
@lru_cache()
def numel_per_block(self) -> int:
"""Number of elements per block."""
return math.prod(self.shape)
@cached_property
def _reduction_dim(self) -> Tuple[int, ...]:
"""Odd-indexed dimensions to reduce over for block statistics."""
return tuple(range(1, 2 * len(self.shape), 2))
[docs]
def blocks(self) -> Iterable[SparseBlock]:
"""Return self as the only BlockSpec."""
return [self]
[docs]
def parameters(self) -> List[View]:
"""List containing the single underlying Parameter."""
return [self.view]
[docs]
def set_data(self, data):
"""Copy data into the underlying Parameter tensor."""
self.view.data.copy_(data)
[docs]
def nnz(self, eps=1e-8) -> int:
"""Number of *non-zero* elements (within tolerance)."""
return int((self.view.data.abs() > eps).sum().item())
def _resolve_values(self, values: Values) -> Tensor:
"""Resolve values to a view-shaped tensor.
If values is a raw Tensor, applies the same as_strided view
as self.view so that block operations see the correct layout.
"""
if values is None:
return self.view.data
if isinstance(values, dict):
return values[self]
if isinstance(values, Tensor):
if tuple(values.shape) == tuple(self.view.shape):
return values
if tuple(values.shape) == tuple(self.view.param.shape):
return torch.as_strided(
values.contiguous(), self.view.shape, self.view.stride
)
raise ShapeMismatchError(
tuple(self.view.shape), tuple(values.shape), "values"
)
raise ValueError(
"values has to be None, Tensor or Dict[BlockSpec, Tensor]"
)
[docs]
def block_view(
self, values: Values, reorder: bool = True, merge=False
) -> Tensor:
"""Reshape tensor to interleaved block view.
Args:
values: Input values (None uses param.data).
reorder: If True, permute grid dims before block dims.
merge: If True, collapse block dims to trailing dim.
"""
t = self._resolve_values(values)
return View.block_view_of(t, self.shape, reorder=reorder, merge=merge)
[docs]
def expand_block_tensor(self, block_values: Tensor) -> Tensor:
"""Convert grid-shaped tensor to full grid shape with singletons.
Args:
block_values: Tensor with shape block_grid_shape.
Returns:
Tensor reshaped to grid_shape.
"""
return block_values.view(self.grid_shape)
[docs]
def broadcast_block_to_element(
self, block_values: Tensor, fake=False
) -> Tensor:
"""Broadcast block grid-shaped tensor to full tensor shape.
Args:
block_values: Tensor with shape grid_shape.
fake: If True, only unsqueeze without repeating.
Returns:
Tensor with shape self.shape (or interleaved if fake=True).
"""
assert tuple(block_values.shape) == self.grid_shape
return View.broadcast_block_to_element(
block_values, self.shape, fake=fake
)
[docs]
def apply_mask(self, mask: Tensor):
"""Zero out blocks where mask is True."""
self.apply_multiplier(~mask)
[docs]
def apply_multiplier(self, multiplier: Tensor):
"""Multiply each block by corresponding scalar in multiplier."""
assert multiplier.shape == self.grid_shape
multiplier = self.expand_block_tensor(multiplier)
if isinstance(self.view, View):
self.view.apply_multiplier(multiplier, self.shape)
else:
multiplier = interleave_unsqueeze(multiplier)
b_view = View.block_view_of(
self.view.data, self.shape, reorder=False
)
b_view.mul_(multiplier)
[docs]
def reduce(
self, values: Values, reduce_fn: Callable[[Tensor], Tensor]
) -> Tensor:
"""Apply reduce_fn over each block and return grid-shaped result."""
t = self.block_view(values, reorder=False)
return reduce_fn(t).view(self.grid_shape)
def _soft_threshold_diag_cond(
self,
block_thresholds: Tensor,
conditioners: Values,
max_iter,
atol,
eps,
):
"""In-place Adam-conditioned proximal step via bisection."""
assert block_thresholds.shape == self.grid_shape
assert conditioners is not None
conditioner = self._resolve_values(conditioners)
assert isinstance(conditioner, Tensor)
assert conditioner.shape == self.view.shape
if self.numel_per_block() == 1:
return self._soft_threshold_euclid(
block_thresholds / conditioner.view(self.grid_shape), eps=eps
)
if eps is None:
eps = torch.finfo(block_thresholds.dtype).eps
hess_weighted = conditioner * self.view.data
hess_weighted_norms = self.norms(hess_weighted)
denom = hess_weighted_norms - block_thresholds
non_survivors = denom <= 0.0
denom.clamp_(min=0.0).add_(eps)
h_min = self.min(conditioner)
h_max = self.max(conditioner)
mu_low = (block_thresholds * h_min) / denom
mu_high = (block_thresholds * h_max) / denom
# (B1, 1, B2,1,...)
mu_low = self.broadcast_block_to_element(mu_low, fake=True).clamp_(
min=0.0
)
mu_high = self.broadcast_block_to_element(mu_high, fake=True).clamp_(
min=0.0
)
blocked_e_thresholds = self.broadcast_block_to_element(
block_thresholds, fake=True
)
# (B1, b1, B2, b1,...)
blocked_e_conditioner = self.block_view(conditioner, reorder=False)
blocked_hess_weighted = self.block_view(hess_weighted, reorder=False)
mu = (mu_low + mu_high) / 2
for _ in range(max_iter):
# Compute Zeta(mu)
# scaling = H_block / (H_block + mu)
# ||H / (H+mu) v||
# (B1, 1, B2,1,...)
weighted_norm = self._lp_norm_fn(
blocked_hess_weighted / (blocked_e_conditioner + mu),
p=2,
keepdim=True,
)
# zeta = mu * ||weighted_v||
zeta = mu * weighted_norm
# Zeta is strictly increasing with mu.
# If zeta < threshold, mu is too small -> low = mu
# If zeta > threshold, mu is too big -> high = mu
mask_low = zeta < blocked_e_thresholds
mu_low = torch.where(mask_low, mu, mu_low)
mu_high = torch.where(~mask_low, mu, mu_high)
mu = (mu_low + mu_high) / 2
if (mu_low - mu_high).abs().max() < atol:
break
scaling = conditioner / (
conditioner
+ self.broadcast_block_to_element(mu.view(self.grid_shape))
)
self.set_data(scaling * self.view.data)
# only keep survivors
self.apply_mask(non_survivors)
[docs]
def get_masks(self, block_masks: Tensor) -> Mapping["BlockSpec", Tensor]:
"""Convert block-level mask to element-level mask.
Args:
block_masks: Boolean tensor with shape block_grid_shape.
Returns:
Dict mapping self to the broadcasted element mask.
"""
block_masks = self.broadcast_block_to_element(block_masks)
return {self: block_masks}
def __repr__(self) -> str:
"""Return string representation with shape information."""
return (
f"{self.__class__.__name__}(shape={self.shape}, "
f"grid_shape={self.grid_shape}, name={self.name!r})"
)
def __hash__(self) -> int:
"""Hash based on the underlying Parameter instance."""
return hash(self.view)
[docs]
@dataclass
class BlockCoupling(SparseBlock):
"""Merges multiple BlockSpec objects into one coupled sparse node.
Attributes:
specs: List of BlockSpec objects to couple.
orders: Axis permutations to align block grids.
name: Optional name for identification.
"""
specs: List[BlockSpec]
orders: List[Tuple[int, ...]]
name: Optional[str] = None
_ref_order: Tuple[int] = field(init=False)
_reverse_orders: List[Tuple[int, ...]] = field(init=False)
def __post_init__(self):
"""Validate and compute axis orderings for all specs."""
if not self.orders:
self.orders = [tuple(range(len(s.grid_shape))) for s in self.specs]
if len(self.orders) != len(self.specs):
raise ValueError("orders must match number of specs.")
self.orders = [
normalize_order(o, len(s.grid_shape))
for o, s in zip(self.orders, self.specs)
]
ref_permute = tuple(self.specs[0].grid_shape[i] for i in self.orders[0])
self._reverse_orders = []
for s, o in zip(self.specs, self.orders):
gperm = tuple(s.grid_shape[i] for i in o)
if gperm != ref_permute:
s_name = s.name or "<unnamed"
raise ValueError(
"Incompatible block grid shapes "
f"after order: {gperm} vs {ref_permute} (spec {s_name})"
)
self._reverse_orders.append(inverse_permutation(o))
@property
def shape(self) -> Tuple[int, ...]:
"""Return placeholder shape (-1, -1) for coupled specs."""
return (-1, -1)
@property
def data(self) -> Mapping[BlockSpec, Tensor]:
"""Dict mapping each spec to its tensor data."""
return {s: s.view.data for s in self.specs}
@cached_property
def grid_shape(self) -> Tuple[int, ...]:
"""Grid shape for the coupling (after order permutation)."""
return tuple(self.specs[0].grid_shape[i] for i in self.orders[0])
[docs]
@lru_cache()
def numel_per_block(self) -> int:
"""Total elements per block across all specs."""
return sum([s.numel_per_block() for s in self.specs])
@cached_property
def _reduction_dim(self):
"""Reduction dimension for block statistics (last dim)."""
return -1
[docs]
def parameters(self) -> List[View]:
"""List of all Parameter objects from coupled specs."""
return [s.view for s in self.specs]
[docs]
def nnz(self, eps=1e-8) -> int:
"""Count non-zero elements across all specs."""
return sum([s.nnz(eps=eps) for s in self.specs])
def _resolve_values(self, values: Values) -> Mapping[BlockSpec, Tensor]:
"""Resolve values to a mapping of BlockSpec to Tensor."""
if values is None:
return {s: s.view.data for s in self.specs}
if isinstance(values, dict):
return {s: values[s] for s in self.specs}
raise ValueError("values must be Mapping[BlockSpec,Tensor]")
def _raw_block_view(
self, spec_values: Mapping[BlockSpec, Tensor]
) -> Tensor:
"""Reshape and concatenate all spec values into unified block view.
Args:
spec_values: Mapping from BlockSpec to tensor values.
Returns:
Concatenated tensor of shape (B0, B1, ..., total_block_numel).
"""
values = []
for o, s in zip(self.orders, self.specs):
merged = s.block_view(spec_values[s], merge=True)
# merged shape: (*grid_shape, block_numel)
ndim = len(s.grid_shape)
values.append(merged.permute(*o, ndim))
return torch.concat(values, dim=-1)
[docs]
def block_view(
self, values: Values, reorder: bool = True, merge=False
) -> Tensor:
"""Return concatenated block view across all specs."""
spec_values = self._resolve_values(values)
return self._raw_block_view(spec_values)
[docs]
def reduce(
self, values: Values, reduce_fn: Callable[[Tensor], Tensor]
) -> Tensor:
"""Apply reduce_fn over concatenated block view."""
spec_values = self._resolve_values(values)
concat_values = self._raw_block_view(spec_values)
return reduce_fn(concat_values)
def _soft_threshold_euclid(
self, block_thresholds: Tensor, eps: float | None = None
):
"""In-place Euclidean (L2) proximal step."""
assert tuple(block_thresholds.shape) == self.grid_shape
eps = get_dtype_epsilon(block_thresholds.dtype, eps)
block_norms = self.norms({s: s.view.data for s in self.specs})
block_factor = 1 - block_thresholds / (block_norms + eps)
block_factor.clamp_(min=0.0)
self.apply_multiplier(block_factor)
def _soft_threshold_diag_cond(
self,
block_thresholds: Tensor,
conditioners: Values,
max_iter: int = 20,
atol: float = 1e-8,
eps: float | None = None,
) -> None:
"""In-place Adam-conditioned proximal step via bisection."""
assert tuple(block_thresholds.shape) == self.grid_shape
assert isinstance(conditioners, Mapping)
for s in self.specs:
assert conditioners[s].shape == s.view.shape
eps = get_dtype_epsilon(block_thresholds.dtype, eps)
hess_weighted = {s: conditioners[s] * s.view.data for s in self.specs}
hess_weighted_norms = self.norms(hess_weighted)
denom = hess_weighted_norms - block_thresholds
non_survivors = denom <= 0.0
denom.clamp_(min=eps)
h_min = self.min(conditioners)
h_max = self.max(conditioners)
mu_low = ((block_thresholds * h_min) / denom).clamp_(min=0.0)
mu_high = ((block_thresholds * h_max) / denom).clamp_(min=0.0)
mu = (mu_low + mu_high) / 2
for _ in range(max_iter):
# Compute Zeta(mu): scaling = H_block / (H_block + mu)
weighted_vs = {
s: hess_weighted[s]
/ (
conditioners[s]
+ s.broadcast_block_to_element(
mu.permute(ro).reshape(s.grid_shape).contiguous()
)
)
for ro, s in zip(self._reverse_orders, self.specs)
}
weighted_norm = self.norms(weighted_vs)
# zeta = mu * ||weighted_v||
zeta = mu * weighted_norm
# Zeta is strictly increasing with mu.
# If zeta < threshold, mu is too small -> low = mu
# If zeta > threshold, mu is too big -> high = mu
mask_low = zeta < block_thresholds
mu_low = torch.where(mask_low, mu, mu_low)
mu_high = torch.where(~mask_low, mu, mu_high)
mu = (mu_low + mu_high) / 2
if (mu_low - mu_high).abs().max() < atol:
break
for o, s in zip(self._reverse_orders, self.specs):
s.set_data(
s.view.data
* conditioners[s]
/ (
conditioners[s]
+ s.broadcast_block_to_element(
mu.permute(o).reshape(s.grid_shape)
)
)
)
# only keep survivors
self.apply_mask(non_survivors)
[docs]
def get_masks(self, block_masks: Tensor) -> Mapping["BlockSpec", Tensor]:
"""Convert block-level mask to element-level masks for each spec."""
spec_masks = {}
for ro, s in zip(self._reverse_orders, self.specs):
m = block_masks.permute(ro).reshape(s.grid_shape).contiguous()
spec_masks.update(s.get_masks(m))
return spec_masks
[docs]
def apply_mask(self, mask: Tensor):
"""Zero out blocks where mask is True across all specs."""
self.apply_multiplier(~mask)
[docs]
def apply_multiplier(self, multiplier: Tensor):
"""Multiply each block by corresponding scalar across all specs."""
assert (
tuple(multiplier.shape) == self.grid_shape
), "Incompatible Multiplier"
for ro, s in zip(self._reverse_orders, self.specs):
m = multiplier.permute(ro).reshape(s.grid_shape)
s.apply_multiplier(m)
def __repr__(self) -> str:
"""Return string representation with specs info."""
specs_str = ",\n\t".join(str(s) for s in self.specs)
return (
f"{self.__class__.__name__}"
f"(grid_shape={self.grid_shape}, "
f"name={self.name!r}, "
f"BlockSpecs=[\n\t{specs_str}])"
)
def __hash__(self) -> int:
"""Hash based on hashes of all coupled specs."""
return hash(tuple(hash(s) for s in self.specs))