mgbam commited on
Commit
a8b15b5
·
verified ·
1 Parent(s): 39d4c85

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +176 -89
app.py CHANGED
@@ -1,28 +1,55 @@
1
  import gradio as gr
2
  import torch
 
3
  from transformers import AutoConfig, AutoModelForCausalLM
4
  from janus.models import MultiModalityCausalLM, VLChatProcessor
5
  from janus.utils.io import load_pil_images
6
  from PIL import Image
7
- import numpy as np
8
- import os
9
- import time
10
  import spaces
 
11
 
12
- # Load medical imaging-optimized model and processor
 
 
 
 
 
 
 
 
13
  model_path = "deepseek-ai/Janus-Pro-1B"
14
  config = AutoConfig.from_pretrained(model_path)
15
  language_config = config.language_config
16
  language_config._attn_implementation = 'eager'
17
 
18
- # Initialize model with medical imaging parameters
19
  vl_gpt = AutoModelForCausalLM.from_pretrained(
20
  model_path,
21
  language_config=language_config,
22
  trust_remote_code=True,
23
- medical_head=True # Assuming custom medical imaging head
 
 
24
  ).to(torch.bfloat16 if torch.cuda.is_available() else torch.float16)
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  if torch.cuda.is_available():
27
  vl_gpt = vl_gpt.cuda()
28
 
@@ -30,27 +57,41 @@ vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
30
  tokenizer = vl_chat_processor.tokenizer
31
  cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  @torch.inference_mode()
34
  @spaces.GPU(duration=120)
35
- def medical_image_analysis(medical_image, clinical_question, seed, top_p, temperature):
36
- """Analyze medical images (CT, MRI, X-ray, histopathology) with clinical context."""
37
  torch.cuda.empty_cache()
38
  torch.manual_seed(seed)
39
 
40
- # Medical-specific conversation template
 
 
41
  conversation = [{
42
  "role": "<|Radiologist|>",
43
- "content": f"<medical_image>\nClinical Context: {clinical_question}",
44
  "images": [medical_image],
45
  }, {"role": "<|AI_Assistant|>", "content": ""}]
46
 
47
- processed_image = [Image.fromarray(medical_image)]
48
  inputs = vl_chat_processor(
49
- conversations=conversation,
50
- images=processed_image,
51
  force_batchify=True
52
- ).to(cuda_device, dtype=torch.bfloat16)
53
-
54
  inputs_embeds = vl_gpt.prepare_inputs_embeds(**inputs)
55
 
56
  # Medical-optimized generation parameters
@@ -58,115 +99,161 @@ def medical_image_analysis(medical_image, clinical_question, seed, top_p, temper
58
  inputs_embeds=inputs_embeds,
59
  attention_mask=inputs.attention_mask,
60
  max_new_tokens=512,
61
- temperature=0.2, # Lower for clinical precision
62
  top_p=0.9,
63
- repetition_penalty=1.2, # Reduce hallucination
64
- medical_mode=True
 
65
  )
66
 
67
- findings = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
68
- return f"Clinical Findings:\n{findings}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
 
70
  @torch.inference_mode()
71
  @spaces.GPU(duration=120)
72
- def generate_medical_image(prompt, seed=None, guidance=5, t2i_temperature=0.5):
73
- """Generate synthetic medical images for educational/research purposes."""
74
  torch.cuda.empty_cache()
75
  if seed is not None:
76
  torch.manual_seed(seed)
77
-
78
- # Medical image generation parameters
79
- medical_config = {
80
- 'width': 512,
81
- 'height': 512,
82
- 'parallel_size': 3,
83
- 'modality': 'mri', # Can specify CT, X-ray, etc.
84
- 'anatomy': 'brain' # Target anatomy
85
- }
86
 
87
  messages = [{
88
  'role': '<|Clinician|>',
89
- 'content': f"{prompt} [Modality: {medical_config['modality']}, Anatomy: {medical_config['anatomy']}]"
90
  }]
91
 
92
- text = vl_chat_processor.apply_medical_template(
93
  messages,
94
- system_prompt='Generate education-quality medical imaging data'
95
  )
96
 
97
  input_ids = torch.LongTensor(tokenizer.encode(text)).to(cuda_device)
98
- generated_tokens, patches = vl_gpt.generate_medical_image(
 
 
99
  input_ids,
100
- **medical_config,
 
101
  cfg_weight=guidance,
102
- temperature=t2i_temperature
 
 
 
103
  )
104
 
105
- # Post-processing for medical imaging standards
106
- synthetic_images = postprocess_medical_images(patches, **medical_config)
107
  return [Image.fromarray(img).resize((512, 512)) for img in synthetic_images]
