Cleanup
Browse files- completions.py +0 -17
- main.py +1 -19
completions.py
CHANGED
@@ -71,23 +71,6 @@ def calculate_log_probabilities(model: PreTrainedModel, tokenizer: Tokenizer, in
|
|
71 |
tokens: torch.Tensor = input_ids[0][1:]
|
72 |
return list(zip(tokens.tolist(), token_log_probs.tolist()))
|
73 |
|
74 |
-
def generate_outputs(model: PreTrainedModel, inputs: BatchEncoding, num_samples: int = 5) -> GenerateOutput | torch.LongTensor:
|
75 |
-
input_ids = inputs["input_ids"]
|
76 |
-
attention_mask = inputs["attention_mask"]
|
77 |
-
with torch.no_grad():
|
78 |
-
outputs = model.generate(
|
79 |
-
input_ids=input_ids,
|
80 |
-
attention_mask=attention_mask,
|
81 |
-
max_new_tokens=4,
|
82 |
-
num_return_sequences=num_samples,
|
83 |
-
temperature=1.0,
|
84 |
-
top_k=50,
|
85 |
-
top_p=0.95,
|
86 |
-
do_sample=True
|
87 |
-
# num_beams=num_samples
|
88 |
-
)
|
89 |
-
return outputs
|
90 |
-
|
91 |
#%%
|
92 |
|
93 |
def load_model() -> tuple[PreTrainedModel, Tokenizer, torch.device]:
|
|
|
71 |
tokens: torch.Tensor = input_ids[0][1:]
|
72 |
return list(zip(tokens.tolist(), token_log_probs.tolist()))
|
73 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
#%%
|
75 |
|
76 |
def load_model() -> tuple[PreTrainedModel, Tokenizer, torch.device]:
|
main.py
CHANGED
@@ -2,29 +2,13 @@ from fastapi import FastAPI
|
|
2 |
from fastapi.staticfiles import StaticFiles
|
3 |
from functools import lru_cache
|
4 |
|
5 |
-
from models import
|
6 |
from completions import check_text, load_model
|
7 |
|
8 |
app = FastAPI()
|
9 |
|
10 |
model, tokenizer, device = load_model()
|
11 |
|
12 |
-
def check_text_stub(text: str):
|
13 |
-
def rep(i):
|
14 |
-
if i == 3:
|
15 |
-
return -10, [" jumped", " leaps"]
|
16 |
-
if i == 5:
|
17 |
-
return -10, [" calm"]
|
18 |
-
if i == 7:
|
19 |
-
return -10, [" dog", " cat", " bird", " fish"]
|
20 |
-
return -3, []
|
21 |
-
|
22 |
-
result = []
|
23 |
-
for i, w in enumerate(text.split()):
|
24 |
-
logprob, replacements = rep(i)
|
25 |
-
result.append(ApiWord(text=f" {w}", logprob=logprob, replacements=replacements))
|
26 |
-
return result
|
27 |
-
|
28 |
@lru_cache(maxsize=100)
|
29 |
def cached_check_text(text: str):
|
30 |
return check_text(text, model, tokenizer, device)
|
@@ -34,5 +18,3 @@ def check(text: str):
|
|
34 |
return CheckResponse(text=text, words=cached_check_text(text))
|
35 |
|
36 |
app.mount("/", StaticFiles(directory="frontend/public", html=True))
|
37 |
-
|
38 |
-
#%%
|
|
|
2 |
from fastapi.staticfiles import StaticFiles
|
3 |
from functools import lru_cache
|
4 |
|
5 |
+
from models import CheckResponse
|
6 |
from completions import check_text, load_model
|
7 |
|
8 |
app = FastAPI()
|
9 |
|
10 |
model, tokenizer, device = load_model()
|
11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
@lru_cache(maxsize=100)
|
13 |
def cached_check_text(text: str):
|
14 |
return check_text(text, model, tokenizer, device)
|
|
|
18 |
return CheckResponse(text=text, words=cached_check_text(text))
|
19 |
|
20 |
app.mount("/", StaticFiles(directory="frontend/public", html=True))
|
|
|
|