Papaya-Voldemort's picture
Update app.py
58b67c2 verified
import gradio as gr
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
from functools import lru_cache
# Pre-selected small models
MODELS = {
"SmolLM2-135M-Instruct": "HuggingFaceTB/SmolLM2-135M-Instruct",
"GPT-2 (Small)": "gpt2",
"DistilGPT-2": "distilgpt2",
"Facebook OPT-125M": "facebook/opt-125m"
}
# Cache the model and tokenizer to avoid reloading
@lru_cache(maxsize=1)
def load_model_cached(model_name):
try:
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
return pipeline("text-generation", model=model, tokenizer=tokenizer)
except Exception as e:
return f"Error loading model: {str(e)}"
# Function to generate a response from the model
def chat(selected_model, user_input, chat_history, system_prompt=""):
if not selected_model:
return "Please select a model from the dropdown.", chat_history
# Get the model name from the dropdown
model_name = MODELS.get(selected_model)
if not model_name:
return "Invalid model selected.", chat_history
# Load the model (cached)
generator = load_model_cached(model_name)
if isinstance(generator, str): # If there was an error loading the model
return generator, chat_history
# Prepare the input with an optional system prompt
full_input = f"{system_prompt}\n\n{user_input}" if system_prompt else user_input
# Generate a response
try:
# Get the model's maximum context length
max_context_length = generator.model.config.max_position_embeddings
max_length = min(500, max_context_length) # Ensure we don't exceed the model's limit
# Truncate the input if it's too long
inputs = generator.tokenizer(
full_input,
return_tensors="pt",
max_length=max_length,
truncation=True
)
# Generate the response
response = generator(
inputs['input_ids'],
max_length=max_length,
num_return_sequences=1,
do_sample=True,
top_p=0.95,
top_k=60
)[0]['generated_text']
# Append the interaction to the chat history
chat_history.append((user_input, response))
return "", chat_history
except Exception as e:
return f"Error generating response: {str(e)}", chat_history
# Gradio interface
with gr.Blocks() as demo:
gr.Markdown("# Chat with Small Language Models")
with gr.Row():
selected_model = gr.Dropdown(
label="Select a Model",
choices=list(MODELS.keys()),
value="SmolLM2-135M-Instruct" # Default model
)
chatbot = gr.Chatbot(label="Chat")
user_input = gr.Textbox(label="Your Message", placeholder="Type your message here...")
system_prompt = gr.Textbox(
label="System Prompt (Optional)",
placeholder="e.g., You are a helpful AI assistant.",
lines=2
)
clear_button = gr.Button("Clear Chat")
# Define the chat function
user_input.submit(chat, [selected_model, user_input, chatbot, system_prompt], [user_input, chatbot])
clear_button.click(lambda: [], None, chatbot, queue=False)
# Launch the app
demo.launch()