resolved the PermissionError
Browse files
medrax/tools/report_generation.py
CHANGED
@@ -2,6 +2,7 @@ from typing import Any, Dict, Optional, Tuple, Type
|
|
2 |
from pydantic import BaseModel, Field
|
3 |
|
4 |
import torch
|
|
|
5 |
|
6 |
from langchain_core.callbacks import (
|
7 |
AsyncCallbackManagerForToolRun,
|
@@ -28,18 +29,6 @@ class ChestXRayInput(BaseModel):
|
|
28 |
|
29 |
|
30 |
class ChestXRayReportGeneratorTool(BaseTool):
|
31 |
-
"""Tool that generates comprehensive chest X-ray reports with both findings and impressions.
|
32 |
-
|
33 |
-
This tool uses two Vision-Encoder-Decoder models (ViT-BERT) trained on CheXpert
|
34 |
-
and MIMIC-CXR datasets to generate structured radiology reports. It automatically
|
35 |
-
generates both detailed findings and impression summaries for each chest X-ray,
|
36 |
-
following standard radiological reporting format.
|
37 |
-
|
38 |
-
The tool uses:
|
39 |
-
- Findings model: Generates detailed observations of all visible structures
|
40 |
-
- Impression model: Provides concise clinical interpretation and key diagnoses
|
41 |
-
"""
|
42 |
-
|
43 |
name: str = "chest_xray_report_generator"
|
44 |
description: str = (
|
45 |
"A tool that analyzes chest X-ray images and generates comprehensive radiology reports "
|
@@ -47,7 +36,7 @@ class ChestXRayReportGeneratorTool(BaseTool):
|
|
47 |
"to a chest X-ray image file. Output is a structured report with both detailed "
|
48 |
"observations and key clinical conclusions."
|
49 |
)
|
50 |
-
device: Optional[str] = "cpu"
|
51 |
args_schema: Type[BaseModel] = ChestXRayInput
|
52 |
findings_model: VisionEncoderDecoderModel = None
|
53 |
impression_model: VisionEncoderDecoderModel = None
|
@@ -57,12 +46,12 @@ class ChestXRayReportGeneratorTool(BaseTool):
|
|
57 |
impression_processor: ViTImageProcessor = None
|
58 |
generation_args: Dict[str, Any] = None
|
59 |
|
60 |
-
def __init__(self, cache_dir: str = "
|
61 |
-
"""Initialize the ChestXRayReportGeneratorTool with both findings and impression models."""
|
62 |
super().__init__()
|
63 |
-
|
|
|
64 |
|
65 |
-
#
|
66 |
self.findings_model = VisionEncoderDecoderModel.from_pretrained(
|
67 |
"IAMJB/chexpert-mimic-cxr-findings-baseline", cache_dir=cache_dir
|
68 |
).eval()
|
@@ -73,7 +62,7 @@ class ChestXRayReportGeneratorTool(BaseTool):
|
|
73 |
"IAMJB/chexpert-mimic-cxr-findings-baseline", cache_dir=cache_dir
|
74 |
)
|
75 |
|
76 |
-
#
|
77 |
self.impression_model = VisionEncoderDecoderModel.from_pretrained(
|
78 |
"IAMJB/chexpert-mimic-cxr-impression-baseline", cache_dir=cache_dir
|
79 |
).eval()
|
@@ -84,11 +73,10 @@ class ChestXRayReportGeneratorTool(BaseTool):
|
|
84 |
"IAMJB/chexpert-mimic-cxr-impression-baseline", cache_dir=cache_dir
|
85 |
)
|
86 |
|
87 |
-
# Move models to
|
88 |
self.findings_model = self.findings_model.to(self.device)
|
89 |
self.impression_model = self.impression_model.to(self.device)
|
90 |
|
91 |
-
# Default generation arguments
|
92 |
self.generation_args = {
|
93 |
"num_return_sequences": 1,
|
94 |
"max_length": 128,
|
@@ -99,19 +87,8 @@ class ChestXRayReportGeneratorTool(BaseTool):
|
|
99 |
def _process_image(
|
100 |
self, image_path: str, processor: ViTImageProcessor, model: VisionEncoderDecoderModel
|
101 |
) -> torch.Tensor:
|
102 |
-
"""Process the input image for a specific model.
|
103 |
-
|
104 |
-
Args:
|
105 |
-
image_path (str): Path to the input image.
|
106 |
-
processor: Image processor for the specific model.
|
107 |
-
model: The model to process the image for.
|
108 |
-
|
109 |
-
Returns:
|
110 |
-
torch.Tensor: Processed image tensor ready for model input.
|
111 |
-
"""
|
112 |
image = Image.open(image_path).convert("RGB")
|
113 |
pixel_values = processor(image, return_tensors="pt").pixel_values
|
114 |
-
|
115 |
expected_size = model.config.encoder.image_size
|
116 |
actual_size = pixel_values.shape[-1]
|
117 |
|
@@ -123,23 +100,11 @@ class ChestXRayReportGeneratorTool(BaseTool):
|
|
123 |
align_corners=False,
|
124 |
)
|
125 |
|
126 |
-
|
127 |
-
|
128 |
-
return pixel_values
|
129 |
|
130 |
def _generate_report_section(
|
131 |
self, pixel_values: torch.Tensor, model: VisionEncoderDecoderModel, tokenizer: BertTokenizer
|
132 |
) -> str:
|
133 |
-
"""Generate a report section using the specified model.
|
134 |
-
|
135 |
-
Args:
|
136 |
-
pixel_values: Processed image tensor.
|
137 |
-
model: The model to use for generation.
|
138 |
-
tokenizer: The tokenizer for the model.
|
139 |
-
|
140 |
-
Returns:
|
141 |
-
str: Generated text for the report section.
|
142 |
-
"""
|
143 |
generation_config = GenerationConfig(
|
144 |
**{
|
145 |
**self.generation_args,
|
@@ -149,9 +114,7 @@ class ChestXRayReportGeneratorTool(BaseTool):
|
|
149 |
"decoder_start_token_id": tokenizer.cls_token_id,
|
150 |
}
|
151 |
)
|
152 |
-
|
153 |
generated_ids = model.generate(pixel_values, generation_config=generation_config)
|
154 |
-
|
155 |
return tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
156 |
|
157 |
def _run(
|
@@ -159,17 +122,7 @@ class ChestXRayReportGeneratorTool(BaseTool):
|
|
159 |
image_path: str,
|
160 |
run_manager: Optional[CallbackManagerForToolRun] = None,
|
161 |
) -> Tuple[str, Dict]:
|
162 |
-
"""Generate a comprehensive chest X-ray report containing both findings and impression.
|
163 |
-
|
164 |
-
Args:
|
165 |
-
image_path (str): The path to the chest X-ray image file.
|
166 |
-
run_manager (Optional[CallbackManagerForToolRun]): The callback manager.
|
167 |
-
|
168 |
-
Returns:
|
169 |
-
Tuple[str, Dict]: A tuple containing the complete report and metadata.
|
170 |
-
"""
|
171 |
try:
|
172 |
-
# Process image for both models
|
173 |
findings_pixels = self._process_image(
|
174 |
image_path, self.findings_processor, self.findings_model
|
175 |
)
|
@@ -177,7 +130,6 @@ class ChestXRayReportGeneratorTool(BaseTool):
|
|
177 |
image_path, self.impression_processor, self.impression_model
|
178 |
)
|
179 |
|
180 |
-
# Generate both sections
|
181 |
with torch.inference_mode():
|
182 |
findings_text = self._generate_report_section(
|
183 |
findings_pixels, self.findings_model, self.findings_tokenizer
|
@@ -186,19 +138,16 @@ class ChestXRayReportGeneratorTool(BaseTool):
|
|
186 |
impression_pixels, self.impression_model, self.impression_tokenizer
|
187 |
)
|
188 |
|
189 |
-
# Combine into formatted report
|
190 |
report = (
|
191 |
"CHEST X-RAY REPORT\n\n"
|
192 |
f"FINDINGS:\n{findings_text}\n\n"
|
193 |
f"IMPRESSION:\n{impression_text}"
|
194 |
)
|
195 |
-
|
196 |
metadata = {
|
197 |
"image_path": image_path,
|
198 |
"analysis_status": "completed",
|
199 |
"sections_generated": ["findings", "impression"],
|
200 |
}
|
201 |
-
|
202 |
return report, metadata
|
203 |
|
204 |
except Exception as e:
|
@@ -213,5 +162,4 @@ class ChestXRayReportGeneratorTool(BaseTool):
|
|
213 |
image_path: str,
|
214 |
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
215 |
) -> Tuple[str, Dict]:
|
216 |
-
"""Asynchronously generate a comprehensive chest X-ray report."""
|
217 |
return self._run(image_path)
|
|
|
2 |
from pydantic import BaseModel, Field
|
3 |
|
4 |
import torch
|
5 |
+
import os # Added to create local cache dir
|
6 |
|
7 |
from langchain_core.callbacks import (
|
8 |
AsyncCallbackManagerForToolRun,
|
|
|
29 |
|
30 |
|
31 |
class ChestXRayReportGeneratorTool(BaseTool):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
name: str = "chest_xray_report_generator"
|
33 |
description: str = (
|
34 |
"A tool that analyzes chest X-ray images and generates comprehensive radiology reports "
|
|
|
36 |
"to a chest X-ray image file. Output is a structured report with both detailed "
|
37 |
"observations and key clinical conclusions."
|
38 |
)
|
39 |
+
device: Optional[str] = "cpu"
|
40 |
args_schema: Type[BaseModel] = ChestXRayInput
|
41 |
findings_model: VisionEncoderDecoderModel = None
|
42 |
impression_model: VisionEncoderDecoderModel = None
|
|
|
46 |
impression_processor: ViTImageProcessor = None
|
47 |
generation_args: Dict[str, Any] = None
|
48 |
|
49 |
+
def __init__(self, cache_dir: str = "./model_weights", device: Optional[str] = "cpu"):
|
|
|
50 |
super().__init__()
|
51 |
+
os.makedirs(cache_dir, exist_ok=True) # ✅ Ensure local folder exists
|
52 |
+
self.device = torch.device(device) if device else torch.device("cpu")
|
53 |
|
54 |
+
# Load findings model
|
55 |
self.findings_model = VisionEncoderDecoderModel.from_pretrained(
|
56 |
"IAMJB/chexpert-mimic-cxr-findings-baseline", cache_dir=cache_dir
|
57 |
).eval()
|
|
|
62 |
"IAMJB/chexpert-mimic-cxr-findings-baseline", cache_dir=cache_dir
|
63 |
)
|
64 |
|
65 |
+
# Load impression model
|
66 |
self.impression_model = VisionEncoderDecoderModel.from_pretrained(
|
67 |
"IAMJB/chexpert-mimic-cxr-impression-baseline", cache_dir=cache_dir
|
68 |
).eval()
|
|
|
73 |
"IAMJB/chexpert-mimic-cxr-impression-baseline", cache_dir=cache_dir
|
74 |
)
|
75 |
|
76 |
+
# Move models to CPU
|
77 |
self.findings_model = self.findings_model.to(self.device)
|
78 |
self.impression_model = self.impression_model.to(self.device)
|
79 |
|
|
|
80 |
self.generation_args = {
|
81 |
"num_return_sequences": 1,
|
82 |
"max_length": 128,
|
|
|
87 |
def _process_image(
|
88 |
self, image_path: str, processor: ViTImageProcessor, model: VisionEncoderDecoderModel
|
89 |
) -> torch.Tensor:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
90 |
image = Image.open(image_path).convert("RGB")
|
91 |
pixel_values = processor(image, return_tensors="pt").pixel_values
|
|
|
92 |
expected_size = model.config.encoder.image_size
|
93 |
actual_size = pixel_values.shape[-1]
|
94 |
|
|
|
100 |
align_corners=False,
|
101 |
)
|
102 |
|
103 |
+
return pixel_values.to(self.device)
|
|
|
|
|
104 |
|
105 |
def _generate_report_section(
|
106 |
self, pixel_values: torch.Tensor, model: VisionEncoderDecoderModel, tokenizer: BertTokenizer
|
107 |
) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
generation_config = GenerationConfig(
|
109 |
**{
|
110 |
**self.generation_args,
|
|
|
114 |
"decoder_start_token_id": tokenizer.cls_token_id,
|
115 |
}
|
116 |
)
|
|
|
117 |
generated_ids = model.generate(pixel_values, generation_config=generation_config)
|
|
|
118 |
return tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
119 |
|
120 |
def _run(
|
|
|
122 |
image_path: str,
|
123 |
run_manager: Optional[CallbackManagerForToolRun] = None,
|
124 |
) -> Tuple[str, Dict]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
try:
|
|
|
126 |
findings_pixels = self._process_image(
|
127 |
image_path, self.findings_processor, self.findings_model
|
128 |
)
|
|
|
130 |
image_path, self.impression_processor, self.impression_model
|
131 |
)
|
132 |
|
|
|
133 |
with torch.inference_mode():
|
134 |
findings_text = self._generate_report_section(
|
135 |
findings_pixels, self.findings_model, self.findings_tokenizer
|
|
|
138 |
impression_pixels, self.impression_model, self.impression_tokenizer
|
139 |
)
|
140 |
|
|
|
141 |
report = (
|
142 |
"CHEST X-RAY REPORT\n\n"
|
143 |
f"FINDINGS:\n{findings_text}\n\n"
|
144 |
f"IMPRESSION:\n{impression_text}"
|
145 |
)
|
|
|
146 |
metadata = {
|
147 |
"image_path": image_path,
|
148 |
"analysis_status": "completed",
|
149 |
"sections_generated": ["findings", "impression"],
|
150 |
}
|
|
|
151 |
return report, metadata
|
152 |
|
153 |
except Exception as e:
|
|
|
162 |
image_path: str,
|
163 |
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
|
164 |
) -> Tuple[str, Dict]:
|
|
|
165 |
return self._run(image_path)
|