import gradio as gr import torch import numpy as np from transformers import AutoModelForCausalLM, AutoTokenizer from janus.models import VLChatProcessor from PIL import Image import spaces # Suppress specific warnings import warnings warnings.filterwarnings("ignore", category=FutureWarning) # Medical Imaging Analysis Configuration MEDICAL_CONFIG = { "echo_guidelines": "ASE 2023 Standards", "histo_guidelines": "CAP Protocols 2024", "cardiac_params": ["LVEF", "E/A Ratio", "Wall Motion"], "histo_params": ["Nuclear Atypia", "Mitotic Count", "Stromal Invasion"] } # Initialize Medical Imaging Model model_path = "deepseek-ai/Janus-Pro-1B" class MedicalImagingAdapter(torch.nn.Module): def __init__(self, base_model): super().__init__() self.base_model = base_model # Cardiac-specific projections self.cardiac_proj = torch.nn.Linear(2048, 2048) # Histopathology-specific projections self.histo_proj = torch.nn.Linear(2048, 2048) def forward(self, *args, **kwargs): outputs = self.base_model(*args, **kwargs) return outputs vl_gpt = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True) vl_gpt.language_model = MedicalImagingAdapter(vl_gpt.language_model) if torch.cuda.is_available(): vl_gpt = vl_gpt.to(torch.bfloat16).cuda() vl_chat_processor = VLChatProcessor.from_pretrained(model_path) # **Fix: Set legacy=False in tokenizer to use the new behavior** vl_chat_processor.tokenizer = AutoTokenizer.from_pretrained(model_path, legacy=False) # Medical Image Processing Pipelines def preprocess_echo(image): """Process echocardiography images""" img = Image.fromarray(image).convert('L') # Grayscale return np.array(img.resize((512, 512))) def preprocess_histo(image): """Process histopathology slides""" img = Image.fromarray(image) return np.array(img.resize((1024, 1024))) @torch.inference_mode() @spaces.GPU(duration=120) def analyze_medical_case(image, clinical_context, modality): # Preprocess based on modality processed_img = preprocess_echo(image) if modality == "Echo" else preprocess_histo(image) # Create modality-specific prompt system_prompt = f""" Analyze this {modality} image following {MEDICAL_CONFIG['echo_guidelines' if modality=='Echo' else 'histo_guidelines']}. Clinical Context: {clinical_context} """ conversation = [{ "role": "<|Radiologist|>" if modality == "Echo" else "<|Pathologist|>", "content": system_prompt, "images": [processed_img], }, {"role": "<|AI_Assistant|>", "content": ""}] inputs = vl_chat_processor( conversations=conversation, images=[Image.fromarray(processed_img)], force_batchify=True ).to(vl_gpt.device) outputs = vl_gpt.generate( inputs_embeds=vl_gpt.prepare_inputs_embeds(**inputs), attention_mask=inputs.attention_mask, max_new_tokens=512, temperature=0.1, top_p=0.9, repetition_penalty=1.5 ) report = vl_chat_processor.tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True) return format_medical_report(report, modality) def format_medical_report(text, modality): # Structure report based on modality sections = { "Echo": [ ("Chamber Dimensions", "LVEDD", "LVESD"), ("Valvular Function", "Aortic Valve", "Mitral Valve"), ("Hemodynamics", "E/A Ratio", "LVEF") ], "Histo": [ ("Architecture", "Gland Formation", "Stromal Pattern"), ("Cellular Features", "Nuclear Atypia", "Mitotic Count"), ("Diagnostic Impression", "Tumor Grade", "Margin Status") ] } formatted = f"**{modality} Analysis Report**\n\n" for section in sections[modality]: header = section[0] formatted += f"### {header}\n" for sub in section[1:]: if sub in text: start = text.find(sub) end = text.find("\n\n", start) formatted += f"- **{sub}:** {text[start+len(sub)+1:end].strip()}\n" return formatted # Medical Imaging Interface with gr.Blocks(title="Cardiac & Histopathology AI", theme=gr.themes.Soft()) as demo: gr.Markdown(""" ## Medical Imaging Analysis Platform *Analyzes echocardiograms and histopathology slides - Research Use Only* """) with gr.Row(): with gr.Column(): image_input = gr.Image(label="Upload Medical Image") modality_select = gr.Radio( ["Echo", "Histo"], label="Image Modality", info="Select 'Echo' for cardiac ultrasound, 'Histo' for biopsy slides" ) clinical_input = gr.Textbox( label="Clinical Context", placeholder="e.g., 'Assess LV function' or 'Evaluate for malignancy'" ) analyze_btn = gr.Button("Analyze Case", variant="primary") with gr.Column(): report_output = gr.Markdown(label="AI Clinical Report") # Preloaded examples gr.Examples( examples=[ ["Evaluate LV systolic function", "case1.png", "Echo"], ["Assess mitral valve function", "case2.jpg", "Echo"], ["Analyze for malignant features", "case3.png", "Histo"], ["Evaluate tumor margins", "case4.png", "Histo"] ], inputs=[clinical_input, image_input, modality_select], label="Example Medical Cases" ) # **Fixed: Removed @demo.func and used .click() correctly** analyze_btn.click( analyze_medical_case, [image_input, clinical_input, modality_select], report_output ) demo.launch(share=True)