In [1]:
# test probe loading 
import pickle as pkl
import numpy as np
import sklearn 
from sklearn import linear_model
import os
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

# load the probe data
with open("./model/20240625-131035_demo.pkl", "rb") as f:
    probe_data = pkl.load(f)
# take the NQ open one
probe_data = probe_data[-2]
probe_data



Found Intel OpenMP ('libiomp') and LLVM OpenMP ('libomp') loaded at
the same time. Both libraries are known to be incompatible and this
can cause random crashes or deadlocks on Linux when loaded in the
same Python program.
Using threadpoolctl may cause crashes or deadlocks. For more
information and possible workarounds, please see
    https://github.com/joblib/threadpoolctl/blob/master/multiple_openmp.md



{'name': 'nq',
 't_bmodel': LogisticRegression(),
 't_amodel': LogisticRegression(),
 'sep_layer_range': (27, 32),
 'ap_layer_range': (17, 22)}

In [2]:
se_probe = probe_data['t_bmodel']
se_layer_range = probe_data['sep_layer_range']
acc_probe = probe_data['t_amodel']
acc_layer_range = probe_data['ap_layer_range']

In [3]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer

model_id = "meta-llama/Llama-2-7b-chat-hf"
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.use_default_system_prompt = False

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Some parameters are on the meta device device because they were offloaded to the disk.


In [6]:
from typing import Tuple

MAX_INPUT_TOKEN_LENGTH = 512


def generate(
    message: str,
    system_prompt: str,
    max_new_tokens: int = 10,
    temperature: float = 0.6,
    top_p: float = 0.9,
    top_k: int = 50,
    repetition_penalty: float = 1.2,
) -> Tuple[str, str]:
    conversation = []
    if system_prompt:
        conversation.append({"role": "system", "content": system_prompt})
    conversation.append({"role": "user", "content": message})

    input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
    if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
        input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
    input_ids = input_ids.to(model.device)

    #### Generate without threading
    generation_kwargs = dict(
        input_ids=input_ids,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        top_p=top_p,
        top_k=top_k,
        temperature=temperature,
        repetition_penalty=repetition_penalty,
        output_hidden_states=True,
        return_dict_in_generate=True,
        attention_mask=torch.ones_like(input_ids),
    )
    with torch.no_grad():
        outputs = model.generate(**generation_kwargs)
    generated_tokens = outputs.sequences[0, input_ids.shape[1]:]
    generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
    print(generated_text)
    # hidden states
    hidden = outputs.hidden_states  # list of tensors, one for each token, then (batch size, sequence length, hidden size)

    se_highlighted_text = ""
    acc_highlighted_text = ""

    # skip the first hidden state as it is the prompt
    for i in range(1, len(hidden)):

        # Semantic Uncertainty Probe
        token_embeddings = torch.stack([generated_token[0, 0, :].cpu() for generated_token in hidden[i]]).numpy()   # (num_layers, hidden_size)
        se_concat_layers = token_embeddings[se_layer_range[0]:se_layer_range[1]].reshape(-1)
        se_probe_pred = se_probe.predict_proba(se_concat_layers.reshape(1, -1))[0][1] * 2 - 1
        
        # Accuracy Probe
        acc_concat_layers = token_embeddings[acc_layer_range[0]:acc_layer_range[1]].reshape(-1)
        acc_probe_pred = (1 - acc_probe.predict_proba(acc_concat_layers.reshape(1, -1))[0][1]) * 2 - 1
        
        output_id = outputs.sequences[0, input_ids.shape[1]+i]
        output_word = tokenizer.decode(output_id)
        print(output_id, output_word, se_probe_pred, acc_probe_pred)  

        se_new_highlighted_text = highlight_text(output_word, se_probe_pred)
        acc_new_highlighted_text = highlight_text(output_word, acc_probe_pred)
        se_highlighted_text += f" {se_new_highlighted_text}"
        acc_highlighted_text += f" {acc_new_highlighted_text}"
        
    return se_highlighted_text, acc_highlighted_text


def highlight_text(text: str, uncertainty_score: float) -> str:
    if uncertainty_score > 0:
        html_color = "#%02X%02X%02X" % (
            255,
            int(255 * (1 - uncertainty_score)),
            int(255 * (1 - uncertainty_score)),
        )
    else:
        html_color = "#%02X%02X%02X" % (
            int(255 * (1 + uncertainty_score)),
            255,
            int(255 * (1 + uncertainty_score)),
        )
    return '<span style="background-color: {}; color: black">{}</span>'.format(
        html_color, text
    )

