|
import gradio as gr |
|
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer |
|
from functools import lru_cache |
|
|
|
|
|
MODELS = { |
|
"SmolLM2-135M-Instruct": "HuggingFaceTB/SmolLM2-135M-Instruct", |
|
"GPT-2 (Small)": "gpt2", |
|
"DistilGPT-2": "distilgpt2", |
|
"Facebook OPT-125M": "facebook/opt-125m" |
|
} |
|
|
|
|
|
@lru_cache(maxsize=1) |
|
def load_model_cached(model_name): |
|
try: |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForCausalLM.from_pretrained(model_name) |
|
return pipeline("text-generation", model=model, tokenizer=tokenizer) |
|
except Exception as e: |
|
return f"Error loading model: {str(e)}" |
|
|
|
|
|
def chat(selected_model, user_input, chat_history, system_prompt=""): |
|
if not selected_model: |
|
return "Please select a model from the dropdown.", chat_history |
|
|
|
|
|
model_name = MODELS.get(selected_model) |
|
if not model_name: |
|
return "Invalid model selected.", chat_history |
|
|
|
|
|
generator = load_model_cached(model_name) |
|
if isinstance(generator, str): |
|
return generator, chat_history |
|
|
|
|
|
full_input = f"{system_prompt}\n\n{user_input}" if system_prompt else user_input |
|
|
|
|
|
try: |
|
|
|
max_context_length = generator.model.config.max_position_embeddings |
|
max_length = min(500, max_context_length) |
|
|
|
|
|
inputs = generator.tokenizer( |
|
full_input, |
|
return_tensors="pt", |
|
max_length=max_length, |
|
truncation=True |
|
) |
|
|
|
|
|
response = generator( |
|
inputs['input_ids'], |
|
max_length=max_length, |
|
num_return_sequences=1, |
|
do_sample=True, |
|
top_p=0.95, |
|
top_k=60 |
|
)[0]['generated_text'] |
|
|
|
|
|
chat_history.append((user_input, response)) |
|
return "", chat_history |
|
except Exception as e: |
|
return f"Error generating response: {str(e)}", chat_history |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# Chat with Small Language Models") |
|
|
|
with gr.Row(): |
|
selected_model = gr.Dropdown( |
|
label="Select a Model", |
|
choices=list(MODELS.keys()), |
|
value="SmolLM2-135M-Instruct" |
|
) |
|
|
|
chatbot = gr.Chatbot(label="Chat") |
|
user_input = gr.Textbox(label="Your Message", placeholder="Type your message here...") |
|
system_prompt = gr.Textbox( |
|
label="System Prompt (Optional)", |
|
placeholder="e.g., You are a helpful AI assistant.", |
|
lines=2 |
|
) |
|
clear_button = gr.Button("Clear Chat") |
|
|
|
|
|
user_input.submit(chat, [selected_model, user_input, chatbot, system_prompt], [user_input, chatbot]) |
|
clear_button.click(lambda: [], None, chatbot, queue=False) |
|
|
|
|
|
demo.launch() |