108
 
109
- # Medical-optimized Gradio interface
110
- with gr.Blocks(title="Medical Imaging AI Suite") as demo:
111
- gr.Markdown("""## Medical Image Analysis Suite v2.1
112
- *For research use only - not for clinical diagnosis*""")
 
 
 
 
 
113
 
114
  with gr.Tab("Clinical Image Analysis"):
 
115
  with gr.Row():
116
- medical_image_input = gr.Image(label="Upload Medical Scan")
117
- clinical_question = gr.Textbox(label="Clinical Query",
118
- placeholder="E.g.: 'Assess tumor progression in this MRI series'")
119
-
120
- with gr.Accordion("Advanced Parameters", open=False):
121
- und_seed = gr.Number(42, label="Reproducibility Seed")
122
- analysis_top_p = gr.Slider(0.8, 1.0, 0.95, label="Diagnostic Certainty")
123
- analysis_temp = gr.Slider(0.1, 0.5, 0.2, label="Analysis Precision")
124
-
125
- analysis_btn = gr.Button("Analyze Scan", variant="primary")
126
- clinical_report = gr.Textbox(label="AI Analysis Report", interactive=False)
127
-
128
- gr.Examples(
129
- examples=[
130
- ["Identify pulmonary nodules in this CT scan", "ct_chest.png"],
131
- ["Assess MRI for multiple sclerosis lesions", "brain_mri.jpg"],
132
- ["Histopathology analysis: tumor grading", "biopsy_slide.png"]
133
- ],
134
- inputs=[clinical_question, medical_image_input]
135
- )
136
 
137
- with gr.Tab("Medical Imaging Synthesis"):
138
- gr.Markdown("**Educational Image Generation**")
139
- synth_prompt = gr.Textbox(label="Synthesis Prompt",
140
- placeholder="E.g.: 'Synthetic brain MRI showing glioblastoma multiforme'")
141
-
 
 
 
142
  with gr.Row():
143
- synth_guidance = gr.Slider(3, 7, 5, label="Anatomical Accuracy")
144
- synth_temp = gr.Slider(0.3, 1.0, 0.6, label="Synthesis Variability")
145
-
146
- synth_btn = gr.Button("Generate Educational Images", variant="secondary")
147
- synthetic_gallery = gr.Gallery(label="Synthetic Medical Images",
148
- columns=3, object_fit="contain")
149
-
150
- gr.Examples(
151
- examples=[
152
- "High-resolution CT of healthy lung parenchyma",
153
- "T2-weighted MRI of lumbar spine with herniated disc",
154
- "Histopathology slide of benign breast tissue"
155
- ],
156
- inputs=synth_prompt
157
- )
 
 
 
 
 
158
 
159
- # Connect functionality
160
  analysis_btn.click(
161
  medical_image_analysis,
162
- inputs=[medical_image_input, clinical_question, und_seed, analysis_top_p, analysis_temp],
163
- outputs=clinical_report
164
  )
165
 
166
- synth_btn.click(
167
  generate_medical_image,
168
- inputs=[synth_prompt, und_seed, synth_guidance, synth_temp],
169
- outputs=synthetic_gallery
170
  )
 
 
 
 
 
 
 
 
 
 
 
171
 
172
  demo.launch(share=True, server_port=7860)
 
1
  import gradio as gr
2
  import torch
3
+ import numpy as np
4
  from transformers import AutoConfig, AutoModelForCausalLM
5
  from janus.models import MultiModalityCausalLM, VLChatProcessor
6
  from janus.utils.io import load_pil_images
7
  from PIL import Image
 
 
 
8
  import spaces
9
+ from torchvision import transforms
10
 
11
+ # Medical Imaging Configuration
12
+ MEDICAL_CONFIG = {
13
+ "modality": "CT", # Default imaging modality
14
+ "anatomical_region": "Chest",
15
+ "clinical_task": "analysis",
16
+ "report_style": "structured"
17
+ }
18
+
19
+ # Load base model
20
  model_path = "deepseek-ai/Janus-Pro-1B"
21
  config = AutoConfig.from_pretrained(model_path)
22
  language_config = config.language_config
23
  language_config._attn_implementation = 'eager'
24
 
25
+ # Initialize model with medical adaptations
26
  vl_gpt = AutoModelForCausalLM.from_pretrained(
27
  model_path,
28
  language_config=language_config,
29
  trust_remote_code=True,
30
+ hidden_dropout_prob=0.1,
31
+ attention_probs_dropout_prob=0.1,
32
+ output_attentions=True
33
  ).to(torch.bfloat16 if torch.cuda.is_available() else torch.float16)
