Ruurd commited on
Commit
0c196de
·
1 Parent(s): 2372fe6

Show loading bars

Browse files
Files changed (1) hide show
  1. app.py +13 -8
app.py CHANGED
@@ -15,24 +15,29 @@ import spaces
15
  # Dictionary to store loaded models and tokenizers
16
  loaded_models = {}
17
 
18
- def load_model(model_name):
19
- """Load the model and tokenizer if not already loaded."""
20
  if model_name not in loaded_models:
21
- tokenizer = AutoTokenizer.from_pretrained(model_name)
 
 
 
22
  model = AutoModelForCausalLM.from_pretrained(
23
- model_name, torch_dtype=torch.float16, device_map="auto"
24
  )
 
25
  loaded_models[model_name] = (tokenizer, model)
26
  return loaded_models[model_name]
27
 
28
  @spaces.GPU
29
- def generate_text(model_name, prompt):
30
- """Generate text using the selected model."""
31
- tokenizer, model = load_model(model_name)
32
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
33
  outputs = model.generate(**inputs, max_new_tokens=256)
34
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
35
 
 
36
  # List of models to choose from
37
  model_choices = [
38
  "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
@@ -40,7 +45,6 @@ model_choices = [
40
  "google/gemma-7b"
41
  ]
42
 
43
- # Gradio interface setup
44
  with gr.Blocks() as demo:
45
  gr.Markdown("## Clinical Text Analysis with Multiple Models")
46
  model_selector = gr.Dropdown(choices=model_choices, label="Select Model")
@@ -51,3 +55,4 @@ with gr.Blocks() as demo:
51
  analyze_button.click(fn=generate_text, inputs=[model_selector, input_text], outputs=output_text)
52
 
53
  demo.launch()
 
 
15
  # Dictionary to store loaded models and tokenizers
16
  loaded_models = {}
17
 
18
+ def load_model(model_name, progress=gr.Progress()):
19
+ """Load the model and tokenizer with a progress bar."""
20
  if model_name not in loaded_models:
21
+ access_token = os.getenv("HF_TOKEN")
22
+ progress(0, desc="Initializing model loading...")
23
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=access_token)
24
+ progress(0.5, desc="Tokenizer loaded. Loading model...")
25
  model = AutoModelForCausalLM.from_pretrained(
26
+ model_name, torch_dtype=torch.float16, device_map="auto", use_auth_token=access_token
27
  )
28
+ progress(1, desc="Model loaded successfully.")
29
  loaded_models[model_name] = (tokenizer, model)
30
  return loaded_models[model_name]
31
 
32
  @spaces.GPU
33
+ def generate_text(model_name, prompt, progress=gr.Progress()):
34
+ """Generate text using the selected model with a loading indicator."""
35
+ tokenizer, model = load_model(model_name, progress)
36
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
37
  outputs = model.generate(**inputs, max_new_tokens=256)
38
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
39
 
40
+
41
  # List of models to choose from
42
  model_choices = [
43
  "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
 
45
  "google/gemma-7b"
46
  ]
47
 
 
48
  with gr.Blocks() as demo:
49
  gr.Markdown("## Clinical Text Analysis with Multiple Models")
50
  model_selector = gr.Dropdown(choices=model_choices, label="Select Model")
 
55
  analyze_button.click(fn=generate_text, inputs=[model_selector, input_text], outputs=output_text)
56
 
57
  demo.launch()
58
+