# Copyright (c) 2026 - Ayoub Ghriss & Contributors
# Licensed under CC BY-NC 4.0
# (see LICENSE or https://creativecommons.org/licenses/by-nc/4.0/)
# Non-commercial use only; contact us for commercial licensing.
"""Structured OBS (Optimal Brain Surgeon) via BlockSpec/ScopeSpec.
Compensation modes: ``local``, ``full``, ``split``, ``interleaved``.
"""
import math
from itertools import combinations
import torch
import torch.linalg as LA
from torch import Tensor
from ..view import View
from ..block import BlockSpec
from ..scope import ScopeSpec
def _block_flat_offsets(
block: BlockSpec,
device: torch.device = torch.device("cpu"),
) -> Tensor:
"""Compute flat storage offsets for every element of every block.
Returns:
flat_offsets: ``(*grid_shape, block_numel)`` long tensor.
"""
grid_shape = block.grid_shape
block_shape = block.shape
rank = len(block_shape)
ranges = [torch.arange(b, device=device) for b in block_shape]
offsets = torch.stack(
torch.meshgrid(*ranges, indexing="ij"), dim=-1
).reshape(-1, rank)
grid_ranges = [torch.arange(g, device=device) for g in grid_shape]
grid_pts = torch.stack(torch.meshgrid(*grid_ranges, indexing="ij"), dim=-1)
bs = torch.tensor(block_shape, device=device)
elem_idx = grid_pts.unsqueeze(-2) * bs + offsets
param = block.view
if isinstance(param, View):
return param.linear_offset(elem_idx)
else:
strides = torch.tensor(
param.data.stride(), device=device, dtype=torch.long
)
return (elem_idx * strides).sum(dim=-1)
[docs]
def block_col_indices(
block: BlockSpec,
num_cols: int,
device: torch.device = torch.device("cpu"),
) -> Tensor:
"""Map each block to its original column indices.
Returns:
col_idx: ``(*grid_shape[1:], block_numel)``
long tensor.
"""
flat_offsets = _block_flat_offsets(block, device)
col_idx_full = flat_offsets % num_cols
col_grid = block.grid_shape[1:]
row_slice = tuple(0 for _ in range(len(block.grid_shape) - len(col_grid)))
return col_idx_full[row_slice]
[docs]
def block_param_rc(
block: BlockSpec,
num_cols: int,
device: torch.device = torch.device("cpu"),
) -> tuple[Tensor, Tensor]:
"""Map each block to ``(param_row, param_col)``.
Unlike :func:`block_col_indices` this returns the
**full** grid (including the row dimension) and both
row and column indices.
Returns:
row_idx: ``(*grid_shape, block_numel)`` long
tensor.
col_idx: ``(*grid_shape, block_numel)`` long
tensor.
"""
flat_offsets = _block_flat_offsets(block, device)
return (
flat_offsets // num_cols,
flat_offsets % num_cols,
)
[docs]
class StructuredOBS:
"""Structured OBS pruner operating through ScopeSpec.
Args:
scope: ScopeSpec defining the block and scope
structure.
hessian: (K, K) Hessian matrix.
damp: Damping factor for hessian regularization.
inv_h: Precomputed (K, K) damped inverse. Skips
inversion if given.
"""
def __init__(
self,
scope: ScopeSpec,
hessian: Tensor,
damp: float = 1e-4,
inv_h: Tensor | None = None,
):
self.scope = scope
block = scope.block
self.block = block
assert isinstance(block, BlockSpec)
# param = block.view
self.hessian = hessian
self.damp = damp
self.num_cols = hessian.shape[0]
device = hessian.device
self.W = block.view.param
# if isinstance(param, View):
# self.W = param.param
# else:
# self.W = param
self.M = self.W.shape[0]
self.bk = block.numel_per_block()
grid_shape = block.grid_shape
scope_shape = scope.shape
scope_grid = scope.grid_shape
self.rows_per_scope_row = block.shape[0] * scope_shape[0]
self.num_scope_rows = scope_grid[0]
col_grid = grid_shape[1:]
self.total_blocks = math.prod(col_grid)
self.blocks_per_scope = (
math.prod(scope_shape[1:]) if len(scope_shape) > 1 else 1
)
self.num_scopes_per_row = (
math.prod(scope_grid[1:]) if len(scope_grid) > 1 else 1
)
# Block column indices: (total_blocks, bk)
col_idx_full = block_col_indices(block, self.num_cols, device=device)
self.col_idx = col_idx_full.reshape(self.total_blocks, self.bk)
if self.bk == 1:
self.col_idx_flat = self.col_idx.squeeze(-1)
# Block mapping
self.block_to_scope, _ = self._build_block_mapping(
col_grid, scope_shape[1:], scope_grid[1:], device
)
# Row coupling detection: do blocks within the same scope span
# different param rows? (e.g. block-16 with 8-row coupling)
row_full, _ = block_param_rc(block, self.num_cols, device=device)
row_slice = tuple(0 for _ in range(len(grid_shape) - len(col_grid)))
row_cg = row_full[row_slice].reshape(self.total_blocks, self.bk)
sort_perm = torch.argsort(self.block_to_scope)
gs = self.blocks_per_scope
num_parts = self.num_scopes_per_row
sorted_row_base = row_cg[sort_perm].view(num_parts, gs, self.bk)
block_rows = sorted_row_base[:, :, 0]
self.row_coupled = bool(
(block_rows.max(1).values - block_rows.min(1).values > 0).any()
)
if self.row_coupled:
self.block_sort_perm = sort_perm
self.full_row_idx = row_full
self.block_col_map = self.col_idx[sort_perm].view(
num_parts, gs, self.bk
)
if inv_h is not None:
self.inv_hessian = inv_h
else:
self.inv_hessian = self.compute_inverse(hessian, damp)
[docs]
@staticmethod
def compute_inverse(hessian: Tensor, damp: float = 1e-4) -> Tensor:
"""Compute damped inverse of the Hessian matrix.
Args:
hessian: (K, K) symmetric positive semi-definite
Hessian.
damp: Damping factor. If < 1.0, scaled by mean
diagonal; otherwise used as absolute value.
Returns:
(K, K) inverse of the damped Hessian:
``(H + damp * I)^{-1}``.
"""
num_cols = hessian.shape[0]
device = hessian.device
damp_val = (
damp * torch.mean(torch.diag(hessian)) if damp < 1.0 else damp
)
hessian_reg = hessian.clone()
diag_idx = torch.arange(num_cols, device=device)
hessian_reg[diag_idx, diag_idx] += damp_val
return LA.inv(hessian_reg) # pylint: disable=not-callable
@staticmethod
def _build_block_mapping(
col_grid, block_shape_rest, block_grid_rest, device
):
"""Build mapping from block index to scope index."""
total_blocks = math.prod(col_grid)
if len(col_grid) == 0:
return (
torch.zeros(1, dtype=torch.long, device=device),
torch.zeros(1, dtype=torch.long, device=device),
)
ranges = [torch.arange(g, device=device) for g in col_grid]
bc_grid = torch.stack(
torch.meshgrid(*ranges, indexing="ij"), dim=-1
).reshape(total_blocks, -1)
gs_rest = list(block_shape_rest)
gg_rest = list(block_grid_rest)
block_pos = bc_grid // torch.tensor(gs_rest, device=device)
gg_strides = []
s = 1
for g in reversed(gg_rest):
gg_strides.append(s)
s *= g
gg_strides.reverse()
block_to_scope = (
block_pos * torch.tensor(gg_strides, device=device)
).sum(dim=-1)
within_pos = bc_grid % torch.tensor(gs_rest, device=device)
gs_strides = []
s = 1
for g in reversed(gs_rest):
gs_strides.append(s)
s *= g
gs_strides.reverse()
block_within_scope = (
within_pos * torch.tensor(gs_strides, device=device)
).sum(dim=-1)
return block_to_scope, block_within_scope
def _find_best_subsets( # pylint: disable=too-many-locals
self,
W,
block_cols,
prune_subsets,
keep_subsets, # pylint: disable=unused-argument
block_size,
inv_h=None,
):
"""Phase 1: find optimal pruning subset per (row, block).
Args:
inv_h: Optional inverse matrix. Uses
self.inv_hessian if not provided. When provided,
block_cols must be in inv_h's coordinate space.
Returns:
best_si: ``(M, G_input)`` long tensor -- index into
prune_subsets.
"""
device = W.device
M, num_cols = W.shape
gs = self.blocks_per_scope
bk = self.bk
if inv_h is None:
inv_h = self.inv_hessian
num_prune = prune_subsets.shape[1]
n_subsets = prune_subsets.shape[0]
elem_per_block = gs * bk if bk > 1 else gs
num_parts = block_cols.shape[0]
if bk == 1:
block_start = block_cols[:, 0]
else:
block_start = block_cols[:, 0, 0]
best_si_full = torch.zeros(
M, num_parts, dtype=torch.long, device=device
)
for b_start in range(0, num_cols, block_size):
b_end = min(b_start + block_size, num_cols)
in_block = (block_start >= b_start) & (block_start < b_end)
gidx = torch.where(in_block)[0]
num_batch_parts = gidx.shape[0]
if num_batch_parts == 0:
continue
if bk == 1:
cols = block_cols[gidx]
flat = cols.view(-1)
else:
cols = block_cols[gidx]
flat = cols.view(num_batch_parts, elem_per_block).view(-1)
if bk == 1:
inv_sub = inv_h[
cols.unsqueeze(-1),
cols.unsqueeze(-2),
]
else:
fc = cols.view(num_batch_parts, elem_per_block)
inv_sub = inv_h[fc.unsqueeze(-1), fc.unsqueeze(-2)]
weight_block = W[:, flat].view(M, num_batch_parts, elem_per_block)
eye_np = 1e-8 * torch.eye(num_prune * bk, device=device)
all_costs = torch.empty(
n_subsets,
M,
num_batch_parts,
device=device,
)
for si in range(n_subsets):
pidx = prune_subsets[si]
if bk == 1:
fp = pidx
else:
fp = (
pidx.unsqueeze(-1) * bk
+ torch.arange(bk, device=device)
).view(-1)
inv_pruned = inv_sub[:, fp][:, :, fp]
inv_pruned_inv = LA.inv( # pylint: disable=not-callable
inv_pruned + eye_np
)
weight_pruned = weight_block[:, :, fp]
temp = torch.einsum(
"mgp,gpq->mgq",
weight_pruned,
inv_pruned_inv,
)
all_costs[si] = (temp * weight_pruned).sum(dim=2)
best_si_full[:, gidx] = all_costs.argmin(dim=0)
return best_si_full
def _compensate_local(
self,
W,
block_cols,
prune_subsets,
keep_subsets,
best_si,
block_size,
):
"""Within-block compensation only."""
device = W.device
M, num_cols = W.shape
gs = self.blocks_per_scope
bk = self.bk
inv_h = self.inv_hessian
num_prune = prune_subsets.shape[1]
n_subsets = prune_subsets.shape[0]
elem_per_block = gs * bk if bk > 1 else gs
if bk == 1:
block_start = block_cols[:, 0]
else:
block_start = block_cols[:, 0, 0]
for b_start in range(0, num_cols, block_size):
b_end = min(b_start + block_size, num_cols)
in_block = (block_start >= b_start) & (block_start < b_end)
gidx = torch.where(in_block)[0]
num_batch_parts = gidx.shape[0]
if num_batch_parts == 0:
continue
if bk == 1:
cols = block_cols[gidx]
flat = cols.view(-1)
else:
cols = block_cols[gidx]
flat = cols.view(num_batch_parts, elem_per_block).view(-1)
if bk == 1:
inv_sub = inv_h[
cols.unsqueeze(-1),
cols.unsqueeze(-2),
]
else:
fc = cols.view(num_batch_parts, elem_per_block)
inv_sub = inv_h[fc.unsqueeze(-1), fc.unsqueeze(-2)]
weight_block = W[:, flat].view(M, num_batch_parts, elem_per_block)
weight_block_new = weight_block.clone()
eye_np = 1e-8 * torch.eye(num_prune * bk, device=device)
block_best_si = best_si[:, gidx]
for si in range(n_subsets):
pidx = prune_subsets[si]
kidx = keep_subsets[si]
if bk == 1:
fp, fk = pidx, kidx
else:
fp = (
pidx.unsqueeze(-1) * bk
+ torch.arange(bk, device=device)
).view(-1)
fk = (
kidx.unsqueeze(-1) * bk
+ torch.arange(bk, device=device)
).view(-1)
mask = block_best_si == si
if not mask.any():
continue
inv_pruned = inv_sub[:, fp][:, :, fp]
inv_pruned_inv = LA.inv( # pylint: disable=not-callable
inv_pruned + eye_np
)
inv_keep_prune = inv_sub[:, fk][:, :, fp]
comp = torch.bmm(inv_keep_prune, inv_pruned_inv)
weight_pruned = weight_block[:, :, fp]
delta_k = torch.einsum(
"mgp,gnp->mgn",
weight_pruned,
comp,
)
mask_exp = mask.unsqueeze(-1)
weight_block_new[:, :, fk] -= delta_k * mask_exp
weight_block_new[:, :, fp] = weight_block_new[
:, :, fp
].masked_fill(mask_exp, 0.0)
W.scatter_(
1,
flat.unsqueeze(0).expand(M, -1),
weight_block_new.reshape(M, num_batch_parts * elem_per_block),
)
def _compensate_full(self, W, block_cols, prune_subsets, best_si):
"""Sequential full-column compensation.
Uses inv_hessian. For each block, compensate ALL K
columns (not just within-block), processing blocks
sequentially so each sees updated W.
"""
device = W.device
M, num_cols = W.shape
num_parts = self.num_scopes_per_row
bk = self.bk
inv_h = self.inv_hessian
num_prune = prune_subsets.shape[1]
n_subsets = prune_subsets.shape[0]
np_bk = num_prune * bk
# Precompute compensation matrices for all
# (block, subset) pairs:
# comp[g, si] = inv(C[P,P]) @ C[P, :]
# shape (np*bk, K)
# where P = absolute column indices of pruned
# blocks in scope g, subset si.
# Pruned column indices per (block, subset):
# (num_parts, n_subsets, np*bk)
if bk == 1:
pruned_cols = block_cols[:, prune_subsets]
else:
pruned_blk = block_cols[:, prune_subsets]
pruned_cols = pruned_blk.view(num_parts, n_subsets, np_bk)
inv_pruned = inv_h[
pruned_cols[:, :, :, None],
pruned_cols[:, :, None, :],
]
eye = 1e-8 * torch.eye(np_bk, device=device)
inv_pruned_inv = torch.inverse(inv_pruned + eye)
inv_prune_rows = inv_h[pruned_cols.reshape(-1), :].view(
num_parts, n_subsets, np_bk, num_cols
)
comp_all = torch.einsum(
"gsij,gsjk->gsik",
inv_pruned_inv,
inv_prune_rows,
)
active = torch.ones(M, num_cols, device=device, dtype=torch.bool)
for g in range(num_parts):
si_per_row = best_si[:, g]
comp_g = comp_all[g]
comp_rows = comp_g[si_per_row]
p_cols = pruned_cols[g, si_per_row]
comp_rows = comp_rows * active.unsqueeze(1)
weight_pruned = torch.gather(W, 1, p_cols)
delta = torch.bmm(weight_pruned.unsqueeze(1), comp_rows).squeeze(1)
W -= delta
W.scatter_(
1,
p_cols,
torch.zeros_like(weight_pruned),
)
active.scatter_(1, p_cols, False)
def _compensate_full_split(
self,
W,
block_cols,
prune_subsets,
best_si,
n_splits=2,
):
"""Full-column compensation with C recomputed.
Splits blocks into n_splits chunks by column position.
Each chunk uses C = inv(H[active, active] + damp*I)
where 'active' excludes columns from previous chunks.
"""
device = W.device
M, num_cols = W.shape
num_parts = self.num_scopes_per_row
bk = self.bk
num_prune = prune_subsets.shape[1]
n_subsets = prune_subsets.shape[0]
np_bk = num_prune * bk
chunk_g = (num_parts + n_splits - 1) // n_splits
active_mask = torch.ones(num_cols, dtype=torch.bool, device=device)
for split_idx in range(n_splits):
g_start = split_idx * chunk_g
g_end = min((split_idx + 1) * chunk_g, num_parts)
if g_start >= num_parts:
break
num_split_parts = g_end - g_start
if split_idx == 0:
active_cols = torch.arange(num_cols, device=device)
n_active = num_cols
inv_h = self.inv_hessian
abs_to_local = torch.arange(num_cols, device=device)
else:
active_cols = torch.where(active_mask)[0]
n_active = active_cols.shape[0]
hessian_active = self.hessian[
active_cols[:, None],
active_cols[None, :],
]
inv_h = self.compute_inverse(hessian_active, self.damp)
abs_to_local = torch.full(
(num_cols,),
-1,
dtype=torch.long,
device=device,
)
abs_to_local[active_cols] = torch.arange(
n_active, device=device
)
split_block_cols = block_cols[g_start:g_end]
if bk == 1:
pruned_abs = split_block_cols[:, prune_subsets]
else:
pruned_abs = split_block_cols[:, prune_subsets].view(
num_split_parts, n_subsets, np_bk
)
pruned_local = abs_to_local[pruned_abs]
inv_pruned = inv_h[
pruned_local[:, :, :, None],
pruned_local[:, :, None, :],
]
eye = 1e-8 * torch.eye(np_bk, device=device)
inv_pruned_inv = torch.inverse(inv_pruned + eye)
inv_prune_rows = inv_h[pruned_local.reshape(-1), :].view(
num_split_parts,
n_subsets,
np_bk,
n_active,
)
comp_split = torch.einsum(
"gsij,gsjk->gsik",
inv_pruned_inv,
inv_prune_rows,
)
split_active = torch.ones(
M,
n_active,
device=device,
dtype=torch.bool,
)
for g_local in range(num_split_parts):
g = g_start + g_local
si_per_row = best_si[:, g]
comp_rows = comp_split[g_local][si_per_row]
p_cols = pruned_abs[g_local, si_per_row]
comp_rows = comp_rows * split_active.unsqueeze(1)
weight_pruned = torch.gather(W, 1, p_cols)
delta = torch.bmm(
weight_pruned.unsqueeze(1),
comp_rows,
).squeeze(1)
if n_active == num_cols:
W -= delta
else:
ac_exp = active_cols.unsqueeze(0).expand(M, -1)
W.scatter_add_(1, ac_exp, -delta)
W.scatter_(
1,
p_cols,
torch.zeros_like(weight_pruned),
)
p_local = abs_to_local[p_cols]
split_active.scatter_(1, p_local, False)
if split_idx < n_splits - 1:
if bk == 1:
frozen = split_block_cols.reshape(-1)
else:
frozen = split_block_cols.view(-1)
active_mask[frozen] = False
def _interleaved(
self,
W,
block_cols,
prune_subsets,
keep_subsets,
n_splits=16,
block_size=2048,
):
"""Interleaved selection + compensation.
Unlike split compensation (which selects masks once
then compensates), this re-selects which columns to
prune at each split using the updated C. Uses a
single shared C (not per-row), so O(K^2) memory.
Key: _find_best_subsets is called with local-space
block_cols and the recomputed C, so mask selection
uses the Schur-updated inverse.
"""
device = W.device
M, num_cols = W.shape
num_parts = self.num_scopes_per_row
bk = self.bk
num_prune = prune_subsets.shape[1]
n_subsets = prune_subsets.shape[0]
np_bk = num_prune * bk
chunk_g = (num_parts + n_splits - 1) // n_splits
active_mask = torch.ones(num_cols, dtype=torch.bool, device=device)
for split_idx in range(n_splits):
g_start = split_idx * chunk_g
g_end = min((split_idx + 1) * chunk_g, num_parts)
if g_start >= num_parts:
break
num_split_parts = g_end - g_start
if split_idx == 0:
active_cols = torch.arange(num_cols, device=device)
n_active = num_cols
inv_h = self.inv_hessian
abs_to_local = torch.arange(num_cols, device=device)
else:
active_cols = torch.where(active_mask)[0]
n_active = active_cols.shape[0]
hessian_active = self.hessian[
active_cols[:, None],
active_cols[None, :],
]
inv_h = self.compute_inverse(hessian_active, self.damp)
abs_to_local = torch.full(
(num_cols,),
-1,
dtype=torch.long,
device=device,
)
abs_to_local[active_cols] = torch.arange(
n_active, device=device
)
split_block_cols = block_cols[g_start:g_end]
if bk == 1:
split_block_cols_local = abs_to_local[split_block_cols]
else:
split_block_cols_local = abs_to_local[split_block_cols]
weight_active = W[:, active_cols]
best_si_split = self._find_best_subsets(
weight_active,
split_block_cols_local,
prune_subsets,
keep_subsets,
block_size,
inv_h=inv_h,
)
if bk == 1:
pruned_local = split_block_cols_local[:, prune_subsets]
pruned_abs = split_block_cols[:, prune_subsets]
else:
pruned_local = abs_to_local[
split_block_cols[:, prune_subsets].view(
num_split_parts,
n_subsets,
np_bk,
)
]
pruned_abs = split_block_cols[:, prune_subsets].view(
num_split_parts,
n_subsets,
np_bk,
)
inv_pruned = inv_h[
pruned_local[:, :, :, None],
pruned_local[:, :, None, :],
]
eye = 1e-8 * torch.eye(np_bk, device=device)
inv_pruned_inv = torch.inverse(inv_pruned + eye)
inv_prune_rows = inv_h[pruned_local.reshape(-1), :].view(
num_split_parts,
n_subsets,
np_bk,
n_active,
)
comp_split = torch.einsum(
"gsij,gsjk->gsik",
inv_pruned_inv,
inv_prune_rows,
)
split_active = torch.ones(
M,
n_active,
device=device,
dtype=torch.bool,
)
for g_local in range(num_split_parts):
si_per_row = best_si_split[:, g_local]
comp_rows = comp_split[g_local][si_per_row]
p_cols = pruned_abs[g_local, si_per_row]
comp_rows = comp_rows * split_active.unsqueeze(1)
weight_pruned = torch.gather(W, 1, p_cols)
delta = torch.bmm(
weight_pruned.unsqueeze(1),
comp_rows,
).squeeze(1)
if n_active == num_cols:
W -= delta
else:
ac_exp = active_cols.unsqueeze(0).expand(M, -1)
W.scatter_add_(1, ac_exp, -delta)
W.scatter_(
1,
p_cols,
torch.zeros_like(weight_pruned),
)
p_local = abs_to_local[p_cols]
split_active.scatter_(1, p_local, False)
if split_idx < n_splits - 1:
if bk == 1:
frozen = split_block_cols.reshape(-1)
else:
frozen = split_block_cols.view(-1)
active_mask[frozen] = False
[docs]
@torch.no_grad()
def prune(
self,
nnz: int,
block_size: int = 2048,
compensate: str = "local",
n_splits: int = 1,
) -> None:
"""Prune to nnz blocks per scope.
Phase 1: enumerate all C(bs, num_prune) subsets per scope, pick
the best per (row, scope) using C = H^{-1} submatrices.
Phase 2 (compensation):
- 'local': within-scope only (fast, independent scopes)
- 'full': sequential compensation to ALL K columns via C[P, :]
(slower but ~44% better than SparseGPT)
- 'split': like 'full' but recomputes C between column splits.
Use n_splits to control granularity (2 = one C update
at the midpoint).
- 'interleaved': re-selects masks AND compensates at each split
using recomputed C. Single shared C (O(K²) memory).
Args:
nnz: Blocks to keep per scope.
block_size: Column chunk size for subset search.
compensate: 'local', 'full', 'split', or 'interleaved'.
n_splits: Number of column splits (for 'split'/'interleaved').
"""
gs = self.blocks_per_scope
num_prune = gs - nnz
if num_prune <= 0:
return
device = self.W.data.device
M = self.M
num_cols = self.num_cols
bk = self.bk
num_parts = self.num_scopes_per_row
W = self.W.data.clone().float().view(M, num_cols)
# Block column indices
if bk == 1:
block_cols = self.col_idx_flat.view(num_parts, gs)
else:
block_cols = self.col_idx.view(num_parts, gs, bk)
# Pruning subsets
prune_subsets = torch.tensor(
list(combinations(range(gs), num_prune)),
device=device,
dtype=torch.long,
)
keep_subsets = torch.tensor(
[
sorted(set(range(gs)) - set(s))
for s in combinations(range(gs), num_prune)
],
device=device,
dtype=torch.long,
)
if compensate == "interleaved":
self._interleaved(
W,
block_cols,
prune_subsets,
keep_subsets,
n_splits=n_splits,
block_size=block_size,
)
self.W.data.copy_(W.view_as(self.W.data))
return
# Phase 1: find best subset per (row, block)
best_si = self._find_best_subsets(
W,
block_cols,
prune_subsets,
keep_subsets,
block_size,
)
# Phase 2: apply compensation
if compensate == "split":
self._compensate_full_split(
W,
block_cols,
prune_subsets,
best_si,
n_splits=n_splits,
)
elif compensate == "full":
self._compensate_full(W, block_cols, prune_subsets, best_si)
else:
self._compensate_local(
W,
block_cols,
prune_subsets,
keep_subsets,
best_si,
block_size,
)
self.W.data.copy_(W.view_as(self.W.data))
# ── True OBS (per-row C with Schur updates) ──────────────────────
def _prescore_blocks_order(
self,
W,
col_map,
prune_subsets,
sub_to_cols,
np_cols,
):
"""Pre-score blocks by OBS cost.
Uses shared inv_hessian. Returns block indices sorted
by descending total cost (highest-cost blocks first,
so they are processed while C is most accurate).
"""
device = W.device
M = W.shape[0]
num_parts = col_map.shape[0]
n_subs = prune_subsets.shape[0]
inv_h = self.inv_hessian
eye_np = 1e-8 * torch.eye(np_cols, device=device)
scores = torch.zeros(num_parts, device=device)
for g in range(num_parts):
cols = col_map[g]
inv_block = inv_h[cols][:, cols]
weight_block = W[:, cols]
best_cost = torch.full((M,), float("inf"), device=device)
for si in range(n_subs):
co = sub_to_cols[si]
inv_pruned_inv = torch.inverse(inv_block[co][:, co] + eye_np)
weight_pruned = weight_block[:, co]
cost = (weight_pruned @ inv_pruned_inv * weight_pruned).sum(1)
better = cost < best_cost
best_cost[better] = cost[better]
scores[g] = best_cost.sum()
return torch.argsort(scores, descending=True)
# ── Row-coupled True OBS ─────────────────────────────────────────
def _prescore_coupled_blocks(
self,
weight_chunk,
local_row_map,
n_vr,
eye_bk,
):
"""Pre-score row-coupled blocks.
Uses shared inv_hessian. Returns block indices
sorted descending by OBS cost. Fully vectorized
over all num_parts blocks.
"""
gcm = self.block_col_map
num_parts, gs, bk = gcm.shape
device = weight_chunk.device
inv_h = self.inv_hessian
cols = gcm[:, 0, :]
num_cols = inv_h.shape[0]
ci = cols.unsqueeze(2).expand(-1, -1, bk)
cj = cols.unsqueeze(1).expand(-1, bk, -1)
flat_idx = ci * num_cols + cj
inv_pruned = inv_h.reshape(-1)[flat_idx.reshape(-1)].view(
num_parts, bk, bk
)
inv_pruned_inv = torch.linalg.inv( # pylint: disable=not-callable
inv_pruned.float() + eye_bk
)
block_scores = torch.empty(n_vr, num_parts, gs, device=device)
for b in range(gs):
lr = local_row_map[:, :, b]
lr_exp = lr.unsqueeze(2).expand(-1, -1, bk)
cols_exp = cols.unsqueeze(0).expand(n_vr, -1, -1)
weight_blk = weight_chunk[
lr_exp.reshape(-1),
cols_exp.reshape(-1),
].view(n_vr, num_parts, bk)
temp = torch.einsum(
"ngb,gbc->ngc",
weight_blk,
inv_pruned_inv,
)
block_scores[:, :, b] = (temp * weight_blk).sum(2)
scores = block_scores.min(dim=2).values.sum(dim=0)
return torch.argsort(scores, descending=True)
@torch.no_grad()
def _prune_true_obs_coupled(
self,
nnz,
ng, # pylint: disable=unused-argument
chunk_size,
order,
progress_fn,
):
"""True OBS for row-coupled blocks.
Scopes span different param rows. Blocks processed
sequentially; vectorized across view-rows using
flat indexing (no per-element Python loops).
"""
gs = self.blocks_per_scope
bk = self.bk
num_parts = self.num_scopes_per_row
num_cols = self.num_cols
M = self.M
device = self.inv_hessian.device
num_prune = gs - nnz
num_gr = self.num_scope_rows
gcm = self.block_col_map
gsp = self.block_sort_perm
eye_bk = 1e-4 * torch.eye(bk, device=device)
W = self.W.data.clone().float().view(M, num_cols)
n_chunks = (num_gr + chunk_size - 1) // chunk_size
for ci in range(n_chunks):
vr0 = ci * chunk_size
vr1 = min(vr0 + chunk_size, num_gr)
n_vr = vr1 - vr0
if progress_fn:
progress_fn(
f"chunk {ci + 1}/{n_chunks}" f" ({vr0}/{num_gr} view-rows)"
)
chunk_row_map = torch.empty(
n_vr,
num_parts,
gs,
bk,
device=device,
dtype=torch.long,
)
for vr_local in range(n_vr):
row_cg = self.full_row_idx[vr0 + vr_local].reshape(
self.total_blocks, bk
)
chunk_row_map[vr_local] = row_cg[gsp].view(num_parts, gs, bk)
unique_rows = chunk_row_map[:, :, :, 0].reshape(-1).unique()
b_param = unique_rows.shape[0]
p2l = torch.full(
(M,),
-1,
device=device,
dtype=torch.long,
)
p2l[unique_rows] = torch.arange(b_param, device=device)
local_row_map = p2l[chunk_row_map[:, :, :, 0]]
weight_chunk = W[unique_rows].clone()
inv_h = torch.empty(
b_param,
num_cols,
num_cols,
device=device,
dtype=torch.float16,
)
inv_h[:] = self.inv_hessian.half()
pruned_mask = torch.ones(b_param, num_cols, device=device)
inv_flat = inv_h.reshape(-1)
if order == "largest_first":
if progress_fn:
progress_fn("Pre-scoring blocks...")
block_order = self._prescore_coupled_blocks(
weight_chunk,
local_row_map,
n_vr,
eye_bk,
)
else:
block_order = torch.arange(num_parts, device=device)
arange_vr = torch.arange(n_vr, device=device)
for gi in range(num_parts):
g = block_order[gi]
cols = gcm[g, 0, :]
rows_g = local_row_map[:, g, :]
n_score = n_vr * gs
sc_rows = rows_g.reshape(n_score)
inv_pruned = inv_h[:, cols, :][:, :, cols][sc_rows]
inv_pruned_inv = (
torch.linalg.inv( # pylint: disable=not-callable
inv_pruned.float() + eye_bk
)
)
weight_blk = weight_chunk[sc_rows][:, cols]
temp = torch.bmm(
weight_blk.unsqueeze(1),
inv_pruned_inv,
).squeeze(1)
scores = (temp * weight_blk).sum(1).view(n_vr, gs)
_, prune_bi = scores.topk(
num_prune,
dim=1,
largest=False,
)
for pi in range(num_prune):
pb = prune_bi[:, pi]
flat_rows = rows_g[arange_vr, pb]
si = arange_vr * gs + pb
c_inv = inv_pruned_inv[si]
inv_col = inv_h[:, :, cols][flat_rows]
comp = torch.bmm(inv_col.float(), c_inv)
weight_pruned = weight_chunk[flat_rows][:, cols]
delta = torch.bmm(
comp,
weight_pruned.unsqueeze(2),
).squeeze(2)
weight_chunk.index_add_(0, flat_rows, -delta)
weight_chunk[
flat_rows.unsqueeze(1).expand(-1, bk),
cols.unsqueeze(0).expand(n_vr, -1),
] = 0.0
comp_h = comp.half()
for vr in range(n_vr):
r = flat_rows[vr]
torch.addmm(
inv_h[r],
comp_h[vr],
inv_col[vr].T,
beta=1,
alpha=-1,
out=inv_h[r],
)
pruned_mask[
flat_rows.unsqueeze(1).expand(-1, bk),
cols.unsqueeze(0).expand(n_vr, -1),
] = 0.0
weight_chunk *= pruned_mask
W[unique_rows] = weight_chunk
del inv_h, inv_flat
if progress_fn:
progress_fn("")
self.W.data.copy_(W.view_as(self.W.data))
[docs]
@torch.no_grad()
def prune_true_obs(
self,
nnz: int,
ng: int = 64,
chunk_size: int = 16,
order: str = "left_to_right",
scoring: str = "independent",
c_dtype=None,
progress_fn=None,
) -> None:
"""Per-row True OBS with Schur complement updates.
Each row maintains its own C = inv(H), updated via
Schur complement after pruning. Processes ``ng``
blocks simultaneously per batch.
Args:
nnz: Blocks to keep per scope.
ng: Number of scopes to process per batch.
chunk_size: Rows to process simultaneously.
order: ``"left_to_right"`` or
``"largest_first"``.
scoring: ``"joint"`` (enumerate subsets) or
``"independent"`` (per-element w^2/diag(C)
+ topk).
c_dtype: Dtype for per-row C matrices. Default
``None`` uses fp16 for tensor-core Schur
updates.
progress_fn: Optional ``callable(str)`` for
progress messages.
"""
gs = self.blocks_per_scope
bk = self.bk
num_parts = self.num_scopes_per_row
num_cols = self.num_cols
M = self.M
device = self.inv_hessian.device
num_prune = gs - nnz
if num_prune <= 0:
return
if self.row_coupled:
self._prune_true_obs_coupled(
nnz,
ng,
chunk_size,
order,
progress_fn,
)
return
epg = gs * bk if bk > 1 else gs
if bk == 1:
block_cols = self.col_idx_flat.view(num_parts, gs)
else:
block_cols = self.col_idx.view(num_parts, gs, bk)
col_map = block_cols.reshape(num_parts, epg)
prune_subsets = torch.tensor(
list(combinations(range(gs), num_prune)),
device=device,
dtype=torch.long,
)
n_subs = prune_subsets.shape[0]
if bk == 1:
sub_to_cols = prune_subsets
else:
sub_to_cols = (
prune_subsets.unsqueeze(-1) * bk
+ torch.arange(bk, device=device)
).view(n_subs, num_prune * bk)
np_cols = sub_to_cols.shape[1]
use_closed_form = np_cols == 2
if order == "largest_first":
if progress_fn:
progress_fn("Pre-scoring blocks for ordering...")
weight_flat = self.W.data.clone().float().view(M, num_cols)
block_order = self._prescore_blocks_order(
weight_flat,
col_map,
prune_subsets,
sub_to_cols,
np_cols,
)
del weight_flat
if progress_fn:
progress_fn("Pre-scoring done.")
else:
block_order = torch.arange(num_parts, device=device)
W = self.W.data.clone().float().view(M, num_cols)
n_chunks = (M + chunk_size - 1) // chunk_size
num_batches = (num_parts + ng - 1) // ng
for ci in range(n_chunks):
c0 = ci * chunk_size
c1 = min(c0 + chunk_size, M)
B = c1 - c0
if progress_fn and n_chunks > 4:
progress_fn(f"chunk {ci + 1}/{n_chunks}" f" ({c0}/{M} rows)")
weight_chunk = W[c0:c1]
c_dt = c_dtype or torch.float16
inv_h = (
self.inv_hessian.to(c_dt).unsqueeze(0).expand(B, -1, -1).clone()
)
pruned_mask = torch.ones(B, num_cols, device=device)
for blk in range(num_batches):
batch_gids = block_order[
blk * ng : min((blk + 1) * ng, num_parts)
]
n_g = batch_gids.shape[0]
batch_col_map = col_map[batch_gids]
base_cols = batch_col_map.reshape(-1)
ri = batch_col_map.unsqueeze(2).expand(n_g, epg, epg)
ci_idx = batch_col_map.unsqueeze(1).expand(n_g, epg, epg)
inv_diag = (
inv_h[
:,
ri.reshape(-1),
ci_idx.reshape(-1),
]
.view(B, n_g, epg, epg)
.float()
)
weight_all = weight_chunk[:, base_cols].view(B, n_g, epg)
if scoring == "independent":
inv_diag_vec = torch.diagonal(inv_diag, dim1=-2, dim2=-1)
if bk == 1:
elem_cost = weight_all**2 / (inv_diag_vec + 1e-8)
_, prune_idx = elem_cost.topk(
num_prune,
dim=-1,
largest=False,
)
else:
block_cost = (
(weight_all**2 / (inv_diag_vec + 1e-8))
.view(B, n_g, gs, bk)
.sum(-1)
)
_, blk_idx = block_cost.topk(
num_prune,
dim=-1,
largest=False,
)
prune_idx = (
blk_idx.unsqueeze(-1) * bk
+ torch.arange(bk, device=device)
).view(B, n_g, np_cols)
pruned_local = prune_idx
else:
best_cost = torch.full(
(B, n_g),
float("inf"),
device=device,
)
best_si = torch.zeros(
B,
n_g,
dtype=torch.long,
device=device,
)
for si in range(n_subs):
co = sub_to_cols[si]
weight_pruned = weight_all[:, :, co]
if use_closed_form:
a = inv_diag[:, :, co[0], co[0]]
b_ = inv_diag[:, :, co[0], co[1]]
d = inv_diag[:, :, co[1], co[1]]
det = a * d - b_ * b_ + 1e-8
w0 = weight_pruned[:, :, 0]
w1 = weight_pruned[:, :, 1]
cost = (
w0 * w0 * d - 2 * w0 * w1 * b_ + w1 * w1 * a
) / det
else:
inv_pruned = inv_diag[:, :, co][:, :, :, co]
eye_pp = 1e-8 * torch.eye(
np_cols,
device=device,
)
inv_pruned_inv = (
LA.inv( # pylint: disable=not-callable
(inv_pruned + eye_pp).reshape(
B * n_g,
np_cols,
np_cols,
)
)
)
weight_flat = weight_pruned.reshape(
B * n_g,
1,
np_cols,
)
cost = (
(
torch.bmm(
weight_flat,
inv_pruned_inv,
).squeeze(1)
* weight_pruned.reshape(
B * n_g,
np_cols,
)
)
.sum(1)
.view(B, n_g)
)
better = cost < best_cost
best_cost[better] = cost[better]
best_si[better] = si
pruned_local = sub_to_cols[best_si.view(-1)].view(
B, n_g, np_cols
)
g_exp = (
torch.arange(n_g, device=device)
.view(1, n_g, 1)
.expand(B, n_g, np_cols)
)
all_p = batch_col_map[g_exp, pruned_local].reshape(
B, n_g * np_cols
)
np_total = all_p.shape[1]
pc_exp = all_p.unsqueeze(1).expand(B, num_cols, np_total)
inv_col_prune = inv_h.gather(2, pc_exp)
eye_n = 1e-8 * torch.eye(np_total, device=device)
inv_pruned = inv_col_prune.gather(
1,
all_p.unsqueeze(2).expand(B, np_total, np_total),
).float()
inv_pruned_inv = LA.inv( # pylint: disable=not-callable
inv_pruned + eye_n
)
weight_pruned = weight_chunk.gather(1, all_p)
comp = torch.bmm(
inv_col_prune.float(),
inv_pruned_inv,
)
delta = torch.bmm(
comp,
weight_pruned.unsqueeze(2),
).squeeze(2)
weight_chunk -= delta
weight_chunk.scatter_(
1,
all_p,
torch.zeros(B, np_total, device=device),
)
lpp = LA.cholesky( # pylint: disable=not-callable
inv_pruned_inv + eye_n
)
schur_factor = torch.bmm(inv_col_prune, lpp.to(c_dt))
inv_h.baddbmm_(
schur_factor,
schur_factor.transpose(1, 2),
alpha=-1.0,
)
pruned_mask.scatter_(1, all_p, 0.0)
weight_chunk *= pruned_mask
del inv_h
if progress_fn and n_chunks > 4:
progress_fn("")
self.W.data.copy_(W.view_as(self.W.data))