First working version
Browse files- completions.py +4 -4
- frontend/src/components/TokenChip.tsx +16 -1
- frontend/src/components/app.tsx +21 -13
completions.py
CHANGED
@@ -91,13 +91,13 @@ def generate_outputs(model: PreTrainedModel, inputs: BatchEncoding, num_samples:
|
|
91 |
def extract_replacements(outputs: GenerateOutput | torch.LongTensor, tokenizer: Tokenizer, num_inputs: int, input_len: int, num_samples: int = 5) -> list[list[str]]:
|
92 |
all_new_words = []
|
93 |
for i in range(num_inputs):
|
94 |
-
replacements =
|
95 |
for j in range(num_samples):
|
96 |
generated_ids = outputs[i * num_samples + j][input_len:]
|
97 |
new_word = tokenizer.convert_ids_to_tokens(generated_ids.tolist())[0]
|
98 |
if starts_with_space(new_word):
|
99 |
-
replacements.
|
100 |
-
all_new_words.append(replacements)
|
101 |
return all_new_words
|
102 |
|
103 |
#%%
|
@@ -128,7 +128,7 @@ def check_text(input_text: str, model: PreTrainedModel, tokenizer: Tokenizer, de
|
|
128 |
input_ids = inputs["input_ids"]
|
129 |
|
130 |
#%%
|
131 |
-
num_samples =
|
132 |
start_time = time.time()
|
133 |
outputs = generate_outputs(model, inputs, num_samples)
|
134 |
end_time = time.time()
|
|
|
91 |
def extract_replacements(outputs: GenerateOutput | torch.LongTensor, tokenizer: Tokenizer, num_inputs: int, input_len: int, num_samples: int = 5) -> list[list[str]]:
|
92 |
all_new_words = []
|
93 |
for i in range(num_inputs):
|
94 |
+
replacements = set()
|
95 |
for j in range(num_samples):
|
96 |
generated_ids = outputs[i * num_samples + j][input_len:]
|
97 |
new_word = tokenizer.convert_ids_to_tokens(generated_ids.tolist())[0]
|
98 |
if starts_with_space(new_word):
|
99 |
+
replacements.add(" " +new_word[1:])
|
100 |
+
all_new_words.append(sorted(list(replacements)))
|
101 |
return all_new_words
|
102 |
|
103 |
#%%
|
|
|
128 |
input_ids = inputs["input_ids"]
|
129 |
|
130 |
#%%
|
131 |
+
num_samples = 10
|
132 |
start_time = time.time()
|
133 |
outputs = generate_outputs(model, inputs, num_samples)
|
134 |
end_time = time.time()
|
frontend/src/components/TokenChip.tsx
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
import React, { useState } from "react"
|
2 |
|
3 |
export const TokenChip = ({
|
4 |
token,
|
@@ -14,6 +14,20 @@ export const TokenChip = ({
|
|
14 |
onReplace: (newToken: string) => void
|
15 |
}) => {
|
16 |
const [isExpanded, setIsExpanded] = useState(false);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
const handleClick = () => {
|
19 |
if (logprob < threshold && replacements.length > 0) {
|
@@ -39,6 +53,7 @@ export const TokenChip = ({
|
|
39 |
{token}
|
40 |
{isExpanded && (
|
41 |
<select
|
|
|
42 |
onChange={handleReplacement}
|
43 |
value={token}
|
44 |
style={{
|
|
|
1 |
+
import React, { useState, useEffect, useRef } from "react"
|
2 |
|
3 |
export const TokenChip = ({
|
4 |
token,
|
|
|
14 |
onReplace: (newToken: string) => void
|
15 |
}) => {
|
16 |
const [isExpanded, setIsExpanded] = useState(false);
|
17 |
+
const dropdownRef = useRef<HTMLSelectElement>(null);
|
18 |
+
|
19 |
+
useEffect(() => {
|
20 |
+
const handleClickOutside = (event: MouseEvent) => {
|
21 |
+
if (dropdownRef.current && !dropdownRef.current.contains(event.target as Node)) {
|
22 |
+
setIsExpanded(false);
|
23 |
+
}
|
24 |
+
};
|
25 |
+
|
26 |
+
document.addEventListener("mousedown", handleClickOutside);
|
27 |
+
return () => {
|
28 |
+
document.removeEventListener("mousedown", handleClickOutside);
|
29 |
+
};
|
30 |
+
}, []);
|
31 |
|
32 |
const handleClick = () => {
|
33 |
if (logprob < threshold && replacements.length > 0) {
|
|
|
53 |
{token}
|
54 |
{isExpanded && (
|
55 |
<select
|
56 |
+
ref={dropdownRef}
|
57 |
onChange={handleReplacement}
|
58 |
value={token}
|
59 |
style={{
|
frontend/src/components/app.tsx
CHANGED
@@ -37,32 +37,40 @@ export default function App() {
|
|
37 |
const [context, setContext] = useState("")
|
38 |
const [wordlist, setWordlist] = useState("")
|
39 |
const [showWholePrompt, setShowWholePrompt] = useState(false)
|
40 |
-
const [text, setText] = useState("
|
41 |
const [mode, setMode] = useState<"edit" | "check">("edit")
|
42 |
const [words, setWords] = useState<Word[]>([])
|
43 |
const [isLoading, setIsLoading] = useState(false)
|
44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
const toggleMode = async () => {
|
46 |
if (mode === "edit") {
|
47 |
setIsLoading(true)
|
48 |
-
|
49 |
-
const checkedWords = await checkText(text)
|
50 |
-
setWords(checkedWords)
|
51 |
-
} finally {
|
52 |
-
setMode("check")
|
53 |
-
setIsLoading(false)
|
54 |
-
}
|
55 |
} else {
|
56 |
setMode("edit")
|
57 |
}
|
58 |
}
|
59 |
|
60 |
-
const handleReplace = (index: number, newToken: string) => {
|
61 |
const updatedWords = [...words]
|
62 |
updatedWords[index].text = newToken
|
|
|
|
|
63 |
setWords(updatedWords)
|
64 |
setText(updatedWords.map(w => w.text).join(""))
|
65 |
-
setMode("edit")
|
|
|
66 |
}
|
67 |
|
68 |
let result
|
@@ -101,7 +109,7 @@ export default function App() {
|
|
101 |
<details>
|
102 |
<summary>Advanced settings</summary>
|
103 |
<label>
|
104 |
-
<strong>Threshold:</strong> <input type="number" step="
|
105 |
<small>
|
106 |
The <a href="https://en.wikipedia.org/wiki/Log_probability" target="_blank" rel="noreferrer">logprob</a> threshold.
|
107 |
Tokens with logprobs smaller than this will be marked red.
|
@@ -132,8 +140,8 @@ export default function App() {
|
|
132 |
|
133 |
<p>
|
134 |
<small>
|
135 |
-
|
136 |
-
Made with
|
137 |
<br />
|
138 |
This software is provided with absolutely no warranty.
|
139 |
</small>
|
|
|
37 |
const [context, setContext] = useState("")
|
38 |
const [wordlist, setWordlist] = useState("")
|
39 |
const [showWholePrompt, setShowWholePrompt] = useState(false)
|
40 |
+
const [text, setText] = useState("I just drove to the store to but eggs, but they had some.")
|
41 |
const [mode, setMode] = useState<"edit" | "check">("edit")
|
42 |
const [words, setWords] = useState<Word[]>([])
|
43 |
const [isLoading, setIsLoading] = useState(false)
|
44 |
|
45 |
+
const check = async () => {
|
46 |
+
setIsLoading(true)
|
47 |
+
try {
|
48 |
+
const checkedWords = await checkText(text)
|
49 |
+
setWords(checkedWords)
|
50 |
+
} finally {
|
51 |
+
setIsLoading(false)
|
52 |
+
setMode("check")
|
53 |
+
}
|
54 |
+
}
|
55 |
+
|
56 |
const toggleMode = async () => {
|
57 |
if (mode === "edit") {
|
58 |
setIsLoading(true)
|
59 |
+
await check()
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
} else {
|
61 |
setMode("edit")
|
62 |
}
|
63 |
}
|
64 |
|
65 |
+
const handleReplace = async (index: number, newToken: string) => {
|
66 |
const updatedWords = [...words]
|
67 |
updatedWords[index].text = newToken
|
68 |
+
updatedWords[index].logprob = 0
|
69 |
+
updatedWords[index].replacements = []
|
70 |
setWords(updatedWords)
|
71 |
setText(updatedWords.map(w => w.text).join(""))
|
72 |
+
// setMode("edit")
|
73 |
+
await check()
|
74 |
}
|
75 |
|
76 |
let result
|
|
|
109 |
<details>
|
110 |
<summary>Advanced settings</summary>
|
111 |
<label>
|
112 |
+
<strong>Threshold:</strong> <input type="number" step="1" value={threshold} onChange={e => setThreshold(Number(e.target.value))} />
|
113 |
<small>
|
114 |
The <a href="https://en.wikipedia.org/wiki/Log_probability" target="_blank" rel="noreferrer">logprob</a> threshold.
|
115 |
Tokens with logprobs smaller than this will be marked red.
|
|
|
140 |
|
141 |
<p>
|
142 |
<small>
|
143 |
+
Based on <a href="https://github.com/vgel/gpted">GPTed</a> by <a href="https://vgel.me">Theia Vogel</a>.
|
144 |
+
Made with React, Transformers, LLama 3.2, and transitively, most of the web.
|
145 |
<br />
|
146 |
This software is provided with absolutely no warranty.
|
147 |
</small>
|