message = "What is the capital of France?"
system_prompt = ""
se_highlighted_text, acc_highlighted_text = generate(message, system_prompt)
print(se_highlighted_text)
    

Љ ( "ass
ЪЏ
հ MO-OC
tensor(30488, device='mps:0') Љ 1.0 -0.014414779243550946
tensor(313, device='mps:0') ( -0.9998164331881116 0.9597905489862286
tensor(376, device='mps:0') " 0.9999998197256226 -0.9792630307582237
tensor(465, device='mps:0') ass -0.9999994897301452 0.9680999957882863
tensor(13, device='mps:0') 
 -0.99999964561314 0.9983907264450047
tensor(31147, device='mps:0') Ъ 1.0 -0.9999976710226259
tensor(30282, device='mps:0') Џ 1.0 0.9999912572082477
tensor(13, device='mps:0') 
 0.9999999999869607 0.9999964462206883
tensor(31488, device='mps:0') հ 1.0 -1.0
tensor(341, device='mps:0') M 0.9045896738793786 0.5590883316684834
tensor(29949, device='mps:0') O -0.9999999803476437 -0.5270551643185932
tensor(29899, device='mps:0') - 0.9992488974195408 0.9987826119127319
tensor(29949, device='mps:0') O -0.9713693636571169 0.9993573968241007
tensor(29907, device='mps:0') C -0.9999999701427968 0.9904799691607524
 <span style="background-color: #FF0000; color: black">Љ</span> <span style=

In [13]:
from threading import Thread

system_prompt = "You are a helpful assistant."
message = "what is the capital of France?"
max_new_tokens = 100
top_p = 0.9
top_k = 50
temperature = 0.7
repetition_penalty = 1.2

conversation = []

conversation.append({"role": "system", "content": system_prompt})
conversation.append({"role": "user", "content": message})
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
input_ids = input_ids.to(model.device)
print(input_ids, input_ids.shape)
streamer = TextIteratorStreamer(tokenizer, timeout=1000.0, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = dict(
    input_ids=input_ids,
    max_new_tokens=max_new_tokens,
    do_sample=True,
    top_p=top_p,
    top_k=top_k,
    temperature=temperature,
    repetition_penalty=repetition_penalty,
    streamer=streamer,
    output_hidden_states=True,
    return_dict_in_generate=True,
)

thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()

generated_text = ""
highlighted_text = ""
for output in streamer:
    print(output)
    generated_text += output

    # yield generated_text
for new_text in streamer:
    print(new_text)
    generated_text += new_text
    print(generated_text)
    current_input_ids = tokenizer.encode(generated_text, return_tensors="pt").to(model.device)
    print(current_input_ids, current_input_ids.shape)
    with torch.no_grad():
        outputs = model(current_input_ids, output_hidden_states=True)
        hidden = outputs.hidden_states    
        print(len(hidden))
        print(hidden[-1].shape)
        # Stack second last token embeddings from all layers 
        # if len(hidden) == 1:  # FIX: runtime error for mistral-7b on bioasq
        #     sec_last_input = hidden[0]
        # elif ((n_generated - 2) >= len(hidden)):
        #     sec_last_input = hidden[-2]
        # else:
        #     sec_last_input = hidden[n_generated - 2]
        sec_last_token_embedding = torch.stack([layer[:, -1, :].cpu() for layer in hidden])
        print(sec_last_token_embedding.shape)
    last_hidden_state = hidden[-1][:, -1, :].cpu().numpy()
    print(last_hidden_state.shape)  
    # TODO potentially need to only compute uncertainty for the last token in sentence?


tensor([[    1,   518, 25580, 29962,  3532, 14816, 29903,  6778,    13,  3492,
           526,   263,  8444, 20255, 29889,    13, 29966,   829, 14816, 29903,
          6778,    13,    13,  5816,   338,   278,  7483,   310,  3444, 29973,
           518, 29914, 25580, 29962]]) torch.Size([1, 34])

 


KeyboardInterrupt: 

In [None]:
# concat hidden states

sec_last_token_embedding = np.concatenate(sec_last_token_embedding.cpu().numpy()[layer_range], axis=1)
# predict with probe
pred = probe.predict(hidden_states)
print(pred)