mebubo commited on
Commit
4164b83
·
1 Parent(s): 308bca9
Files changed (2) hide show
  1. completions.py +0 -17
  2. main.py +1 -19
completions.py CHANGED
@@ -71,23 +71,6 @@ def calculate_log_probabilities(model: PreTrainedModel, tokenizer: Tokenizer, in
71
  tokens: torch.Tensor = input_ids[0][1:]
72
  return list(zip(tokens.tolist(), token_log_probs.tolist()))
73
 
74
- def generate_outputs(model: PreTrainedModel, inputs: BatchEncoding, num_samples: int = 5) -> GenerateOutput | torch.LongTensor:
75
- input_ids = inputs["input_ids"]
76
- attention_mask = inputs["attention_mask"]
77
- with torch.no_grad():
78
- outputs = model.generate(
79
- input_ids=input_ids,
80
- attention_mask=attention_mask,
81
- max_new_tokens=4,
82
- num_return_sequences=num_samples,
83
- temperature=1.0,
84
- top_k=50,
85
- top_p=0.95,
86
- do_sample=True
87
- # num_beams=num_samples
88
- )
89
- return outputs
90
-
91
  #%%
92
 
93
  def load_model() -> tuple[PreTrainedModel, Tokenizer, torch.device]:
 
71
  tokens: torch.Tensor = input_ids[0][1:]
72
  return list(zip(tokens.tolist(), token_log_probs.tolist()))
73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  #%%
75
 
76
  def load_model() -> tuple[PreTrainedModel, Tokenizer, torch.device]:
main.py CHANGED
@@ -2,29 +2,13 @@ from fastapi import FastAPI
2
  from fastapi.staticfiles import StaticFiles
3
  from functools import lru_cache
4
 
5
- from models import ApiWord, CheckResponse
6
  from completions import check_text, load_model
7
 
8
  app = FastAPI()
9
 
10
  model, tokenizer, device = load_model()
11
 
12
- def check_text_stub(text: str):
13
- def rep(i):
14
- if i == 3:
15
- return -10, [" jumped", " leaps"]
16
- if i == 5:
17
- return -10, [" calm"]
18
- if i == 7:
19
- return -10, [" dog", " cat", " bird", " fish"]
20
- return -3, []
21
-
22
- result = []
23
- for i, w in enumerate(text.split()):
24
- logprob, replacements = rep(i)
25
- result.append(ApiWord(text=f" {w}", logprob=logprob, replacements=replacements))
26
- return result
27
-
28
  @lru_cache(maxsize=100)
29
  def cached_check_text(text: str):
30
  return check_text(text, model, tokenizer, device)
@@ -34,5 +18,3 @@ def check(text: str):
34
  return CheckResponse(text=text, words=cached_check_text(text))
35
 
36
  app.mount("/", StaticFiles(directory="frontend/public", html=True))
37
-
38
- #%%
 
2
  from fastapi.staticfiles import StaticFiles
3
  from functools import lru_cache
4
 
5
+ from models import CheckResponse
6
  from completions import check_text, load_model
7
 
8
  app = FastAPI()
9
 
10
  model, tokenizer, device = load_model()
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  @lru_cache(maxsize=100)
13
  def cached_check_text(text: str):
14
  return check_text(text, model, tokenizer, device)
 
18
  return CheckResponse(text=text, words=cached_check_text(text))
19
 
20
  app.mount("/", StaticFiles(directory="frontend/public", html=True))