import os from threading import Thread from typing import Iterator import gradio as gr import spaces import torch import json from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer DESCRIPTION = """\ Shakti LLMs (Large Language Models) are a group of compact language models specifically optimized for resource-constrained environments such as edge devices, including smartphones, wearables, and IoT (Internet of Things) systems. These models provide support for vernacular languages and domain-specific tasks, making them particularly suitable for industries such as healthcare, finance, and customer service. For more details, please check [here](https://arxiv.org/pdf/2410.11331v1) """ # """\ # Shakti LLMs are a group of small language model specifically optimized for resource-constrained environments such as edge devices, including smartphones, wearables, and IoT systems. With support for vernacular languages and domain-specific tasks, Shakti excels in industries such as healthcare, finance, and customer service. # For more details, please check [here](https://arxiv.org/pdf/2410.11331v1). # """ # Custom CSS for the send button CUSTOM_CSS = """ .send-btn { padding: 0.5rem !important; width: 55px !important; height: 55px !important; border-radius: 50% !important; margin-top: 1rem; cursor: pointer; } .send-btn svg { width: 20px !important; height: 20px !important; position: absolute; top: 50%; left: 50%; transform: translate(-50%, -50%); } """ MAX_MAX_NEW_TOKENS = 2048 DEFAULT_MAX_NEW_TOKENS = 1024 MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "2048")) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # Model configurations model_options = { "Shakti-100M": "SandLogicTechnologies/Shakti-100M", "Shakti-250M": "SandLogicTechnologies/Shakti-250M", "Shakti-2.5B": "SandLogicTechnologies/Shakti-2.5B" } # Initialize tokenizer and model variables tokenizer = None model = None current_model = "Shakti-2.5B" # Keep track of current model def load_model(selected_model: str): global tokenizer, model, current_model model_id = model_options[selected_model] tokenizer = AutoTokenizer.from_pretrained(model_id, token=os.getenv("SHAKTI")) model = AutoModelForCausalLM.from_pretrained( model_id, device_map="auto", torch_dtype=torch.bfloat16, token=os.getenv("SHAKTI") ) model.eval() print("Selected Model: ", selected_model) current_model = selected_model # Initial model load load_model("Shakti-2.5B") def generate( message: str, chat_history: list[tuple[str, str]], max_new_tokens: int = 1024, temperature: float = 0.6, top_p: float = 0.9, top_k: int = 50, repetition_penalty: float = 1.2, ) -> Iterator[str]: conversation = [] if current_model == "Shakti-2.5B": for user, assistant in chat_history: conversation.extend([ json.loads(os.getenv("PROMPT")), {"role": "user", "content": user}, {"role": "assistant", "content": assistant}, ]) else: for user, assistant in chat_history: conversation.extend([ {"role": "user", "content": user}, {"role": "assistant", "content": assistant}, ]) conversation.append({"role": "user", "content": message}) input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt") if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.") input_ids = input_ids.to(model.device) streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True) generate_kwargs = dict( {"input_ids": input_ids}, streamer=streamer, max_new_tokens=max_new_tokens, do_sample=True, top_p=top_p, top_k=top_k, temperature=temperature, num_beams=1, repetition_penalty=repetition_penalty, ) t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() outputs = [] for text in streamer: outputs.append(text) yield "".join(outputs) def respond(message, chat_history, max_new_tokens, temperature): bot_message = "" for chunk in generate(message, chat_history, max_new_tokens, temperature): bot_message += chunk chat_history.append((message, bot_message)) return "", chat_history def get_examples(selected_model): examples = { "Shakti-100M": [ ["Tell me a story"], ["Write a short poem on Rose"], ["What are computers"] ], "Shakti-250M": [ ["Can you explain the pathophysiology of hypertension and its impact on the cardiovascular system?"], ["What are the potential side effects of beta-blockers in the treatment of arrhythmias?"], ["What foods are good for boosting the immune system?"], ["What is the difference between a stock and a bond?"], ["How can I start saving for retirement?"], ["What are some low-risk investment options?"] ], "Shakti-2.5B": [ ["Tell me a story"], ["write a short poem which is hard to sing"], ['मुझे भारतीय इतिहास के बारे में बताएं'] ] } return examples.get(selected_model, []) def on_model_select(selected_model): load_model(selected_model) # Load the selected model # Return the message and chat history updates return gr.update(value=""), gr.update(value=[]) # Clear message and chat history def update_examples_visibility(selected_model): # Return individual updates for each example section return ( gr.update(visible=selected_model == "Shakti-100M"), gr.update(visible=selected_model == "Shakti-250M"), gr.update(visible=selected_model == "Shakti-2.5B") ) def example_selector(example): return example with gr.Blocks(css=CUSTOM_CSS) as demo: gr.Markdown(DESCRIPTION) with gr.Row(): model_dropdown = gr.Dropdown( label="Select Model", choices=list(model_options.keys()), value="Shakti-2.5B", interactive=True ) chatbot = gr.Chatbot() with gr.Row(): with gr.Column(scale=20): msg = gr.Textbox( label="Message", placeholder="Enter your message here", lines=2, show_label=False ) with gr.Column(scale=1, min_width=50): send_btn = gr.Button( value="➤", variant="primary", elem_classes=["send-btn"] ) with gr.Accordion("Parameters", open=False): max_tokens_slider = gr.Slider( label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS, ) temperature_slider = gr.Slider( label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6, ) # Add submit action handlers submit_click = send_btn.click( respond, inputs=[msg, chatbot, max_tokens_slider, temperature_slider], outputs=[msg, chatbot] ) submit_enter = msg.submit( respond, inputs=[msg, chatbot, max_tokens_slider, temperature_slider], outputs=[msg, chatbot] ) # Create separate example sections for each model with gr.Row(): with gr.Column(visible=False) as examples_100m: gr.Examples( examples=get_examples("Shakti-100M"), inputs=msg, label="Example prompts for Shakti-100M", fn=example_selector ) with gr.Column(visible=False) as examples_250m: gr.Examples( examples=get_examples("Shakti-250M"), inputs=msg, label="Example prompts for Shakti-250M", fn=example_selector ) with gr.Column(visible=True) as examples_2_5b: gr.Examples( examples=get_examples("Shakti-2.5B"), inputs=msg, label="Example prompts for Shakti-2.5B", fn=example_selector ) # Update model selection and examples visibility def combined_update(selected_model): msg_update, chat_update = on_model_select(selected_model) examples_100m_update, examples_250m_update, examples_2_5b_update = update_examples_visibility( selected_model) return [ msg_update, chat_update, examples_100m_update, examples_250m_update, examples_2_5b_update ] # Updated change event handler model_dropdown.change( combined_update, inputs=[model_dropdown], outputs=[ msg, chatbot, examples_100m, examples_250m, examples_2_5b ] ) if __name__ == "__main__": demo.queue(max_size=20).launch()