34
 
35
+ # Add medical projection layer
36
+ class MedicalProjectionWrapper(torch.nn.Module):
37
+ def __init__(self, base_model):
38
+ super().__init__()
39
+ self.base_model = base_model
40
+ self.medical_proj = torch.nn.Linear(
41
+ base_model.config.hidden_size,
42
+ base_model.config.hidden_size * 2
43
+ )
44
+ self.activation = torch.nn.GELU()
45
+
46
+ def forward(self, *args, **kwargs):
47
+ outputs = self.base_model(*args, **kwargs)
48
+ medical_rep = self.activation(self.medical_proj(outputs.last_hidden_state))
49
+ return outputs.__class__(last_hidden_state=medical_rep)
50
+
51
+ vl_gpt.language_model = MedicalProjectionWrapper(vl_gpt.language_model)
52
+
53
  if torch.cuda.is_available():
54
  vl_gpt = vl_gpt.cuda()
55
 
 
57
  tokenizer = vl_chat_processor.tokenizer
58
  cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
59
 
60
+ # Medical image preprocessing
61
+ def preprocess_medical_image(image):
62
+ if isinstance(image, np.ndarray):
63
+ image = Image.fromarray(image)
64
+
65
+ medical_transforms = transforms.Compose([
66
+ transforms.Resize((512, 512)),
67
+ transforms.Grayscale(num_output_channels=3),
68
+ transforms.ToTensor(),
69
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
70
+ ])
71
+
72
+ return medical_transforms(image).unsqueeze(0).to(cuda_device)
73
+
74
  @torch.inference_mode()
75
  @spaces.GPU(duration=120)
76
+ def medical_image_analysis(image, clinical_query, seed=42):
 
77
  torch.cuda.empty_cache()
78
  torch.manual_seed(seed)
79
 
80
+ # Preprocess with medical transformations
81
+ medical_image = preprocess_medical_image(image)
82
+
83
  conversation = [{
84
  "role": "<|Radiologist|>",
85
+ "content": f"<medical_image>\nClinical Context: {clinical_query}",
86
  "images": [medical_image],
87
  }, {"role": "<|AI_Assistant|>", "content": ""}]
88
 
 
89
  inputs = vl_chat_processor(
90
+ conversations=conversation,
91
+ images=[Image.fromarray(image)],
92
  force_batchify=True
93
+ ).to(cuda_device)
94
+
95
  inputs_embeds = vl_gpt.prepare_inputs_embeds(**inputs)
96
 
97
  # Medical-optimized generation parameters
 
99
  inputs_embeds=inputs_embeds,
100
  attention_mask=inputs.attention_mask,
101
  max_new_tokens=512,
102
+ temperature=0.2,
103
  top_p=0.9,
104
+ num_beams=5,
105
+ repetition_penalty=1.5,
106
+ early_stopping=True
107
  )
108
 
109
+ report = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
110
+ return format_medical_report(report)
111
+
112
+ def format_medical_report(raw_text):
113
+ sections = {
114
+ "Findings": "",
115
+ "Impression": "",
116
+ "Recommendations": ""
117
+ }
118
+
119
+ current_section = None
120
+ for line in raw_text.split('\n'):
121
+ if "FINDINGS:" in line:
122
+ current_section = "Findings"
123
+ elif "IMPRESSION:" in line:
124
+ current_section = "Impression"
125
+ elif "RECOMMENDATIONS:" in line:
126
+ current_section = "Recommendations"
127
+ elif current_section:
128
+ sections[current_section] += line.strip() + '\n'
129
+
130
+ return f"""**Clinical Report**
131
+
132
+ **Findings:**
133
+ {sections['Findings'] or 'No significant findings'}
134
+
135
+ **Impression:**
136
+ {sections['Impression'] or 'No conclusive diagnosis'}
137
+
138
+ **Recommendations:**
139
+ {sections['Recommendations'] or 'Follow-up as clinically indicated'}"""
140
 
141
+ # Medical image generation components
142
  @torch.inference_mode()
143
  @spaces.GPU(duration=120)
144
+ def generate_medical_image(prompt, seed=12345, guidance=7, temperature=0.6):
 
145
  torch.cuda.empty_cache()
146
  if seed is not None:
147
  torch.manual_seed(seed)
148
+
149
+ medical_prompt = f"{prompt} [Modality: {MEDICAL_CONFIG['modality']}, Anatomy: {MEDICAL_CONFIG['anatomical_region']}]"
 
 
 
 
 
 
 
150
 
