Spaces:
Runtime error
Runtime error
from typing import Dict, Optional, Tuple, Type | |
from pathlib import Path | |
import uuid | |
import tempfile | |
import torch | |
from pydantic import BaseModel, Field | |
from diffusers import StableDiffusionPipeline | |
from langchain_core.callbacks import AsyncCallbackManagerForToolRun, CallbackManagerForToolRun | |
from langchain_core.tools import BaseTool | |
class ChestXRayGeneratorInput(BaseModel): | |
"""Input schema for the Chest X-Ray Generator Tool.""" | |
prompt: str = Field( | |
..., | |
description="Description of the medical condition to generate (e.g., 'big left-sided pleural effusion')" | |
) | |
height: int = Field( | |
512, | |
description="Height of generated image in pixels" | |
) | |
width: int = Field( | |
512, | |
description="Width of generated image in pixels" | |
) | |
num_inference_steps: int = Field( | |
75, | |
description="Number of denoising steps (higher = better quality but slower)" | |
) | |
guidance_scale: float = Field( | |
4.0, | |
description="How closely to follow the prompt (higher = more faithful but less diverse)" | |
) | |
class ChestXRayGeneratorTool(BaseTool): | |
"""Tool for generating synthetic chest X-ray images using a fine-tuned Stable Diffusion model.""" | |
name: str = "chest_xray_generator" | |
description: str = ( | |
"Generates synthetic chest X-ray images from text descriptions of medical conditions. " | |
"Input: Text description of the medical finding or condition to generate, " | |
"along with optional parameters for image size (height, width), " | |
"quality (num_inference_steps), and prompt adherence (guidance_scale). " | |
"Output: Path to the generated X-ray image and generation metadata." | |
) | |
args_schema: Type[BaseModel] = ChestXRayGeneratorInput | |
model: StableDiffusionPipeline = None | |
device: torch.device = None | |
temp_dir: Path = None | |
def __init__( | |
self, | |
model_path: str = "/model-weights/roentgen", | |
cache_dir: str = "/model-weights", | |
temp_dir: Optional[str] = None, | |
device: Optional[str] = "cuda", | |
): | |
"""Initialize the chest X-ray generator tool.""" | |
super().__init__() | |
self.device = torch.device(device) if device else "cuda" | |
self.model = StableDiffusionPipeline.from_pretrained(model_path, cache_dir=cache_dir) | |
self.model = self.model.to(torch.float32).to(self.device) | |
self.temp_dir = Path(temp_dir if temp_dir else tempfile.mkdtemp()) | |
self.temp_dir.mkdir(exist_ok=True) | |
def _run( | |
self, | |
prompt: str, | |
num_inference_steps: int = 75, | |
guidance_scale: float = 4.0, | |
height: int = 512, | |
width: int = 512, | |
run_manager: Optional[CallbackManagerForToolRun] = None, | |
) -> Tuple[Dict[str, str], Dict]: | |
"""Generate a chest X-ray image from a text description. | |
Args: | |
prompt: Text description of the medical condition to generate | |
num_inference_steps: Number of denoising steps | |
guidance_scale: How closely to follow the prompt | |
height: Height of generated image in pixels | |
width: Width of generated image in pixels | |
run_manager: Optional callback manager | |
Returns: | |
Tuple[Dict, Dict]: Output dictionary with image path and metadata dictionary | |
""" | |
try: | |
# Generate image | |
generation_output = self.model( | |
[prompt], | |
num_inference_steps=num_inference_steps, | |
height=height, | |
width=width, | |
guidance_scale=guidance_scale | |
) | |
# Save generated image | |
image_path = self.temp_dir / f"generated_xray_{uuid.uuid4().hex[:8]}.png" | |
generation_output.images[0].save(image_path) | |
output = { | |
"image_path": str(image_path), | |
} | |
metadata = { | |
"prompt": prompt, | |
"num_inference_steps": num_inference_steps, | |
"guidance_scale": guidance_scale, | |
"device": str(self.device), | |
"image_size": (height, width), | |
"analysis_status": "completed", | |
} | |
return output, metadata | |
except Exception as e: | |
return ( | |
{"error": str(e)}, | |
{ | |
"prompt": prompt, | |
"analysis_status": "failed", | |
"error_details": str(e), | |
} | |
) | |
async def _arun( | |
self, | |
prompt: str, | |
num_inference_steps: int = 75, | |
guidance_scale: float = 4.0, | |
height: int = 512, | |
width: int = 512, | |
run_manager: Optional[AsyncCallbackManagerForToolRun] = None, | |
) -> Tuple[Dict[str, str], Dict]: | |
"""Async version of _run.""" | |
return self._run(prompt, num_inference_steps, guidance_scale, height, width) |