Papaya-Voldemort commited on
Commit
66a12b0
·
verified ·
1 Parent(s): 502c1eb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -9
app.py CHANGED
@@ -1,8 +1,10 @@
1
  import gradio as gr
2
  from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
 
3
 
4
- # Function to load the model and tokenizer
5
- def load_model(model_name):
 
6
  try:
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
  model = AutoModelForCausalLM.from_pretrained(model_name)
@@ -11,18 +13,45 @@ def load_model(model_name):
11
  return f"Error loading model: {str(e)}"
12
 
13
  # Function to generate a response from the model
14
- def chat(model_name, user_input, chat_history):
15
  if model_name.strip() == "":
16
  return "Please enter a valid model name.", chat_history
17
 
18
- # Load the model
19
- generator = load_model(model_name)
20
  if isinstance(generator, str): # If there was an error loading the model
21
  return generator, chat_history
22
 
 
 
 
23
  # Generate a response
24
  try:
25
- response = generator(user_input, max_length=500, num_return_sequences=1, do_sample=True, top_p=0.95, top_k=60)[0]['generated_text']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  chat_history.append((user_input, response))
27
  return "", chat_history
28
  except Exception as e:
@@ -30,17 +59,26 @@ def chat(model_name, user_input, chat_history):
30
 
31
  # Gradio interface
32
  with gr.Blocks() as demo:
33
- gr.Markdown("# Chat with Any Hugging Face Model")
34
 
35
  with gr.Row():
36
- model_name = gr.Textbox(label="Enter Hugging Face Model Name", placeholder="e.g., gpt2, facebook/opt-125m")
 
 
 
 
37
 
38
  chatbot = gr.Chatbot(label="Chat")
39
  user_input = gr.Textbox(label="Your Message", placeholder="Type your message here...")
 
 
 
 
 
40
  clear_button = gr.Button("Clear Chat")
41
 
42
  # Define the chat function
43
- user_input.submit(chat, [model_name, user_input, chatbot], [user_input, chatbot])
44
  clear_button.click(lambda: [], None, chatbot, queue=False)
45
 
46
  # Launch the app
 
1
  import gradio as gr
2
  from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
3
+ from functools import lru_cache
4
 
5
+ # Cache the model and tokenizer to avoid reloading
6
+ @lru_cache(maxsize=1)
7
+ def load_model_cached(model_name):
8
  try:
9
  tokenizer = AutoTokenizer.from_pretrained(model_name)
10
  model = AutoModelForCausalLM.from_pretrained(model_name)
 
13
  return f"Error loading model: {str(e)}"
14
 
15
  # Function to generate a response from the model
16
+ def chat(model_name, user_input, chat_history, system_prompt=""):
17
  if model_name.strip() == "":
18
  return "Please enter a valid model name.", chat_history
19
 
20
+ # Load the model (cached)
21
+ generator = load_model_cached(model_name)
22
  if isinstance(generator, str): # If there was an error loading the model
23
  return generator, chat_history
24
 
25
+ # Prepare the input with an optional system prompt
26
+ full_input = f"{system_prompt}\n\n{user_input}" if system_prompt else user_input
27
+
28
  # Generate a response
29
  try:
30
+ # Get the model's maximum context length
31
+ max_context_length = generator.model.config.max_position_embeddings
32
+ max_length = min(500, max_context_length) # Ensure we don't exceed the model's limit
33
+
34
+ # Truncate the input if it's too long
35
+ inputs = generator.tokenizer(
36
+ full_input,
37
+ return_tensors="pt",
38
+ max_length=max_length,
39
+ truncation=True
40
+ )
41
+
42
+ # Generate the response with a progress indicator
43
+ with gr.Progress() as progress:
44
+ progress(0.5, desc="Generating response...")
45
+ response = generator(
46
+ inputs['input_ids'],
47
+ max_length=max_length,
48
+ num_return_sequences=1,
49
+ do_sample=True,
50
+ top_p=0.95,
51
+ top_k=60
52
+ )[0]['generated_text']
53
+
54
+ # Append the interaction to the chat history
55
  chat_history.append((user_input, response))
56
  return "", chat_history
57
  except Exception as e:
 
59
 
60
  # Gradio interface
61
  with gr.Blocks() as demo:
62
+ gr.Markdown("# Chat with SmolLM2-135M-Instruct")
63
 
64
  with gr.Row():
65
+ model_name = gr.Textbox(
66
+ label="Enter Hugging Face Model Name",
67
+ value="HuggingFaceTB/SmolLM2-135M-Instruct", # Default model
68
+ placeholder="e.g., HuggingFaceTB/SmolLM2-135M-Instruct"
69
+ )
70
 
71
  chatbot = gr.Chatbot(label="Chat")
72
  user_input = gr.Textbox(label="Your Message", placeholder="Type your message here...")
73
+ system_prompt = gr.Textbox(
74
+ label="System Prompt (Optional)",
75
+ placeholder="e.g., You are a helpful AI assistant.",
76
+ lines=2
77
+ )
78
  clear_button = gr.Button("Clear Chat")
79
 
80
  # Define the chat function
81
+ user_input.submit(chat, [model_name, user_input, chatbot, system_prompt], [user_input, chatbot])
82
  clear_button.click(lambda: [], None, chatbot, queue=False)
83
 
84
  # Launch the app