mebubo commited on
Commit
83ec4f2
·
1 Parent(s): b295d62
Files changed (2) hide show
  1. completions.py +32 -42
  2. main.py +5 -11
completions.py CHANGED
@@ -6,14 +6,9 @@ from transformers.generation.utils import GenerateOutput
6
  from typing import cast
7
  from dataclasses import dataclass
8
 
9
- type Tokenizer = PreTrainedTokenizer | PreTrainedTokenizerFast
10
 
11
- @dataclass
12
- class Word:
13
- tokens: list[int]
14
- text: str
15
- logprob: float
16
- context: list[int]
17
 
18
  def starts_with_space(token: str) -> bool:
19
  return token.startswith(chr(9601)) or token.startswith(chr(288))
@@ -107,47 +102,42 @@ def extract_replacements(outputs: GenerateOutput | torch.LongTensor, tokenizer:
107
 
108
  #%%
109
 
110
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
111
-
112
- # model_name = "mistralai/Mistral-7B-v0.1"
113
- model_name = "unsloth/Llama-3.2-1B"
114
- model, tokenizer = load_model_and_tokenizer(model_name, device)
115
 
116
- #%%
117
- input_text = "He asked me to prostrate myself before the king, but I rifused."
118
- inputs: BatchEncoding = tokenize(input_text, tokenizer, device)
 
119
 
 
120
  #%%
121
- token_probs: list[tuple[int, float]] = calculate_log_probabilities(model, tokenizer, inputs)
122
 
123
- #%%
124
- words = split_into_words(token_probs, tokenizer)
125
- log_prob_threshold = -5.0
126
- low_prob_words = [word for word in words if word.logprob < log_prob_threshold]
127
 
128
- #%%
129
- contexts = [word.context for word in low_prob_words]
130
- inputs = prepare_inputs(contexts, tokenizer, device)
131
- input_ids = inputs["input_ids"]
132
 
133
- #%%
134
- num_samples = 5
135
- start_time = time.time()
136
- outputs = generate_outputs(model, inputs, num_samples)
137
- end_time = time.time()
138
- print(f"Total time taken for replacements: {end_time - start_time:.4f} seconds")
139
-
140
- #%%
141
- replacements_batch = extract_replacements(outputs, tokenizer, input_ids.shape[0], input_ids.shape[1], num_samples)
142
 
143
- #%%
144
- for word, replacements in zip(low_prob_words, replacements_batch):
145
- print(f"Original word: {word.text}, Log Probability: {word.logprob:.4f}")
146
- print(f"Proposed replacements: {replacements}")
 
 
147
 
148
- # %%
149
- generated_ids = outputs[:, input_ids.shape[-1]:]
150
- for g in generated_ids:
151
- print(tokenizer.convert_ids_to_tokens(g.tolist()))
152
 
153
- # %%
 
 
 
 
6
  from typing import cast
7
  from dataclasses import dataclass
8
 
9
+ from models import ApiWord, Word
10
 
11
+ type Tokenizer = PreTrainedTokenizer | PreTrainedTokenizerFast
 
 
 
 
 
12
 
13
  def starts_with_space(token: str) -> bool:
14
  return token.startswith(chr(9601)) or token.startswith(chr(288))
 
102
 
103
  #%%
104
 
105
+ def load_model() -> tuple[PreTrainedModel, Tokenizer, torch.device]:
106
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
107
 
108
+ # model_name = "mistralai/Mistral-7B-v0.1"
109
+ model_name = "unsloth/Llama-3.2-1B"
110
+ model, tokenizer = load_model_and_tokenizer(model_name, device)
111
+ return model, tokenizer, device
112
 
113
+ def check_text(input_text: str, model: PreTrainedModel, tokenizer: Tokenizer, device: torch.device) -> list[ApiWord]:
114
  #%%
115
+ inputs: BatchEncoding = tokenize(input_text, tokenizer, device)
116
 
117
+ #%%
118
+ token_probs: list[tuple[int, float]] = calculate_log_probabilities(model, tokenizer, inputs)
 
 
119
 
120
+ #%%
121
+ words = split_into_words(token_probs, tokenizer)
122
+ log_prob_threshold = -5.0
123
+ low_prob_words = [word for word in words if word.logprob < log_prob_threshold]
124
 
125
+ #%%
126
+ contexts = [word.context for word in low_prob_words]
127
+ inputs = prepare_inputs(contexts, tokenizer, device)
128
+ input_ids = inputs["input_ids"]
 
 
 
 
 
129
 
130
+ #%%
131
+ num_samples = 5
132
+ start_time = time.time()
133
+ outputs = generate_outputs(model, inputs, num_samples)
134
+ end_time = time.time()
135
+ print(f"Total time taken for replacements: {end_time - start_time:.4f} seconds")
136
 
137
+ #%%
138
+ replacements = extract_replacements(outputs, tokenizer, input_ids.shape[0], input_ids.shape[1], num_samples)
 
 
139
 
140
+ #%%
141
+ for word, replacements in zip(low_prob_words, replacements):
142
+ print(f"Original word: {word.text}, Log Probability: {word.logprob:.4f}")
143
+ print(f"Proposed replacements: {replacements}")
main.py CHANGED
@@ -2,22 +2,16 @@ from fastapi import FastAPI
2
  from fastapi.staticfiles import StaticFiles
3
  from pydantic import BaseModel
4
 
5
- class Word(BaseModel):
6
- word: str
7
- start: int
8
- end: int
9
- logprob: float
10
- suggestions: list[str]
11
-
12
- class CheckResponse(BaseModel):
13
- text: str
14
- words: list[Word]
15
 
16
  app = FastAPI()
17
 
 
 
18
  @app.get("/check", response_model=CheckResponse)
19
  def check(text: str):
20
- return CheckResponse(text=text, words=[])
21
 
22
  # serve files from frontend/public
23
  app.mount("/", StaticFiles(directory="frontend/public", html=True))
 
2
  from fastapi.staticfiles import StaticFiles
3
  from pydantic import BaseModel
4
 
5
+ from models import CheckResponse, ApiWord
6
+ from completions import check_text, load_model
 
 
 
 
 
 
 
 
7
 
8
  app = FastAPI()
9
 
10
+ model, tokenizer, device = load_model()
11
+
12
  @app.get("/check", response_model=CheckResponse)
13
  def check(text: str):
14
+ return CheckResponse(text=text, words=check_text(text, model, tokenizer, device))
15
 
16
  # serve files from frontend/public
17
  app.mount("/", StaticFiles(directory="frontend/public", html=True))