mebubo commited on
Commit
bbae7a9
·
1 Parent(s): c4d5641

First working version

Browse files
completions.py CHANGED
@@ -91,13 +91,13 @@ def generate_outputs(model: PreTrainedModel, inputs: BatchEncoding, num_samples:
91
  def extract_replacements(outputs: GenerateOutput | torch.LongTensor, tokenizer: Tokenizer, num_inputs: int, input_len: int, num_samples: int = 5) -> list[list[str]]:
92
  all_new_words = []
93
  for i in range(num_inputs):
94
- replacements = []
95
  for j in range(num_samples):
96
  generated_ids = outputs[i * num_samples + j][input_len:]
97
  new_word = tokenizer.convert_ids_to_tokens(generated_ids.tolist())[0]
98
  if starts_with_space(new_word):
99
- replacements.append(new_word[1:])
100
- all_new_words.append(replacements)
101
  return all_new_words
102
 
103
  #%%
@@ -128,7 +128,7 @@ def check_text(input_text: str, model: PreTrainedModel, tokenizer: Tokenizer, de
128
  input_ids = inputs["input_ids"]
129
 
130
  #%%
131
- num_samples = 5
132
  start_time = time.time()
133
  outputs = generate_outputs(model, inputs, num_samples)
134
  end_time = time.time()
 
91
  def extract_replacements(outputs: GenerateOutput | torch.LongTensor, tokenizer: Tokenizer, num_inputs: int, input_len: int, num_samples: int = 5) -> list[list[str]]:
92
  all_new_words = []
93
  for i in range(num_inputs):
94
+ replacements = set()
95
  for j in range(num_samples):
96
  generated_ids = outputs[i * num_samples + j][input_len:]
97
  new_word = tokenizer.convert_ids_to_tokens(generated_ids.tolist())[0]
98
  if starts_with_space(new_word):
99
+ replacements.add(" " +new_word[1:])
100
+ all_new_words.append(sorted(list(replacements)))
101
  return all_new_words
102
 
103
  #%%
 
128
  input_ids = inputs["input_ids"]
129
 
130
  #%%
131
+ num_samples = 10
132
  start_time = time.time()
133
  outputs = generate_outputs(model, inputs, num_samples)
134
  end_time = time.time()
frontend/src/components/TokenChip.tsx CHANGED
@@ -1,4 +1,4 @@
1
- import React, { useState } from "react"
2
 
3
  export const TokenChip = ({
4
  token,
@@ -14,6 +14,20 @@ export const TokenChip = ({
14
  onReplace: (newToken: string) => void
15
  }) => {
16
  const [isExpanded, setIsExpanded] = useState(false);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  const handleClick = () => {
19
  if (logprob < threshold && replacements.length > 0) {
@@ -39,6 +53,7 @@ export const TokenChip = ({
39
  {token}
40
  {isExpanded && (
41
  <select
 
42
  onChange={handleReplacement}
43
  value={token}
44
  style={{
 
1
+ import React, { useState, useEffect, useRef } from "react"
2
 
3
  export const TokenChip = ({
4
  token,
 
14
  onReplace: (newToken: string) => void
15
  }) => {
16
  const [isExpanded, setIsExpanded] = useState(false);
17
+ const dropdownRef = useRef<HTMLSelectElement>(null);
18
+
19
+ useEffect(() => {
20
+ const handleClickOutside = (event: MouseEvent) => {
21
+ if (dropdownRef.current && !dropdownRef.current.contains(event.target as Node)) {
22
+ setIsExpanded(false);
23
+ }
24
+ };
25
+
26
+ document.addEventListener("mousedown", handleClickOutside);
27
+ return () => {
28
+ document.removeEventListener("mousedown", handleClickOutside);
29
+ };
30
+ }, []);
31
 
32
  const handleClick = () => {
33
  if (logprob < threshold && replacements.length > 0) {
 
53
  {token}
54
  {isExpanded && (
55
  <select
56
+ ref={dropdownRef}
57
  onChange={handleReplacement}
58
  value={token}
59
  style={{
frontend/src/components/app.tsx CHANGED
@@ -37,32 +37,40 @@ export default function App() {
37
  const [context, setContext] = useState("")
38
  const [wordlist, setWordlist] = useState("")
39
  const [showWholePrompt, setShowWholePrompt] = useState(false)
40
- const [text, setText] = useState("The cumbersomely quick brown fox jumps over the conspicuously lazy dog.")
41
  const [mode, setMode] = useState<"edit" | "check">("edit")
42
  const [words, setWords] = useState<Word[]>([])
43
  const [isLoading, setIsLoading] = useState(false)
44
 
 
 
 
 
 
 
 
 
 
 
 
45
  const toggleMode = async () => {
46
  if (mode === "edit") {
47
  setIsLoading(true)
48
- try {
49
- const checkedWords = await checkText(text)
50
- setWords(checkedWords)
51
- } finally {
52
- setMode("check")
53
- setIsLoading(false)
54
- }
55
  } else {
56
  setMode("edit")
57
  }
58
  }
59
 
60
- const handleReplace = (index: number, newToken: string) => {
61
  const updatedWords = [...words]
62
  updatedWords[index].text = newToken
 
 
63
  setWords(updatedWords)
64
  setText(updatedWords.map(w => w.text).join(""))
65
- setMode("edit")
 
66
  }
67
 
68
  let result
@@ -101,7 +109,7 @@ export default function App() {
101
  <details>
102
  <summary>Advanced settings</summary>
103
  <label>
104
- <strong>Threshold:</strong> <input type="number" step="0.1" value={threshold} onChange={e => setThreshold(Number(e.target.value))} />
105
  <small>
106
  The <a href="https://en.wikipedia.org/wiki/Log_probability" target="_blank" rel="noreferrer">logprob</a> threshold.
107
  Tokens with logprobs smaller than this will be marked red.
@@ -132,8 +140,8 @@ export default function App() {
132
 
133
  <p>
134
  <small>
135
- Made by <a href="https://vgel.me">Theia Vogel</a> (<a href="https://twitter.com/voooooogel">@voooooogel</a>).
136
- Made with Svelte, GPT-3, and transitively, most of the web.
137
  <br />
138
  This software is provided with absolutely no warranty.
139
  </small>
 
37
  const [context, setContext] = useState("")
38
  const [wordlist, setWordlist] = useState("")
39
  const [showWholePrompt, setShowWholePrompt] = useState(false)
40
+ const [text, setText] = useState("I just drove to the store to but eggs, but they had some.")
41
  const [mode, setMode] = useState<"edit" | "check">("edit")
42
  const [words, setWords] = useState<Word[]>([])
43
  const [isLoading, setIsLoading] = useState(false)
44
 
45
+ const check = async () => {
46
+ setIsLoading(true)
47
+ try {
48
+ const checkedWords = await checkText(text)
49
+ setWords(checkedWords)
50
+ } finally {
51
+ setIsLoading(false)
52
+ setMode("check")
53
+ }
54
+ }
55
+
56
  const toggleMode = async () => {
57
  if (mode === "edit") {
58
  setIsLoading(true)
59
+ await check()
 
 
 
 
 
 
60
  } else {
61
  setMode("edit")
62
  }
63
  }
64
 
65
+ const handleReplace = async (index: number, newToken: string) => {
66
  const updatedWords = [...words]
67
  updatedWords[index].text = newToken
68
+ updatedWords[index].logprob = 0
69
+ updatedWords[index].replacements = []
70
  setWords(updatedWords)
71
  setText(updatedWords.map(w => w.text).join(""))
72
+ // setMode("edit")
73
+ await check()
74
  }
75
 
76
  let result
 
109
  <details>
110
  <summary>Advanced settings</summary>
111
  <label>
112
+ <strong>Threshold:</strong> <input type="number" step="1" value={threshold} onChange={e => setThreshold(Number(e.target.value))} />
113
  <small>
114
  The <a href="https://en.wikipedia.org/wiki/Log_probability" target="_blank" rel="noreferrer">logprob</a> threshold.
115
  Tokens with logprobs smaller than this will be marked red.
 
140
 
141
  <p>
142
  <small>
143
+ Based on <a href="https://github.com/vgel/gpted">GPTed</a> by <a href="https://vgel.me">Theia Vogel</a>.
144
+ Made with React, Transformers, LLama 3.2, and transitively, most of the web.
145
  <br />
146
  This software is provided with absolutely no warranty.
147
  </small>