NicholasGuerrero commited on
Commit
59466c6
·
1 Parent(s): 97a5583

llama deepsparse

Browse files
Files changed (1) hide show
  1. app.py +268 -4
app.py CHANGED
@@ -1,7 +1,271 @@
 
1
  import gradio as gr
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import deepsparse
2
  import gradio as gr
3
+ from typing import Tuple, List
4
 
5
+ deepsparse.cpu.print_hardware_capability()
 
6
 
7
+ MODEL_ID = "hf:neuralmagic/Llama-2-7b-pruned70-retrained-ultrachat-quant-ds"
8
+
9
+ DESCRIPTION = f"""
10
+ # Chat with an Efficient Sparse Llama 2 Model on CPU
11
+ This demo showcases a groundbreaking [sparse Llama 2 7B model](https://huggingface.co/neuralmagic/Llama-2-7b-pruned70-retrained-ultrachat-quant-ds) that has been pruned to 70% sparsity, retrained on pretraining data, and then sparse transferred for chat using the UltraChat 200k dataset. By leveraging the power of sparse transfer learning, this model delivers high-quality chat capabilities while significantly reducing computational costs and inference times.
12
+ ### Under the Hood
13
+ - **Sparse Transfer Learning**: The model's pre-sparsified structure enables efficient fine-tuning on new tasks, minimizing the need for extensive hyperparameter tuning and reducing training times.
14
+ - **Accelerated Inference**: Powered by the [DeepSparse CPU inference runtime](https://github.com/neuralmagic/deepsparse), this model takes advantage of its inherent sparsity to provide lightning-fast token generation on CPUs.
15
+ - **Quantization**: 8-bit weight and activation quantization further optimizes the model's performance and memory footprint without compromising quality.
16
+ By combining state-of-the-art sparsity techniques with the robustness of the Llama 2 architecture, this model pushes the boundaries of efficient generation. Experience the future of AI-powered chat, where cutting-edge sparse models deliver exceptional performance on everyday hardware.
17
+ """
18
+
19
+ MAX_MAX_NEW_TOKENS = 1024
20
+ DEFAULT_MAX_NEW_TOKENS = 200
21
+
22
+ # Setup the engine
23
+ from deepsparse.legacy import Pipeline
24
+ pipe = Pipeline.create(
25
+ task="text-generation",
26
+ model_path=MODEL_ID,
27
+ sequence_length=MAX_MAX_NEW_TOKENS,
28
+ prompt_sequence_length=8,
29
+ num_cores=8,
30
+ )
31
+
32
+
33
+ def clear_and_save_textbox(message: str) -> Tuple[str, str]:
34
+ return "", message
35
+
36
+
37
+ def display_input(
38
+ message: str, history: List[Tuple[str, str]]
39
+ ) -> List[Tuple[str, str]]:
40
+ history.append((message, ""))
41
+ return history
42
+
43
+
44
+ def delete_prev_fn(history: List[Tuple[str, str]]) -> Tuple[List[Tuple[str, str]], str]:
45
+ try:
46
+ message, _ = history.pop()
47
+ except IndexError:
48
+ message = ""
49
+ return history, message or ""
50
+
51
+
52
+ with gr.Blocks(css="style.css") as demo:
53
+ gr.Markdown(DESCRIPTION)
54
+
55
+ with gr.Group():
56
+ chatbot = gr.Chatbot(label="Chatbot")
57
+ with gr.Row():
58
+ textbox = gr.Textbox(
59
+ container=False,
60
+ show_label=False,
61
+ placeholder="Type a message...",
62
+ scale=10,
63
+ )
64
+ submit_button = gr.Button("Submit", variant="primary", scale=1, min_width=0)
65
+
66
+ with gr.Row():
67
+ retry_button = gr.Button("🔄 Retry", variant="secondary")
68
+ undo_button = gr.Button("↩️ Undo", variant="secondary")
69
+ clear_button = gr.Button("🗑️ Clear", variant="secondary")
70
+
71
+ saved_input = gr.State()
72
+
73
+ gr.Examples(
74
+ examples=[
75
+ "Write a story about sparse neurons.",
76
+ "Write a story about a summer camp.",
77
+ "Make a recipe for banana bread.",
78
+ "Write a cookbook for gluten-free snacks.",
79
+ "Write about the role of animation in video games."
80
+ ],
81
+ inputs=[textbox],
82
+ )
83
+
84
+ max_new_tokens = gr.Slider(
85
+ label="Max new tokens",
86
+ value=DEFAULT_MAX_NEW_TOKENS,
87
+ minimum=0,
88
+ maximum=MAX_MAX_NEW_TOKENS,
89
+ step=1,
90
+ interactive=True,
91
+ info="The maximum numbers of new tokens",
92
+ )
93
+ temperature = gr.Slider(
94
+ label="Temperature",
95
+ value=0.9,
96
+ minimum=0.05,
97
+ maximum=1.0,
98
+ step=0.05,
99
+ interactive=True,
100
+ info="Higher values produce more diverse outputs",
101
+ )
102
+ top_p = gr.Slider(
103
+ label="Top-p (nucleus) sampling",
104
+ value=0.40,
105
+ minimum=0.0,
106
+ maximum=1,
107
+ step=0.05,
108
+ interactive=True,
109
+ info="Higher values sample more low-probability tokens",
110
+ )
111
+ top_k = gr.Slider(
112
+ label="Top-k sampling",
113
+ value=20,
114
+ minimum=1,
115
+ maximum=100,
116
+ step=1,
117
+ interactive=True,
118
+ info="Sample from the top_k most likely tokens",
119
+ )
120
+ reptition_penalty = gr.Slider(
121
+ label="Repetition penalty",
122
+ value=1.2,
123
+ minimum=1.0,
124
+ maximum=2.0,
125
+ step=0.05,
126
+ interactive=True,
127
+ info="Penalize repeated tokens",
128
+ )
129
+
130
+ # Generation inference
131
+ def generate(
132
+ message,
133
+ history,
134
+ max_new_tokens: int,
135
+ temperature: float,
136
+ top_p: float,
137
+ top_k: int,
138
+ reptition_penalty: float,
139
+ ):
140
+ generation_config = {
141
+ "max_new_tokens": max_new_tokens,
142
+ "do_sample": True,
143
+ "temperature": temperature,
144
+ "top_p": top_p,
145
+ "top_k": top_k,
146
+ "reptition_penalty": reptition_penalty,
147
+ }
148
+
149
+ conversation = []
150
+ conversation.append({"role": "user", "content": message})
151
+
152
+ formatted_conversation = pipe.tokenizer.apply_chat_template(
153
+ conversation, tokenize=False, add_generation_prompt=True
154
+ )
155
+
156
+ inference = pipe(
157
+ sequences=formatted_conversation,
158
+ generation_config=generation_config,
159
+ streaming=True,
160
+ )
161
+
162
+ for token in inference:
163
+ history[-1][1] += token.generations[0].text
164
+ yield history
165
+
166
+ print(pipe.timer_manager)
167
+
168
+ # Hooking up all the buttons
169
+ textbox.submit(
170
+ fn=clear_and_save_textbox,
171
+ inputs=textbox,
172
+ outputs=[textbox, saved_input],
173
+ api_name=False,
174
+ queue=False,
175
+ ).then(
176
+ fn=display_input,
177
+ inputs=[saved_input, chatbot],
178
+ outputs=chatbot,
179
+ api_name=False,
180
+ queue=False,
181
+ ).success(
182
+ generate,
183
+ inputs=[
184
+ saved_input,
185
+ chatbot,
186
+ max_new_tokens,
187
+ temperature,
188
+ top_p,
189
+ top_k,
190
+ reptition_penalty,
191
+ ],
192
+ outputs=[chatbot],
193
+ api_name=False,
194
+ )
195
+
196
+ submit_button.click(
197
+ fn=clear_and_save_textbox,
198
+ inputs=textbox,
199
+ outputs=[textbox, saved_input],
200
+ api_name=False,
201
+ queue=False,
202
+ ).then(
203
+ fn=display_input,
204
+ inputs=[saved_input, chatbot],
205
+ outputs=chatbot,
206
+ api_name=False,
207
+ queue=False,
208
+ ).success(
209
+ generate,
210
+ inputs=[
211
+ saved_input,
212
+ chatbot,
213
+ max_new_tokens,
214
+ temperature,
215
+ top_p,
216
+ top_k,
217
+ reptition_penalty,
218
+ ],
219
+ outputs=[chatbot],
220
+ api_name=False,
221
+ )
222
+
223
+ retry_button.click(
224
+ fn=delete_prev_fn,
225
+ inputs=chatbot,
226
+ outputs=[chatbot, saved_input],
227
+ api_name=False,
228
+ queue=False,
229
+ ).then(
230
+ fn=display_input,
231
+ inputs=[saved_input, chatbot],
232
+ outputs=chatbot,
233
+ api_name=False,
234
+ queue=False,
235
+ ).then(
236
+ generate,
237
+ inputs=[
238
+ saved_input,
239
+ chatbot,
240
+ max_new_tokens,
241
+ temperature,
242
+ top_p,
243
+ top_k,
244
+ reptition_penalty,
245
+ ],
246
+ outputs=[chatbot],
247
+ api_name=False,
248
+ )
249
+
250
+ undo_button.click(
251
+ fn=delete_prev_fn,
252
+ inputs=chatbot,
253
+ outputs=[chatbot, saved_input],
254
+ api_name=False,
255
+ queue=False,
256
+ ).then(
257
+ fn=lambda x: x,
258
+ inputs=[saved_input],
259
+ outputs=textbox,
260
+ api_name=False,
261
+ queue=False,
262
+ )
263
+
264
+ clear_button.click(
265
+ fn=lambda: ([], ""),
266
+ outputs=[chatbot, saved_input],
267
+ queue=False,
268
+ api_name=False,
269
+ )
270
+
271
+ demo.queue().launch(share=True)