frankaging
minor
29825db
raw
history blame contribute delete
11.3 kB
import os, json, random
import torch
import gradio as gr
import spaces
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from huggingface_hub import login, hf_hub_download
import pyreft
import pyvene as pv
from threading import Thread
from typing import Iterator
import torch.nn.functional as F
HF_TOKEN = os.environ.get("HF_TOKEN")
login(token=HF_TOKEN)
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 128 # smaller default to save memory
MAX_INPUT_TOKEN_LENGTH = 4096
css = """
#alert-message textarea {
background-color: #e8f4ff;
border: 1px solid #cce5ff;
color: #084298;
font-size: 1.1em;
padding: 12px;
border-radius: 4px;
font-weight: 500;
}
"""
def load_jsonl(jsonl_path):
jsonl_data = []
with open(jsonl_path, 'r') as f:
for line in f:
data = json.loads(line)
jsonl_data.append(data)
return jsonl_data
class Steer(pv.SourcelessIntervention):
"""Steer model via activation addition"""
def __init__(self, **kwargs):
super().__init__(**kwargs, keep_last_dim=True)
self.proj = torch.nn.Linear(
self.embed_dim, kwargs["latent_dim"], bias=False)
self.subspace_generator = kwargs["subspace_generator"]
def forward(self, base, source=None, subspaces=None):
if subspaces == None:
return base
if subspaces["subspace_gen_inputs"] is not None:
# we call our subspace generator to generate the subspace on-the-fly.
raw_steering_vec = self.subspace_generator(
subspaces["subspace_gen_inputs"]["input_ids"],
subspaces["subspace_gen_inputs"]["attention_mask"],
)[0]
steering_vec = torch.tensor(subspaces["mag"]) * \
raw_steering_vec.unsqueeze(dim=0)
return base + steering_vec
else:
steering_vec = torch.tensor(subspaces["mag"]) * \
self.proj.weight[subspaces["idx"]].unsqueeze(dim=0)
return base + steering_vec
class RegressionWrapper(torch.nn.Module):
def __init__(self, base_model, hidden_size, output_dim):
super().__init__()
self.base_model = base_model
self.regression_head = torch.nn.Linear(hidden_size, output_dim)
def forward(self, input_ids, attention_mask):
outputs = self.base_model.model(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
return_dict=True
)
last_hiddens = outputs.hidden_states[-1]
last_token_representations = last_hiddens[:, -1]
preds = self.regression_head(last_token_representations)
preds = F.normalize(preds, p=2, dim=-1)
return preds
# Check GPU
if not torch.cuda.is_available():
print("Warning: Running on CPU, may be slow.")
# Load model & dictionary
model_id = "google/gemma-2-2b-it"
pv_model = None
tokenizer = None
concept_list = []
concept_id_map = {}
if torch.cuda.is_available():
model = AutoModelForCausalLM.from_pretrained(
model_id, device_map="cuda", torch_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained(model_id)
# Download dictionary
weight_path = hf_hub_download(repo_id="pyvene/gemma-reft-2b-it-res", filename="l20/weight.pt")
meta_path = hf_hub_download(repo_id="pyvene/gemma-reft-2b-it-res", filename="l20/metadata.jsonl")
params = torch.load(weight_path).cuda()
md = load_jsonl(meta_path)
concept_list = [item["concept"] for item in md]
concept_id_map = {}
# the reason to reindex is because there is one concept that is missing.
concept_reindex = 0
for item in md:
concept_id_map[item["concept"]] = concept_reindex
concept_reindex += 1
# load subspace generator.
base_tokenizer = AutoTokenizer.from_pretrained(
f"google/gemma-2-2b", model_max_length=512)
config = AutoConfig.from_pretrained("google/gemma-2-2b")
base_model = AutoModelForCausalLM.from_config(config)
subspace_generator_weight_path = hf_hub_download(repo_id="pyvene/gemma-reft-2b-it-res-generator", filename="l20/weight.pt")
hidden_size = base_model.config.hidden_size
subspace_generator = RegressionWrapper(
base_model, hidden_size, hidden_size).bfloat16().to("cuda")
subspace_generator.load_state_dict(torch.load(subspace_generator_weight_path))
print(f"Loading model from saved file {subspace_generator_weight_path}")
_ = subspace_generator.eval()
steer = Steer(
embed_dim=params.shape[0], latent_dim=params.shape[1],
subspace_generator=subspace_generator)
steer.proj.weight.data = params.float()
pv_model = pv.IntervenableModel({
"component": f"model.layers[20].output",
"intervention": steer}, model=model)
terminators = [tokenizer.eos_token_id] if tokenizer else []
@spaces.GPU
def generate(
message: str,
chat_history: list[tuple[str, str]],
subspaces_list: list[dict],
max_new_tokens: int=DEFAULT_MAX_NEW_TOKENS,
) -> Iterator[str]:
# limit to last 4 turns
start_idx = max(0, len(chat_history) - 4)
recent_history = chat_history[start_idx:]
# build list of messages
messages = []
for rh in recent_history:
messages.append({"role": rh["role"], "content": rh["content"]})
messages.append({"role": "user", "content": message})
input_ids = torch.tensor([tokenizer.apply_chat_template(
messages, tokenize=True, add_generation_prompt=True)]).cuda()
# trim if needed
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
yield "[Truncated prior text]\n"
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
print(subspaces_list)
generate_kwargs = {
"base": {"input_ids": input_ids},
"unit_locations": None,
"max_new_tokens": max_new_tokens,
"intervene_on_prompt": True,
"subspaces": [
{
"idx": int(subspaces_list[0]["idx"]),
"mag": int(subspaces_list[0]["internal_mag"]),
"subspace_gen_inputs": base_tokenizer(subspaces_list[0]["subspace_gen_text"], return_tensors="pt").to("cuda") \
if subspaces_list[0]["subspace_gen_text"] is not None else None
}
] if subspaces_list else None,
"streamer": streamer,
"do_sample": True
}
t = Thread(target=pv_model.generate, kwargs=generate_kwargs)
t.start()
partial_text = []
for token_str in streamer:
partial_text.append(token_str)
yield "".join(partial_text)
def filter_concepts(search_text: str):
if not search_text.strip():
return concept_list[:500]
filtered = [c for c in concept_list if search_text.lower() in c.lower()]
return filtered[:500]
def add_concept_to_list(selected_concept, user_slider_val, current_list):
if not selected_concept:
return current_list
selected_concept_text = None
if selected_concept.startswith("[New] "):
selected_concept_text = selected_concept[6:]
idx = 0
else:
idx = concept_id_map[selected_concept]
internal_mag = user_slider_val * 50
new_entry = {
"text": selected_concept,
"idx": idx,
"display_mag": user_slider_val,
"internal_mag": internal_mag,
"subspace_gen_text": selected_concept_text
}
# Add to the beginning of the list
current_list = [new_entry]
return current_list
def update_dropdown_choices(search_text):
filtered = filter_concepts(search_text)
if not filtered or len(filtered) == 0:
return gr.update(choices=[f"[New] {search_text}"], value=f"[New] {search_text}", interactive=True), gr.Textbox(
label="No matching existing topics were found!",
value="Good news! Based on the topic you provided, we will automatically generate a steering vector. Try it out by starting a chat!",
lines=3,
interactive=False,
visible=True,
elem_id="alert-message"
)
# Automatically select the first matching concept
return gr.update(
choices=filtered,
value=filtered[0], # Select the first match
interactive=True, visible=True
), gr.Textbox(visible=False)
with gr.Blocks(css=css, fill_height=True) as demo:
# Remove default subspaces
selected_subspaces = gr.State([])
with gr.Row(min_height=300):
# Left side: bigger chat area
with gr.Column(scale=7):
chat_interface = gr.ChatInterface(
fn=generate,
title="Chat with a Topic Steering Model",
description="""Choose a topic you want the model to discuss on the right →\n\nWe intervene on Gemma-2-2B-it by adding steering vectors to the residual stream at layer 20. You can also try our **conditioned steering** model [here](https://huggingface.co/spaces/pyvene/AxBench-ReFT-cr1-16K).""",
type="messages",
additional_inputs=[selected_subspaces],
)
# Right side: concept management
with gr.Column(scale=3):
gr.Markdown("# Steer model responses")
gr.Markdown("Search and then select a topic you want the model to discuss. The closest match will be automatically selected. If there is no match, a finetuned Gemma-2-2B model auto-steers for you!")
# Concept Search and Selection
with gr.Group():
search_box = gr.Textbox(
label="Search topics to steer",
placeholder="Try: 'time travel'",
lines=2,
)
msg = gr.TextArea(visible=False)
concept_dropdown = gr.Dropdown(
label="Select a topic to steer the model (Click to see more!)",
interactive=True,
allow_custom_value=False,
)
concept_magnitude = gr.Slider(
label="Steering intensity",
minimum=-5,
maximum=5,
step=0.1,
value=3,
)
# Wire up events
# When search box changes, update dropdown AND trigger concept selection
search_box.input(
update_dropdown_choices,
[search_box],
[concept_dropdown, msg]
).then( # Chain the events to automatically add the concept
add_concept_to_list,
[concept_dropdown, concept_magnitude, selected_subspaces],
[selected_subspaces]
)
concept_dropdown.select(
add_concept_to_list,
[concept_dropdown, concept_magnitude, selected_subspaces],
[selected_subspaces]
)
concept_dropdown.change(
add_concept_to_list,
[concept_dropdown, concept_magnitude, selected_subspaces],
[selected_subspaces]
)
concept_magnitude.input(
add_concept_to_list,
[concept_dropdown, concept_magnitude, selected_subspaces],
[selected_subspaces]
)
demo.launch(share=True)