mgbam commited on
Commit
3afe8c3
·
verified ·
1 Parent(s): a8b15b5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +111 -214
app.py CHANGED
@@ -1,259 +1,156 @@
1
  import gradio as gr
2
  import torch
3
  import numpy as np
4
- from transformers import AutoConfig, AutoModelForCausalLM
5
- from janus.models import MultiModalityCausalLM, VLChatProcessor
6
- from janus.utils.io import load_pil_images
7
  from PIL import Image
8
  import spaces
9
- from torchvision import transforms
10
 
11
- # Medical Imaging Configuration
12
  MEDICAL_CONFIG = {
13
- "modality": "CT", # Default imaging modality
14
- "anatomical_region": "Chest",
15
- "clinical_task": "analysis",
16
- "report_style": "structured"
17
  }
18
 
19
- # Load base model
20
  model_path = "deepseek-ai/Janus-Pro-1B"
21
- config = AutoConfig.from_pretrained(model_path)
22
- language_config = config.language_config
23
- language_config._attn_implementation = 'eager'
24
 
25
- # Initialize model with medical adaptations
26
- vl_gpt = AutoModelForCausalLM.from_pretrained(
27
- model_path,
28
- language_config=language_config,
29
- trust_remote_code=True,
30
- hidden_dropout_prob=0.1,
31
- attention_probs_dropout_prob=0.1,
32
- output_attentions=True
33
- ).to(torch.bfloat16 if torch.cuda.is_available() else torch.float16)
34
-
35
- # Add medical projection layer
36
- class MedicalProjectionWrapper(torch.nn.Module):
37
  def __init__(self, base_model):
38
  super().__init__()
39
  self.base_model = base_model
40
- self.medical_proj = torch.nn.Linear(
41
- base_model.config.hidden_size,
42
- base_model.config.hidden_size * 2
43
- )
44
- self.activation = torch.nn.GELU()
45
-
46
  def forward(self, *args, **kwargs):
47
  outputs = self.base_model(*args, **kwargs)
48
- medical_rep = self.activation(self.medical_proj(outputs.last_hidden_state))
49
- return outputs.__class__(last_hidden_state=medical_rep)
50
 
51
- vl_gpt.language_model = MedicalProjectionWrapper(vl_gpt.language_model)
 
52
 
53
  if torch.cuda.is_available():
54
- vl_gpt = vl_gpt.cuda()
55
 
56
  vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
57
- tokenizer = vl_chat_processor.tokenizer
58
- cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
59
 
60
- # Medical image preprocessing
61
- def preprocess_medical_image(image):
62
- if isinstance(image, np.ndarray):
63
- image = Image.fromarray(image)
64
-
65
- medical_transforms = transforms.Compose([
66
- transforms.Resize((512, 512)),
67
- transforms.Grayscale(num_output_channels=3),
68
- transforms.ToTensor(),
69
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
70
- ])
71
-
72
- return medical_transforms(image).unsqueeze(0).to(cuda_device)
73
 
74
  @torch.inference_mode()
75
  @spaces.GPU(duration=120)
76
- def medical_image_analysis(image, clinical_query, seed=42):
77
- torch.cuda.empty_cache()
78
- torch.manual_seed(seed)
79
 
80
- # Preprocess with medical transformations
81
- medical_image = preprocess_medical_image(image)
 
 
 
82
 
83
  conversation = [{
84
- "role": "<|Radiologist|>",
85
- "content": f"<medical_image>\nClinical Context: {clinical_query}",
86
- "images": [medical_image],
87
  }, {"role": "<|AI_Assistant|>", "content": ""}]
88
-
89
  inputs = vl_chat_processor(
90
  conversations=conversation,
91
- images=[Image.fromarray(image)],
92
  force_batchify=True
93
- ).to(cuda_device)
94
 
95
- inputs_embeds = vl_gpt.prepare_inputs_embeds(**inputs)
96
-
97
- # Medical-optimized generation parameters
98
- outputs = vl_gpt.language_model.generate(
99
- inputs_embeds=inputs_embeds,
100
  attention_mask=inputs.attention_mask,
101
  max_new_tokens=512,
102
- temperature=0.2,
103
  top_p=0.9,
104
- num_beams=5,
105
- repetition_penalty=1.5,
106
- early_stopping=True
107
  )
