File size: 8,153 Bytes
84f6785
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
from typing import Any, Dict, Optional, Tuple, Type
from pydantic import BaseModel, Field

import torch

from langchain_core.callbacks import (
    AsyncCallbackManagerForToolRun,
    CallbackManagerForToolRun,
)
from langchain_core.tools import BaseTool

from PIL import Image

from transformers import (
    BertTokenizer,
    ViTImageProcessor,
    VisionEncoderDecoderModel,
    GenerationConfig,
)


class ChestXRayInput(BaseModel):
    """Input for chest X-ray analysis tools. Only supports JPG or PNG images."""

    image_path: str = Field(
        ..., description="Path to the radiology image file, only supports JPG or PNG images"
    )


class ChestXRayReportGeneratorTool(BaseTool):
    """Tool that generates comprehensive chest X-ray reports with both findings and impressions.

    This tool uses two Vision-Encoder-Decoder models (ViT-BERT) trained on CheXpert
    and MIMIC-CXR datasets to generate structured radiology reports. It automatically
    generates both detailed findings and impression summaries for each chest X-ray,
    following standard radiological reporting format.

    The tool uses:
    - Findings model: Generates detailed observations of all visible structures
    - Impression model: Provides concise clinical interpretation and key diagnoses
    """

    name: str = "chest_xray_report_generator"
    description: str = (
        "A tool that analyzes chest X-ray images and generates comprehensive radiology reports "
        "containing both detailed findings and impression summaries. Input should be the path "
        "to a chest X-ray image file. Output is a structured report with both detailed "
        "observations and key clinical conclusions."
    )
    device: Optional[str] = "cuda"
    args_schema: Type[BaseModel] = ChestXRayInput
    findings_model: VisionEncoderDecoderModel = None
    impression_model: VisionEncoderDecoderModel = None
    findings_tokenizer: BertTokenizer = None
    impression_tokenizer: BertTokenizer = None
    findings_processor: ViTImageProcessor = None
    impression_processor: ViTImageProcessor = None
    generation_args: Dict[str, Any] = None

    def __init__(self, cache_dir: str = "/model-weights", device: Optional[str] = "cuda"):
        """Initialize the ChestXRayReportGeneratorTool with both findings and impression models."""
        super().__init__()
        self.device = torch.device(device) if device else "cuda"

        # Initialize findings model
        self.findings_model = VisionEncoderDecoderModel.from_pretrained(
            "IAMJB/chexpert-mimic-cxr-findings-baseline", cache_dir=cache_dir
        ).eval()
        self.findings_tokenizer = BertTokenizer.from_pretrained(
            "IAMJB/chexpert-mimic-cxr-findings-baseline", cache_dir=cache_dir
        )
        self.findings_processor = ViTImageProcessor.from_pretrained(
            "IAMJB/chexpert-mimic-cxr-findings-baseline", cache_dir=cache_dir
        )

        # Initialize impression model
        self.impression_model = VisionEncoderDecoderModel.from_pretrained(
            "IAMJB/chexpert-mimic-cxr-impression-baseline", cache_dir=cache_dir
        ).eval()
        self.impression_tokenizer = BertTokenizer.from_pretrained(
            "IAMJB/chexpert-mimic-cxr-impression-baseline", cache_dir=cache_dir
        )
        self.impression_processor = ViTImageProcessor.from_pretrained(
            "IAMJB/chexpert-mimic-cxr-impression-baseline", cache_dir=cache_dir
        )

        # Move models to device
        self.findings_model = self.findings_model.to(self.device)
        self.impression_model = self.impression_model.to(self.device)

        # Default generation arguments
        self.generation_args = {
            "num_return_sequences": 1,
            "max_length": 128,
            "use_cache": True,
            "beam_width": 2,
        }

    def _process_image(
        self, image_path: str, processor: ViTImageProcessor, model: VisionEncoderDecoderModel
    ) -> torch.Tensor:
        """Process the input image for a specific model.

        Args:
            image_path (str): Path to the input image.
            processor: Image processor for the specific model.
            model: The model to process the image for.

        Returns:
            torch.Tensor: Processed image tensor ready for model input.
        """
        image = Image.open(image_path).convert("RGB")
        pixel_values = processor(image, return_tensors="pt").pixel_values

        expected_size = model.config.encoder.image_size
        actual_size = pixel_values.shape[-1]

        if expected_size != actual_size:
            pixel_values = torch.nn.functional.interpolate(
                pixel_values,
                size=(expected_size, expected_size),
                mode="bilinear",
                align_corners=False,
            )

        pixel_values = pixel_values.to(self.device)

        return pixel_values

    def _generate_report_section(
        self, pixel_values: torch.Tensor, model: VisionEncoderDecoderModel, tokenizer: BertTokenizer
    ) -> str:
        """Generate a report section using the specified model.

        Args:
            pixel_values: Processed image tensor.
            model: The model to use for generation.
            tokenizer: The tokenizer for the model.

        Returns:
            str: Generated text for the report section.
        """
        generation_config = GenerationConfig(
            **{
                **self.generation_args,
                "bos_token_id": model.config.bos_token_id,
                "eos_token_id": model.config.eos_token_id,
                "pad_token_id": model.config.pad_token_id,
                "decoder_start_token_id": tokenizer.cls_token_id,
            }
        )

        generated_ids = model.generate(pixel_values, generation_config=generation_config)

        return tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

    def _run(
        self,
        image_path: str,
        run_manager: Optional[CallbackManagerForToolRun] = None,
    ) -> Tuple[str, Dict]:
        """Generate a comprehensive chest X-ray report containing both findings and impression.

        Args:
            image_path (str): The path to the chest X-ray image file.
            run_manager (Optional[CallbackManagerForToolRun]): The callback manager.

        Returns:
            Tuple[str, Dict]: A tuple containing the complete report and metadata.
        """
        try:
            # Process image for both models
            findings_pixels = self._process_image(
                image_path, self.findings_processor, self.findings_model
            )
            impression_pixels = self._process_image(
                image_path, self.impression_processor, self.impression_model
            )

            # Generate both sections
            with torch.inference_mode():
                findings_text = self._generate_report_section(
                    findings_pixels, self.findings_model, self.findings_tokenizer
                )
                impression_text = self._generate_report_section(
                    impression_pixels, self.impression_model, self.impression_tokenizer
                )

            # Combine into formatted report
            report = (
                "CHEST X-RAY REPORT\n\n"
                f"FINDINGS:\n{findings_text}\n\n"
                f"IMPRESSION:\n{impression_text}"
            )

            metadata = {
                "image_path": image_path,
                "analysis_status": "completed",
                "sections_generated": ["findings", "impression"],
            }

            return report, metadata

        except Exception as e:
            return f"Error generating report: {str(e)}", {
                "image_path": image_path,
                "analysis_status": "failed",
                "error": str(e),
            }

    async def _arun(
        self,
        image_path: str,
        run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
    ) -> Tuple[str, Dict]:
        """Asynchronously generate a comprehensive chest X-ray report."""
        return self._run(image_path)