Pruthvi369i commited on
Commit
6a21530
·
verified ·
1 Parent(s): 867a39c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -27
app.py CHANGED
@@ -1,54 +1,59 @@
1
  import os
2
  import gradio as gr
3
  import torch
4
- from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor, BitsAndBytesConfig
5
  from PIL import Image
6
 
7
  # Model ID
8
  MODEL_ID = "0llheaven/Llama-3.2-11B-Vision-Radiology-mini"
9
 
10
- # Configure 4-bit quantization
11
- quantization_config = BitsAndBytesConfig(
12
- load_in_4bit=True,
13
- bnb_4bit_compute_dtype=torch.float16,
14
- bnb_4bit_quant_type="nf4",
15
- bnb_4bit_use_double_quant=True
16
- )
17
-
18
  # Load tokenizer and processor
19
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
20
  processor = AutoProcessor.from_pretrained(MODEL_ID)
21
 
22
- # Load the model with quantization
23
- print("Loading model with 4-bit quantization...")
24
  model = AutoModelForCausalLM.from_pretrained(
25
  MODEL_ID,
26
- quantization_config=quantization_config,
27
- device_map="auto",
 
 
 
28
  trust_remote_code=True,
29
  )
30
  print("Model loaded!")
31
 
32
- def generate_response(image_file, prompt, max_new_tokens=512, temperature=0.7, top_p=0.9):
 
 
 
 
33
  try:
34
  # Process image if provided
35
  if image_file is not None:
36
  image = Image.open(image_file).convert('RGB')
37
 
38
- # Process inputs with processor
39
  inputs = processor(
40
  text=prompt,
41
  images=image,
42
  return_tensors="pt"
43
- ).to(model.device)
44
 
45
- # For multimodal models, we need to handle the inputs differently
46
- # Extract only the input_ids and attention_mask for generation
47
- input_ids = inputs.get("input_ids")
48
- attention_mask = inputs.get("attention_mask", None)
49
 
50
- # Generate response
 
 
 
 
51
  with torch.no_grad():
 
 
 
 
52
  outputs = model.generate(
53
  input_ids=input_ids,
54
  attention_mask=attention_mask,
@@ -65,8 +70,12 @@ def generate_response(image_file, prompt, max_new_tokens=512, temperature=0.7, t
65
  # Text-only input
66
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
67
 
68
- # Generate response
69
  with torch.no_grad():
 
 
 
 
70
  outputs = model.generate(
71
  **inputs,
72
  max_new_tokens=max_new_tokens,
@@ -78,7 +87,7 @@ def generate_response(image_file, prompt, max_new_tokens=512, temperature=0.7, t
78
  # Decode and return the response
79
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
80
 
81
- # Remove the input prompt from the response
82
  if response.startswith(prompt):
83
  response = response[len(prompt):].strip()
84
 
@@ -98,7 +107,7 @@ with gr.Blocks() as demo:
98
  prompt_input = gr.Textbox(label="Question or Prompt", placeholder="Describe what you see in this image and identify any abnormalities.")
99
 
100
  with gr.Row():
101
- max_tokens = gr.Slider(minimum=16, maximum=1024, value=512, step=8, label="Max New Tokens")
102
  temperature = gr.Slider(minimum=0.1, maximum=1.5, value=0.7, step=0.1, label="Temperature")
103
  top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-p")
104
 
@@ -115,10 +124,11 @@ with gr.Blocks() as demo:
115
 
116
  gr.Examples(
117
  [
118
- ["sample_xray.jpg", "Describe what you see in this chest X-ray and identify any abnormalities."],
119
- ["sample_ct.jpg", "Analyze this CT scan and provide a detailed report."],
120
  ],
121
  inputs=[image_input, prompt_input],
122
  )
123
 
124
- demo.launch()
 
 
1
  import os
2
  import gradio as gr
3
  import torch
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer, AutoProcessor
5
  from PIL import Image
6
 
7
  # Model ID
8
  MODEL_ID = "0llheaven/Llama-3.2-11B-Vision-Radiology-mini"
9
 
 
 
 
 
 
 
 
 
10
  # Load tokenizer and processor
11
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
12
  processor = AutoProcessor.from_pretrained(MODEL_ID)
13
 
14
+ # Load the model with reduced precision and memory optimizations
15
+ print("Loading model with memory optimizations...")
16
  model = AutoModelForCausalLM.from_pretrained(
17
  MODEL_ID,
18
+ torch_dtype=torch.float16, # Use half precision
19
+ device_map="auto", # Let the library decide how to map the model
20
+ low_cpu_mem_usage=True, # Optimize CPU memory usage
21
+ offload_folder="offload", # Offload weights to disk if needed
22
+ offload_state_dict=True, # Enable state dict offloading
23
  trust_remote_code=True,
24
  )
25
  print("Model loaded!")
26
 
27
+ # Clear CUDA cache after loading
28
+ if torch.cuda.is_available():
29
+ torch.cuda.empty_cache()
30
+
31
+ def generate_response(image_file, prompt, max_new_tokens=256, temperature=0.7, top_p=0.9):
32
  try:
33
  # Process image if provided
34
  if image_file is not None:
35
  image = Image.open(image_file).convert('RGB')
36
 
37
+ # Process inputs
38
  inputs = processor(
39
  text=prompt,
40
  images=image,
41
  return_tensors="pt"
42
+ )
43
 
44
+ # Move inputs to the same device as model
45
+ inputs = {k: v.to(model.device) for k, v in inputs.items() if isinstance(v, torch.Tensor)}
 
 
46
 
47
+ # For safer generation, extract only what's needed
48
+ input_ids = inputs.pop("input_ids", None)
49
+ attention_mask = inputs.pop("attention_mask", None)
50
+
51
+ # Generate response with conservative memory settings
52
  with torch.no_grad():
53
+ # Clear cache before generation
54
+ if torch.cuda.is_available():
55
+ torch.cuda.empty_cache()
56
+
57
  outputs = model.generate(
58
  input_ids=input_ids,
59
  attention_mask=attention_mask,
 
70
  # Text-only input
71
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
72
 
73
+ # Generate response with conservative memory settings
74
  with torch.no_grad():
75
+ # Clear cache before generation
76
+ if torch.cuda.is_available():
77
+ torch.cuda.empty_cache()
78
+
79
  outputs = model.generate(
80
  **inputs,
81
  max_new_tokens=max_new_tokens,
 
87
  # Decode and return the response
88
  response = tokenizer.decode(outputs[0], skip_special_tokens=True)
89
 
90
+ # Remove the input prompt from the response if present
91
  if response.startswith(prompt):
92
  response = response[len(prompt):].strip()
93
 
 
107
  prompt_input = gr.Textbox(label="Question or Prompt", placeholder="Describe what you see in this image and identify any abnormalities.")
108
 
109
  with gr.Row():
110
+ max_tokens = gr.Slider(minimum=16, maximum=512, value=256, step=8, label="Max New Tokens")
111
  temperature = gr.Slider(minimum=0.1, maximum=1.5, value=0.7, step=0.1, label="Temperature")
112
  top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-p")
113
 
 
124
 
125
  gr.Examples(
126
  [
127
+ ["sample_xray.jpg", "What abnormalities do you see in this X-ray?"],
128
+ ["sample_ct.jpg", "Describe this image and any findings."],
129
  ],
130
  inputs=[image_input, prompt_input],
131
  )
132
 
133
+ # Reduce maximum allowed concurrent users to conserve memory
134
+ demo.launch(max_threads=1)