Remove custom stoppingcriteria and trust generate
Browse files
app.py
CHANGED
@@ -6,15 +6,9 @@ import spaces
|
|
6 |
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
|
7 |
import threading
|
8 |
|
9 |
-
from transformers import TextIteratorStreamer
|
10 |
import threading
|
11 |
|
12 |
-
class StopOnEos(StoppingCriteria):
|
13 |
-
def __init__(self, eos_token_id):
|
14 |
-
self.eos_token_id = eos_token_id
|
15 |
-
|
16 |
-
def __call__(self, input_ids, scores, **kwargs):
|
17 |
-
return input_ids[0, -1].item() == self.eos_token_id
|
18 |
|
19 |
@spaces.GPU
|
20 |
def chat_with_model(messages):
|
@@ -33,14 +27,12 @@ def chat_with_model(messages):
|
|
33 |
inputs = current_tokenizer(prompt, return_tensors="pt").to(current_model.device)
|
34 |
|
35 |
streamer = TextIteratorStreamer(current_tokenizer, skip_prompt=True, skip_special_tokens=False)
|
36 |
-
stopping_criteria = StoppingCriteriaList([StopOnEos(current_tokenizer.eos_token_id)])
|
37 |
|
38 |
generation_kwargs = dict(
|
39 |
**inputs,
|
40 |
max_new_tokens=256,
|
41 |
do_sample=True,
|
42 |
streamer=streamer,
|
43 |
-
stopping_criteria=stopping_criteria,
|
44 |
eos_token_id=current_tokenizer.eos_token_id,
|
45 |
pad_token_id=pad_id
|
46 |
)
|
|
|
6 |
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
|
7 |
import threading
|
8 |
|
9 |
+
from transformers import TextIteratorStreamer
|
10 |
import threading
|
11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
@spaces.GPU
|
14 |
def chat_with_model(messages):
|
|
|
27 |
inputs = current_tokenizer(prompt, return_tensors="pt").to(current_model.device)
|
28 |
|
29 |
streamer = TextIteratorStreamer(current_tokenizer, skip_prompt=True, skip_special_tokens=False)
|
|
|
30 |
|
31 |
generation_kwargs = dict(
|
32 |
**inputs,
|
33 |
max_new_tokens=256,
|
34 |
do_sample=True,
|
35 |
streamer=streamer,
|
|
|
36 |
eos_token_id=current_tokenizer.eos_token_id,
|
37 |
pad_token_id=pad_id
|
38 |
)
|