Spaces:
Runtime error
Runtime error
File size: 8,153 Bytes
84f6785 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 |
from typing import Any, Dict, Optional, Tuple, Type
from pydantic import BaseModel, Field
import torch
from langchain_core.callbacks import (
AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun,
)
from langchain_core.tools import BaseTool
from PIL import Image
from transformers import (
BertTokenizer,
ViTImageProcessor,
VisionEncoderDecoderModel,
GenerationConfig,
)
class ChestXRayInput(BaseModel):
"""Input for chest X-ray analysis tools. Only supports JPG or PNG images."""
image_path: str = Field(
..., description="Path to the radiology image file, only supports JPG or PNG images"
)
class ChestXRayReportGeneratorTool(BaseTool):
"""Tool that generates comprehensive chest X-ray reports with both findings and impressions.
This tool uses two Vision-Encoder-Decoder models (ViT-BERT) trained on CheXpert
and MIMIC-CXR datasets to generate structured radiology reports. It automatically
generates both detailed findings and impression summaries for each chest X-ray,
following standard radiological reporting format.
The tool uses:
- Findings model: Generates detailed observations of all visible structures
- Impression model: Provides concise clinical interpretation and key diagnoses
"""
name: str = "chest_xray_report_generator"
description: str = (
"A tool that analyzes chest X-ray images and generates comprehensive radiology reports "
"containing both detailed findings and impression summaries. Input should be the path "
"to a chest X-ray image file. Output is a structured report with both detailed "
"observations and key clinical conclusions."
)
device: Optional[str] = "cuda"
args_schema: Type[BaseModel] = ChestXRayInput
findings_model: VisionEncoderDecoderModel = None
impression_model: VisionEncoderDecoderModel = None
findings_tokenizer: BertTokenizer = None
impression_tokenizer: BertTokenizer = None
findings_processor: ViTImageProcessor = None
impression_processor: ViTImageProcessor = None
generation_args: Dict[str, Any] = None
def __init__(self, cache_dir: str = "/model-weights", device: Optional[str] = "cuda"):
"""Initialize the ChestXRayReportGeneratorTool with both findings and impression models."""
super().__init__()
self.device = torch.device(device) if device else "cuda"
# Initialize findings model
self.findings_model = VisionEncoderDecoderModel.from_pretrained(
"IAMJB/chexpert-mimic-cxr-findings-baseline", cache_dir=cache_dir
).eval()
self.findings_tokenizer = BertTokenizer.from_pretrained(
"IAMJB/chexpert-mimic-cxr-findings-baseline", cache_dir=cache_dir
)
self.findings_processor = ViTImageProcessor.from_pretrained(
"IAMJB/chexpert-mimic-cxr-findings-baseline", cache_dir=cache_dir
)
# Initialize impression model
self.impression_model = VisionEncoderDecoderModel.from_pretrained(
"IAMJB/chexpert-mimic-cxr-impression-baseline", cache_dir=cache_dir
).eval()
self.impression_tokenizer = BertTokenizer.from_pretrained(
"IAMJB/chexpert-mimic-cxr-impression-baseline", cache_dir=cache_dir
)
self.impression_processor = ViTImageProcessor.from_pretrained(
"IAMJB/chexpert-mimic-cxr-impression-baseline", cache_dir=cache_dir
)
# Move models to device
self.findings_model = self.findings_model.to(self.device)
self.impression_model = self.impression_model.to(self.device)
# Default generation arguments
self.generation_args = {
"num_return_sequences": 1,
"max_length": 128,
"use_cache": True,
"beam_width": 2,
}
def _process_image(
self, image_path: str, processor: ViTImageProcessor, model: VisionEncoderDecoderModel
) -> torch.Tensor:
"""Process the input image for a specific model.
Args:
image_path (str): Path to the input image.
processor: Image processor for the specific model.
model: The model to process the image for.
Returns:
torch.Tensor: Processed image tensor ready for model input.
"""
image = Image.open(image_path).convert("RGB")
pixel_values = processor(image, return_tensors="pt").pixel_values
expected_size = model.config.encoder.image_size
actual_size = pixel_values.shape[-1]
if expected_size != actual_size:
pixel_values = torch.nn.functional.interpolate(
pixel_values,
size=(expected_size, expected_size),
mode="bilinear",
align_corners=False,
)
pixel_values = pixel_values.to(self.device)
return pixel_values
def _generate_report_section(
self, pixel_values: torch.Tensor, model: VisionEncoderDecoderModel, tokenizer: BertTokenizer
) -> str:
"""Generate a report section using the specified model.
Args:
pixel_values: Processed image tensor.
model: The model to use for generation.
tokenizer: The tokenizer for the model.
Returns:
str: Generated text for the report section.
"""
generation_config = GenerationConfig(
**{
**self.generation_args,
"bos_token_id": model.config.bos_token_id,
"eos_token_id": model.config.eos_token_id,
"pad_token_id": model.config.pad_token_id,
"decoder_start_token_id": tokenizer.cls_token_id,
}
)
generated_ids = model.generate(pixel_values, generation_config=generation_config)
return tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
def _run(
self,
image_path: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> Tuple[str, Dict]:
"""Generate a comprehensive chest X-ray report containing both findings and impression.
Args:
image_path (str): The path to the chest X-ray image file.
run_manager (Optional[CallbackManagerForToolRun]): The callback manager.
Returns:
Tuple[str, Dict]: A tuple containing the complete report and metadata.
"""
try:
# Process image for both models
findings_pixels = self._process_image(
image_path, self.findings_processor, self.findings_model
)
impression_pixels = self._process_image(
image_path, self.impression_processor, self.impression_model
)
# Generate both sections
with torch.inference_mode():
findings_text = self._generate_report_section(
findings_pixels, self.findings_model, self.findings_tokenizer
)
impression_text = self._generate_report_section(
impression_pixels, self.impression_model, self.impression_tokenizer
)
# Combine into formatted report
report = (
"CHEST X-RAY REPORT\n\n"
f"FINDINGS:\n{findings_text}\n\n"
f"IMPRESSION:\n{impression_text}"
)
metadata = {
"image_path": image_path,
"analysis_status": "completed",
"sections_generated": ["findings", "impression"],
}
return report, metadata
except Exception as e:
return f"Error generating report: {str(e)}", {
"image_path": image_path,
"analysis_status": "failed",
"error": str(e),
}
async def _arun(
self,
image_path: str,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> Tuple[str, Dict]:
"""Asynchronously generate a comprehensive chest X-ray report."""
return self._run(image_path)
|