mgbam commited on
Commit
08137ac
·
verified ·
1 Parent(s): b2334c7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +249 -171
app.py CHANGED
@@ -1,186 +1,264 @@
1
  import gradio as gr
2
  import torch
 
3
  from PIL import Image
4
- import open_clip
5
  import numpy as np
6
- from LeGrad.legrad import LeWrapper, LePreprocess
7
- import cv2
8
-
9
- #---------------------------------
10
- #++++++++ Model ++++++++++
11
- #---------------------------------
12
-
13
- def load_biomedclip_model():
14
- """Loads the BiomedCLIP model and prepares it with LeGrad."""
15
- device = "cuda" if torch.cuda.is_available() else "cpu"
16
- model_name = "hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224"
17
- model, preprocess = open_clip.create_model_from_pretrained(
18
- model_name=model_name, device=device
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  )
20
- tokenizer = open_clip.get_tokenizer(model_name=model_name)
21
- model = LeWrapper(model) # Equip the model with LeGrad
22
- preprocess = LePreprocess(
23
- preprocess=preprocess, image_size=448
24
- ) # Optional higher-res preprocessing
25
- return model, preprocess, tokenizer, device
26
-
27
- def classify_image_with_biomedclip(editor_value, prompts, model, preprocess, tokenizer, device):
28
- """Classifies the image with the given text prompts using BiomedCLIP."""
29
- if editor_value is None:
30
- return None, None
31
-
32
- image = editor_value["composite"]
33
-
34
- if not isinstance(image, Image.Image):
35
- image = Image.fromarray(image)
36
-
37
- image_input = preprocess(image).unsqueeze(0).to(device)
38
- text_inputs = tokenizer(prompts).to(device)
39
-
40
- text_embeddings = model.encode_text(text_inputs, normalize=True)
41
- image_embeddings = model.encode_image(image_input, normalize=True)
42
-
43
- similarity = (
44
- model.logit_scale.exp() * image_embeddings @ text_embeddings.T
45
- ).softmax(dim=-1)
46
- probabilities = similarity[0].detach().cpu().numpy()
47
- explanation_maps = model.compute_legrad_clip(
48
- image=image_input, text_embedding=text_embeddings[probabilities.argmax()]
49
- )
50
-
51
- explanation_maps = explanation_maps.squeeze(0).detach().cpu().numpy()
52
- explanation_map = (explanation_maps * 255).astype(np.uint8)
53
-
54
- return probabilities, explanation_map
55
-
56
- def prepare_output_image(image, explanation_map):
57
- """Prepares the output image by blending the original image with the explanation map."""
58
- if not isinstance(image, Image.Image):
59
- image = Image.fromarray(image)
60
-
61
- explanation_image = explanation_map[0]
62
- if isinstance(explanation_image, torch.Tensor):
63
- explanation_image = explanation_image.cpu().numpy()
64
-
65
- explanation_image_resized = cv2.resize(
66
- explanation_image, (image.width, image.height)
67
- )
68
-
69
- explanation_image_resized = cv2.normalize(
70
- explanation_image_resized, None, 0, 255, cv2.NORM_MINMAX
71
- )
72
-
73
- explanation_colormap = cv2.applyColorMap(
74
- explanation_image_resized.astype(np.uint8), cv2.COLORMAP_JET
75
- )
76
-
77
- image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
78
-
79
- alpha = 0.5
80
- blended_image = cv2.addWeighted(image_cv, 1 - alpha, explanation_colormap, alpha, 0)
81
-
82
- blended_image_rgb = cv2.cvtColor(blended_image, cv2.COLOR_BGR2RGB)
83
- output_image = Image.fromarray(blended_image_rgb)
84
- return output_image
85
-
86
- #---------------------------------
87
- #++++++++ Gradio ++++++++++
88
- #---------------------------------
89
-
90
- def update_output(editor_value, prompts_input, model, preprocess, tokenizer, device):
91
- """Main function to update the output based on image and prompts."""
92
- prompts_list = [p.strip() for p in prompts_input.split(",") if p.strip()]
93
- if not prompts_list:
94
- return None, "Please enter at least one prompt."
95
-
96
- probabilities, explanation_map = classify_image_with_biomedclip(
97
- editor_value, prompts_list, model, preprocess, tokenizer, device
98
- )
99
-
100
- if probabilities is None:
101
- return None, "Please upload and annotate an image."
102
-
103
- prob_text = "\n".join(
104
- [
105
- f"{prompt}: {prob*100:.2f}%"
106
- for prompt, prob in zip(prompts_list, probabilities)
107
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  )
 
 
 
109
 
110
- image = editor_value["composite"]
111
- output_image = prepare_output_image(image, explanation_map)
112
-
113
- return output_image, prob_text
 
114
 
115
- def clear_inputs():
116
- return None, ""
117
 
118
- # Load model outside the Gradio blocks
119
- model, preprocess, tokenizer, device = load_biomedclip_model()
120
 
 
121
 
122
- with gr.Blocks() as demo:
123
- gr.Markdown(
124
- "# ✨ Visual Prompt Engineering for Medical Vision Language Models in Radiology ✨",
125
- elem_id="main-header",
 
 
 
 
 
126
  )
127
-
128
- gr.Markdown(
129
- "This tool applies **visual prompt engineering to improve the classification of medical images using the BiomedCLIP**[3], the current state of the art in zero-shot biomedical image classification. By uploading biomedical images (e.g., chest X-rays), you can manually annotate areas of interest directly on the image. These annotations serve as visual prompts, which guide the model's attention on the region of interest. This technique improves the model's ability to focus on subtle yet important details.\n\n"
130
- "After annotating and inputting text prompts (e.g., 'A chest X-ray with a benign/malignant lung nodule indicated by a red circle'), the tool returns classification results. These results are accompanied by **explainability maps** generated by **LeGrad** [3], which show where the model focused its attention, conditioned on the highest scoring text prompt. This helps to better interpret the model's decision-making process.\n\n"
131
- "In our paper **[Visual Prompt Engineering for Medical Vision Language Models in Radiology](https://arxiv.org/pdf/2408.15802)**, we show, that visual prompts such as arrows, circles, and contours improve the zero-shot classification of biomedical vision language models in radiology."
132
  )
133
-
134
- gr.Markdown("---")
135
-
136
- gr.Markdown(
137
- "## 📝 **How It Works**:\n"
138
- "1. **Upload** a biomedical image.\n"
139
- "2. **Annotate** the image using the built-in editor to highlight regions of interest.\n"
140
- "3. **Enter text prompts** separated by comma (e.g., 'A chest X-ray with a (benign/malignant) lung nodule indicated by a red circle').\n"
141
- "4. **Submit** to get class probabilities and an explainability map conditioned on the highest scoring text prompt."
142
  )
143
 
144
- gr.Markdown("---")
145
-
146
- with gr.Row():
147
- with gr.Column():
148
- image_editor = gr.ImageEditor(
149
- label="Upload and Annotate Image",
150
- type="pil",
151
- interactive=True,
152
- mirror_webcam=False,
153
- layers=False,
154
- scale=2,
155
- )
156
- prompts_input = gr.Textbox(
157
- placeholder="Enter prompts, comma-separated", label="Text Prompts"
158
- )
159
- submit_button = gr.Button("Submit", variant="primary")
160
- with gr.Column():
161
- output_image = gr.Image(
162
- type="pil",
163
- label="Output Image with Explanation Map",
164
- )
165
- prob_text = gr.Textbox(
166
- label="Class Probabilities", interactive=False, lines=10
167
- )
168
-
169
- inputs = [image_editor, prompts_input]
170
- outputs = [output_image, prob_text]
171
- submit_button.click(fn=update_output, inputs=inputs, outputs=outputs,
172
- _js=None,
173
- api_name=None,
174
- scroll_to_output=True,
175
- show_progress=True,
176
- queue=True,
177
- batch=False,
178
- preprocess=True,
179
- postprocess=True,
180
- cancels=None,
181
- show_loading_status=True,
182
- scroll_to_output_id=None,
183
- model=model, preprocess=preprocess, tokenizer=tokenizer, device=device
184
- )
185
- if __name__ == "__main__":
186
- demo.launch(share=True)
 
1
  import gradio as gr
2
  import torch
3
+ from janus.janusflow.models import MultiModalityCausalLM, VLChatProcessor
4
  from PIL import Image
5
+ from diffusers.models import AutoencoderKL
6
  import numpy as np
7
+ import spaces # Import spaces for ZeroGPU compatibility
8
+
9
+ cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
10
+
11
+ # Load model and processor
12
+ model_path = "deepseek-ai/JanusFlow-1.3B"
13
+ vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
14
+ tokenizer = vl_chat_processor.tokenizer
15
+
16
+ vl_gpt = MultiModalityCausalLM.from_pretrained(model_path)
17
+ vl_gpt = vl_gpt.to(torch.bfloat16).to(cuda_device).eval()
18
+
19
+ # remember to use bfloat16 dtype, this vae doesn't work with fp16
20
+ vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae")
21
+ vae = vae.to(torch.bfloat16).to(cuda_device).eval()
22
+
23
+ # Multimodal Understanding function
24
+ @torch.inference_mode()
25
+ @spaces.GPU(duration=120)
26
+ def multimodal_understanding(image, question, seed, top_p, temperature):
27
+ # Clear CUDA cache before generating
28
+ torch.cuda.empty_cache()
29
+
30
+ # set seed
31
+ torch.manual_seed(seed)
32
+ np.random.seed(seed)
33
+ torch.cuda.manual_seed(seed)
34
+
35
+ # Medical image preprocessing (this is a placeholder, implement based on your specific needs)
36
+ # NOTE: If input is DICOM or another medical format, add custom loading and preprocessing steps here
37
+ # Example: if input is DICOM:
38
+ # 1. load with pydicom.dcmread()
39
+ # 2. normalize pixel values based on windowing/leveling if necessary
40
+ # 3. convert to np.array
41
+ # else: if the input is a regular numpy array (e.g. png or jpg) no action is needed, image = image
42
+
43
+ conversation = [
44
+ {
45
+ "role": "User",
46
+ "content": f"<image_placeholder>\n{question}",
47
+ "images": [image],
48
+ },
49
+ {"role": "Assistant", "content": ""},
50
+ ]
51
+
52
+ pil_images = [Image.fromarray(image)]
53
+ prepare_inputs = vl_chat_processor(
54
+ conversations=conversation, images=pil_images, force_batchify=True
55
+ ).to(cuda_device, dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16)
56
+
57
+
58
+ inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
59
+
60
+ outputs = vl_gpt.language_model.generate(
61
+ inputs_embeds=inputs_embeds,
62
+ attention_mask=prepare_inputs.attention_mask,
63
+ pad_token_id=tokenizer.eos_token_id,
64
+ bos_token_id=tokenizer.bos_token_id,
65
+ eos_token_id=tokenizer.eos_token_id,
66
+ max_new_tokens=512,
67
+ do_sample=False if temperature == 0 else True,
68
+ use_cache=True,
69
+ temperature=temperature,
70
+ top_p=top_p,
71
  )
72
+
73
+ answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
74
+
75
+ return answer
76
+
77
+
78
+ @torch.inference_mode()
79
+ @spaces.GPU(duration=120)
80
+ def generate(
81
+ input_ids,
82
+ cfg_weight: float = 2.0,
83
+ num_inference_steps: int = 30
84
+ ):
85
+ # we generate 5 images at a time, *2 for CFG
86
+ tokens = torch.stack([input_ids] * 10).cuda()
87
+ tokens[5:, 1:] = vl_chat_processor.pad_id
88
+ inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
89
+ print(inputs_embeds.shape)
90
+
91
+ # we remove the last <bog> token and replace it with t_emb later
92
+ inputs_embeds = inputs_embeds[:, :-1, :]
93
+
94
+ # generate with rectified flow ode
95
+ # step 1: encode with vision_gen_enc
96
+ z = torch.randn((5, 4, 48, 48), dtype=torch.bfloat16).cuda()
97
+
98
+ dt = 1.0 / num_inference_steps
99
+ dt = torch.zeros_like(z).cuda().to(torch.bfloat16) + dt
100
+
101
+ # step 2: run ode
102
+ attention_mask = torch.ones((10, inputs_embeds.shape[1]+577)).to(vl_gpt.device)
103
+ attention_mask[5:, 1:inputs_embeds.shape[1]] = 0
104
+ attention_mask = attention_mask.int()
105
+ for step in range(num_inference_steps):
106
+ # prepare inputs for the llm
107
+ z_input = torch.cat([z, z], dim=0) # for cfg
108
+ t = step / num_inference_steps * 1000.
109
+ t = torch.tensor([t] * z_input.shape[0]).to(dt)
110
+ z_enc = vl_gpt.vision_gen_enc_model(z_input, t)
111
+ z_emb, t_emb, hs = z_enc[0], z_enc[1], z_enc[2]
112
+ z_emb = z_emb.view(z_emb.shape[0], z_emb.shape[1], -1).permute(0, 2, 1)
113
+ z_emb = vl_gpt.vision_gen_enc_aligner(z_emb)
114
+ llm_emb = torch.cat([inputs_embeds, t_emb.unsqueeze(1), z_emb], dim=1)
115
+
116
+ # input to the llm
117
+ # we apply attention mask for CFG: 1 for tokens that are not masked, 0 for tokens that are masked.
118
+ if step == 0:
119
+ outputs = vl_gpt.language_model.model(inputs_embeds=llm_emb,
120
+ use_cache=True,
121
+ attention_mask=attention_mask,
122
+ past_key_values=None)
123
+ past_key_values = []
124
+ for kv_cache in past_key_values:
125
+ k, v = kv_cache[0], kv_cache[1]
126
+ past_key_values.append((k[:, :, :inputs_embeds.shape[1], :], v[:, :, :inputs_embeds.shape[1], :]))
127
+ past_key_values = tuple(past_key_values)
128
+ else:
129
+ outputs = vl_gpt.language_model.model(inputs_embeds=llm_emb,
130
+ use_cache=True,
131
+ attention_mask=attention_mask,
132
+ past_key_values=past_key_values)
133
+ hidden_states = outputs.last_hidden_state
134
+
135
+ # transform hidden_states back to v
136
+ hidden_states = vl_gpt.vision_gen_dec_aligner(vl_gpt.vision_gen_dec_aligner_norm(hidden_states[:, -576:, :]))
137
+ hidden_states = hidden_states.reshape(z_emb.shape[0], 24, 24, 768).permute(0, 3, 1, 2)
138
+ v = vl_gpt.vision_gen_dec_model(hidden_states, hs, t_emb)
139
+ v_cond, v_uncond = torch.chunk(v, 2)
140
+ v = cfg_weight * v_cond - (cfg_weight-1.) * v_uncond
141
+ z = z + dt * v
142
+
143
+ # step 3: decode with vision_gen_dec and sdxl vae
144
+ decoded_image = vae.decode(z / vae.config.scaling_factor).sample
145
+
146
+ images = decoded_image.float().clip_(-1., 1.).permute(0,2,3,1).cpu().numpy()
147
+ images = ((images+1) / 2. * 255).astype(np.uint8)
148
+
149
+ return images
150
+
151
+ def unpack(dec, width, height, parallel_size=5):
152
+ dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
153
+ dec = np.clip((dec + 1) / 2 * 255, 0, 255)
154
+
155
+ visual_img = np.zeros((parallel_size, width, height, 3), dtype=np.uint8)
156
+ visual_img[:, :, :] = dec
157
+
158
+ return visual_img
159
+
160
+
161
+ @torch.inference_mode()
162
+ @spaces.GPU(duration=120)
163
+ def generate_image(prompt,
164
+ seed=None,
165
+ guidance=5,
166
+ num_inference_steps=30):
167
+ # Clear CUDA cache and avoid tracking gradients
168
+ torch.cuda.empty_cache()
169
+ # Set the seed for reproducible results
170
+ if seed is not None:
171
+ torch.manual_seed(seed)
172
+ torch.cuda.manual_seed(seed)
173
+ np.random.seed(seed)
174
+
175
+ with torch.no_grad():
176
+ messages = [{'role': 'User', 'content': prompt},
177
+ {'role': 'Assistant', 'content': ''}]
178
+ text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(conversations=messages,
179
+ sft_format=vl_chat_processor.sft_format,
180
+ system_prompt='')
181
+ text = text + vl_chat_processor.image_start_tag
182
+ input_ids = torch.LongTensor(tokenizer.encode(text))
183
+ images = generate(input_ids,
184
+ cfg_weight=guidance,
185
+ num_inference_steps=num_inference_steps)
186
+ return [Image.fromarray(images[i]).resize((1024, 1024), Image.LANCZOS) for i in range(images.shape[0])]
187
+
188
+
189
+
190
+ # Gradio interface
191
+ with gr.Blocks() as demo:
192
+ gr.Markdown(value="# Medical Image Analysis and Generation")
193
+ # with gr.Row():
194
+ with gr.Row():
195
+ image_input = gr.Image(label="Medical Image Input")
196
+ with gr.Column():
197
+ question_input = gr.Textbox(label="Analysis Prompt (e.g., 'Identify tumor', 'Characterize lesion', 'Describe anatomic structures')")
198
+ und_seed_input = gr.Number(label="Seed", precision=0, value=42)
199
+ top_p = gr.Slider(minimum=0, maximum=1, value=0.95, step=0.05, label="top_p")
200
+ temperature = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.05, label="temperature")
201
+
202
+ understanding_button = gr.Button("Analyze Image")
203
+ understanding_output = gr.Textbox(label="Analysis Response")
204
+
205
+ examples_inpainting = gr.Examples(
206
+ label="Multimodal Understanding examples",
207
+ examples=[
208
+ [
209
+ "Identify the tumor in the given image.",
210
+ "./ct_scan.png" # Placeholder medical image path
211
+ ],
212
+ [
213
+ "Characterize the lesion in the image. Is it malignant or benign?",
214
+ "./mri_scan.png", # Placeholder medical image path
215
+ ],
216
+ [
217
+ "Generate a report for the given medical image.",
218
+ "./xray.png", # Placeholder medical image path
219
+ ],
220
+
221
+ ],
222
+ inputs=[question_input, image_input],
223
  )
