SandLogicTechnologies commited on
Commit
2e33ec7
·
verified ·
1 Parent(s): e997b79

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +191 -74
app.py CHANGED
@@ -1,7 +1,6 @@
1
  import os
2
  from threading import Thread
3
  from typing import Iterator
4
-
5
  import gradio as gr
6
  import spaces
7
  import torch
@@ -9,8 +8,36 @@ import json
9
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
10
 
11
  DESCRIPTION = """\
12
- Shakti is a 2.5 billion parameter language model specifically optimized for resource-constrained environments such as edge devices, including smartphones, wearables, and IoT systems. With support for vernacular languages and domain-specific tasks, Shakti excels in industries such as healthcare, finance, and customer service.
13
- For more details, please check [here](https://arxiv.org/pdf/2410.11331v1).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  """
15
 
16
  MAX_MAX_NEW_TOKENS = 2048
@@ -43,14 +70,14 @@ def load_model(selected_model: str):
43
  token=os.getenv("SHAKTI")
44
  )
45
  model.eval()
46
- current_model = selected_model # Update the current model
 
47
 
48
 
49
- # Initial model load (default to 2.5B)
50
  load_model("Shakti-2.5B")
51
 
52
 
53
- @spaces.GPU(duration=90)
54
  def generate(
55
  message: str,
56
  chat_history: list[tuple[str, str]],
@@ -62,24 +89,19 @@ def generate(
62
  ) -> Iterator[str]:
63
  conversation = []
64
 
65
- # Conditional logic for adding prompt based on model
66
  if current_model == "Shakti-2.5B":
67
  for user, assistant in chat_history:
68
- conversation.extend(
69
- [
70
- json.loads(os.getenv("PROMPT")),
71
- {"role": "user", "content": user},
72
- {"role": "assistant", "content": assistant},
73
- ]
74
- )
75
  else:
76
  for user, assistant in chat_history:
77
- conversation.extend(
78
- [
79
- {"role": "user", "content": user},
80
- {"role": "assistant", "content": assistant},
81
- ]
82
- )
83
 
84
  conversation.append({"role": "user", "content": message})
85
 
@@ -110,72 +132,167 @@ def generate(
110
  yield "".join(outputs)
111
 
112
 
113
- def update_examples(selected_model):
114
- if selected_model == "Shakti-100M":
115
- return [["Tell me a story"],
116
- ["Write a short poem on Rose"],
117
- ["What are computers"]]
118
- elif selected_model == "Shakti-250M":
119
- return [["Can you explain the pathophysiology of hypertension and its impact on the cardiovascular system?"],
120
- ["What are the potential side effects of beta-blockers in the treatment of arrhythmias?"],
121
- ["What foods are good for boosting the immune system?"],
122
- ["What is the difference between a stock and a bond?"],
123
- ["How can I start saving for retirement?"],
124
- ["What are some low-risk investment options?"],
125
- ["What is a power of attorney and when is it used?"],
126
- ["What are the key differences between a will and a trust?"],
127
- ["How do I legally protect my business name?"]]
128
- else:
129
- return [["Tell me a story"], ["write a short poem which is hard to sing"],
130
- ['मुझे भारतीय इतिहास के बारे में बताएं']]
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
 
133
  def on_model_select(selected_model):
134
  load_model(selected_model) # Load the selected model
135
- examples = update_examples(selected_model) # Update examples
136
- return gr.update(examples=examples), gr.update(value=[]) # Clear the chat space and update examples
 
 
 
 
 
 
 
 
 
137
 
138
 
139
- chat_history = gr.Chatbot()
 
140
 
141
- with gr.Blocks(css="style.css", fill_height=True) as demo:
 
142
  gr.Markdown(DESCRIPTION)
143
- gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
144
-
145
- # Dropdown for model selection
146
- model_dropdown = gr.Dropdown(
147
- label="Select Model",
148
- choices=["Shakti-100M", "Shakti-250M", "Shakti-2.5B"],
149
- value="Shakti-2.5B",
150
- interactive=True,
151
- )
152
 
153
- # Create the interface with dynamic inputs and chat history
154
- max_tokens_slider = gr.Slider(
155
- label="Max new tokens",
156
- minimum=1,
157
- maximum=MAX_MAX_NEW_TOKENS,
158
- step=1,
159
- value=DEFAULT_MAX_NEW_TOKENS,
160
- )
161
 
162
- temperature_slider = gr.Slider(
163
- label="Temperature",
164
- minimum=0.1,
165
- maximum=4.0,
166
- step=0.1,
167
- value=0.6,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  )
169
 
170
- chat_interface = gr.Interface(
171
- fn=generate,
172
- inputs=[gr.Textbox(lines=2, placeholder="Enter your message here"), chat_history, max_tokens_slider,
173
- temperature_slider],
174
- outputs=chat_history,
175
- live=True,
176
  )
177
 
178
- # Function to handle model change and update examples dynamically
179
- model_dropdown.change(on_model_select, inputs=model_dropdown, outputs=[chat_interface, chat_history])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
 
181
- demo.queue(max_size=20).launch()
 
 
1
  import os
2
  from threading import Thread
3
  from typing import Iterator
 
4
  import gradio as gr
5
  import spaces
6
  import torch
 
8
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
9
 
10
  DESCRIPTION = """\
11
+ Shakti LLMs (Large Language Models) are a group of compact language models specifically optimized for resource-constrained environments such as edge devices, including smartphones, wearables, and IoT (Internet of Things) systems. These models provide support for vernacular languages and domain-specific tasks, making them particularly suitable for industries such as healthcare, finance, and customer service.
12
+ For more details, please check [here](https://arxiv.org/pdf/2410.11331v1)
13
+ """
14
+
15
+
16
+ # """\
17
+ # Shakti LLMs are a group of small language model specifically optimized for resource-constrained environments such as edge devices, including smartphones, wearables, and IoT systems. With support for vernacular languages and domain-specific tasks, Shakti excels in industries such as healthcare, finance, and customer service.
18
+ # For more details, please check [here](https://arxiv.org/pdf/2410.11331v1).
19
+ # """
20
+
21
+
22
+ # Custom CSS for the send button
23
+ CUSTOM_CSS = """
24
+ .send-btn {
25
+ padding: 0.5rem !important;
26
+ width: 55px !important;
27
+ height: 55px !important;
28
+ border-radius: 50% !important;
29
+ margin-top: 1rem;
30
+ cursor: pointer;
31
+ }
32
+
33
+ .send-btn svg {
34
+ width: 20px !important;
35
+ height: 20px !important;
36
+ position: absolute;
37
+ top: 50%;
38
+ left: 50%;
39
+ transform: translate(-50%, -50%);
40
+ }
41
  """
42
 
43
  MAX_MAX_NEW_TOKENS = 2048
 
70
  token=os.getenv("SHAKTI")
71
  )
72
  model.eval()
73
+ print("Selected Model: ", selected_model)
74
+ current_model = selected_model
75
 
76
 
77
+ # Initial model load
78
  load_model("Shakti-2.5B")
79
 
80
 
 
81
  def generate(
82
  message: str,
83
  chat_history: list[tuple[str, str]],
 
89
  ) -> Iterator[str]:
90
  conversation = []
91
 
 
92
  if current_model == "Shakti-2.5B":
93
  for user, assistant in chat_history:
94
+ conversation.extend([
95
+ json.loads(os.getenv("PROMPT")),
96
+ {"role": "user", "content": user},
97
+ {"role": "assistant", "content": assistant},
98
+ ])
 
 
99
  else:
100
  for user, assistant in chat_history:
101
+ conversation.extend([
102
+ {"role": "user", "content": user},
103
+ {"role": "assistant", "content": assistant},
104
+ ])
 
 
105
 
106
  conversation.append({"role": "user", "content": message})
107
 
 
132
  yield "".join(outputs)
133
 
134
 
135
+ def respond(message, chat_history, max_new_tokens, temperature):
136
+ bot_message = ""
137
+ for chunk in generate(message, chat_history, max_new_tokens, temperature):
138
+ bot_message += chunk
139
+ chat_history.append((message, bot_message))
140
+ return "", chat_history
141
+
142
+
143
+ def get_examples(selected_model):
144
+ examples = {
145
+ "Shakti-100M": [
146
+ ["Tell me a story"],
147
+ ["Write a short poem on Rose"],
148
+ ["What are computers"]
149
+ ],
150
+ "Shakti-250M": [
151
+ ["Can you explain the pathophysiology of hypertension and its impact on the cardiovascular system?"],
152
+ ["What are the potential side effects of beta-blockers in the treatment of arrhythmias?"],
153
+ ["What foods are good for boosting the immune system?"],
154
+ ["What is the difference between a stock and a bond?"],
155
+ ["How can I start saving for retirement?"],
156
+ ["What are some low-risk investment options?"]
157
+ ],
158
+ "Shakti-2.5B": [
159
+ ["Tell me a story"],
160
+ ["write a short poem which is hard to sing"],
161
+ ['मुझे भारतीय इतिहास के बारे में बताएं']
162
+ ]
163
+ }
164
+ return examples.get(selected_model, [])
165
 
166
 
167
  def on_model_select(selected_model):
168
  load_model(selected_model) # Load the selected model
169
+ # Return the message and chat history updates
170
+ return gr.update(value=""), gr.update(value=[]) # Clear message and chat history
171
+
172
+
173
+ def update_examples_visibility(selected_model):
174
+ # Return individual updates for each example section
175
+ return (
176
+ gr.update(visible=selected_model == "Shakti-100M"),
177
+ gr.update(visible=selected_model == "Shakti-250M"),
178
+ gr.update(visible=selected_model == "Shakti-2.5B")
179
+ )
180
 
181
 
182
+ def example_selector(example):
183
+ return example
184
 
185
+
186
+ with gr.Blocks(css=CUSTOM_CSS) as demo:
187
  gr.Markdown(DESCRIPTION)
 
 
 
 
 
 
 
 
 
188
 
189
+ with gr.Row():
190
+ model_dropdown = gr.Dropdown(
191
+ label="Select Model",
192
+ choices=list(model_options.keys()),
193
+ value="Shakti-2.5B",
194
+ interactive=True
195
+ )
 
196
 
197
+ chatbot = gr.Chatbot()
198
+
199
+ with gr.Row():
200
+ with gr.Column(scale=20):
201
+ msg = gr.Textbox(
202
+ label="Message",
203
+ placeholder="Enter your message here",
204
+ lines=2,
205
+ show_label=False
206
+ )
207
+ with gr.Column(scale=1, min_width=50):
208
+ send_btn = gr.Button(
209
+ value="➤",
210
+ variant="primary",
211
+ elem_classes=["send-btn"]
212
+ )
213
+
214
+ with gr.Accordion("Parameters", open=False):
215
+ max_tokens_slider = gr.Slider(
216
+ label="Max new tokens",
217
+ minimum=1,
218
+ maximum=MAX_MAX_NEW_TOKENS,
219
+ step=1,
220
+ value=DEFAULT_MAX_NEW_TOKENS,
221
+ )
222
+ temperature_slider = gr.Slider(
223
+ label="Temperature",
224
+ minimum=0.1,
225
+ maximum=4.0,
226
+ step=0.1,
227
+ value=0.6,
228
+ )
229
+
230
+ # Add submit action handlers
231
+ submit_click = send_btn.click(
232
+ respond,
233
+ inputs=[msg, chatbot, max_tokens_slider, temperature_slider],
234
+ outputs=[msg, chatbot]
235
  )
236
 
237
+ submit_enter = msg.submit(
238
+ respond,
239
+ inputs=[msg, chatbot, max_tokens_slider, temperature_slider],
240
+ outputs=[msg, chatbot]
 
 
241
  )
242
 
243
+ # Create separate example sections for each model
244
+ with gr.Row():
245
+ with gr.Column(visible=False) as examples_100m:
246
+ gr.Examples(
247
+ examples=get_examples("Shakti-100M"),
248
+ inputs=msg,
249
+ label="Example prompts for Shakti-100M",
250
+ fn=example_selector
251
+ )
252
+
253
+ with gr.Column(visible=False) as examples_250m:
254
+ gr.Examples(
255
+ examples=get_examples("Shakti-250M"),
256
+ inputs=msg,
257
+ label="Example prompts for Shakti-250M",
258
+ fn=example_selector
259
+ )
260
+
261
+ with gr.Column(visible=True) as examples_2_5b:
262
+ gr.Examples(
263
+ examples=get_examples("Shakti-2.5B"),
264
+ inputs=msg,
265
+ label="Example prompts for Shakti-2.5B",
266
+ fn=example_selector
267
+ )
268
+
269
+
270
+ # Update model selection and examples visibility
271
+ def combined_update(selected_model):
272
+ msg_update, chat_update = on_model_select(selected_model)
273
+ examples_100m_update, examples_250m_update, examples_2_5b_update = update_examples_visibility(
274
+ selected_model)
275
+ return [
276
+ msg_update,
277
+ chat_update,
278
+ examples_100m_update,
279
+ examples_250m_update,
280
+ examples_2_5b_update
281
+ ]
282
+
283
+
284
+ # Updated change event handler
285
+ model_dropdown.change(
286
+ combined_update,
287
+ inputs=[model_dropdown],
288
+ outputs=[
289
+ msg,
290
+ chatbot,
291
+ examples_100m,
292
+ examples_250m,
293
+ examples_2_5b
294
+ ]
295
+ )
296
 
297
+ if __name__ == "__main__":
298
+ demo.queue(max_size=20).launch()