108
 
109
- report = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
110
- return format_medical_report(report)
111
 
112
- def format_medical_report(raw_text):
 
113
  sections = {
114
- "Findings": "",
115
- "Impression": "",
116
- "Recommendations": ""
 
 
 
 
 
 
 
117
  }
118
 
119
- current_section = None
120
- for line in raw_text.split('\n'):
121
- if "FINDINGS:" in line:
122
- current_section = "Findings"
123
- elif "IMPRESSION:" in line:
124
- current_section = "Impression"
125
- elif "RECOMMENDATIONS:" in line:
126
- current_section = "Recommendations"
127
- elif current_section:
128
- sections[current_section] += line.strip() + '\n'
129
-
130
- return f"""**Clinical Report**
131
-
132
- **Findings:**
133
- {sections['Findings'] or 'No significant findings'}
134
-
135
- **Impression:**
136
- {sections['Impression'] or 'No conclusive diagnosis'}
137
-
138
- **Recommendations:**
139
- {sections['Recommendations'] or 'Follow-up as clinically indicated'}"""
140
-
141
- # Medical image generation components
142
- @torch.inference_mode()
143
- @spaces.GPU(duration=120)
144
- def generate_medical_image(prompt, seed=12345, guidance=7, temperature=0.6):
145
- torch.cuda.empty_cache()
146
- if seed is not None:
147
- torch.manual_seed(seed)
148
-
149
- medical_prompt = f"{prompt} [Modality: {MEDICAL_CONFIG['modality']}, Anatomy: {MEDICAL_CONFIG['anatomical_region']}]"
150
-
151
- messages = [{
152
- 'role': '<|Clinician|>',
153
- 'content': medical_prompt
154
- }]
155
-
156
- text = vl_chat_processor.apply_chat_template(
157
- messages,
158
- system_prompt='Generate educational medical imaging data'
159
- )
160
-
161
- input_ids = torch.LongTensor(tokenizer.encode(text)).to(cuda_device)
162
-
163
- # Medical image generation parameters
164
- generated_tokens, patches = vl_gpt.generate(
165
- input_ids,
166
- width=512,
167
- height=512,
168
- cfg_weight=guidance,
169
- temperature=temperature,
170
- parallel_size=3,
171
- image_token_num_per_image=576,
172
- patch_size=16
173
  )
174
-
175
- synthetic_images = postprocess_medical_images(patches)
176
- return [Image.fromarray(img).resize((512, 512)) for img in synthetic_images]
177
 
178
- def postprocess_medical_images(patches):
179
- patches = patches.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
180
- patches = np.clip((patches + 1) / 2 * 255, 0, 255).astype(np.uint8)
181
- return [patches[i] for i in range(patches.shape[0])]
182
-
183
- # Medical-optimized interface
184
- with gr.Blocks(title="Medical Imaging AI", theme=gr.themes.Soft()) as demo:
185
- gr.Markdown("""## Medical Imaging Analysis Suite v3.2
186
- *Research use only - Not for clinical decision-making*""")
187
-
188
- with gr.Tab("Clinical Image Analysis"):
189
- gr.Markdown("### Upload medical scan and clinical context")
190
- with gr.Row():
191
- with gr.Column(scale=1):
192
- med_image = gr.Image(label="Medical Imaging Study", type="numpy")
193
- med_upload_btns = gr.Row([
194
- gr.Button("CT Scan"),
195
- gr.Button("MRI"),
196
- gr.Button("X-ray")
197
- ])
198
-
199
- with gr.Column(scale=2):
200
- clinical_input = gr.Textbox(label="Clinical Context", lines=3,
201
- placeholder="Patient history and clinical question...")
202
- analysis_btn = gr.Button("Analyze Study", variant="primary")
203
- report_output = gr.Markdown(label="AI Analysis Report")
204
-
205
- gr.Examples([
206
- ["Evaluate lung nodules in this CT scan", "ct_chest.png"],
207
- ["Assess brain MRI for metastatic lesions", "brain_mri.jpg"],
208
- ["Analyze bone structure in this wrist X-ray", "wrist_xray.png"]
209
- ], [clinical_input, med_image])
210
-
211
- with gr.Tab("Educational Image Synthesis"):
212
- gr.Markdown("### Generate synthetic medical images for training")
213
- with gr.Row():
214
- with gr.Column():
215
- synth_prompt = gr.Textbox(label="Synthesis Prompt", lines=2,
216
- placeholder="Describe the desired medical image...")
217
- gr.Markdown("**Modality Options**")
218
- modality_btns = gr.Row([
219
- gr.Button("CT"),
220
- gr.Button("MRI"),
221
- gr.Button("X-ray")
222
- ])
223
-
224
- with gr.Column():
225
- synth_params = gr.Accordion("Advanced Parameters", open=False)
226
- with synth_params:
227
- gr.Row([
228
- gr.Slider(3, 7, 5, label="Anatomical Accuracy"),
229
- gr.Slider(0.3, 1.0, 0.6, label="Synthesis Variability")
230
- ])
231
- generate_btn = gr.Button("Generate Educational Images", variant="secondary")
232
-
233
- synth_gallery = gr.Gallery(label="Synthetic Images", columns=3, height=400)
234
-
235
- # Event handlers
236
- analysis_btn.click(
237
- medical_image_analysis,
238
- [med_image, clinical_input],
239
- report_output
240
- )
241
-
242
- generate_btn.click(
243
- generate_medical_image,
244
- [synth_prompt, synth_params],
245
- synth_gallery
246
- )
247
-
248
- for btn in [*med_upload_btns.children, *modality_btns.children]:
249
- btn.click(
250
- lambda m: MEDICAL_CONFIG.update(modality=m),
251
- [btn],
252
- None
253
- ).then(
254
- lambda: gr.Info(f"Modality set to {MEDICAL_CONFIG['modality']}"),
255
- None,
256
- None
257
- )
258
 
