File size: 3,304 Bytes
e7a77ba
 
66a12b0
e7a77ba
b0e67d9
 
 
 
 
 
 
 
66a12b0
 
 
e7a77ba
 
 
 
 
 
 
 
b0e67d9
 
 
 
 
 
 
 
e7a77ba
66a12b0
 
e7a77ba
 
 
66a12b0
 
 
e7a77ba
 
66a12b0
 
 
 
 
 
 
 
 
 
 
 
58b67c2
 
 
 
 
 
 
 
 
66a12b0
 
e7a77ba
 
 
 
 
 
 
b0e67d9
e7a77ba
 
b0e67d9
 
 
 
66a12b0
e7a77ba
 
 
66a12b0
 
 
 
 
e7a77ba
 
 
b0e67d9
e7a77ba
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import gradio as gr
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
from functools import lru_cache

# Pre-selected small models
MODELS = {
    "SmolLM2-135M-Instruct": "HuggingFaceTB/SmolLM2-135M-Instruct",
    "GPT-2 (Small)": "gpt2",
    "DistilGPT-2": "distilgpt2",
    "Facebook OPT-125M": "facebook/opt-125m"
}

# Cache the model and tokenizer to avoid reloading
@lru_cache(maxsize=1)
def load_model_cached(model_name):
    try:
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        model = AutoModelForCausalLM.from_pretrained(model_name)
        return pipeline("text-generation", model=model, tokenizer=tokenizer)
    except Exception as e:
        return f"Error loading model: {str(e)}"

# Function to generate a response from the model
def chat(selected_model, user_input, chat_history, system_prompt=""):
    if not selected_model:
        return "Please select a model from the dropdown.", chat_history

    # Get the model name from the dropdown
    model_name = MODELS.get(selected_model)
    if not model_name:
        return "Invalid model selected.", chat_history

    # Load the model (cached)
    generator = load_model_cached(model_name)
    if isinstance(generator, str):  # If there was an error loading the model
        return generator, chat_history

    # Prepare the input with an optional system prompt
    full_input = f"{system_prompt}\n\n{user_input}" if system_prompt else user_input

    # Generate a response
    try:
        # Get the model's maximum context length
        max_context_length = generator.model.config.max_position_embeddings
        max_length = min(500, max_context_length)  # Ensure we don't exceed the model's limit

        # Truncate the input if it's too long
        inputs = generator.tokenizer(
            full_input,
            return_tensors="pt",
            max_length=max_length,
            truncation=True
        )

        # Generate the response
        response = generator(
            inputs['input_ids'],
            max_length=max_length,
            num_return_sequences=1,
            do_sample=True,
            top_p=0.95,
            top_k=60
        )[0]['generated_text']

        # Append the interaction to the chat history
        chat_history.append((user_input, response))
        return "", chat_history
    except Exception as e:
        return f"Error generating response: {str(e)}", chat_history

# Gradio interface
with gr.Blocks() as demo:
    gr.Markdown("# Chat with Small Language Models")
    
    with gr.Row():
        selected_model = gr.Dropdown(
            label="Select a Model",
            choices=list(MODELS.keys()),
            value="SmolLM2-135M-Instruct"  # Default model
        )
    
    chatbot = gr.Chatbot(label="Chat")
    user_input = gr.Textbox(label="Your Message", placeholder="Type your message here...")
    system_prompt = gr.Textbox(
        label="System Prompt (Optional)",
        placeholder="e.g., You are a helpful AI assistant.",
        lines=2
    )
    clear_button = gr.Button("Clear Chat")

    # Define the chat function
    user_input.submit(chat, [selected_model, user_input, chatbot, system_prompt], [user_input, chatbot])
    clear_button.click(lambda: [], None, chatbot, queue=False)

# Launch the app
demo.launch()