151
  messages = [{
152
  'role': '<|Clinician|>',
153
+ 'content': medical_prompt
154
  }]
155
 
156
+ text = vl_chat_processor.apply_chat_template(
157
  messages,
158
+ system_prompt='Generate educational medical imaging data'
159
  )
160
 
161
  input_ids = torch.LongTensor(tokenizer.encode(text)).to(cuda_device)
162
+
163
+ # Medical image generation parameters
164
+ generated_tokens, patches = vl_gpt.generate(
165
  input_ids,
166
+ width=512,
167
+ height=512,
168
  cfg_weight=guidance,
169
+ temperature=temperature,
170
+ parallel_size=3,
171
+ image_token_num_per_image=576,
172
+ patch_size=16
173
  )
174
 
175
+ synthetic_images = postprocess_medical_images(patches)
 
176
  return [Image.fromarray(img).resize((512, 512)) for img in synthetic_images]
177
 
178
+ def postprocess_medical_images(patches):
179
+ patches = patches.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
180
+ patches = np.clip((patches + 1) / 2 * 255, 0, 255).astype(np.uint8)
181
+ return [patches[i] for i in range(patches.shape[0])]
182
+
183
+ # Medical-optimized interface
184
+ with gr.Blocks(title="Medical Imaging AI", theme=gr.themes.Soft()) as demo:
185
+ gr.Markdown("""## Medical Imaging Analysis Suite v3.2
186
+ *Research use only - Not for clinical decision-making*""")
187
 
188
  with gr.Tab("Clinical Image Analysis"):
189
+ gr.Markdown("### Upload medical scan and clinical context")
190
  with gr.Row():
191
+ with gr.Column(scale=1):
192
+ med_image = gr.Image(label="Medical Imaging Study", type="numpy")
193
+ med_upload_btns = gr.Row([
194
+ gr.Button("CT Scan"),
195
+ gr.Button("MRI"),
196
+ gr.Button("X-ray")
197
+ ])
198
+
199
+ with gr.Column(scale=2):
200
+ clinical_input = gr.Textbox(label="Clinical Context", lines=3,
201
+ placeholder="Patient history and clinical question...")
202
+ analysis_btn = gr.Button("Analyze Study", variant="primary")
203
+ report_output = gr.Markdown(label="AI Analysis Report")
 
 
 
 
 
 
 
204
 
205
+ gr.Examples([
206
+ ["Evaluate lung nodules in this CT scan", "ct_chest.png"],
207
+ ["Assess brain MRI for metastatic lesions", "brain_mri.jpg"],
208
+ ["Analyze bone structure in this wrist X-ray", "wrist_xray.png"]
209
+ ], [clinical_input, med_image])
210
+
211
+ with gr.Tab("Educational Image Synthesis"):
212
+ gr.Markdown("### Generate synthetic medical images for training")
213
  with gr.Row():
214
+ with gr.Column():
215
+ synth_prompt = gr.Textbox(label="Synthesis Prompt", lines=2,
216
+ placeholder="Describe the desired medical image...")
217
+ gr.Markdown("**Modality Options**")
218
+ modality_btns = gr.Row([
219
+ gr.Button("CT"),
220
+ gr.Button("MRI"),
221
+ gr.Button("X-ray")
222
+ ])
223
+
224
+ with gr.Column():
225
+ synth_params = gr.Accordion("Advanced Parameters", open=False)
226
+ with synth_params:
227
+ gr.Row([
228
+ gr.Slider(3, 7, 5, label="Anatomical Accuracy"),
229
+ gr.Slider(0.3, 1.0, 0.6, label="Synthesis Variability")
230
+ ])
231
+ generate_btn = gr.Button("Generate Educational Images", variant="secondary")
232
+
233
+ synth_gallery = gr.Gallery(label="Synthetic Images", columns=3, height=400)
234
 
235
+ # Event handlers
236
  analysis_btn.click(
237
  medical_image_analysis,
238
+ [med_image, clinical_input],
239
+ report_output
240
  )
241
 
242
+ generate_btn.click(
243
  generate_medical_image,
244
+ [synth_prompt, synth_params],
245
+ synth_gallery
246
  )
247
+
248
+ for btn in [*med_upload_btns.children, *modality_btns.children]:
249
+ btn.click(
250
+ lambda m: MEDICAL_CONFIG.update(modality=m),
251
+ [btn],
252
+ None
253
+ ).then(
254
+ lambda: gr.Info(f"Modality set to {MEDICAL_CONFIG['modality']}"),
255
+ None,
256
+ None
257
+ )
258
 
259
  demo.launch(share=True, server_port=7860)