Ruurd commited on
Commit
5e84c69
·
1 Parent(s): bc19680

Fix loading on startup

Browse files
Files changed (1) hide show
  1. app.py +9 -6
app.py CHANGED
@@ -2,7 +2,7 @@ import os
2
  import torch
3
  import gradio as gr
4
  import spaces
5
- from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
6
 
7
 
8
  # Use a global variable to hold the current model and tokenizer
@@ -46,7 +46,7 @@ def generate_text(prompt):
46
  return_dict_in_generate=True,
47
  output_scores=False
48
  ).sequences[0]:
49
-
50
  output_ids.append(token_id.item())
51
 
52
  yield current_tokenizer.decode(output_ids, skip_special_tokens=True)
@@ -63,21 +63,24 @@ model_choices = [
63
  with gr.Blocks() as demo:
64
  gr.Markdown("## Clinical Text Testing with LLaMA, DeepSeek, and Gemma")
65
 
 
 
 
66
  model_selector = gr.Dropdown(choices=model_choices, label="Select Model")
67
  model_status = gr.Textbox(label="Model Status", interactive=False)
68
 
69
  input_text = gr.Textbox(label="Input Clinical Text")
70
  generate_btn = gr.Button("Generate")
71
-
72
  output_text = gr.Textbox(label="Generated Output")
73
 
74
- # Load model on dropdown change
 
 
 
75
  model_selector.change(fn=load_model_on_selection, inputs=model_selector, outputs=model_status)
76
 
77
  # Generate with current model
78
  generate_btn.click(fn=generate_text, inputs=input_text, outputs=output_text)
79
  input_text.submit(fn=generate_text, inputs=input_text, outputs=output_text)
80
 
81
-
82
- load_model_on_selection("meta-llama/Llama-3.2-3B-Instruct")
83
  demo.launch()
 
2
  import torch
3
  import gradio as gr
4
  import spaces
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM
6
 
7
 
8
  # Use a global variable to hold the current model and tokenizer
 
46
  return_dict_in_generate=True,
47
  output_scores=False
48
  ).sequences[0]:
49
+
50
  output_ids.append(token_id.item())
51
 
52
  yield current_tokenizer.decode(output_ids, skip_special_tokens=True)
 
63
  with gr.Blocks() as demo:
64
  gr.Markdown("## Clinical Text Testing with LLaMA, DeepSeek, and Gemma")
65
 
66
+ # State to track initial model to load
67
+ default_model = gr.State("meta-llama/Llama-3.2-3B-Instruct")
68
+
69
  model_selector = gr.Dropdown(choices=model_choices, label="Select Model")
70
  model_status = gr.Textbox(label="Model Status", interactive=False)
71
 
72
  input_text = gr.Textbox(label="Input Clinical Text")
73
  generate_btn = gr.Button("Generate")
 
74
  output_text = gr.Textbox(label="Generated Output")
75
 
76
+ # Auto-load default model on launch
77
+ demo.load(fn=load_model_on_selection, inputs=default_model, outputs=model_status)
78
+
79
+ # Manual model selection
80
  model_selector.change(fn=load_model_on_selection, inputs=model_selector, outputs=model_status)
81
 
82
  # Generate with current model
83
  generate_btn.click(fn=generate_text, inputs=input_text, outputs=output_text)
84
  input_text.submit(fn=generate_text, inputs=input_text, outputs=output_text)
85
 
 
 
86
  demo.launch()