224
+
225
+
226
+ gr.Markdown(value="# Medical Image Generation with Hugging Face Logo")
227
 
228
+
229
+
230
+ with gr.Row():
231
+ cfg_weight_input = gr.Slider(minimum=1, maximum=10, value=2, step=0.5, label="CFG Weight")
232
+ step_input = gr.Slider(minimum=1, maximum=50, value=30, step=1, label="Number of Inference Steps")
233
 
234
+ prompt_input = gr.Textbox(label="Generation Prompt (e.g., 'Generate a CT scan with the Hugging Face logo', 'Create an MRI scan showing the Hugging Face logo', 'Render a medical x-ray with the Hugging Face logo.')")
235
+ seed_input = gr.Number(label="Seed (Optional)", precision=0, value=12345)
236
 
237
+ generation_button = gr.Button("Generate Images")
 
238
 
239
+ image_output = gr.Gallery(label="Generated Images", columns=2, rows=2, height=300)
240
 
241
+ examples_t2i = gr.Examples(
242
+ label="Medical image generation examples with Hugging Face logo.",
243
+ examples=[
244
+ "Generate a CT scan with the Hugging Face logo clearly visible.",
245
+ "Create an MRI scan showing the Hugging Face logo embedded within the tissue.",
246
+ "Render a medical x-ray with the Hugging Face logo subtly visible in the background.",
247
+ "Generate an ultrasound image with a faint Hugging Face logo on the screen",
248
+ ],
249
+ inputs=prompt_input,
250
  )
251
+
252
+ understanding_button.click(
253
+ multimodal_understanding,
254
+ inputs=[image_input, question_input, und_seed_input, top_p, temperature],
255
+ outputs=understanding_output
256
  )
257
+
258
+ generation_button.click(
259
+ fn=generate_image,
260
+ inputs=[prompt_input, seed_input, cfg_weight_input, step_input],
261
+ outputs=image_output
 
 
 
 
262
  )
263
 
264
+ demo.launch(share=True, ssr_mode = False)