259
- demo.launch(share=True, server_port=7860)
 
1
  import gradio as gr
2
  import torch
3
  import numpy as np
4
+ from transformers import AutoModelForCausalLM
5
+ from janus.models import VLChatProcessor
 
6
  from PIL import Image
7
  import spaces
 
8
 
9
+ # Medical Image Analysis Configuration
10
  MEDICAL_CONFIG = {
11
+ "echo_guidelines": "ASE 2023 Standards",
12
+ "histo_guidelines": "CAP Protocols 2024",
13
+ "cardiac_params": ["LVEF", "E/A Ratio", "Wall Motion"],
14
+ "histo_params": ["Nuclear Atypia", "Mitotic Count", "Stromal Invasion"]
15
  }
16
 
17
+ # Initialize Medical Imaging Model
18
  model_path = "deepseek-ai/Janus-Pro-1B"
 
 
 
19
 
20
+ class MedicalImagingAdapter(torch.nn.Module):
 
 
 
 
 
 
 
 
 
 
 
21
  def __init__(self, base_model):
22
  super().__init__()
23
  self.base_model = base_model
24
+ # Cardiac-specific projections
25
+ self.cardiac_proj = torch.nn.Linear(2048, 2048)
26
+ # Histopathology-specific projections
27
+ self.histo_proj = torch.nn.Linear(2048, 2048)
28
+
 
29
  def forward(self, *args, **kwargs):
30
  outputs = self.base_model(*args, **kwargs)
31
+ return outputs
 
32
 
33
+ vl_gpt = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)
34
+ vl_gpt.language_model = MedicalImagingAdapter(vl_gpt.language_model)
35
 
36
  if torch.cuda.is_available():
37
+ vl_gpt = vl_gpt.to(torch.bfloat16).cuda()
38
 
39
  vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
 
 
40
 
41
+ # Medical Image Processing Pipelines
42
+ def preprocess_echo(image):
43
+ """Process echocardiography images"""
44
+ img = Image.fromarray(image).convert('L') # Grayscale
45
+ return np.array(img.resize((512, 512)))
46
+
47
+ def preprocess_histo(image):
48
+ """Process histopathology slides"""
49
+ img = Image.fromarray(image)
50
+ return np.array(img.resize((1024, 1024)))
 
 
 
51
 
52
  @torch.inference_mode()
53
  @spaces.GPU(duration=120)
54
+ def analyze_medical_case(image, clinical_context, modality):
55
+ # Preprocess based on modality
56
+ processed_img = preprocess_echo(image) if modality == "Echo" else preprocess_histo(image)
57
 
