Spaces:
Running
on
Zero
Running
on
Zero
"""Provides a class for mapping transformer hidden states to logits (and vice versa). | |
Example: | |
from standalone_logit_lens import LogitLens, ReverseLogitLens | |
model = AutoModelForCausalLM.from_pretrained(model_name).to(device).to(dtype) | |
lens = LogitLens.from_model(model).to(device).to(dtype) | |
reverse_lens = ReverseLogitLens.from_model(model).to(device).to(dtype) | |
hidden_state = ... | |
result = lens(hidden_state, layer_index) # layer_index is not really used, you can pass whatever | |
""" | |
import abc | |
import logging | |
import copy | |
from typing import Union | |
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
import transformers | |
from transformers import models | |
from transformers import PreTrainedModel | |
Model = Union[PreTrainedModel] | |
Norm = Union[ | |
nn.LayerNorm, | |
models.llama.modeling_llama.LlamaRMSNorm, | |
models.gemma.modeling_gemma.GemmaRMSNorm, | |
models.gemma2.modeling_gemma2.Gemma2RMSNorm, | |
nn.Module, | |
] | |
def get_unembedding_matrix(model: Model) -> nn.Linear: | |
"""The final linear tranformation from the model hidden state to the output.""" | |
if isinstance(model, PreTrainedModel): | |
unembed = model.get_output_embeddings() | |
if not isinstance(unembed, nn.Linear): | |
raise ValueError("We currently only support linear unemebdings") | |
return unembed | |
else: | |
raise ValueError(f"Model class {type(model)} not recognized!") | |
def get_embedding_matrix(model: nn.Module) -> nn.Embedding: | |
"""The initial embedding matrix from the input tokens to the model hidden state.""" | |
if isinstance(model, PreTrainedModel): | |
embed = model.get_input_embeddings() | |
if not isinstance(embed, nn.Embedding): | |
raise ValueError("We currently only support embedding matrices") | |
return embed | |
else: | |
raise ValueError(f"Model class {type(model)} not recognized!") | |
def get_final_norm(model: Model) -> Norm: | |
"""Get the final norm from a model. | |
This isn't standardized across models, so this will need to be updated as | |
we add new models. | |
""" | |
if not hasattr(model, "base_model"): | |
raise ValueError("Model does not have a `base_model` attribute.") | |
base_model = model.base_model | |
if isinstance(base_model, models.opt.modeling_opt.OPTModel): | |
final_layer_norm = base_model.decoder.final_layer_norm | |
elif isinstance(base_model, models.gpt_neox.modeling_gpt_neox.GPTNeoXModel): | |
final_layer_norm = base_model.final_layer_norm | |
elif isinstance( | |
base_model, | |
( | |
models.bloom.modeling_bloom.BloomModel, | |
models.gpt2.modeling_gpt2.GPT2Model, | |
models.gpt_neo.modeling_gpt_neo.GPTNeoModel, | |
models.gptj.modeling_gptj.GPTJModel, | |
), | |
): | |
final_layer_norm = base_model.ln_f | |
elif isinstance(base_model, models.llama.modeling_llama.LlamaModel): | |
final_layer_norm = base_model.norm | |
elif isinstance(base_model, models.mistral.modeling_mistral.MistralModel): | |
final_layer_norm = base_model.norm | |
elif isinstance(base_model, models.t5.modeling_t5.T5ForConditionalGeneration): | |
# For T5, use the LayerNorm from the last decoder block, before the feed-forward layer. | |
final_layer_norm = base_model.decoder.block[-1].layer[1].layer_norm | |
else: | |
raise NotImplementedError(f"Unknown model type {type(base_model)}") | |
if final_layer_norm is None: | |
raise ValueError("Model does not have a final layer norm.") | |
assert isinstance(final_layer_norm, Norm.__args__) # type: ignore | |
return final_layer_norm | |
class Unembed(nn.Module): | |
"""Module that maps transformer hidden states to logits (and vice versa).""" | |
final_norm: Norm | |
unembedding: nn.Linear | |
def __init__( | |
self, | |
model: Model, | |
): | |
"""Initialize unmebed. | |
Args: | |
model: A HuggingFace model from which to extract the unembedding matrix. | |
""" | |
super().__init__() | |
final_norm = get_final_norm(model) | |
unembedding_matrix = get_unembedding_matrix(model) | |
self.final_norm = copy.deepcopy(final_norm) | |
self.unembedding = copy.deepcopy(unembedding_matrix) | |
# In general we don't want to finetune the unembed operation. | |
self.requires_grad_(False) | |
def forward(self, h: torch.Tensor) -> torch.Tensor: | |
"""Convert hidden states into logits.""" | |
return self.unembedding(self.final_norm(h)) | |
class Reembed(nn.Module): | |
"""Module that maps transformer hidden states to logits (and vice versa).""" | |
embedding: torch.Tensor | |
def __init__( | |
self, | |
model: Model, | |
distance_metric: str = "logits", | |
): | |
"""Initialize unmebed. | |
Args: | |
model: A HuggingFace model from which to extract the unembedding matrix. | |
""" | |
super().__init__() | |
embedding_matrix = get_embedding_matrix(model) | |
self.embedding = copy.deepcopy(embedding_matrix.weight.data) | |
self.distance_metric = distance_metric | |
# In general we don't want to finetune the unembed operation. | |
self.requires_grad_(False) | |
def forward(self, hidden_state: torch.Tensor) -> torch.Tensor: | |
"""Convert hidden states into logits.""" | |
if self.distance_metric == 'logits': | |
logits = torch.matmul(hidden_state, self.embedding.T).squeeze(0) | |
elif self.distance_metric == 'cosine': | |
# Normalize E and h | |
E_normalized = F.normalize(self.embedding, p=2, dim=-1) | |
h_normalized = F.normalize(hidden_state, p=2, dim=-1) | |
# Compute cosine similarity | |
logits = torch.matmul(h_normalized, E_normalized.T).squeeze(0) | |
elif self.distance_metric == 'euclidean': | |
# Compute Euclidean distance | |
distances = torch.cdist(hidden_state, self.embedding, p=2).squeeze(0) | |
# Convert distances to logits (negative distance for logits-like values) | |
logits = -distances | |
else: # Compute regular dot-product as a similarity measure | |
logits = torch.matmul(hidden_state, self.embedding.T).squeeze(0) | |
return logits | |
class ReverseLens(abc.ABC, nn.Module): | |
"""Abstract base class for all Lens.""" | |
reembed: Reembed | |
def __init__(self, reembed: Reembed): | |
"""Create a Lens. | |
Args: | |
unembed: The unembed operation to use. | |
""" | |
super().__init__() | |
self.reembed = reembed | |
def forward(self, h: torch.Tensor, idx: int) -> torch.Tensor: | |
"""Decode hidden states into logits.""" | |
... | |
class ReverseLogitLens(ReverseLens): | |
"""Reembeds the residual stream into logits.""" | |
reembed: Reembed | |
def __init__( | |
self, | |
reembed: Reembed, | |
): | |
"""Create a Reverse Logit Lens. | |
Args: | |
reembed: The reembed operation to use. | |
""" | |
super().__init__(reembed) | |
def from_model( | |
cls, | |
model: PreTrainedModel, | |
) -> "ReverseLogitLens": | |
"""Create a ReverseLogitLens from a pretrained model. | |
Args: | |
model: A pretrained model from the transformers library you wish to inspect. | |
""" | |
reembed = Reembed(model) | |
return cls(reembed) | |
def forward(self, h: torch.Tensor, idx: int) -> torch.Tensor: | |
"""Decode a hidden state into logits. | |
Args: | |
h: The hidden state to decode. | |
idx: the layer of the transformer these hidden states come from. | |
""" | |
del idx | |
return self.reembed.forward(h) | |
class Lens(abc.ABC, nn.Module): | |
"""Abstract base class for all Lens.""" | |
unembed: Unembed | |
def __init__(self, unembed: Unembed): | |
"""Create a Lens. | |
Args: | |
unembed: The unembed operation to use. | |
""" | |
super().__init__() | |
self.unembed = unembed | |
def forward(self, h: torch.Tensor, idx: int) -> torch.Tensor: | |
"""Decode hidden states into logits.""" | |
... | |
class LogitLens(Lens): | |
"""Unembeds the residual stream into logits.""" | |
unembed: Unembed | |
def __init__( | |
self, | |
unembed: Unembed, | |
): | |
"""Create a Logit Lens. | |
Args: | |
unembed: The unembed operation to use. | |
""" | |
super().__init__(unembed) | |
def from_model( | |
cls, | |
model: PreTrainedModel, | |
) -> "LogitLens": | |
"""Create a LogitLens from a pretrained model. | |
Args: | |
model: A pretrained model from the transformers library you wish to inspect. | |
""" | |
unembed = Unembed(model) | |
return cls(unembed) | |
def forward(self, h: torch.Tensor, idx: int) -> torch.Tensor: | |
"""Decode a hidden state into logits. | |
Args: | |
h: The hidden state to decode. | |
idx: the layer of the transformer these hidden states come from. | |
""" | |
del idx | |
return self.unembed.forward(h) | |