Spaces:
Runtime error
Runtime error
from typing import Any, Dict, Optional, Tuple, Type | |
from pydantic import BaseModel, Field | |
import torch | |
from langchain_core.callbacks import ( | |
AsyncCallbackManagerForToolRun, | |
CallbackManagerForToolRun, | |
) | |
from langchain_core.tools import BaseTool | |
from PIL import Image | |
from transformers import ( | |
BertTokenizer, | |
ViTImageProcessor, | |
VisionEncoderDecoderModel, | |
GenerationConfig, | |
) | |
class ChestXRayInput(BaseModel): | |
"""Input for chest X-ray analysis tools. Only supports JPG or PNG images.""" | |
image_path: str = Field( | |
..., description="Path to the radiology image file, only supports JPG or PNG images" | |
) | |
class ChestXRayReportGeneratorTool(BaseTool): | |
"""Tool that generates comprehensive chest X-ray reports with both findings and impressions. | |
This tool uses two Vision-Encoder-Decoder models (ViT-BERT) trained on CheXpert | |
and MIMIC-CXR datasets to generate structured radiology reports. It automatically | |
generates both detailed findings and impression summaries for each chest X-ray, | |
following standard radiological reporting format. | |
The tool uses: | |
- Findings model: Generates detailed observations of all visible structures | |
- Impression model: Provides concise clinical interpretation and key diagnoses | |
""" | |
name: str = "chest_xray_report_generator" | |
description: str = ( | |
"A tool that analyzes chest X-ray images and generates comprehensive radiology reports " | |
"containing both detailed findings and impression summaries. Input should be the path " | |
"to a chest X-ray image file. Output is a structured report with both detailed " | |
"observations and key clinical conclusions." | |
) | |
device: Optional[str] = "cuda" | |
args_schema: Type[BaseModel] = ChestXRayInput | |
findings_model: VisionEncoderDecoderModel = None | |
impression_model: VisionEncoderDecoderModel = None | |
findings_tokenizer: BertTokenizer = None | |
impression_tokenizer: BertTokenizer = None | |
findings_processor: ViTImageProcessor = None | |
impression_processor: ViTImageProcessor = None | |
generation_args: Dict[str, Any] = None | |
def __init__(self, cache_dir: str = "/model-weights", device: Optional[str] = "cuda"): | |
"""Initialize the ChestXRayReportGeneratorTool with both findings and impression models.""" | |
super().__init__() | |
self.device = torch.device(device) if device else "cuda" | |
# Initialize findings model | |
self.findings_model = VisionEncoderDecoderModel.from_pretrained( | |
"IAMJB/chexpert-mimic-cxr-findings-baseline", cache_dir=cache_dir | |
).eval() | |
self.findings_tokenizer = BertTokenizer.from_pretrained( | |
"IAMJB/chexpert-mimic-cxr-findings-baseline", cache_dir=cache_dir | |
) | |
self.findings_processor = ViTImageProcessor.from_pretrained( | |
"IAMJB/chexpert-mimic-cxr-findings-baseline", cache_dir=cache_dir | |
) | |
# Initialize impression model | |
self.impression_model = VisionEncoderDecoderModel.from_pretrained( | |
"IAMJB/chexpert-mimic-cxr-impression-baseline", cache_dir=cache_dir | |
).eval() | |
self.impression_tokenizer = BertTokenizer.from_pretrained( | |
"IAMJB/chexpert-mimic-cxr-impression-baseline", cache_dir=cache_dir | |
) | |
self.impression_processor = ViTImageProcessor.from_pretrained( | |
"IAMJB/chexpert-mimic-cxr-impression-baseline", cache_dir=cache_dir | |
) | |
# Move models to device | |
self.findings_model = self.findings_model.to(self.device) | |
self.impression_model = self.impression_model.to(self.device) | |
# Default generation arguments | |
self.generation_args = { | |
"num_return_sequences": 1, | |
"max_length": 128, | |
"use_cache": True, | |
"beam_width": 2, | |
} | |
def _process_image( | |
self, image_path: str, processor: ViTImageProcessor, model: VisionEncoderDecoderModel | |
) -> torch.Tensor: | |
"""Process the input image for a specific model. | |
Args: | |
image_path (str): Path to the input image. | |
processor: Image processor for the specific model. | |
model: The model to process the image for. | |
Returns: | |
torch.Tensor: Processed image tensor ready for model input. | |
""" | |
image = Image.open(image_path).convert("RGB") | |
pixel_values = processor(image, return_tensors="pt").pixel_values | |
expected_size = model.config.encoder.image_size | |
actual_size = pixel_values.shape[-1] | |
if expected_size != actual_size: | |
pixel_values = torch.nn.functional.interpolate( | |
pixel_values, | |
size=(expected_size, expected_size), | |
mode="bilinear", | |
align_corners=False, | |
) | |
pixel_values = pixel_values.to(self.device) | |
return pixel_values | |
def _generate_report_section( | |
self, pixel_values: torch.Tensor, model: VisionEncoderDecoderModel, tokenizer: BertTokenizer | |
) -> str: | |
"""Generate a report section using the specified model. | |
Args: | |
pixel_values: Processed image tensor. | |
model: The model to use for generation. | |
tokenizer: The tokenizer for the model. | |
Returns: | |
str: Generated text for the report section. | |
""" | |
generation_config = GenerationConfig( | |
**{ | |
**self.generation_args, | |
"bos_token_id": model.config.bos_token_id, | |
"eos_token_id": model.config.eos_token_id, | |
"pad_token_id": model.config.pad_token_id, | |
"decoder_start_token_id": tokenizer.cls_token_id, | |
} | |
) | |
generated_ids = model.generate(pixel_values, generation_config=generation_config) | |
return tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
def _run( | |
self, | |
image_path: str, | |
run_manager: Optional[CallbackManagerForToolRun] = None, | |
) -> Tuple[str, Dict]: | |
"""Generate a comprehensive chest X-ray report containing both findings and impression. | |
Args: | |
image_path (str): The path to the chest X-ray image file. | |
run_manager (Optional[CallbackManagerForToolRun]): The callback manager. | |
Returns: | |
Tuple[str, Dict]: A tuple containing the complete report and metadata. | |
""" | |
try: | |
# Process image for both models | |
findings_pixels = self._process_image( | |
image_path, self.findings_processor, self.findings_model | |
) | |
impression_pixels = self._process_image( | |
image_path, self.impression_processor, self.impression_model | |
) | |
# Generate both sections | |
with torch.inference_mode(): | |
findings_text = self._generate_report_section( | |
findings_pixels, self.findings_model, self.findings_tokenizer | |
) | |
impression_text = self._generate_report_section( | |
impression_pixels, self.impression_model, self.impression_tokenizer | |
) | |
# Combine into formatted report | |
report = ( | |
"CHEST X-RAY REPORT\n\n" | |
f"FINDINGS:\n{findings_text}\n\n" | |
f"IMPRESSION:\n{impression_text}" | |
) | |
metadata = { | |
"image_path": image_path, | |
"analysis_status": "completed", | |
"sections_generated": ["findings", "impression"], | |
} | |
return report, metadata | |
except Exception as e: | |
return f"Error generating report: {str(e)}", { | |
"image_path": image_path, | |
"analysis_status": "failed", | |
"error": str(e), | |
} | |
async def _arun( | |
self, | |
image_path: str, | |
run_manager: Optional[AsyncCallbackManagerForToolRun] = None, | |
) -> Tuple[str, Dict]: | |
"""Asynchronously generate a comprehensive chest X-ray report.""" | |
return self._run(image_path) | |