# Copyright (c) 2026 - Ayoub Ghriss & Contributors
# Licensed under CC BY-NC 4.0
# (see LICENSE or https://creativecommons.org/licenses/by-nc/4.0/)
# Non-commercial use only; contact us for commercial licensing.
"""Scope-level sparsity specification over block grid."""
from dataclasses import dataclass, field
from typing import Optional, Tuple, List, Set, Mapping, Iterable
from functools import cached_property, lru_cache
import math
from torch import Tensor
import torch
from abc import abstractmethod, ABC
from .view import View
from .block import SparseBlock
from .tensor_ops import merge_odd_dims, append_odd_dims
from .tensor_ops import normalize_order
from .tensor_ops import unmerge_odd_dims
from .tensor_ops import inverse_permutation
from .tensor_ops import kth_largest
from .tensor_ops import mid_kth_largest
from .tensor_ops import get_dtype_epsilon
from .block import Values
from .block import CouplingError
[docs]
class SparseScope(ABC):
"""Abstract base class for scope-level sparsity specifications."""
@lru_cache()
def numscp(self) -> int:
return math.prod(self.grid_shape)
@lru_cache()
def numblk(self) -> int:
return self.numscp() * self.numblk_per_scope()
[docs]
@lru_cache()
@abstractmethod
def numblk_per_scope(self) -> int:
"""Number of blocks per scope."""
pass
[docs]
@lru_cache()
@abstractmethod
def numel_per_scope(self) -> int:
"""Number of elements per scope."""
pass
@abstractmethod
def nnz(self, eps=1e-8) -> int:
pass
@cached_property
@abstractmethod
def grid_shape(self) -> Tuple[int, ...]:
"""Shape of the scope grid (number of scopes per dimension)."""
pass
@abstractmethod
def blocks(self) -> Iterable[SparseBlock]:
pass
[docs]
@abstractmethod
@torch.no_grad()
def hard_threshold(
self,
thresholds: Optional[Tensor] = None,
nnz: Optional[int] = None,
values: Values = None,
):
"""
Zeros out blocks in-place based on scope-level thresholds.
"""
pass
@abstractmethod
@torch.no_grad()
def soft_threshold(
self,
thresholds: Tensor,
conditioners: Values = None,
scale: bool = False,
max_iter: int = 20,
atol: float = 1e-8,
eps: float | None = None,
) -> None:
pass
@abstractmethod
@torch.no_grad()
def get_masks(
self,
nnz: int,
values: Values = None,
block_scores: Tensor | None = None,
block_mask: Tensor | None = None,
**kwargs,
) -> Mapping[SparseBlock, Tensor]:
pass
[docs]
def sparsity_to_nnz(self, sparsity: float) -> int:
"""Convert a sparsity ratio to number of non-zero blocks per scope.
Args:
sparsity: Fraction of blocks to prune (0.0 = keep all,
1.0 = prune all).
Returns:
Number of blocks to keep per scope.
"""
return max(
self.numblk_per_scope()
- int(sparsity * self.numblk_per_scope()),
0,
)
@abstractmethod
def __repr__(self) -> str:
pass
def __str__(self) -> str:
return repr(self)
@abstractmethod
def __hash__(self) -> int:
pass
[docs]
@dataclass
class ScopeSpec(SparseScope):
"""Organizes blocks from a SparseNode into scopes (decision units).
Divides the block grid into scopes of ``shape`` blocks.
Pruning decisions (hard/soft threshold, mask selection) operate at the
scope level: within each scope, blocks compete to survive.
Args:
block: SparseBlock or BlockCoupling that defines the block grid.
Must have ``grid_shape`` divisible by ``shape``.
shape: Number of blocks of scope in each dimension.
Use -1 to span the entire dimension.
name: Optional name for identification.
"""
block: SparseBlock
shape: Tuple[int, ...]
name: Optional[str] = None
def __post_init__(self):
if not self.shape:
self.shape = tuple(-1 for _ in self.block.grid_shape)
# Pad with -1 for missing trailing dimensions
if len(self.shape) < len(self.block.grid_shape):
self.shape = self.shape + tuple(
-1 for _ in range(len(self.block.grid_shape) - len(self.shape))
)
if len(self.shape) != len(self.block.grid_shape):
raise ValueError(
f"scope shape {self.shape} has len {len(self.shape)} "
f"but grid_shape = {self.block.grid_shape}D"
)
self.shape = tuple(
[
self.block.grid_shape[i] if gi == -1 else gi
for i, gi in enumerate(self.shape)
]
)
for i, (block_idx, gi) in enumerate(
zip(self.block.grid_shape, self.shape)
):
if block_idx % gi != 0:
raise ValueError(
f"dim {i}: block_grid[{i}]={block_idx} "
f"not divisible by "
f"block_size[{i}]={gi}"
)
@cached_property
def grid_shape(self) -> Tuple[int, ...]:
"""Full grid shape including singleton dimensions."""
return tuple(
block_idx // gi
for block_idx, gi in zip(self.block.grid_shape, self.shape)
)
[docs]
@lru_cache()
def numel_per_scope(self) -> int:
"""Number of blocks per scope."""
return math.prod(self.shape)
[docs]
@lru_cache()
def numblk_per_scope(self) -> int:
"""Number of blocks per scope."""
return math.prod(self.shape)
[docs]
def blocks(self) -> Iterable[SparseBlock]:
return [self.block]
[docs]
def nnz(self, eps=1e-8) -> int:
return self.block.nnz(eps=eps)
[docs]
def block_to_scope(
self, b: Tensor, reorder: bool = True, merge: bool = False
) -> Tensor:
"""Reshape a block-grid tensor into scope layout.
Args:
b: Tensor with shape ``(B1, B2, ..., Bm, ...)``.
reorder: If True, permute so scope dims precede block dims.
merge: If True, collapse block dims into a single trailing dim.
Returns:
Tensor with shape ``(G1, G2, ..., g1, g2, ..., ...)`` (or merged).
"""
assert b.shape[: len(self.block.grid_shape)] == self.block.grid_shape
inter_shape = [
grid_idx
for pair in zip(self.grid_shape, self.shape)
for grid_idx in pair
]
view = b.view(*inter_shape)
if reorder or merge:
view = append_odd_dims(view)
if merge:
view = merge_odd_dims(view)
return view
[docs]
def scope_to_block(self, block_values: Tensor) -> Tensor:
"""Broadcast scope-level values back to the block grid.
Args:
block_values: Tensor with shape ``(G1, ..., Gm)``.
Returns:
Tensor with shape ``(B1, ..., Bm)`` where ``Bi = Gi * gi``.
"""
assert tuple(block_values.shape) == self.grid_shape
inter_values = block_values.view(self.grid_shape)
for i, gi in enumerate(self.shape): # type: ignore
inter_values = inter_values.unsqueeze(2 * i + 1)
inter_values = inter_values.repeat_interleave(gi, dim=2 * i + 1)
inter_values = inter_values.view(self.block.grid_shape)
return inter_values
[docs]
def block_norms(self, values: Values) -> Tensor:
"""Compute block L2 norms arranged in scope layout.
Args:
values: Element values to compute norms from (None uses param data).
Returns:
Tensor with shape ``(*grid_shape, block_numel)``.
"""
block_norms = self.block.norms(values)
block_norms = self.block_to_scope(block_norms, reorder=False)
merged = merge_odd_dims(block_norms)
return merged
[docs]
def kth_largest(self, element_values: Values, nnz: int) -> Tensor:
"""
Calculates the k-th largest score across all blocks from all specs.
This is used to determine the threshold for pruning.
"""
block_scores = self.block_norms(element_values)
top_scores = kth_largest(block_scores, k=nnz, dim=-1)
top_scores = top_scores.view(self.grid_shape)
return top_scores
[docs]
def kth_mid(self, element_values: Values, nnz: int, k_weight=1.0) -> Tensor:
"""
Calculates the k-th largest score across all blocks from all specs.
This is used to determine the threshold for pruning.
"""
block_scores = self.block_norms(element_values)
top_scores = mid_kth_largest(
block_scores, k=nnz, dim=-1, k_weight=k_weight
)
top_scores = top_scores.view(self.grid_shape)
return top_scores
[docs]
@torch.no_grad()
def hard_threshold(
self,
thresholds: Optional[Tensor] = None,
nnz: Optional[int] = None,
values: Values = None,
):
"""Zero out blocks in-place based on thresholds.
Exactly one of ``thresholds`` or ``nnz`` must be
given.
Args:
thresholds: Per-block thresholds, shape
``grid_shape``.
nnz: Number of blocks to keep per scope.
values: Element values for computing norms.
"""
if thresholds is None:
if nnz is None:
raise ValueError("One of {thresholds, nnz} must be provided")
if nnz == self.numblk_per_scope():
return
thresholds = self.kth_largest(values, nnz=nnz)
assert thresholds.shape == self.grid_shape
block_thresholds = self.scope_to_block(thresholds)
self.block.hard_threshold(block_thresholds, values=values)
[docs]
@torch.no_grad()
def get_masks(
self,
nnz: int,
values: Values = None,
block_scores: Tensor | None = None,
block_mask: Tensor | None = None,
**kwargs,
) -> Mapping[SparseBlock, Tensor]:
"""Compute element-level boolean masks from scope-level scores.
Args:
nnz: Number of blocks to keep per scope.
block_scores: Pre-computed scores with shape
``(*grid_shape, block_numel)``.
values: Element values for computing
block norms (if scores not given).
block_mask: Pre-computed boolean mask to use directly.
Returns:
Dict mapping each SparseBlock to its element-level boolean mask.
"""
if block_mask is None:
if block_scores is None:
block_scores = self.block_norms(values)
else:
assert block_scores.shape == self.grid_shape + (
self.numblk_per_scope(),
)
thresholds = kth_largest(block_scores, k=nnz, dim=-1).unsqueeze(-1)
block_mask = block_scores >= thresholds
block_mask = unmerge_odd_dims(
block_mask.view(self.grid_shape + (self.numblk_per_scope(),)),
self.shape,
)
block_mask = block_mask.view(self.block.grid_shape)
return self.block.get_masks(block_mask)
[docs]
@torch.no_grad()
def soft_threshold(
self,
thresholds: Tensor,
conditioners: Values = None,
scale: bool = False,
max_iter: int = 20,
atol: float = 1e-8,
eps: float | None = None,
) -> None:
"""Apply soft thresholding (L1 proximal operator) to scopes in-place.
Args:
thresholds: Per-scope threshold values
with shape ``grid_shape``.
conditioners: Diagonal preconditioner per SparseBlock.
scale: If True, scale thresholds by sqrt(block_numel).
max_iter: Maximum bisection iterations for Adam variant.
eps: Small constant for numerical stability.
atol: Absolute tolerance for convergence.
"""
assert tuple(thresholds.shape) == self.grid_shape
eps = get_dtype_epsilon(thresholds.dtype, eps)
block_thresholds = self.scope_to_block(thresholds)
if scale:
block_thresholds = block_thresholds * (self.block.numel() ** 0.5)
self.block.soft_threshold(
block_thresholds,
conditioners=conditioners,
max_iter=max_iter,
atol=atol,
eps=eps,
)
def __repr__(self):
return (
f"{self.__class__.__name__}[block_shape={self.shape}, "
f"grid_shape={self.grid_shape}, "
f"name={self.name}], "
f"block={self.block}"
)
def __hash__(self) -> int:
return hash((hash(self.block), self.shape))
[docs]
@dataclass
class ScopeCoupling(SparseScope):
"""Couples multiple ScopeSpec instances for joint pruning.
Aligns scope grids from different parameters via dimension permutations
(``orders``) so they share a common grid shape.
Within each aligned scope, blocks from all specs compete to survive pruning.
Args:
scopes: List of ScopeSpec instances to couple.
orders: Dimension permutations to align each scope's grid.
Identity permutation if omitted.
name: Optional name for identification.
"""
scopes: List[ScopeSpec]
orders: List[Tuple[int, ...]]
name: Optional[str] = None
_ref_order: Tuple[int] = field(init=False)
_ref_scope_grid_shape: Tuple[int, ...] = field(init=False)
_reverse_orders: List[Tuple[int, ...]] = field(init=False)
def __post_init__(self):
if not self.orders:
self.orders = [tuple(range(len(g.grid_shape))) for g in self.scopes]
if len(self.orders) != len(self.scopes):
raise ValueError("orders must match number of specs.")
self.orders = [
normalize_order(o, len(g.grid_shape))
for o, g in zip(self.orders, self.scopes)
]
self._ref_order = self.orders[0] # type: ignore
self._ref_scope_grid_shape = ref_permute = tuple( # type: ignore
self.scopes[0].grid_shape[i] for i in self._ref_order
)
self._reverse_orders = []
for g, o in zip(self.scopes, self.orders):
gperm = tuple(g.grid_shape[i] for i in o)
if gperm != ref_permute:
s_name = g.name or "<unnamed>"
raise CouplingError(
"Incompatible scope shapes "
f"after order: {gperm} vs {ref_permute} (spec {s_name})"
)
self._reverse_orders.append(inverse_permutation(o))
# ── Properties ────────────────────────────────────────────────────
[docs]
@lru_cache()
def numblk(self) -> int:
"""Total number of blocks across all coupled scopes."""
return sum([s.numblk() for s in self.scopes])
@property
def params(self) -> Set[View]:
"""Expose underlying views for optimizer integration."""
return {p for g in self.scopes for p in g.block.parameters()}
[docs]
def blocks(self) -> Iterable[SparseBlock]:
"""All SparseBlock instances across all child scopes."""
merged = []
for sc in self.scopes:
merged.extend(sc.blocks())
return merged
@cached_property
def grid_shape(self) -> Tuple[int, ...]:
"""Reference scope grid shape (after order permutation)."""
return self._ref_scope_grid_shape
[docs]
@lru_cache()
def numblk_per_scope(self) -> int:
"""Total blocks per scope across all coupled scopes."""
return sum(s.numblk_per_scope() for s in self.scopes)
[docs]
@lru_cache()
def numel_per_scope(self) -> int:
return sum([s.numel_per_scope() for s in self.scopes])
[docs]
def nnz(self, eps=1e-8) -> int:
return sum(g.nnz(eps=eps) for g in self.scopes)
[docs]
def specs(self) -> Iterable[SparseBlock]:
return [s for g in self.scopes for s in g.blocks()]
[docs]
def block_norms(self, values: Values) -> Tensor:
blocked_block_norms = torch.cat(
[
g.block_norms(values).permute(o + (len(o),))
for o, g in zip(self.orders, self.scopes)
],
dim=-1,
)
assert blocked_block_norms.shape[:-1] == self.grid_shape
return blocked_block_norms
[docs]
def kth_largest(
self,
k: int,
values: Values,
) -> Tensor:
"""
Calculates the k-th largest score across all blocks from all specs.
This is used to determine the threshold for pruning.
"""
blocked_scores = self.block_norms(values)
return kth_largest(blocked_scores, k=k, dim=-1)
[docs]
@torch.no_grad()
def hard_threshold(
self,
thresholds: Optional[Tensor] = None,
nnz: Optional[int] = None,
values: Values = None,
):
"""Compute kappa-largest block norm among coupled
scopes from all specs then sends the threshold
to specs to hard-threshold in-place.
Note that the threshold is across coupled scopes,
so some parameters might be pruned more than
others (it's expected).
"""
if thresholds is None:
if nnz is None:
raise ValueError("Either thresholds or nnz")
thresholds = self.kth_largest(k=nnz, values=values)
assert thresholds.shape == self.grid_shape
for ro, s in zip(self._reverse_orders, self.scopes):
s.hard_threshold(
thresholds=thresholds.permute(ro).reshape(s.grid_shape)
)
[docs]
@torch.no_grad()
def soft_threshold(
self,
thresholds: Tensor,
conditioners: Values = None,
scale: bool = False,
max_iter: int = 20,
atol: float = 1e-8,
eps: float | None = None,
) -> None:
"""Performs soft thresholding on all coupled parameters."""
assert thresholds.shape == self.grid_shape
for ro, g in zip(self._reverse_orders, self.scopes):
g.soft_threshold(
thresholds.permute(ro).reshape(g.grid_shape),
conditioners=conditioners,
scale=scale,
max_iter=max_iter,
atol=atol,
eps=eps,
)
[docs]
@torch.no_grad()
def get_masks(
self,
nnz: int,
block_scores: Tensor | None = None,
values: Values = None,
**kwargs,
) -> Mapping[SparseBlock, Tensor]:
if block_scores is None:
block_scores = self.block_norms(values)
else:
assert block_scores.shape == self.grid_shape + (
self.numblk_per_scope(),
)
indices = torch.topk(block_scores, k=nnz, dim=-1)[1]
block_mask = torch.zeros_like(block_scores).bool()
block_mask.scatter_(-1, indices, True)
spec_masks = {}
slice_start = 0
for ro, g in zip(self._reverse_orders, self.scopes):
block_slice = block_mask[
..., slice_start : slice_start + g.numblk_per_scope()
]
spec_masks.update(
g.get_masks(
nnz=0, block_mask=block_slice.permute(ro + (len(ro),))
)
)
slice_start += g.numblk_per_scope()
return spec_masks
def __hash__(self):
return hash(tuple(hash(g) for g in self.scopes))
def __repr__(self):
parts = ", ".join(str(s) for s in self.scopes)
return f"ScopeCoupling(" f"orders={self.orders}, {parts})"