Update app.py
Browse files
app.py
CHANGED
@@ -1,8 +1,10 @@
|
|
1 |
import gradio as gr
|
2 |
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
|
|
|
3 |
|
4 |
-
#
|
5 |
-
|
|
|
6 |
try:
|
7 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
8 |
model = AutoModelForCausalLM.from_pretrained(model_name)
|
@@ -11,18 +13,45 @@ def load_model(model_name):
|
|
11 |
return f"Error loading model: {str(e)}"
|
12 |
|
13 |
# Function to generate a response from the model
|
14 |
-
def chat(model_name, user_input, chat_history):
|
15 |
if model_name.strip() == "":
|
16 |
return "Please enter a valid model name.", chat_history
|
17 |
|
18 |
-
# Load the model
|
19 |
-
generator =
|
20 |
if isinstance(generator, str): # If there was an error loading the model
|
21 |
return generator, chat_history
|
22 |
|
|
|
|
|
|
|
23 |
# Generate a response
|
24 |
try:
|
25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
chat_history.append((user_input, response))
|
27 |
return "", chat_history
|
28 |
except Exception as e:
|
@@ -30,17 +59,26 @@ def chat(model_name, user_input, chat_history):
|
|
30 |
|
31 |
# Gradio interface
|
32 |
with gr.Blocks() as demo:
|
33 |
-
gr.Markdown("# Chat with
|
34 |
|
35 |
with gr.Row():
|
36 |
-
model_name = gr.Textbox(
|
|
|
|
|
|
|
|
|
37 |
|
38 |
chatbot = gr.Chatbot(label="Chat")
|
39 |
user_input = gr.Textbox(label="Your Message", placeholder="Type your message here...")
|
|
|
|
|
|
|
|
|
|
|
40 |
clear_button = gr.Button("Clear Chat")
|
41 |
|
42 |
# Define the chat function
|
43 |
-
user_input.submit(chat, [model_name, user_input, chatbot], [user_input, chatbot])
|
44 |
clear_button.click(lambda: [], None, chatbot, queue=False)
|
45 |
|
46 |
# Launch the app
|
|
|
1 |
import gradio as gr
|
2 |
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
|
3 |
+
from functools import lru_cache
|
4 |
|
5 |
+
# Cache the model and tokenizer to avoid reloading
|
6 |
+
@lru_cache(maxsize=1)
|
7 |
+
def load_model_cached(model_name):
|
8 |
try:
|
9 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
10 |
model = AutoModelForCausalLM.from_pretrained(model_name)
|
|
|
13 |
return f"Error loading model: {str(e)}"
|
14 |
|
15 |
# Function to generate a response from the model
|
16 |
+
def chat(model_name, user_input, chat_history, system_prompt=""):
|
17 |
if model_name.strip() == "":
|
18 |
return "Please enter a valid model name.", chat_history
|
19 |
|
20 |
+
# Load the model (cached)
|
21 |
+
generator = load_model_cached(model_name)
|
22 |
if isinstance(generator, str): # If there was an error loading the model
|
23 |
return generator, chat_history
|
24 |
|
25 |
+
# Prepare the input with an optional system prompt
|
26 |
+
full_input = f"{system_prompt}\n\n{user_input}" if system_prompt else user_input
|
27 |
+
|
28 |
# Generate a response
|
29 |
try:
|
30 |
+
# Get the model's maximum context length
|
31 |
+
max_context_length = generator.model.config.max_position_embeddings
|
32 |
+
max_length = min(500, max_context_length) # Ensure we don't exceed the model's limit
|
33 |
+
|
34 |
+
# Truncate the input if it's too long
|
35 |
+
inputs = generator.tokenizer(
|
36 |
+
full_input,
|
37 |
+
return_tensors="pt",
|
38 |
+
max_length=max_length,
|
39 |
+
truncation=True
|
40 |
+
)
|
41 |
+
|
42 |
+
# Generate the response with a progress indicator
|
43 |
+
with gr.Progress() as progress:
|
44 |
+
progress(0.5, desc="Generating response...")
|
45 |
+
response = generator(
|
46 |
+
inputs['input_ids'],
|
47 |
+
max_length=max_length,
|
48 |
+
num_return_sequences=1,
|
49 |
+
do_sample=True,
|
50 |
+
top_p=0.95,
|
51 |
+
top_k=60
|
52 |
+
)[0]['generated_text']
|
53 |
+
|
54 |
+
# Append the interaction to the chat history
|
55 |
chat_history.append((user_input, response))
|
56 |
return "", chat_history
|
57 |
except Exception as e:
|
|
|
59 |
|
60 |
# Gradio interface
|
61 |
with gr.Blocks() as demo:
|
62 |
+
gr.Markdown("# Chat with SmolLM2-135M-Instruct")
|
63 |
|
64 |
with gr.Row():
|
65 |
+
model_name = gr.Textbox(
|
66 |
+
label="Enter Hugging Face Model Name",
|
67 |
+
value="HuggingFaceTB/SmolLM2-135M-Instruct", # Default model
|
68 |
+
placeholder="e.g., HuggingFaceTB/SmolLM2-135M-Instruct"
|
69 |
+
)
|
70 |
|
71 |
chatbot = gr.Chatbot(label="Chat")
|
72 |
user_input = gr.Textbox(label="Your Message", placeholder="Type your message here...")
|
73 |
+
system_prompt = gr.Textbox(
|
74 |
+
label="System Prompt (Optional)",
|
75 |
+
placeholder="e.g., You are a helpful AI assistant.",
|
76 |
+
lines=2
|
77 |
+
)
|
78 |
clear_button = gr.Button("Clear Chat")
|
79 |
|
80 |
# Define the chat function
|
81 |
+
user_input.submit(chat, [model_name, user_input, chatbot, system_prompt], [user_input, chatbot])
|
82 |
clear_button.click(lambda: [], None, chatbot, queue=False)
|
83 |
|
84 |
# Launch the app
|