58
+ # Create modality-specific prompt
59
+ system_prompt = f"""
60
+ Analyze this {modality} image following {MEDICAL_CONFIG['echo_guidelines' if modality=='Echo' else 'histo_guidelines']}.
61
+ Clinical Context: {clinical_context}
62
+ """
63
 
64
  conversation = [{
65
+ "role": "<|Radiologist|>" if modality == "Echo" else "<|Pathologist|>",
66
+ "content": system_prompt,
67
+ "images": [processed_img],
68
  }, {"role": "<|AI_Assistant|>", "content": ""}]
69
+
70
  inputs = vl_chat_processor(
71
  conversations=conversation,
72
+ images=[Image.fromarray(processed_img)],
73
  force_batchify=True
74
+ ).to(vl_gpt.device)
75
 
76
+ outputs = vl_gpt.generate(
77
+ inputs_embeds=vl_gpt.prepare_inputs_embeds(**inputs),
 
 
 
78
  attention_mask=inputs.attention_mask,
79
  max_new_tokens=512,
80
+ temperature=0.1,
81
  top_p=0.9,
82
+ repetition_penalty=1.5
 
 
83
  )
84
 
85
+ report = vl_chat_processor.tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
86
+ return format_medical_report(report, modality)
87
 
88
+ def format_medical_report(text, modality):
89
+ # Structure report based on modality
90
  sections = {
91
+ "Echo": [
92
+ ("Chamber Dimensions", "LVEDD", "LVESD"),
93
+ ("Valvular Function", "Aortic Valve", "Mitral Valve"),
94
+ ("Hemodynamics", "E/A Ratio", "LVEF")
95
+ ],
96
+ "Histo": [
97
+ ("Architecture", "Gland Formation", "Stromal Pattern"),
98
+ ("Cellular Features", "Nuclear Atypia", "Mitotic Count"),
99
+ ("Diagnostic Impression", "Tumor Grade", "Margin Status")
100
+ ]
101
  }
102
 
103
+ formatted = f"**{modality} Analysis Report**\n\n"
104
+ for section in sections[modality]:
105
+ header = section[0]
106
+ formatted += f"### {header}\n"
107
+ for sub in section[1:]:
108
+ if sub in text:
109
+ start = text.find(sub)
110
+ end = text.find("\n\n", start)
111
+ formatted += f"- **{sub}:** {text[start+len(sub)+1:end].strip()}\n"
112
+ return formatted
113
+
114
+ # Medical Imaging Interface
115
+ with gr.Blocks(title="Cardiac & Histopathology AI", theme=gr.themes.Soft()) as demo:
116
+ gr.Markdown("""
117
+ ## Medical Imaging Analysis Platform
118
+ *Analyzes echocardiograms and histopathology slides - Research Use Only*
119
+ """)
120
+
121
+ with gr.Row():
122
+ with gr.Column():
123
+ image_input = gr.Image(label="Upload Medical Image")
124
+ modality_select = gr.Radio(
125
+ ["Echo", "Histo"],
126
+ label="Image Modality",
127
+ info="Select 'Echo' for cardiac ultrasound, 'Histo' for biopsy slides"
128
+ )
129
+ clinical_input = gr.Textbox(
130
+ label="Clinical Context",
131
+ placeholder="e.g., 'Assess LV function' or 'Evaluate for malignancy'"
132
+ )
133
+ analyze_btn = gr.Button("Analyze Case", variant="primary")
134
+
135
+ with gr.Column():
136
+ report_output = gr.Markdown(label="AI Clinical Report")
137
+
138
+ # Preloaded examples from space files
139
+ gr.Examples(
140
+ examples=[
141
+ ["Evaluate LV systolic function", "case1.png", "Echo"],
142
+ ["Assess mitral valve function", "case2.jpg", "Echo"],
143
+ ["Analyze for malignant features", "case3.png", "Histo"],
144
+ ["Evaluate tumor margins", "case4.png", "Histo"]
145
+ ],
146
+ inputs=[clinical_input, image_input, modality_select],
147
+ label="Example Medical Cases"
 
 
 
 
 
 
 
 
 
148
  )
 
 
 
149
 
150
+ analyze_btn.click(
151
+ analyze_medical_case,
152
+ [image_input, clinical_input, modality_select],
153
+ report_output
154
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
+ demo.launch(share=True)