jchwenger commited on
Commit
62100c8
·
1 Parent(s): 3de15d6
Files changed (1) hide show
  1. app.py +102 -0
app.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from the Gradio tutorials:
2
+ # https://www.gradio.app/guides/creating-a-chatbot-fast#example-using-a-local-open-source-llm-with-hugging-face
3
+
4
+ import gradio as gr
5
+
6
+ import torch
7
+
8
+ # Get cpu, gpu or mps device for training.
9
+ # See: https://pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html#creating-models
10
+ device = (
11
+ "cuda"
12
+ if torch.cuda.is_available()
13
+ else "mps"
14
+ if torch.backends.mps.is_available()
15
+ else "cpu"
16
+ )
17
+
18
+ from transformers import AutoTokenizer
19
+ from transformers import AutoModelForCausalLM
20
+ from transformers import StoppingCriteria
21
+ from transformers import StoppingCriteriaList
22
+ from transformers import TextIteratorStreamer
23
+
24
+ from threading import Thread
25
+
26
+ MODEL_ID = "togethercomputer/RedPajama-INCITE-Chat-3B-v1"
27
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
28
+ model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype=torch.float16)
29
+ model = model.to(device) # move model to GPU
30
+
31
+ class StopOnTokens(StoppingCriteria):
32
+ """
33
+ Class used `stopping_criteria` in `generate_kwargs` that provides an additional
34
+ way of stopping the generation loop (if this class returns `True` on a token,
35
+ the generation is stopped)).
36
+ """
37
+ # note: Python now supports type hints, see this: https://realpython.com/lessons/type-hinting/
38
+ # (for the **kwargs see also: https://realpython.com/python-kwargs-and-args/)
39
+ # this could also be written: def __call__(self, input_ids, scores, **kwargs):
40
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
41
+ stop_ids = [29, 0] # see the cell below to understand where these come from
42
+ for stop_id in stop_ids:
43
+ if input_ids[0][-1] == stop_id:
44
+ return True
45
+ return False
46
+
47
+ def predict(message, history):
48
+
49
+ history_transformer_format = history + [[message, ""]]
50
+ stop = StopOnTokens()
51
+
52
+ # useful to debug
53
+ # msg = "history"
54
+ # print(msg)
55
+ # print(*history_transformer_format, sep="\n")
56
+ # print("***")
57
+
58
+ # at each step, we feed the entire history in string format,
59
+ # restoring the format used in their dataset with new lines
60
+ # and <human>: or <bot>: added before the messages
61
+ messages = "".join(
62
+ ["".join(
63
+ ["\n<human>:"+item[0], "\n<bot>:"+item[1]]
64
+ )
65
+ for item in history_transformer_format]
66
+ )
67
+ # # to see what we feed to our net:
68
+ # msg = "string prompt"
69
+ # print(msg)
70
+ # print("-" * len(msg))
71
+ # print(messages)
72
+ # print("-" * 40)
73
+
74
+ # convert the string into tensors & move to GPU
75
+ model_inputs = tokenizer([messages], return_tensors="pt").to("cuda")
76
+
77
+ streamer = TextIteratorStreamer(tokenizer, timeout=30., skip_prompt=True, skip_special_tokens=True)
78
+ generate_kwargs = dict(
79
+ model_inputs,
80
+ streamer=streamer,
81
+ max_new_tokens=1024,
82
+ do_sample=True,
83
+ top_p=0.95,
84
+ top_k=1000,
85
+ temperature=1.0,
86
+ pad_token_id=tokenizer.eos_token_id, # mute annoying warning: https://stackoverflow.com/a/71397707
87
+ num_beams=1, # this is for beam search (disabled), see: https://huggingface.co/blog/how-to-generate#beam-search
88
+ stopping_criteria=StoppingCriteriaList([stop])
89
+ )
90
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
91
+ t.start()
92
+
93
+ partial_message = ""
94
+ for new_token in streamer:
95
+ # seen the format <human>: and \n<bot> above (when 'messages' is defined)?
96
+ # we stream the message *until* we encounter '<', which is by the end
97
+ if new_token != '<':
98
+ partial_message += new_token
99
+ yield partial_message
100
+
101
+
102
+ gr.ChatInterface(predict).queue().launch(debug=True)