|
import logging
|
|
|
|
from .image_analyzer import XRayImageAnalyzer
|
|
from .text_analyzer import MedicalReportAnalyzer
|
|
|
|
|
|
class MultimodalFusion:
|
|
"""
|
|
A class for fusing insights from image analysis and text analysis of medical data.
|
|
|
|
This fusion approach combines the strengths of both modalities:
|
|
- Images provide visual evidence of abnormalities
|
|
- Text reports provide context, history and radiologist interpretations
|
|
|
|
The combined analysis provides a more comprehensive understanding than either modality alone.
|
|
"""
|
|
|
|
def __init__(self, image_model=None, text_model=None, device=None):
|
|
"""
|
|
Initialize the multimodal fusion module with image and text analyzers.
|
|
|
|
Args:
|
|
image_model (str, optional): Model to use for image analysis
|
|
text_model (str, optional): Model to use for text analysis
|
|
device (str, optional): Device to run models on ('cuda' or 'cpu')
|
|
"""
|
|
self.logger = logging.getLogger(__name__)
|
|
|
|
|
|
if device is None:
|
|
import torch
|
|
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
else:
|
|
self.device = device
|
|
|
|
self.logger.info(f"Using device: {self.device}")
|
|
|
|
|
|
try:
|
|
self.image_analyzer = XRayImageAnalyzer(
|
|
model_name=image_model
|
|
if image_model
|
|
else "facebook/deit-base-patch16-224-medical-cxr",
|
|
device=self.device,
|
|
)
|
|
self.logger.info("Successfully initialized image analyzer")
|
|
except Exception as e:
|
|
self.logger.error(f"Failed to initialize image analyzer: {e}")
|
|
self.image_analyzer = None
|
|
|
|
|
|
try:
|
|
self.text_analyzer = MedicalReportAnalyzer(
|
|
classifier_model=text_model if text_model else "medicalai/ClinicalBERT",
|
|
device=self.device,
|
|
)
|
|
self.logger.info("Successfully initialized text analyzer")
|
|
except Exception as e:
|
|
self.logger.error(f"Failed to initialize text analyzer: {e}")
|
|
self.text_analyzer = None
|
|
|
|
def analyze_image(self, image_path):
|
|
"""
|
|
Analyze a medical image.
|
|
|
|
Args:
|
|
image_path (str): Path to the medical image
|
|
|
|
Returns:
|
|
dict: Image analysis results
|
|
"""
|
|
if not self.image_analyzer:
|
|
self.logger.warning("Image analyzer not available")
|
|
return {"error": "Image analyzer not available"}
|
|
|
|
try:
|
|
return self.image_analyzer.analyze(image_path)
|
|
except Exception as e:
|
|
self.logger.error(f"Error analyzing image: {e}")
|
|
return {"error": str(e)}
|
|
|
|
def analyze_text(self, text):
|
|
"""
|
|
Analyze medical report text.
|
|
|
|
Args:
|
|
text (str): Medical report text
|
|
|
|
Returns:
|
|
dict: Text analysis results
|
|
"""
|
|
if not self.text_analyzer:
|
|
self.logger.warning("Text analyzer not available")
|
|
return {"error": "Text analyzer not available"}
|
|
|
|
try:
|
|
return self.text_analyzer.analyze(text)
|
|
except Exception as e:
|
|
self.logger.error(f"Error analyzing text: {e}")
|
|
return {"error": str(e)}
|
|
|
|
def _calculate_agreement_score(self, image_results, text_results):
|
|
"""
|
|
Calculate agreement score between image and text analyses.
|
|
|
|
Args:
|
|
image_results (dict): Results from image analysis
|
|
text_results (dict): Results from text analysis
|
|
|
|
Returns:
|
|
float: Agreement score (0-1, where 1 is perfect agreement)
|
|
"""
|
|
try:
|
|
|
|
agreement = 0.5
|
|
|
|
|
|
image_abnormal = image_results.get("has_abnormality", False)
|
|
|
|
|
|
text_severity = text_results.get("severity", {}).get("level", "Unknown")
|
|
text_abnormal = text_severity not in ["Normal", "Unknown"]
|
|
|
|
|
|
if image_abnormal == text_abnormal:
|
|
agreement += 0.25
|
|
else:
|
|
agreement -= 0.25
|
|
|
|
|
|
image_finding = image_results.get("primary_finding", "").lower()
|
|
|
|
|
|
problems = text_results.get("entities", {}).get("problem", [])
|
|
problem_text = " ".join(problems).lower()
|
|
|
|
|
|
common_conditions = [
|
|
"pneumonia",
|
|
"effusion",
|
|
"nodule",
|
|
"mass",
|
|
"cardiomegaly",
|
|
"opacity",
|
|
"fracture",
|
|
"tumor",
|
|
"edema",
|
|
]
|
|
|
|
matching_conditions = 0
|
|
total_mentioned = 0
|
|
|
|
for condition in common_conditions:
|
|
in_image = condition in image_finding
|
|
in_text = condition in problem_text
|
|
|
|
if in_image or in_text:
|
|
total_mentioned += 1
|
|
|
|
if in_image and in_text:
|
|
matching_conditions += 1
|
|
agreement += 0.05
|
|
|
|
|
|
if total_mentioned > 0:
|
|
match_ratio = matching_conditions / total_mentioned
|
|
agreement += match_ratio * 0.2
|
|
|
|
|
|
agreement = max(0, min(1, agreement))
|
|
|
|
return agreement
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"Error calculating agreement score: {e}")
|
|
return 0.5
|
|
|
|
def _get_confidence_weighted_finding(self, image_results, text_results, agreement):
|
|
"""
|
|
Get the most confident finding weighted by modality confidence.
|
|
|
|
Args:
|
|
image_results (dict): Results from image analysis
|
|
text_results (dict): Results from text analysis
|
|
agreement (float): Agreement score between modalities
|
|
|
|
Returns:
|
|
str: Most confident finding
|
|
"""
|
|
try:
|
|
image_finding = image_results.get("primary_finding", "")
|
|
image_confidence = image_results.get("confidence", 0.5)
|
|
|
|
|
|
problems = text_results.get("entities", {}).get("problem", [])
|
|
|
|
text_confidence = text_results.get("severity", {}).get("confidence", 0.5)
|
|
|
|
if not problems:
|
|
|
|
if image_confidence > 0.7:
|
|
return image_finding
|
|
else:
|
|
return "No significant findings"
|
|
|
|
|
|
if image_confidence > text_confidence + 0.2:
|
|
return image_finding
|
|
elif problems and text_confidence > image_confidence + 0.2:
|
|
return (
|
|
problems[0]
|
|
if isinstance(problems, list) and problems
|
|
else "Unknown finding"
|
|
)
|
|
else:
|
|
|
|
if agreement > 0.7:
|
|
|
|
for problem in problems:
|
|
if problem.lower() in image_finding.lower():
|
|
return problem
|
|
|
|
|
|
if image_confidence > 0.6:
|
|
return image_finding
|
|
elif problems:
|
|
return problems[0]
|
|
else:
|
|
return image_finding
|
|
else:
|
|
|
|
if image_finding and problems:
|
|
return f"{image_finding} (image) / {problems[0]} (report)"
|
|
elif image_finding:
|
|
return image_finding
|
|
elif problems:
|
|
return problems[0]
|
|
else:
|
|
return "Findings unclear - review recommended"
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"Error getting weighted finding: {e}")
|
|
return "Unable to determine primary finding"
|
|
|
|
def _merge_followup_recommendations(self, image_results, text_results):
|
|
"""
|
|
Merge follow-up recommendations from both modalities.
|
|
|
|
Args:
|
|
image_results (dict): Results from image analysis
|
|
text_results (dict): Results from text analysis
|
|
|
|
Returns:
|
|
list: Combined follow-up recommendations
|
|
"""
|
|
try:
|
|
|
|
text_recommendations = text_results.get("followup_recommendations", [])
|
|
|
|
|
|
image_recommendations = []
|
|
|
|
if image_results.get("has_abnormality", False):
|
|
primary = image_results.get("primary_finding", "")
|
|
confidence = image_results.get("confidence", 0)
|
|
|
|
if (
|
|
"nodule" in primary.lower()
|
|
or "mass" in primary.lower()
|
|
or "tumor" in primary.lower()
|
|
):
|
|
image_recommendations.append(
|
|
f"Follow-up imaging recommended to further evaluate {primary}."
|
|
)
|
|
elif "pneumonia" in primary.lower():
|
|
image_recommendations.append(
|
|
"Clinical correlation and follow-up imaging recommended."
|
|
)
|
|
elif confidence > 0.8:
|
|
image_recommendations.append(
|
|
f"Consider follow-up imaging to monitor {primary}."
|
|
)
|
|
elif confidence > 0.5:
|
|
image_recommendations.append(
|
|
"Consider clinical correlation and potential follow-up."
|
|
)
|
|
|
|
|
|
all_recommendations = text_recommendations + image_recommendations
|
|
|
|
|
|
unique_recommendations = []
|
|
for rec in all_recommendations:
|
|
if not any(
|
|
self._is_similar_recommendation(rec, existing)
|
|
for existing in unique_recommendations
|
|
):
|
|
unique_recommendations.append(rec)
|
|
|
|
return unique_recommendations
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"Error merging follow-up recommendations: {e}")
|
|
return ["Follow-up recommended based on findings."]
|
|
|
|
def _is_similar_recommendation(self, rec1, rec2):
|
|
"""Check if two recommendations are semantically similar."""
|
|
|
|
rec1_lower = rec1.lower()
|
|
rec2_lower = rec2.lower()
|
|
|
|
|
|
words1 = set(rec1_lower.split())
|
|
words2 = set(rec2_lower.split())
|
|
|
|
|
|
intersection = words1.intersection(words2)
|
|
union = words1.union(words2)
|
|
|
|
similarity = len(intersection) / len(union) if union else 0
|
|
|
|
|
|
return similarity > 0.6
|
|
|
|
def _get_final_severity(self, image_results, text_results, agreement):
|
|
"""
|
|
Determine final severity based on both modalities.
|
|
|
|
Args:
|
|
image_results (dict): Results from image analysis
|
|
text_results (dict): Results from text analysis
|
|
agreement (float): Agreement score between modalities
|
|
|
|
Returns:
|
|
dict: Final severity assessment
|
|
"""
|
|
try:
|
|
|
|
text_severity = text_results.get("severity", {})
|
|
text_level = text_severity.get("level", "Unknown")
|
|
text_score = text_severity.get("score", 0)
|
|
text_confidence = text_severity.get("confidence", 0.5)
|
|
|
|
|
|
image_abnormal = image_results.get("has_abnormality", False)
|
|
image_confidence = image_results.get("confidence", 0.5)
|
|
|
|
|
|
image_severity = "Normal" if not image_abnormal else "Moderate"
|
|
image_score = 0 if not image_abnormal else 2.0
|
|
|
|
|
|
primary_finding = image_results.get("primary_finding", "").lower()
|
|
|
|
|
|
severity_mapping = {
|
|
"pneumonia": ("Moderate", 2.5),
|
|
"pneumothorax": ("Severe", 3.0),
|
|
"effusion": ("Moderate", 2.0),
|
|
"pulmonary edema": ("Moderate", 2.5),
|
|
"nodule": ("Mild", 1.5),
|
|
"mass": ("Moderate", 2.5),
|
|
"tumor": ("Severe", 3.0),
|
|
"cardiomegaly": ("Mild", 1.5),
|
|
"fracture": ("Moderate", 2.0),
|
|
"consolidation": ("Moderate", 2.0),
|
|
}
|
|
|
|
|
|
for key, (severity, score) in severity_mapping.items():
|
|
if key in primary_finding:
|
|
image_severity = severity
|
|
image_score = score
|
|
break
|
|
|
|
|
|
if agreement > 0.7:
|
|
|
|
final_score = (image_score + text_score) / 2
|
|
else:
|
|
|
|
total_confidence = image_confidence + text_confidence
|
|
if total_confidence > 0:
|
|
image_weight = image_confidence / total_confidence
|
|
text_weight = text_confidence / total_confidence
|
|
final_score = (image_score * image_weight) + (
|
|
text_score * text_weight
|
|
)
|
|
else:
|
|
final_score = (image_score + text_score) / 2
|
|
|
|
|
|
severity_levels = {
|
|
0: "Normal",
|
|
1: "Mild",
|
|
2: "Moderate",
|
|
3: "Severe",
|
|
4: "Critical",
|
|
}
|
|
|
|
|
|
level_index = round(min(4, max(0, final_score)))
|
|
final_level = severity_levels[level_index]
|
|
|
|
return {
|
|
"level": final_level,
|
|
"score": round(final_score, 1),
|
|
"confidence": round((image_confidence + text_confidence) / 2, 2),
|
|
}
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"Error determining final severity: {e}")
|
|
return {"level": "Unknown", "score": 0, "confidence": 0}
|
|
|
|
def fuse_analyses(self, image_results, text_results):
|
|
"""
|
|
Fuse the results from image and text analyses.
|
|
|
|
Args:
|
|
image_results (dict): Results from image analysis
|
|
text_results (dict): Results from text analysis
|
|
|
|
Returns:
|
|
dict: Fused analysis results
|
|
"""
|
|
try:
|
|
|
|
agreement = self._calculate_agreement_score(image_results, text_results)
|
|
self.logger.info(f"Agreement score between modalities: {agreement:.2f}")
|
|
|
|
|
|
primary_finding = self._get_confidence_weighted_finding(
|
|
image_results, text_results, agreement
|
|
)
|
|
|
|
|
|
followup = self._merge_followup_recommendations(image_results, text_results)
|
|
|
|
|
|
severity = self._get_final_severity(image_results, text_results, agreement)
|
|
|
|
|
|
findings = []
|
|
|
|
|
|
text_findings = text_results.get("findings", [])
|
|
if text_findings:
|
|
findings.extend(text_findings)
|
|
|
|
|
|
image_finding = image_results.get("primary_finding", "")
|
|
if image_finding and not any(
|
|
image_finding.lower() in f.lower() for f in findings
|
|
):
|
|
findings.append(f"Image finding: {image_finding}")
|
|
|
|
|
|
fused_result = {
|
|
"agreement_score": round(agreement, 2),
|
|
"primary_finding": primary_finding,
|
|
"severity": severity,
|
|
"findings": findings,
|
|
"followup_recommendations": followup,
|
|
"modality_results": {"image": image_results, "text": text_results},
|
|
}
|
|
|
|
return fused_result
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"Error fusing analyses: {e}")
|
|
return {
|
|
"error": str(e),
|
|
"modality_results": {"image": image_results, "text": text_results},
|
|
}
|
|
|
|
def analyze(self, image_path, report_text):
|
|
"""
|
|
Perform multimodal analysis of medical image and report.
|
|
|
|
Args:
|
|
image_path (str): Path to the medical image
|
|
report_text (str): Medical report text
|
|
|
|
Returns:
|
|
dict: Fused analysis results
|
|
"""
|
|
try:
|
|
|
|
image_results = self.analyze_image(image_path)
|
|
|
|
|
|
text_results = self.analyze_text(report_text)
|
|
|
|
|
|
return self.fuse_analyses(image_results, text_results)
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"Error in multimodal analysis: {e}")
|
|
return {"error": str(e)}
|
|
|
|
def get_explanation(self, fused_results):
|
|
"""
|
|
Generate a human-readable explanation of the fused analysis.
|
|
|
|
Args:
|
|
fused_results (dict): Results from the fused analysis
|
|
|
|
Returns:
|
|
str: A text explanation of the fused analysis
|
|
"""
|
|
try:
|
|
explanation = []
|
|
|
|
|
|
primary_finding = fused_results.get("primary_finding", "Unknown")
|
|
severity = fused_results.get("severity", {}).get("level", "Unknown")
|
|
|
|
explanation.append("# Medical Analysis Summary\n")
|
|
explanation.append("## Overview\n")
|
|
explanation.append(f"Primary finding: **{primary_finding}**\n")
|
|
explanation.append(f"Severity level: **{severity}**\n")
|
|
|
|
|
|
agreement = fused_results.get("agreement_score", 0)
|
|
agreement_text = (
|
|
"High" if agreement > 0.7 else "Moderate" if agreement > 0.4 else "Low"
|
|
)
|
|
|
|
explanation.append(
|
|
f"Image and text analysis agreement: **{agreement_text}** ({agreement:.0%})\n"
|
|
)
|
|
|
|
|
|
explanation.append("\n## Detailed Findings\n")
|
|
findings = fused_results.get("findings", [])
|
|
|
|
if findings:
|
|
for finding in findings:
|
|
explanation.append(f"- {finding}\n")
|
|
else:
|
|
explanation.append("No specific findings detailed.\n")
|
|
|
|
|
|
explanation.append("\n## Recommended Follow-up\n")
|
|
followups = fused_results.get("followup_recommendations", [])
|
|
|
|
if followups:
|
|
for followup in followups:
|
|
explanation.append(f"- {followup}\n")
|
|
else:
|
|
explanation.append("No specific follow-up recommendations provided.\n")
|
|
|
|
|
|
confidence = fused_results.get("severity", {}).get("confidence", 0)
|
|
explanation.append(
|
|
f"\n*Note: This analysis has a confidence level of {confidence:.0%}. "
|
|
f"Please consult with healthcare professionals for official diagnosis.*"
|
|
)
|
|
|
|
return "\n".join(explanation)
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"Error generating explanation: {e}")
|
|
return "Error generating analysis explanation."
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
|
|
|
import os
|
|
|
|
fusion = MultimodalFusion()
|
|
|
|
|
|
sample_report = """
|
|
CHEST X-RAY EXAMINATION
|
|
|
|
CLINICAL HISTORY: 55-year-old male with cough and fever.
|
|
|
|
FINDINGS: The heart size is at the upper limits of normal. The lungs are clear without focal consolidation,
|
|
effusion, or pneumothorax. There is mild prominence of the pulmonary vasculature. No pleural effusion is seen.
|
|
There is a small nodular opacity noted in the right lower lobe measuring approximately 8mm, which is suspicious
|
|
and warrants further investigation. The mediastinum is unremarkable. The visualized bony structures show no acute abnormalities.
|
|
|
|
IMPRESSION:
|
|
1. Mild cardiomegaly.
|
|
2. 8mm nodular opacity in the right lower lobe, recommend follow-up CT for further evaluation.
|
|
3. No acute pulmonary parenchymal abnormality.
|
|
|
|
RECOMMENDATIONS: Follow-up chest CT to further characterize the nodular opacity in the right lower lobe.
|
|
"""
|
|
|
|
|
|
sample_dir = "../data/sample"
|
|
if os.path.exists(sample_dir) and os.listdir(sample_dir):
|
|
sample_image = os.path.join(sample_dir, os.listdir(sample_dir)[0])
|
|
print(f"Analyzing sample image: {sample_image}")
|
|
|
|
|
|
fused_results = fusion.analyze(sample_image, sample_report)
|
|
explanation = fusion.get_explanation(fused_results)
|
|
|
|
print("\nFused Analysis Results:")
|
|
print(explanation)
|
|
else:
|
|
print("No sample images found. Only analyzing text report.")
|
|
|
|
|
|
text_results = fusion.analyze_text(sample_report)
|
|
|
|
print("\nText Analysis Results:")
|
|
print(
|
|
f"Severity: {text_results['severity']['level']} (Score: {text_results['severity']['score']})"
|
|
)
|
|
|
|
print("\nKey Findings:")
|
|
for finding in text_results["findings"]:
|
|
print(f"- {finding}")
|
|
|
|
print("\nEntities:")
|
|
for category, items in text_results["entities"].items():
|
|
if items:
|
|
print(f"- {category.capitalize()}: {', '.join(items)}")
|
|
|
|
print("\nFollow-up Recommendations:")
|
|
for rec in text_results["followup_recommendations"]:
|
|
print(f"- {rec}")
|
|
|