"""
Library for extracting interesting quantites from autograd, see README.md
Not thread-safe because of module-level variables
Author: cybertronai @ https://github.com/cybertronai/autograd-hacks.
Notation:
o: number of output classes (exact Hessian), number of Hessian samples (sampled Hessian)
n: batch-size
do: output dimension (output channels for convolution)
di: input dimension (input channels for convolution)
Hi: per-example Hessian of matmul, shaped as matrix of [dim, dim], indices have been row-vectorized
Hi_bias: per-example Hessian of bias
Oh, Ow: output height, output width (convolution)
Kh, Kw: kernel height, kernel width (convolution)
Jb: batch output Jacobian of matmul, output sensitivity for example,class pair, [o, n, ....]
Jb_bias: as above, but for bias
A, activations: inputs into current layer
B, backprops: backprop values (aka Lop aka Jacobian-vector product) observed at current layer
"""
from typing import List
import torch
import torch.nn as nn
import torch.nn.functional as F
_supported_layers = ['Linear', 'Conv2d'] # Supported layer class types
_hooks_disabled: bool = False # work-around for https://github.com/pytorch/pytorch/issues/25723
_enforce_fresh_backprop: bool = False # global switch to catch double backprop errors on Hessian computation
[docs]def add_hooks(model: nn.Module) -> None:
"""
Adds hooks to model to save activations and backprop values.
The hooks will
1. save activations into param.activations during forward pass
2. append backprops to params.backprops_list during backward pass.
Call "remove_hooks(model)" to disable this.
Args:
model:
"""
global _hooks_disabled
_hooks_disabled = False
handles = []
for layer in model.modules():
if _layer_type(layer) in _supported_layers:
handles.append(layer.register_forward_hook(_capture_activations))
handles.append(layer.register_backward_hook(_capture_backprops))
model.__dict__.setdefault('autograd_hacks_hooks', []).extend(handles)
[docs]def remove_hooks(model: nn.Module) -> None:
"""
Remove hooks added by add_hooks(model)
"""
# assert model == 0, "not working, remove this after fix to https://github.com/pytorch/pytorch/issues/25723"
if not hasattr(model, 'autograd_hacks_hooks'):
print("Warning, asked to remove hooks, but no hooks found")
else:
for handle in model.autograd_hacks_hooks:
handle.remove()
del model.autograd_hacks_hooks
[docs]def disable_hooks() -> None:
"""
Globally disable all hooks installed by this library.
"""
global _hooks_disabled
_hooks_disabled = True
[docs]def enable_hooks() -> None:
"""the opposite of disable_hooks()"""
global _hooks_disabled
_hooks_disabled = False
[docs]def is_supported(layer: nn.Module) -> bool:
"""Check if this layer is supported"""
return _layer_type(layer) in _supported_layers
def _layer_type(layer: nn.Module) -> str:
return layer.__class__.__name__
def _capture_activations(layer: nn.Module, input: List[torch.Tensor], output: torch.Tensor):
"""Save activations into layer.activations in forward pass"""
if _hooks_disabled:
return
assert _layer_type(layer) in _supported_layers, "Hook installed on unsupported layer, this shouldn't happen"
setattr(layer, "activations", input[0].detach())
def _capture_backprops(layer: nn.Module, _input, output):
"""Append backprop to layer.backprops_list in backward pass."""
global _enforce_fresh_backprop
if _hooks_disabled:
return
if _enforce_fresh_backprop:
assert not hasattr(layer, 'backprops_list'), "Seeing result of previous backprop, use clear_backprops(model) to clear"
_enforce_fresh_backprop = False
if not hasattr(layer, 'backprops_list'):
setattr(layer, 'backprops_list', [])
layer.backprops_list.append(output[0].detach())
[docs]def clear_backprops(model: nn.Module) -> None:
"""Delete layer.backprops_list in every layer."""
for layer in model.modules():
if hasattr(layer, 'backprops_list'):
del layer.backprops_list
[docs]def compute_grad1(model: nn.Module, loss_type: str = 'mean') -> None:
"""
Compute per-example gradients and save them under 'param.grad1'. Must be called after loss.backprop()
Args:
model:
loss_type: either "mean" or "sum" depending whether backpropped loss was averaged or summed over batch
"""
assert loss_type in ('sum', 'mean')
for layer in model.modules():
layer_type = _layer_type(layer)
if layer_type not in _supported_layers:
continue
assert hasattr(layer, 'activations'), "No activations detected, run forward after add_hooks(model)"
assert hasattr(layer, 'backprops_list'), "No backprops detected, run backward after add_hooks(model)"
assert len(layer.backprops_list) == 1, "Multiple backprops detected, make sure to call clear_backprops(model)"
A = layer.activations
n = A.shape[0]
if loss_type == 'mean':
B = layer.backprops_list[0] * n
else: # loss_type == 'sum':
B = layer.backprops_list[0]
if layer_type == 'Linear':
setattr(layer.weight, 'grad1', torch.einsum('ni,nj->nij', B, A))
if layer.bias is not None:
setattr(layer.bias, 'grad1', B)
elif layer_type == 'Conv2d':
A = torch.nn.functional.unfold(A, layer.kernel_size)
B = B.reshape(n, -1, A.shape[-1])
grad1 = torch.einsum('ijk,ilk->ijl', B, A)
shape = [n] + list(layer.weight.shape)
setattr(layer.weight, 'grad1', grad1.reshape(shape))
if layer.bias is not None:
setattr(layer.bias, 'grad1', torch.sum(B, dim=2))
[docs]def compute_hess(model: nn.Module,) -> None:
"""Save Hessian under param.hess for each param in the model"""
for layer in model.modules():
layer_type = _layer_type(layer)
if layer_type not in _supported_layers:
continue
assert hasattr(layer, 'activations'), "No activations detected, run forward after add_hooks(model)"
assert hasattr(layer, 'backprops_list'), "No backprops detected, run backward after add_hooks(model)"
if layer_type == 'Linear':
A = layer.activations
B = torch.stack(layer.backprops_list)
n = A.shape[0]
o = B.shape[0]
A = torch.stack([A] * o)
Jb = torch.einsum("oni,onj->onij", B, A).reshape(n*o, -1)
H = torch.einsum('ni,nj->ij', Jb, Jb) / n
setattr(layer.weight, 'hess', H)
if layer.bias is not None:
setattr(layer.bias, 'hess', torch.einsum('oni,onj->ij', B, B)/n)
elif layer_type == 'Conv2d':
Kh, Kw = layer.kernel_size
di, do = layer.in_channels, layer.out_channels
A = layer.activations.detach()
A = torch.nn.functional.unfold(A, (Kh, Kw)) # n, di * Kh * Kw, Oh * Ow
n = A.shape[0]
B = torch.stack([Bt.reshape(n, do, -1) for Bt in layer.backprops_list]) # o, n, do, Oh*Ow
o = B.shape[0]
A = torch.stack([A] * o) # o, n, di * Kh * Kw, Oh*Ow
Jb = torch.einsum('onij,onkj->onik', B, A) # o, n, do, di * Kh * Kw
Hi = torch.einsum('onij,onkl->nijkl', Jb, Jb) # n, do, di*Kh*Kw, do, di*Kh*Kw
Jb_bias = torch.einsum('onij->oni', B)
Hi_bias = torch.einsum('oni,onj->nij', Jb_bias, Jb_bias)
setattr(layer.weight, 'hess', Hi.mean(dim=0))
if layer.bias is not None:
setattr(layer.bias, 'hess', Hi_bias.mean(dim=0))
[docs]def backprop_hess(output: torch.Tensor, hess_type: str) -> None:
"""
Call backprop 1 or more times to get values needed for Hessian computation.
Args:
output: prediction of neural network (ie, input of nn.CrossEntropyLoss())
hess_type: type of Hessian propagation, "CrossEntropy" results in exact Hessian for CrossEntropy
Returns:
"""
assert hess_type in ('LeastSquares', 'CrossEntropy')
global _enforce_fresh_backprop
n, o = output.shape
_enforce_fresh_backprop = True
if hess_type == 'CrossEntropy':
batch = F.softmax(output, dim=1)
mask = torch.eye(o).expand(n, o, o)
diag_part = batch.unsqueeze(2).expand(n, o, o) * mask
outer_prod_part = torch.einsum('ij,ik->ijk', batch, batch)
hess = diag_part - outer_prod_part
assert hess.shape == (n, o, o)
for i in range(n):
hess[i, :, :] = symsqrt(hess[i, :, :])
hess = hess.transpose(0, 1)
elif hess_type == 'LeastSquares':
hess = []
assert len(output.shape) == 2
batch_size, output_size = output.shape
id_mat = torch.eye(output_size)
for out_idx in range(output_size):
hess.append(torch.stack([id_mat[out_idx]] * batch_size))
for o in range(o):
output.backward(hess[o], retain_graph=True)
[docs]def symsqrt(a, cond=None, return_rank=False, dtype=torch.float32):
"""Symmetric square root of a positive semi-definite matrix.
See https://github.com/pytorch/pytorch/issues/25481"""
s, u = torch.symeig(a, eigenvectors=True)
cond_dict = {torch.float32: 1e3 * 1.1920929e-07, torch.float64: 1E6 * 2.220446049250313e-16}
if cond in [None, -1]:
cond = cond_dict[dtype]
above_cutoff = (abs(s) > cond * torch.max(abs(s)))
psigma_diag = torch.sqrt(s[above_cutoff])
u = u[:, above_cutoff]
B = u @ torch.diag(psigma_diag) @ u.t()
if return_rank:
return B, len(psigma_diag)
else:
return B