Spaces:
Running
Running
File size: 7,300 Bytes
2dbe738 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 |
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import DynamicCache
USE_GPU = torch.cuda.is_available()
@st.cache_resource
def load_model():
import torch
model_name = 'google/gemma-2-9b-it'
dtype = torch.bfloat16 if USE_GPU else torch.float16
llm = {
'tokenizer': AutoTokenizer.from_pretrained(model_name),
'model': AutoModelForCausalLM.from_pretrained(
model_name,
device_map="auto" if USE_GPU else "cpu",
torch_dtype=dtype,
attn_implementation='eager'
)
}
llm['model'].eval()
return llm
def type_assistant_response():
if 'messages' not in st.session_state or st.button("Start a new conversation"):
st.session_state['messages'] = [{"role": "user", "content": ""}]
st.session_state['msg_in_progress'] = ""
messages = st.session_state.messages
def rewind_to(i):
st.session_state.messages = st.session_state.messages[:i+1]
st.session_state['msg_in_progress'] = st.session_state.messages[-1]['content']
for i, message in enumerate(st.session_state.messages[:-1]):
with st.chat_message(message["role"]):
st.markdown(message["content"])
st.button("Edit", on_click=rewind_to, args=(i,), key=f"rewind_to_{i}")
# Display message-in-progress in chat message container
last_role = messages[-1]["role"]
with st.chat_message(last_role):
label = "Your message" if last_role == "user" else "Assistant response"
msg_in_progress = st.text_area(label, placeholder="Clicking the buttons below will update this field. You can also edit it directly; press Ctrl+Enter to apply changes.", height=300, key="msg_in_progress")
if msg_in_progress is None:
msg_in_progress = ""
messages[-1]['content'] = msg_in_progress
def append_token(word):
messages[-1]['content'] = st.session_state['msg_in_progress'] = (
msg_in_progress + word
)
allow_multi_word = st.checkbox("Allow multi-word predictions", value=False)
response = continue_messages(
messages=messages,
n_branch_tokens=5,
n_future_tokens=2
)
continuations = response['continuations']
for i, (col, continuation) in enumerate(zip(st.columns(len(continuations)), continuations)):
token = continuation['doc_text']
with col:
if not allow_multi_word and ' ' in token[1:]:
token = token[0] + token[1:].split(' ', 1)[0]
# if not allow_multi_word:
# import re
# split_result = re.split(r'(\s+)', token, maxsplit=1)
# assert len(split_result) == 3
# before_ws, token, after_ws = split_result
# print(repr(split_result))
# if before_ws != '':
# token = before_ws
token_display = show_token(token)
st.button(token_display, on_click=append_token, args=(token,), key=i, use_container_width=True)
def send_message():
other_role = "assistant" if last_role == "user" else "user"
st.session_state['messages'].append({"role": other_role, "content": ""})
st.session_state['msg_in_progress'] = ""
st.button("Send", on_click=send_message)
def show_token(token: str, escape_markdown=True) -> str:
token_display = token.replace('\n', '↵').replace('\t', '⇥')
if escape_markdown:
for c in "\\`*_{}[]()#+-.!":
token_display = token_display.replace(c, "\\" + c)
return token_display
def continue_messages(messages, n_branch_tokens, n_future_tokens):
messages = [{"role": m.role, "content": m.content} for m in messages]
if len(messages) == 0:
raise ValueError("At least one message must be provided.")
llm = load_model()
model = llm['model']
tokenizer = llm['tokenizer']
generated_docs = continue_messages_inner(model, tokenizer, messages, n_branch_tokens, n_future_tokens)
return {
'continuations': [dict(doc_text=doc) for doc in generated_docs]
}
def get_lookahead_sequences(model, tokenizer, hypotheses, n_branch_tokens, device):
"""
For each of the n_branch_tokens next tokens, generate most-likely next tokens and append back on.
"""
assert len(hypotheses.shape) == 2
assert hypotheses.shape[0] == 1
n_tokens_so_far = hypotheses.shape[1]
past_key_values = DynamicCache()
with torch.no_grad():
model_outs_onestep = model(hypotheses, output_hidden_states=True, past_key_values=past_key_values)
branch_tokens = model_outs_onestep.logits[0, -1].topk(n_branch_tokens).indices
# split the cache into n_branch_tokens reps. We pretend we're doing a "Beam search"...
past_key_values.reorder_cache(torch.zeros((n_branch_tokens,), dtype=torch.long, device=device))
# Now call the model again, passing the kv cache, so we can continue generating.
# Each of the n_branch_tokens next tokens will be considered as one sequence in a "batch".
next_tokens_as_batch = branch_tokens.unsqueeze(1)
assert next_tokens_as_batch.shape == (n_branch_tokens, 1)
position_id_for_final_token = n_tokens_so_far
cache_position = torch.full((1,), position_id_for_final_token, dtype=int, device=device)
with torch.no_grad():
model_outs = model(
next_tokens_as_batch,
past_key_values=past_key_values,
output_hidden_states=True,
use_cache=True,
# the cache surprisingly doesn't know the position of the last token
cache_position=cache_position
)
# Grab the single most likely token from each of the n_branch_tokens sequences
next_token_logits = model_outs.logits[:, -1]
vocab_size = model.config.vocab_size
assert next_token_logits.shape == (n_branch_tokens, vocab_size), f"{next_token_logits.shape=}, {n_branch_tokens=}, {vocab_size=}"
most_likely_token_ids = next_token_logits.argmax(dim=-1)
# Stick them at the end of the branch tokens.
assert most_likely_token_ids.shape == (n_branch_tokens,)
lookahead_sequences = torch.cat([
branch_tokens.unsqueeze(1),
most_likely_token_ids.unsqueeze(1)
], dim=1)
assert lookahead_sequences.shape == (n_branch_tokens, 2)
return lookahead_sequences, next_token_logits
def continue_messages_inner(model, tokenizer, messages, n_branch_tokens, n_future_tokens):
# Note: we're ignoring n_future_tokens right now since the old implementation was buggy.
device = model.device
tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=True, return_tensors="pt", continue_final_message=True).to(model.device)
print(tokenizer.batch_decode(tokenized_chat, skip_special_tokens=False))
lookahead_sequences, next_token_logits = get_lookahead_sequences(
model, tokenizer, tokenized_chat, n_branch_tokens, device)
generated_docs = tokenizer.batch_decode(lookahead_sequences, skip_special_tokens=True)
return generated_docs
type_assistant_response()
|