gpt-99 commited on
Commit
46e1acd
·
verified ·
1 Parent(s): 0c97850

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +138 -0
app.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ from einops import einsum
5
+ from tqdm import tqdm
6
+
7
+ device = "cuda" if torch.cuda.is_available() else "cpu"
8
+ model_name = 'microsoft/Phi-3-mini-4k-instruct'
9
+
10
+ model = AutoModelForCausalLM.from_pretrained(
11
+ model_name,
12
+ device_map=device,
13
+ torch_dtype="auto",
14
+ trust_remote_code=True,
15
+ )
16
+
17
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
18
+
19
+ def tokenize_instructions(tokenizer, instructions):
20
+ return tokenizer.apply_chat_template(
21
+ instructions,
22
+ padding=True,
23
+ truncation=False,
24
+ return_tensors="pt",
25
+ return_dict=True,
26
+ add_generation_prompt=True,
27
+ ).input_ids
28
+
29
+ def find_steering_vecs(model, base_toks, target_toks, batch_size=16):
30
+ device = model.device
31
+ num_its = len(range(0, base_toks.shape[0], batch_size))
32
+ steering_vecs = {}
33
+ for i in tqdm(range(0, base_toks.shape[0], batch_size)):
34
+ base_out = model(base_toks[i:i+batch_size].to(device), output_hidden_states=True).hidden_states
35
+ target_out = model(target_toks[i:i+batch_size].to(device), output_hidden_states=True).hidden_states
36
+ for layer in range(len(base_out)):
37
+ if i == 0:
38
+ steering_vecs[layer] = torch.mean(target_out[layer][:,-1,:].detach().cpu() - base_out[layer][:,-1,:].detach().cpu(), dim=0)/num_its
39
+ else:
40
+ steering_vecs[layer] += torch.mean(target_out[layer][:,-1,:].detach().cpu() - base_out[layer][:,-1,:].detach().cpu(), dim=0)/num_its
41
+ return steering_vecs
42
+
43
+ def do_steering(model, test_toks, steering_vec, scale=1, normalise=True, layer=None, proj=True, batch_size=16):
44
+ def modify_activation():
45
+ def hook(model, input):
46
+ if normalise:
47
+ sv = steering_vec / steering_vec.norm()
48
+ else:
49
+ sv = steering_vec
50
+ if proj:
51
+ sv = einsum(input[0], sv.view(-1,1), 'b l h, h s -> b l s') * sv
52
+ input[0][:,:,:] = input[0][:,:,:] - scale * sv
53
+ return hook
54
+
55
+ handles = []
56
+ if steering_vec is not None:
57
+ for i in range(len(model.model.layers)):
58
+ if layer is None or i == layer:
59
+ handles.append(model.model.layers[i].register_forward_pre_hook(modify_activation()))
60
+
61
+ outs_all = []
62
+ for i in tqdm(range(0, test_toks.shape[0], batch_size)):
63
+ outs = model.generate(test_toks[i:i+batch_size], num_beams=4, do_sample=True, max_new_tokens=60)
64
+ outs_all.append(outs)
65
+ outs_all = torch.cat(outs_all, dim=0)
66
+
67
+ for handle in handles:
68
+ handle.remove()
69
+
70
+ return outs_all
71
+
72
+ def create_steering_vector(towards, away):
73
+ towards_data = [[{"role": "user", "content": text.strip()}] for text in towards.split(',')]
74
+ away_data = [[{"role": "user", "content": text.strip()}] for text in away.split(',')]
75
+
76
+ towards_toks = tokenize_instructions(tokenizer, towards_data)
77
+ away_toks = tokenize_instructions(tokenizer, away_data)
78
+
79
+ steering_vecs = find_steering_vecs(model, away_toks, towards_toks)
80
+ return steering_vecs
81
+
82
+ def chat(message, history, steering_vec, layer):
83
+ history_formatted = [{"role": "user" if i % 2 == 0 else "assistant", "content": msg} for i, msg in enumerate(history)]
84
+ history_formatted.append({"role": "user", "content": message})
85
+
86
+ input_ids = tokenize_instructions(tokenizer, [history_formatted])
87
+
88
+ generations_baseline = do_steering(model, input_ids.to(device), None)
89
+ for j in range(generations_baseline.shape[0]):
90
+ response_baseline = f"BASELINE: {tokenizer.decode(generations_baseline[j], skip_special_tokens=True, layer=layer)}"
91
+
92
+ if steering_vec is not None:
93
+ generation_intervene = do_steering(model, input_ids.to(device), steering_vec[layer].to(device), scale=1)
94
+ for j in range(generation_intervene.shape[0]):
95
+ response_intervention = f"INTERVENTION: {tokenizer.decode(generation_intervene[j], skip_special_tokens=True)}"
96
+
97
+ response = response_baseline + "\n\n" + response_intervention
98
+
99
+ return [(message, response)]
100
+
101
+ def launch_app():
102
+ with gr.Blocks() as demo:
103
+ steering_vec = gr.State(None)
104
+ layer = gr.State(None)
105
+
106
+ away_default = ['hate','i hate this', 'hating the', 'hater', 'hating', 'hated in']
107
+
108
+ towards_default = ['love','i love this', 'loving the', 'lover', 'loving', 'loved in']
109
+
110
+ with gr.Row():
111
+ towards = gr.Textbox(label="Towards (comma-separated)", value= ", ".join(sentence.replace(",", "") for sentence in towards_default))
112
+ away = gr.Textbox(label="Away from (comma-separated)", value= ", ".join(sentence.replace(",", "") for sentence in away_default))
113
+
114
+ with gr.Row():
115
+ create_vector = gr.Button("Create Steering Vector")
116
+ layer_slider = gr.Slider(minimum=0, maximum=len(model.model.layers)-1, step=1, label="Layer", value=0)
117
+
118
+ def create_vector_and_set_layer(towards, away, layer_value):
119
+ vectors = create_steering_vector(towards, away)
120
+ layer.value = int(layer_value)
121
+ steering_vec.value = vectors
122
+ return f"Steering vector created for layer {layer_value}"
123
+ create_vector.click(create_vector_and_set_layer, [towards, away, layer_slider], gr.Textbox())
124
+
125
+ chatbot = gr.Chatbot()
126
+ msg = gr.Textbox()
127
+
128
+ msg.submit(chat, [msg, chatbot, steering_vec, layer], chatbot)
129
+
130
+ demo.launch()
131
+
132
+ if __name__ == "__main__":
133
+ launch_app()
134
+
135
+
136
+ # clean up
137
+ # nicer baseline vs intervention
138
+ # auto clear after messgae