kcarnold commited on
Commit
aa3a290
·
verified ·
1 Parent(s): 8b2fdac

Add a version that calls our quick-and-dirty API

Browse files
Files changed (1) hide show
  1. app.py +54 -37
app.py CHANGED
@@ -7,6 +7,7 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
7
  import pandas as pd
8
 
9
  model_options = [
 
10
  'google/gemma-1.1-2b-it',
11
  'google/gemma-1.1-7b-it'
12
  ]
@@ -26,46 +27,62 @@ def get_model(model_name):
26
  print(f"Loaded model, {model.num_parameters():,d} parameters.")
27
  return model
28
 
29
- tokenizer = get_tokenizer(model_name)
30
- model = get_model(model_name)
31
-
32
  prompt = st.text_area("Prompt", "Rewrite this document to be more clear and concise.")
33
  doc = st.text_area("Document", "This is a document that I would like to have rewritten to be more concise.")
34
 
35
 
36
- messages = [
37
- {
38
- "role": "user",
39
- "content": f"{prompt}\n\n{doc}",
40
- },
41
- ]
42
- tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt")[0]
43
- assert len(tokenized_chat.shape) == 1
44
-
45
- doc_ids = tokenizer(doc, return_tensors='pt')['input_ids'][0]
46
- joined_ids = torch.cat([tokenized_chat, doc_ids[1:]])
47
-
48
- # Call the model
49
- with torch.no_grad():
50
- logits = model(joined_ids[None].to(model.device)).logits[0].cpu()
51
-
52
- spans = []
53
- length_so_far = 0
54
- for idx in range(len(tokenized_chat), len(joined_ids)):
55
- probs = logits[idx - 1].softmax(dim=-1)
56
- token_id = joined_ids[idx]
57
- token = tokenizer.decode(token_id)
58
- token_loss = -probs[token_id].log().item()
59
- most_likely_token_id = probs.argmax()
60
- print(idx, token, token_loss, tokenizer.decode(most_likely_token_id))
61
- spans.append(dict(
62
- start=length_so_far,
63
- end=length_so_far + len(token),
64
- token=token,
65
- token_loss=token_loss,
66
- most_likely_token=tokenizer.decode(most_likely_token_id)
67
- ))
68
- length_so_far += len(token)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
  highest_loss = max(span['token_loss'] for span in spans[1:])
71
  for span in spans:
@@ -79,4 +96,4 @@ html = f"<p style=\"background: white;\">{html}</p>"
79
 
80
  st.subheader("Rewritten document")
81
  st.write(html, unsafe_allow_html=True)
82
- st.write(pd.DataFrame(spans))
 
7
  import pandas as pd
8
 
9
  model_options = [
10
+ 'API',
11
  'google/gemma-1.1-2b-it',
12
  'google/gemma-1.1-7b-it'
13
  ]
 
27
  print(f"Loaded model, {model.num_parameters():,d} parameters.")
28
  return model
29
 
 
 
 
30
  prompt = st.text_area("Prompt", "Rewrite this document to be more clear and concise.")
31
  doc = st.text_area("Document", "This is a document that I would like to have rewritten to be more concise.")
32
 
33
 
34
+ def get_spans_local(prompt, doc):
35
+ tokenizer = get_tokenizer(model_name)
36
+ model = get_model(model_name)
37
+
38
+
39
+ messages = [
40
+ {
41
+ "role": "user",
42
+ "content": f"{prompt}\n\n{doc}",
43
+ },
44
+ ]
45
+ tokenized_chat = tokenizer.apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt")[0]
46
+ assert len(tokenized_chat.shape) == 1
47
+
48
+ doc_ids = tokenizer(doc, return_tensors='pt')['input_ids'][0]
49
+ joined_ids = torch.cat([tokenized_chat, doc_ids[1:]])
50
+
51
+ # Call the model
52
+ with torch.no_grad():
53
+ logits = model(joined_ids[None].to(model.device)).logits[0].cpu()
54
+
55
+ spans = []
56
+ length_so_far = 0
57
+ for idx in range(len(tokenized_chat), len(joined_ids)):
58
+ probs = logits[idx - 1].softmax(dim=-1)
59
+ token_id = joined_ids[idx]
60
+ token = tokenizer.decode(token_id)
61
+ token_loss = -probs[token_id].log().item()
62
+ most_likely_token_id = probs.argmax()
63
+ print(idx, token, token_loss, tokenizer.decode(most_likely_token_id))
64
+ spans.append(dict(
65
+ start=length_so_far,
66
+ end=length_so_far + len(token),
67
+ token=token,
68
+ token_loss=token_loss,
69
+ most_likely_token=tokenizer.decode(most_likely_token_id)
70
+ ))
71
+ length_so_far += len(token)
72
+ return spans
73
+
74
+ def get_highlights_api(prompt, doc):
75
+ # Make a request to the API. prompt and doc are query parameters:
76
+ # https://tools.kenarnold.org/api/highlights?prompt=Rewrite%20this%20document&doc=This%20is%20a%20document
77
+ # The response is a JSON array
78
+ import requests
79
+ response = requests.get("https://tools.kenarnold.org/api/highlights", params=dict(prompt=prompt, doc=doc))
80
+ return response.json()['highlights']
81
+
82
+ if model_name == 'API':
83
+ spans = get_highlights_api(prompt, doc)
84
+ else:
85
+ spans = get_spans_local(prompt, doc)
86
 
87
  highest_loss = max(span['token_loss'] for span in spans[1:])
88
  for span in spans:
 
96
 
97
  st.subheader("Rewritten document")
98
  st.write(html, unsafe_allow_html=True)
99
+ st.write(pd.DataFrame(spans)[['token', 'token_loss', 'most_likely_token', 'loss_ratio']])