Spaces:
Sleeping
Sleeping
import os | |
from threading import Thread | |
from typing import Iterator | |
import gradio as gr | |
import spaces | |
import torch | |
import json | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
DESCRIPTION = """\ | |
Shakti LLMs (Large Language Models) are a group of compact language models specifically optimized for resource-constrained environments such as edge devices, including smartphones, wearables, and IoT (Internet of Things) systems. These models provide support for vernacular languages and domain-specific tasks, making them particularly suitable for industries such as healthcare, finance, and customer service. | |
For more details, please check [here](https://arxiv.org/pdf/2410.11331v1) | |
""" | |
# """\ | |
# Shakti LLMs are a group of small language model specifically optimized for resource-constrained environments such as edge devices, including smartphones, wearables, and IoT systems. With support for vernacular languages and domain-specific tasks, Shakti excels in industries such as healthcare, finance, and customer service. | |
# For more details, please check [here](https://arxiv.org/pdf/2410.11331v1). | |
# """ | |
# Custom CSS for the send button | |
CUSTOM_CSS = """ | |
.send-btn { | |
padding: 0.5rem !important; | |
width: 55px !important; | |
height: 55px !important; | |
border-radius: 50% !important; | |
margin-top: 1rem; | |
cursor: pointer; | |
} | |
.send-btn svg { | |
width: 20px !important; | |
height: 20px !important; | |
position: absolute; | |
top: 50%; | |
left: 50%; | |
transform: translate(-50%, -50%); | |
} | |
""" | |
MAX_MAX_NEW_TOKENS = 2048 | |
DEFAULT_MAX_NEW_TOKENS = 1024 | |
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "2048")) | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
# Model configurations | |
model_options = { | |
"Shakti-100M": "SandLogicTechnologies/Shakti-100M", | |
"Shakti-250M": "SandLogicTechnologies/Shakti-250M", | |
"Shakti-2.5B": "SandLogicTechnologies/Shakti-2.5B" | |
} | |
# Initialize tokenizer and model variables | |
tokenizer = None | |
model = None | |
current_model = "Shakti-2.5B" # Keep track of current model | |
def load_model(selected_model: str): | |
global tokenizer, model, current_model | |
model_id = model_options[selected_model] | |
tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.getenv("SHAKTI")) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
device_map="auto", | |
torch_dtype=torch.bfloat16, | |
token=os.getenv("SHAKTI") | |
) | |
model.eval() | |
print("Selected Model: ", selected_model) | |
current_model = selected_model | |
# Initial model load | |
load_model("Shakti-2.5B") | |
def generate( | |
message: str, | |
chat_history: list[tuple[str, str]], | |
max_new_tokens: int = 1024, | |
temperature: float = 0.6, | |
top_p: float = 0.9, | |
top_k: int = 50, | |
repetition_penalty: float = 1.2, | |
) -> Iterator[str]: | |
conversation = [] | |
if current_model == "Shakti-2.5B": | |
for user, assistant in chat_history: | |
conversation.extend([ | |
json.loads(os.getenv("PROMPT")), | |
{"role": "user", "content": user}, | |
{"role": "assistant", "content": assistant}, | |
]) | |
else: | |
for user, assistant in chat_history: | |
conversation.extend([ | |
{"role": "user", "content": user}, | |
{"role": "assistant", "content": assistant}, | |
]) | |
conversation.append({"role": "user", "content": message}) | |
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt") | |
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: | |
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] | |
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.") | |
input_ids = input_ids.to(model.device) | |
streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True) | |
generate_kwargs = dict( | |
{"input_ids": input_ids}, | |
streamer=streamer, | |
max_new_tokens=max_new_tokens, | |
do_sample=True, | |
top_p=top_p, | |
top_k=top_k, | |
temperature=temperature, | |
num_beams=1, | |
repetition_penalty=repetition_penalty, | |
) | |
t = Thread(target=model.generate, kwargs=generate_kwargs) | |
t.start() | |
outputs = [] | |
for text in streamer: | |
outputs.append(text) | |
yield "".join(outputs) | |
def respond(message, chat_history, max_new_tokens, temperature): | |
bot_message = "" | |
for chunk in generate(message, chat_history, max_new_tokens, temperature): | |
bot_message += chunk | |
chat_history.append((message, bot_message)) | |
return "", chat_history | |
def get_examples(selected_model): | |
examples = { | |
"Shakti-100M": [ | |
["Tell me a story"], | |
["Write a short poem on Rose"], | |
["What are computers"] | |
], | |
"Shakti-250M": [ | |
["Can you explain the pathophysiology of hypertension and its impact on the cardiovascular system?"], | |
["What are the potential side effects of beta-blockers in the treatment of arrhythmias?"], | |
["What foods are good for boosting the immune system?"], | |
["What is the difference between a stock and a bond?"], | |
["How can I start saving for retirement?"], | |
["What are some low-risk investment options?"] | |
], | |
"Shakti-2.5B": [ | |
["Tell me a story"], | |
["write a short poem which is hard to sing"], | |
['मुझे भारतीय इतिहास के बारे में बताएं'] | |
] | |
} | |
return examples.get(selected_model, []) | |
def on_model_select(selected_model): | |
load_model(selected_model) # Load the selected model | |
# Return the message and chat history updates | |
return gr.update(value=""), gr.update(value=[]) # Clear message and chat history | |
def update_examples_visibility(selected_model): | |
# Return individual updates for each example section | |
return ( | |
gr.update(visible=selected_model == "Shakti-100M"), | |
gr.update(visible=selected_model == "Shakti-250M"), | |
gr.update(visible=selected_model == "Shakti-2.5B") | |
) | |
def example_selector(example): | |
return example | |
with gr.Blocks(css=CUSTOM_CSS) as demo: | |
gr.Markdown(DESCRIPTION) | |
with gr.Row(): | |
model_dropdown = gr.Dropdown( | |
label="Select Model", | |
choices=list(model_options.keys()), | |
value="Shakti-2.5B", | |
interactive=True | |
) | |
chatbot = gr.Chatbot() | |
with gr.Row(): | |
with gr.Column(scale=20): | |
msg = gr.Textbox( | |
label="Message", | |
placeholder="Enter your message here", | |
lines=2, | |
show_label=False | |
) | |
with gr.Column(scale=1, min_width=50): | |
send_btn = gr.Button( | |
value="➤", | |
variant="primary", | |
elem_classes=["send-btn"] | |
) | |
with gr.Accordion("Parameters", open=False): | |
max_tokens_slider = gr.Slider( | |
label="Max new tokens", | |
minimum=1, | |
maximum=MAX_MAX_NEW_TOKENS, | |
step=1, | |
value=DEFAULT_MAX_NEW_TOKENS, | |
) | |
temperature_slider = gr.Slider( | |
label="Temperature", | |
minimum=0.1, | |
maximum=4.0, | |
step=0.1, | |
value=0.6, | |
) | |
# Add submit action handlers | |
submit_click = send_btn.click( | |
respond, | |
inputs=[msg, chatbot, max_tokens_slider, temperature_slider], | |
outputs=[msg, chatbot] | |
) | |
submit_enter = msg.submit( | |
respond, | |
inputs=[msg, chatbot, max_tokens_slider, temperature_slider], | |
outputs=[msg, chatbot] | |
) | |
# Create separate example sections for each model | |
with gr.Row(): | |
with gr.Column(visible=False) as examples_100m: | |
gr.Examples( | |
examples=get_examples("Shakti-100M"), | |
inputs=msg, | |
label="Example prompts for Shakti-100M", | |
fn=example_selector | |
) | |
with gr.Column(visible=False) as examples_250m: | |
gr.Examples( | |
examples=get_examples("Shakti-250M"), | |
inputs=msg, | |
label="Example prompts for Shakti-250M", | |
fn=example_selector | |
) | |
with gr.Column(visible=True) as examples_2_5b: | |
gr.Examples( | |
examples=get_examples("Shakti-2.5B"), | |
inputs=msg, | |
label="Example prompts for Shakti-2.5B", | |
fn=example_selector | |
) | |
# Update model selection and examples visibility | |
def combined_update(selected_model): | |
msg_update, chat_update = on_model_select(selected_model) | |
examples_100m_update, examples_250m_update, examples_2_5b_update = update_examples_visibility( | |
selected_model) | |
return [ | |
msg_update, | |
chat_update, | |
examples_100m_update, | |
examples_250m_update, | |
examples_2_5b_update | |
] | |
# Updated change event handler | |
model_dropdown.change( | |
combined_update, | |
inputs=[model_dropdown], | |
outputs=[ | |
msg, | |
chatbot, | |
examples_100m, | |
examples_250m, | |
examples_2_5b | |
] | |
) | |
if __name__ == "__main__": | |
demo.queue(max_size=20).launch() |