Sartc commited on
Commit
6306cf0
·
verified ·
1 Parent(s): 64d28a7

trying to fix float error

Browse files
Files changed (1) hide show
  1. app.py +12 -1
app.py CHANGED
@@ -39,7 +39,18 @@ def load_model(repo_id, local_file):
39
  model.eval()
40
  return model
41
 
42
- def generate(model, prompt, max_tokens, temperature=0.7):
 
 
 
 
 
 
 
 
 
 
 
43
  for _ in range(max_tokens):
44
  prompt = prompt[:, :config.context_len]
45
  logits = model(prompt)
 
39
  model.eval()
40
  return model
41
 
42
+ # def generate(model, prompt, max_tokens, temperature=0.7):
43
+ # for _ in range(max_tokens):
44
+ # prompt = prompt[:, :config.context_len]
45
+ # logits = model(prompt)
46
+ # logits = logits[:, -1, :] / temperature
47
+ # logit_probs = nn.functional.softmax(logits, dim=-1)
48
+ # next_prompt = torch.multinomial(logit_probs, num_samples=1)
49
+ # prompt = torch.cat((prompt, next_prompt), dim=1)
50
+ # return prompt
51
+
52
+ def generate(model, input_ids, max_tokens, temperature=0.7):
53
+ prompt = input_ids
54
  for _ in range(max_tokens):
55
  prompt = prompt[:, :config.context_len]
56
  logits = model(prompt)