Ruurd commited on
Commit
b6b7c74
·
1 Parent(s): d0f4aff

Load models using dropdown

Browse files
Files changed (1) hide show
  1. app.py +42 -49
app.py CHANGED
@@ -1,71 +1,64 @@
1
- import os
2
- import subprocess
3
-
4
- def install(package):
5
- subprocess.check_call([os.sys.executable, "-m", "pip", "install", package])
6
-
7
- install("transformers")
8
-
9
  import os
10
  import torch
11
  import gradio as gr
12
  import spaces
13
- from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
 
 
 
 
 
 
 
 
 
14
 
15
- # Global cache for loaded models
16
- model_cache = {}
 
 
 
 
 
17
 
18
- # Load a model with progress bar
19
- def load_model(model_name, progress=gr.Progress(track_tqdm=False)):
20
- if model_name not in model_cache:
21
- token = os.getenv("HF_TOKEN")
22
- progress(0, desc="Loading tokenizer...")
23
- tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=token)
24
- progress(0.5, desc="Loading model...")
25
- model = AutoModelForCausalLM.from_pretrained(
26
- model_name,
27
- torch_dtype=torch.float16,
28
- device_map="auto",
29
- use_auth_token=token
30
- )
31
- model_cache[model_name] = (tokenizer, model)
32
- progress(1, desc="Model ready.")
33
- return f"{model_name} loaded and ready!"
34
- else:
35
- return f"{model_name} already loaded."
36
 
37
- # Inference function using GPU
38
  @spaces.GPU
39
- def generate_text(model_name, prompt):
40
- tokenizer, model = model_cache[model_name]
41
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
42
- outputs = model.generate(**inputs, max_new_tokens=256)
43
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
44
 
45
- # Available models
46
  model_choices = [
47
  "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
48
  "meta-llama/Llama-3.2-3B-Instruct",
49
  "google/gemma-7b"
50
  ]
51
 
52
- # Gradio Interface
53
  with gr.Blocks() as demo:
54
- gr.Markdown("## Clinical Text Analysis with LLMs (LLaMA, DeepSeek, Gemma)")
55
-
56
- with gr.Row():
57
- model_selector = gr.Dropdown(choices=model_choices, label="Select Model")
58
- model_status = gr.Textbox(label="Model Status", interactive=False)
59
 
60
  input_text = gr.Textbox(label="Input Clinical Text")
61
  output_text = gr.Textbox(label="Generated Output")
62
-
63
- analyze_button = gr.Button("Analyze")
64
 
65
- # Load model when changed
66
- model_selector.change(fn=load_model, inputs=model_selector, outputs=model_status)
67
-
68
- # Generate output
69
- analyze_button.click(fn=generate_text, inputs=[model_selector, input_text], outputs=output_text)
 
 
70
 
71
  demo.launch()
 
 
 
 
 
 
 
 
 
1
  import os
2
  import torch
3
  import gradio as gr
4
  import spaces
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM
6
+
7
+ # Use a global variable to hold the current model and tokenizer
8
+ current_model = None
9
+ current_tokenizer = None
10
+
11
+ def load_model_on_selection(model_name, progress=gr.Progress(track_tqdm=False)):
12
+ global current_model, current_tokenizer
13
+ token = os.getenv("HF_TOKEN")
14
+
15
+ progress(0, desc="Loading tokenizer...")
16
+ current_tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=token)
17
 
18
+ progress(0.5, desc="Loading model...")
19
+ current_model = AutoModelForCausalLM.from_pretrained(
20
+ model_name,
21
+ torch_dtype=torch.float16,
22
+ device_map="auto",
23
+ use_auth_token=token
24
+ )
25
 
26
+ progress(1, desc="Model ready.")
27
+ return f"{model_name} loaded and ready!"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
 
29
  @spaces.GPU
30
+ def generate_text(prompt):
31
+ global current_model, current_tokenizer
32
+ if current_model is None or current_tokenizer is None:
33
+ return "⚠️ No model loaded yet. Please select a model first."
34
+
35
+ inputs = current_tokenizer(prompt, return_tensors="pt").to(current_model.device)
36
+ outputs = current_model.generate(**inputs, max_new_tokens=256)
37
+ return current_tokenizer.decode(outputs[0], skip_special_tokens=True)
38
 
39
+ # Model options
40
  model_choices = [
41
  "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
42
  "meta-llama/Llama-3.2-3B-Instruct",
43
  "google/gemma-7b"
44
  ]
45
 
46
+ # Gradio UI
47
  with gr.Blocks() as demo:
48
+ gr.Markdown("## Clinical Text Testing with LLaMA, DeepSeek, and Gemma")
49
+
50
+ model_selector = gr.Dropdown(choices=model_choices, label="Select Model")
51
+ model_status = gr.Textbox(label="Model Status", interactive=False)
 
52
 
53
  input_text = gr.Textbox(label="Input Clinical Text")
54
  output_text = gr.Textbox(label="Generated Output")
 
 
55
 
56
+ generate_btn = gr.Button("Generate")
57
+
58
+ # Load model on dropdown change
59
+ model_selector.change(fn=load_model_on_selection, inputs=model_selector, outputs=model_status)
60
+
61
+ # Generate with current model
62
+ generate_btn.click(fn=generate_text, inputs=input_text, outputs=output_text)
63
 
64
  demo.launch()