Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,186 +1,264 @@
|
|
1 |
import gradio as gr
|
2 |
import torch
|
|
|
3 |
from PIL import Image
|
4 |
-
import
|
5 |
import numpy as np
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
)
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
)
|
|
|
|
|
|
|
109 |
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
|
|
114 |
|
115 |
-
|
116 |
-
|
117 |
|
118 |
-
|
119 |
-
model, preprocess, tokenizer, device = load_biomedclip_model()
|
120 |
|
|
|
121 |
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
|
|
|
|
|
|
|
|
|
|
126 |
)
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
)
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
"1. **Upload** a biomedical image.\n"
|
139 |
-
"2. **Annotate** the image using the built-in editor to highlight regions of interest.\n"
|
140 |
-
"3. **Enter text prompts** separated by comma (e.g., 'A chest X-ray with a (benign/malignant) lung nodule indicated by a red circle').\n"
|
141 |
-
"4. **Submit** to get class probabilities and an explainability map conditioned on the highest scoring text prompt."
|
142 |
)
|
143 |
|
144 |
-
|
145 |
-
|
146 |
-
with gr.Row():
|
147 |
-
with gr.Column():
|
148 |
-
image_editor = gr.ImageEditor(
|
149 |
-
label="Upload and Annotate Image",
|
150 |
-
type="pil",
|
151 |
-
interactive=True,
|
152 |
-
mirror_webcam=False,
|
153 |
-
layers=False,
|
154 |
-
scale=2,
|
155 |
-
)
|
156 |
-
prompts_input = gr.Textbox(
|
157 |
-
placeholder="Enter prompts, comma-separated", label="Text Prompts"
|
158 |
-
)
|
159 |
-
submit_button = gr.Button("Submit", variant="primary")
|
160 |
-
with gr.Column():
|
161 |
-
output_image = gr.Image(
|
162 |
-
type="pil",
|
163 |
-
label="Output Image with Explanation Map",
|
164 |
-
)
|
165 |
-
prob_text = gr.Textbox(
|
166 |
-
label="Class Probabilities", interactive=False, lines=10
|
167 |
-
)
|
168 |
-
|
169 |
-
inputs = [image_editor, prompts_input]
|
170 |
-
outputs = [output_image, prob_text]
|
171 |
-
submit_button.click(fn=update_output, inputs=inputs, outputs=outputs,
|
172 |
-
_js=None,
|
173 |
-
api_name=None,
|
174 |
-
scroll_to_output=True,
|
175 |
-
show_progress=True,
|
176 |
-
queue=True,
|
177 |
-
batch=False,
|
178 |
-
preprocess=True,
|
179 |
-
postprocess=True,
|
180 |
-
cancels=None,
|
181 |
-
show_loading_status=True,
|
182 |
-
scroll_to_output_id=None,
|
183 |
-
model=model, preprocess=preprocess, tokenizer=tokenizer, device=device
|
184 |
-
)
|
185 |
-
if __name__ == "__main__":
|
186 |
-
demo.launch(share=True)
|
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
+
from janus.janusflow.models import MultiModalityCausalLM, VLChatProcessor
|
4 |
from PIL import Image
|
5 |
+
from diffusers.models import AutoencoderKL
|
6 |
import numpy as np
|
7 |
+
import spaces # Import spaces for ZeroGPU compatibility
|
8 |
+
|
9 |
+
cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
10 |
+
|
11 |
+
# Load model and processor
|
12 |
+
model_path = "deepseek-ai/JanusFlow-1.3B"
|
13 |
+
vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
|
14 |
+
tokenizer = vl_chat_processor.tokenizer
|
15 |
+
|
16 |
+
vl_gpt = MultiModalityCausalLM.from_pretrained(model_path)
|
17 |
+
vl_gpt = vl_gpt.to(torch.bfloat16).to(cuda_device).eval()
|
18 |
+
|
19 |
+
# remember to use bfloat16 dtype, this vae doesn't work with fp16
|
20 |
+
vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae")
|
21 |
+
vae = vae.to(torch.bfloat16).to(cuda_device).eval()
|
22 |
+
|
23 |
+
# Multimodal Understanding function
|
24 |
+
@torch.inference_mode()
|
25 |
+
@spaces.GPU(duration=120)
|
26 |
+
def multimodal_understanding(image, question, seed, top_p, temperature):
|
27 |
+
# Clear CUDA cache before generating
|
28 |
+
torch.cuda.empty_cache()
|
29 |
+
|
30 |
+
# set seed
|
31 |
+
torch.manual_seed(seed)
|
32 |
+
np.random.seed(seed)
|
33 |
+
torch.cuda.manual_seed(seed)
|
34 |
+
|
35 |
+
# Medical image preprocessing (this is a placeholder, implement based on your specific needs)
|
36 |
+
# NOTE: If input is DICOM or another medical format, add custom loading and preprocessing steps here
|
37 |
+
# Example: if input is DICOM:
|
38 |
+
# 1. load with pydicom.dcmread()
|
39 |
+
# 2. normalize pixel values based on windowing/leveling if necessary
|
40 |
+
# 3. convert to np.array
|
41 |
+
# else: if the input is a regular numpy array (e.g. png or jpg) no action is needed, image = image
|
42 |
+
|
43 |
+
conversation = [
|
44 |
+
{
|
45 |
+
"role": "User",
|
46 |
+
"content": f"<image_placeholder>\n{question}",
|
47 |
+
"images": [image],
|
48 |
+
},
|
49 |
+
{"role": "Assistant", "content": ""},
|
50 |
+
]
|
51 |
+
|
52 |
+
pil_images = [Image.fromarray(image)]
|
53 |
+
prepare_inputs = vl_chat_processor(
|
54 |
+
conversations=conversation, images=pil_images, force_batchify=True
|
55 |
+
).to(cuda_device, dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16)
|
56 |
+
|
57 |
+
|
58 |
+
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
|
59 |
+
|
60 |
+
outputs = vl_gpt.language_model.generate(
|
61 |
+
inputs_embeds=inputs_embeds,
|
62 |
+
attention_mask=prepare_inputs.attention_mask,
|
63 |
+
pad_token_id=tokenizer.eos_token_id,
|
64 |
+
bos_token_id=tokenizer.bos_token_id,
|
65 |
+
eos_token_id=tokenizer.eos_token_id,
|
66 |
+
max_new_tokens=512,
|
67 |
+
do_sample=False if temperature == 0 else True,
|
68 |
+
use_cache=True,
|
69 |
+
temperature=temperature,
|
70 |
+
top_p=top_p,
|
71 |
)
|
72 |
+
|
73 |
+
answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
|
74 |
+
|
75 |
+
return answer
|
76 |
+
|
77 |
+
|
78 |
+
@torch.inference_mode()
|
79 |
+
@spaces.GPU(duration=120)
|
80 |
+
def generate(
|
81 |
+
input_ids,
|
82 |
+
cfg_weight: float = 2.0,
|
83 |
+
num_inference_steps: int = 30
|
84 |
+
):
|
85 |
+
# we generate 5 images at a time, *2 for CFG
|
86 |
+
tokens = torch.stack([input_ids] * 10).cuda()
|
87 |
+
tokens[5:, 1:] = vl_chat_processor.pad_id
|
88 |
+
inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens)
|
89 |
+
print(inputs_embeds.shape)
|
90 |
+
|
91 |
+
# we remove the last <bog> token and replace it with t_emb later
|
92 |
+
inputs_embeds = inputs_embeds[:, :-1, :]
|
93 |
+
|
94 |
+
# generate with rectified flow ode
|
95 |
+
# step 1: encode with vision_gen_enc
|
96 |
+
z = torch.randn((5, 4, 48, 48), dtype=torch.bfloat16).cuda()
|
97 |
+
|
98 |
+
dt = 1.0 / num_inference_steps
|
99 |
+
dt = torch.zeros_like(z).cuda().to(torch.bfloat16) + dt
|
100 |
+
|
101 |
+
# step 2: run ode
|
102 |
+
attention_mask = torch.ones((10, inputs_embeds.shape[1]+577)).to(vl_gpt.device)
|
103 |
+
attention_mask[5:, 1:inputs_embeds.shape[1]] = 0
|
104 |
+
attention_mask = attention_mask.int()
|
105 |
+
for step in range(num_inference_steps):
|
106 |
+
# prepare inputs for the llm
|
107 |
+
z_input = torch.cat([z, z], dim=0) # for cfg
|
108 |
+
t = step / num_inference_steps * 1000.
|
109 |
+
t = torch.tensor([t] * z_input.shape[0]).to(dt)
|
110 |
+
z_enc = vl_gpt.vision_gen_enc_model(z_input, t)
|
111 |
+
z_emb, t_emb, hs = z_enc[0], z_enc[1], z_enc[2]
|
112 |
+
z_emb = z_emb.view(z_emb.shape[0], z_emb.shape[1], -1).permute(0, 2, 1)
|
113 |
+
z_emb = vl_gpt.vision_gen_enc_aligner(z_emb)
|
114 |
+
llm_emb = torch.cat([inputs_embeds, t_emb.unsqueeze(1), z_emb], dim=1)
|
115 |
+
|
116 |
+
# input to the llm
|
117 |
+
# we apply attention mask for CFG: 1 for tokens that are not masked, 0 for tokens that are masked.
|
118 |
+
if step == 0:
|
119 |
+
outputs = vl_gpt.language_model.model(inputs_embeds=llm_emb,
|
120 |
+
use_cache=True,
|
121 |
+
attention_mask=attention_mask,
|
122 |
+
past_key_values=None)
|
123 |
+
past_key_values = []
|
124 |
+
for kv_cache in past_key_values:
|
125 |
+
k, v = kv_cache[0], kv_cache[1]
|
126 |
+
past_key_values.append((k[:, :, :inputs_embeds.shape[1], :], v[:, :, :inputs_embeds.shape[1], :]))
|
127 |
+
past_key_values = tuple(past_key_values)
|
128 |
+
else:
|
129 |
+
outputs = vl_gpt.language_model.model(inputs_embeds=llm_emb,
|
130 |
+
use_cache=True,
|
131 |
+
attention_mask=attention_mask,
|
132 |
+
past_key_values=past_key_values)
|
133 |
+
hidden_states = outputs.last_hidden_state
|
134 |
+
|
135 |
+
# transform hidden_states back to v
|
136 |
+
hidden_states = vl_gpt.vision_gen_dec_aligner(vl_gpt.vision_gen_dec_aligner_norm(hidden_states[:, -576:, :]))
|
137 |
+
hidden_states = hidden_states.reshape(z_emb.shape[0], 24, 24, 768).permute(0, 3, 1, 2)
|
138 |
+
v = vl_gpt.vision_gen_dec_model(hidden_states, hs, t_emb)
|
139 |
+
v_cond, v_uncond = torch.chunk(v, 2)
|
140 |
+
v = cfg_weight * v_cond - (cfg_weight-1.) * v_uncond
|
141 |
+
z = z + dt * v
|
142 |
+
|
143 |
+
# step 3: decode with vision_gen_dec and sdxl vae
|
144 |
+
decoded_image = vae.decode(z / vae.config.scaling_factor).sample
|
145 |
+
|
146 |
+
images = decoded_image.float().clip_(-1., 1.).permute(0,2,3,1).cpu().numpy()
|
147 |
+
images = ((images+1) / 2. * 255).astype(np.uint8)
|
148 |
+
|
149 |
+
return images
|
150 |
+
|
151 |
+
def unpack(dec, width, height, parallel_size=5):
|
152 |
+
dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
|
153 |
+
dec = np.clip((dec + 1) / 2 * 255, 0, 255)
|
154 |
+
|
155 |
+
visual_img = np.zeros((parallel_size, width, height, 3), dtype=np.uint8)
|
156 |
+
visual_img[:, :, :] = dec
|
157 |
+
|
158 |
+
return visual_img
|
159 |
+
|
160 |
+
|
161 |
+
@torch.inference_mode()
|
162 |
+
@spaces.GPU(duration=120)
|
163 |
+
def generate_image(prompt,
|
164 |
+
seed=None,
|
165 |
+
guidance=5,
|
166 |
+
num_inference_steps=30):
|
167 |
+
# Clear CUDA cache and avoid tracking gradients
|
168 |
+
torch.cuda.empty_cache()
|
169 |
+
# Set the seed for reproducible results
|
170 |
+
if seed is not None:
|
171 |
+
torch.manual_seed(seed)
|
172 |
+
torch.cuda.manual_seed(seed)
|
173 |
+
np.random.seed(seed)
|
174 |
+
|
175 |
+
with torch.no_grad():
|
176 |
+
messages = [{'role': 'User', 'content': prompt},
|
177 |
+
{'role': 'Assistant', 'content': ''}]
|
178 |
+
text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(conversations=messages,
|
179 |
+
sft_format=vl_chat_processor.sft_format,
|
180 |
+
system_prompt='')
|
181 |
+
text = text + vl_chat_processor.image_start_tag
|
182 |
+
input_ids = torch.LongTensor(tokenizer.encode(text))
|
183 |
+
images = generate(input_ids,
|
184 |
+
cfg_weight=guidance,
|
185 |
+
num_inference_steps=num_inference_steps)
|
186 |
+
return [Image.fromarray(images[i]).resize((1024, 1024), Image.LANCZOS) for i in range(images.shape[0])]
|
187 |
+
|
188 |
+
|
189 |
+
|
190 |
+
# Gradio interface
|
191 |
+
with gr.Blocks() as demo:
|
192 |
+
gr.Markdown(value="# Medical Image Analysis and Generation")
|
193 |
+
# with gr.Row():
|
194 |
+
with gr.Row():
|
195 |
+
image_input = gr.Image(label="Medical Image Input")
|
196 |
+
with gr.Column():
|
197 |
+
question_input = gr.Textbox(label="Analysis Prompt (e.g., 'Identify tumor', 'Characterize lesion', 'Describe anatomic structures')")
|
198 |
+
und_seed_input = gr.Number(label="Seed", precision=0, value=42)
|
199 |
+
top_p = gr.Slider(minimum=0, maximum=1, value=0.95, step=0.05, label="top_p")
|
200 |
+
temperature = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.05, label="temperature")
|
201 |
+
|
202 |
+
understanding_button = gr.Button("Analyze Image")
|
203 |
+
understanding_output = gr.Textbox(label="Analysis Response")
|
204 |
+
|
205 |
+
examples_inpainting = gr.Examples(
|
206 |
+
label="Multimodal Understanding examples",
|
207 |
+
examples=[
|
208 |
+
[
|
209 |
+
"Identify the tumor in the given image.",
|
210 |
+
"./ct_scan.png" # Placeholder medical image path
|
211 |
+
],
|
212 |
+
[
|
213 |
+
"Characterize the lesion in the image. Is it malignant or benign?",
|
214 |
+
"./mri_scan.png", # Placeholder medical image path
|
215 |
+
],
|
216 |
+
[
|
217 |
+
"Generate a report for the given medical image.",
|
218 |
+
"./xray.png", # Placeholder medical image path
|
219 |
+
],
|
220 |
+
|
221 |
+
],
|
222 |
+
inputs=[question_input, image_input],
|
223 |
)
|
224 |
+
|
225 |
+
|
226 |
+
gr.Markdown(value="# Medical Image Generation with Hugging Face Logo")
|
227 |
|
228 |
+
|
229 |
+
|
230 |
+
with gr.Row():
|
231 |
+
cfg_weight_input = gr.Slider(minimum=1, maximum=10, value=2, step=0.5, label="CFG Weight")
|
232 |
+
step_input = gr.Slider(minimum=1, maximum=50, value=30, step=1, label="Number of Inference Steps")
|
233 |
|
234 |
+
prompt_input = gr.Textbox(label="Generation Prompt (e.g., 'Generate a CT scan with the Hugging Face logo', 'Create an MRI scan showing the Hugging Face logo', 'Render a medical x-ray with the Hugging Face logo.')")
|
235 |
+
seed_input = gr.Number(label="Seed (Optional)", precision=0, value=12345)
|
236 |
|
237 |
+
generation_button = gr.Button("Generate Images")
|
|
|
238 |
|
239 |
+
image_output = gr.Gallery(label="Generated Images", columns=2, rows=2, height=300)
|
240 |
|
241 |
+
examples_t2i = gr.Examples(
|
242 |
+
label="Medical image generation examples with Hugging Face logo.",
|
243 |
+
examples=[
|
244 |
+
"Generate a CT scan with the Hugging Face logo clearly visible.",
|
245 |
+
"Create an MRI scan showing the Hugging Face logo embedded within the tissue.",
|
246 |
+
"Render a medical x-ray with the Hugging Face logo subtly visible in the background.",
|
247 |
+
"Generate an ultrasound image with a faint Hugging Face logo on the screen",
|
248 |
+
],
|
249 |
+
inputs=prompt_input,
|
250 |
)
|
251 |
+
|
252 |
+
understanding_button.click(
|
253 |
+
multimodal_understanding,
|
254 |
+
inputs=[image_input, question_input, und_seed_input, top_p, temperature],
|
255 |
+
outputs=understanding_output
|
256 |
)
|
257 |
+
|
258 |
+
generation_button.click(
|
259 |
+
fn=generate_image,
|
260 |
+
inputs=[prompt_input, seed_input, cfg_weight_input, step_input],
|
261 |
+
outputs=image_output
|
|
|
|
|
|
|
|
|
262 |
)
|
263 |
|
264 |
+
demo.launch(share=True, ssr_mode = False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|