Ruurd commited on
Commit
ae0ca85
·
1 Parent(s): 6df6769

Remove custom stoppingcriteria and trust generate

Browse files
Files changed (1) hide show
  1. app.py +1 -9
app.py CHANGED
@@ -6,15 +6,9 @@ import spaces
6
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
7
  import threading
8
 
9
- from transformers import TextIteratorStreamer, StoppingCriteria, StoppingCriteriaList
10
  import threading
11
 
12
- class StopOnEos(StoppingCriteria):
13
- def __init__(self, eos_token_id):
14
- self.eos_token_id = eos_token_id
15
-
16
- def __call__(self, input_ids, scores, **kwargs):
17
- return input_ids[0, -1].item() == self.eos_token_id
18
 
19
  @spaces.GPU
20
  def chat_with_model(messages):
@@ -33,14 +27,12 @@ def chat_with_model(messages):
33
  inputs = current_tokenizer(prompt, return_tensors="pt").to(current_model.device)
34
 
35
  streamer = TextIteratorStreamer(current_tokenizer, skip_prompt=True, skip_special_tokens=False)
36
- stopping_criteria = StoppingCriteriaList([StopOnEos(current_tokenizer.eos_token_id)])
37
 
38
  generation_kwargs = dict(
39
  **inputs,
40
  max_new_tokens=256,
41
  do_sample=True,
42
  streamer=streamer,
43
- stopping_criteria=stopping_criteria,
44
  eos_token_id=current_tokenizer.eos_token_id,
45
  pad_token_id=pad_id
46
  )
 
6
  from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
7
  import threading
8
 
9
+ from transformers import TextIteratorStreamer
10
  import threading
11
 
 
 
 
 
 
 
12
 
13
  @spaces.GPU
14
  def chat_with_model(messages):
 
27
  inputs = current_tokenizer(prompt, return_tensors="pt").to(current_model.device)
28
 
29
  streamer = TextIteratorStreamer(current_tokenizer, skip_prompt=True, skip_special_tokens=False)
 
30
 
31
  generation_kwargs = dict(
32
  **inputs,
33
  max_new_tokens=256,
34
  do_sample=True,
35
  streamer=streamer,
 
36
  eos_token_id=current_tokenizer.eos_token_id,
37
  pad_token_id=pad_id
38
  )