mebubo commited on
Commit
c4d5641
·
1 Parent(s): 83ec4f2
completions.py CHANGED
@@ -120,10 +120,10 @@ def check_text(input_text: str, model: PreTrainedModel, tokenizer: Tokenizer, de
120
  #%%
121
  words = split_into_words(token_probs, tokenizer)
122
  log_prob_threshold = -5.0
123
- low_prob_words = [word for word in words if word.logprob < log_prob_threshold]
124
 
125
  #%%
126
- contexts = [word.context for word in low_prob_words]
127
  inputs = prepare_inputs(contexts, tokenizer, device)
128
  input_ids = inputs["input_ids"]
129
 
@@ -137,7 +137,12 @@ def check_text(input_text: str, model: PreTrainedModel, tokenizer: Tokenizer, de
137
  #%%
138
  replacements = extract_replacements(outputs, tokenizer, input_ids.shape[0], input_ids.shape[1], num_samples)
139
 
140
- #%%
141
- for word, replacements in zip(low_prob_words, replacements):
142
- print(f"Original word: {word.text}, Log Probability: {word.logprob:.4f}")
143
- print(f"Proposed replacements: {replacements}")
 
 
 
 
 
 
120
  #%%
121
  words = split_into_words(token_probs, tokenizer)
122
  log_prob_threshold = -5.0
123
+ low_prob_words = [(i, word) for i, word in enumerate(words) if word.logprob < log_prob_threshold]
124
 
125
  #%%
126
+ contexts = [word.context for _, word in low_prob_words]
127
  inputs = prepare_inputs(contexts, tokenizer, device)
128
  input_ids = inputs["input_ids"]
129
 
 
137
  #%%
138
  replacements = extract_replacements(outputs, tokenizer, input_ids.shape[0], input_ids.shape[1], num_samples)
139
 
140
+ low_prob_words_with_replacements = { i: (w, r) for (i, w), r in zip(low_prob_words, replacements) }
141
+
142
+ result = []
143
+ for i, word in enumerate(words):
144
+ if i in low_prob_words_with_replacements:
145
+ result.append(ApiWord(text=word.text, logprob=word.logprob, replacements=low_prob_words_with_replacements[i][1]))
146
+ else:
147
+ result.append(ApiWord(text=word.text, logprob=word.logprob, replacements=[]))
148
+ return result
frontend/src/components/TokenChip.tsx CHANGED
@@ -1,7 +1,5 @@
1
  import React, { useState } from "react"
2
 
3
- import React, { useState } from "react"
4
-
5
  export const TokenChip = ({
6
  token,
7
  logprob,
 
1
  import React, { useState } from "react"
2
 
 
 
3
  export const TokenChip = ({
4
  token,
5
  logprob,
frontend/src/components/app.tsx CHANGED
@@ -8,6 +8,13 @@ interface Word {
8
  }
9
 
10
  async function checkText(text: string): Promise<Word[]> {
 
 
 
 
 
 
 
11
  await new Promise(resolve => setTimeout(resolve, 1000));
12
 
13
  const words = text.split(/\b/)
 
8
  }
9
 
10
  async function checkText(text: string): Promise<Word[]> {
11
+ const response = await fetch(`/check?text=${text}`)
12
+ const data = await response.json()
13
+ console.log(data)
14
+ return data.words
15
+ }
16
+
17
+ async function checkText0(text: string): Promise<Word[]> {
18
  await new Promise(resolve => setTimeout(resolve, 1000));
19
 
20
  const words = text.split(/\b/)
main.py CHANGED
@@ -1,8 +1,7 @@
1
  from fastapi import FastAPI
2
  from fastapi.staticfiles import StaticFiles
3
- from pydantic import BaseModel
4
 
5
- from models import CheckResponse, ApiWord
6
  from completions import check_text, load_model
7
 
8
  app = FastAPI()
@@ -13,5 +12,4 @@ model, tokenizer, device = load_model()
13
  def check(text: str):
14
  return CheckResponse(text=text, words=check_text(text, model, tokenizer, device))
15
 
16
- # serve files from frontend/public
17
  app.mount("/", StaticFiles(directory="frontend/public", html=True))
 
1
  from fastapi import FastAPI
2
  from fastapi.staticfiles import StaticFiles
 
3
 
4
+ from models import CheckResponse
5
  from completions import check_text, load_model
6
 
7
  app = FastAPI()
 
12
  def check(text: str):
13
  return CheckResponse(text=text, words=check_text(text, model, tokenizer, device))
14
 
 
15
  app.mount("/", StaticFiles(directory="frontend/public", html=True))
models.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ from pydantic import BaseModel
4
+
5
+ @dataclass
6
+ class Word:
7
+ tokens: list[int]
8
+ text: str
9
+ logprob: float
10
+ context: list[int]
11
+
12
+ class ApiWord(BaseModel):
13
+ text: str
14
+ logprob: float
15
+ replacements: list[str]
16
+
17
+ class CheckResponse(BaseModel):
18
+ text: str
19
+ words: list[ApiWord]