Ruurd commited on
Commit
d0f4aff
·
1 Parent(s): 0c196de

Fix loading of models (only once per model)

Browse files
Files changed (1) hide show
  1. app.py +37 -24
app.py CHANGED
@@ -6,53 +6,66 @@ def install(package):
6
 
7
  install("transformers")
8
 
9
- import gradio as gr
10
- from transformers import AutoModelForCausalLM, AutoTokenizer
11
  import torch
 
12
  import spaces
 
13
 
 
 
14
 
15
- # Dictionary to store loaded models and tokenizers
16
- loaded_models = {}
17
-
18
- def load_model(model_name, progress=gr.Progress()):
19
- """Load the model and tokenizer with a progress bar."""
20
- if model_name not in loaded_models:
21
- access_token = os.getenv("HF_TOKEN")
22
- progress(0, desc="Initializing model loading...")
23
- tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=access_token)
24
- progress(0.5, desc="Tokenizer loaded. Loading model...")
25
  model = AutoModelForCausalLM.from_pretrained(
26
- model_name, torch_dtype=torch.float16, device_map="auto", use_auth_token=access_token
 
 
 
27
  )
28
- progress(1, desc="Model loaded successfully.")
29
- loaded_models[model_name] = (tokenizer, model)
30
- return loaded_models[model_name]
 
 
31
 
 
32
  @spaces.GPU
33
- def generate_text(model_name, prompt, progress=gr.Progress()):
34
- """Generate text using the selected model with a loading indicator."""
35
- tokenizer, model = load_model(model_name, progress)
36
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
37
  outputs = model.generate(**inputs, max_new_tokens=256)
38
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
39
 
40
-
41
- # List of models to choose from
42
  model_choices = [
43
  "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
44
  "meta-llama/Llama-3.2-3B-Instruct",
45
  "google/gemma-7b"
46
  ]
47
 
 
48
  with gr.Blocks() as demo:
49
- gr.Markdown("## Clinical Text Analysis with Multiple Models")
50
- model_selector = gr.Dropdown(choices=model_choices, label="Select Model")
 
 
 
 
51
  input_text = gr.Textbox(label="Input Clinical Text")
52
  output_text = gr.Textbox(label="Generated Output")
 
53
  analyze_button = gr.Button("Analyze")
54
 
 
 
 
 
55
  analyze_button.click(fn=generate_text, inputs=[model_selector, input_text], outputs=output_text)
56
 
57
  demo.launch()
58
-
 
6
 
7
  install("transformers")
8
 
9
+ import os
 
10
  import torch
11
+ import gradio as gr
12
  import spaces
13
+ from transformers import AutoModelForCausalLM, AutoTokenizer
14
 
15
+ # Global cache for loaded models
16
+ model_cache = {}
17
 
18
+ # Load a model with progress bar
19
+ def load_model(model_name, progress=gr.Progress(track_tqdm=False)):
20
+ if model_name not in model_cache:
21
+ token = os.getenv("HF_TOKEN")
22
+ progress(0, desc="Loading tokenizer...")
23
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=token)
24
+ progress(0.5, desc="Loading model...")
 
 
 
25
  model = AutoModelForCausalLM.from_pretrained(
26
+ model_name,
27
+ torch_dtype=torch.float16,
28
+ device_map="auto",
29
+ use_auth_token=token
30
  )
31
+ model_cache[model_name] = (tokenizer, model)
32
+ progress(1, desc="Model ready.")
33
+ return f"{model_name} loaded and ready!"
34
+ else:
35
+ return f"{model_name} already loaded."
36
 
37
+ # Inference function using GPU
38
  @spaces.GPU
39
+ def generate_text(model_name, prompt):
40
+ tokenizer, model = model_cache[model_name]
 
41
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
42
  outputs = model.generate(**inputs, max_new_tokens=256)
43
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
44
 
45
+ # Available models
 
46
  model_choices = [
47
  "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
48
  "meta-llama/Llama-3.2-3B-Instruct",
49
  "google/gemma-7b"
50
  ]
51
 
52
+ # Gradio Interface
53
  with gr.Blocks() as demo:
54
+ gr.Markdown("## Clinical Text Analysis with LLMs (LLaMA, DeepSeek, Gemma)")
55
+
56
+ with gr.Row():
57
+ model_selector = gr.Dropdown(choices=model_choices, label="Select Model")
58
+ model_status = gr.Textbox(label="Model Status", interactive=False)
59
+
60
  input_text = gr.Textbox(label="Input Clinical Text")
61
  output_text = gr.Textbox(label="Generated Output")
62
+
63
  analyze_button = gr.Button("Analyze")
64
 
65
+ # Load model when changed
66
+ model_selector.change(fn=load_model, inputs=model_selector, outputs=model_status)
67
+
68
+ # Generate output
69
  analyze_button.click(fn=generate_text, inputs=[model_selector, input_text], outputs=output_text)
70
 
71
  demo.launch()