from typing import Any, Dict, Optional, Tuple, Type from pydantic import BaseModel, Field import torch from langchain_core.callbacks import ( AsyncCallbackManagerForToolRun, CallbackManagerForToolRun, ) from langchain_core.tools import BaseTool from PIL import Image from medrax.llava.conversation import conv_templates from medrax.llava.model.builder import load_pretrained_model from medrax.llava.mm_utils import tokenizer_image_token, process_images from medrax.llava.constants import ( IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, ) class LlavaMedInput(BaseModel): """Input for the LLaVA-Med Visual QA tool. Only supports JPG or PNG images.""" question: str = Field(..., description="The question to ask about the medical image") image_path: Optional[str] = Field( None, description="Path to the medical image file (optional), only supports JPG or PNG images", ) class LlavaMedTool(BaseTool): """Tool that performs medical visual question answering using LLaVA-Med. This tool uses a large language model fine-tuned on medical images to answer questions about medical images. It can handle both image-based questions and general medical questions without images. """ name: str = "llava_med_qa" description: str = ( "A tool that answers questions about biomedical images and general medical questions using LLaVA-Med. " "While it can process chest X-rays, it may not be as reliable for detailed chest X-ray analysis. " "Input should be a question and optionally a path to a medical image file." ) args_schema: Type[BaseModel] = LlavaMedInput tokenizer: Any = None model: Any = None image_processor: Any = None context_len: int = 200000 def __init__( self, model_path: str = "microsoft/llava-med-v1.5-mistral-7b", cache_dir: str = "/model-weights", low_cpu_mem_usage: bool = True, torch_dtype: torch.dtype = torch.bfloat16, device: str = "cuda", load_in_4bit: bool = False, load_in_8bit: bool = False, **kwargs, ): super().__init__() # Set the device (cuda or cpu) self.device = torch.device(device) if device else torch.device("cuda") # Load the model and tokenizer self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model( model_path=model_path, model_base=None, model_name=model_path, load_in_4bit=load_in_4bit, load_in_8bit=load_in_8bit, cache_dir=cache_dir, low_cpu_mem_usage=low_cpu_mem_usage, torch_dtype=torch_dtype, device=device, **kwargs, ) # Move the model to the desired device self.model.to(self.device) self.model.eval() def _process_input( self, question: str, image_path: Optional[str] = None ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: if self.model.config.mm_use_im_start_end: question = ( DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + "\n" + question ) else: question = DEFAULT_IMAGE_TOKEN + "\n" + question conv = conv_templates["vicuna_v1"].copy() conv.append_message(conv.roles[0], question) conv.append_message(conv.roles[1], None) prompt = conv.get_prompt() input_ids = ( tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt") .unsqueeze(0) .to(self.device) # Move to the correct device ) image_tensor = None if image_path: image = Image.open(image_path) image_tensor = process_images([image], self.image_processor, self.model.config)[0] image_tensor = image_tensor.unsqueeze(0).to(self.device, dtype=self.model.dtype) # Move to device return input_ids, image_tensor def _run( self, question: str, image_path: Optional[str] = None, run_manager: Optional[CallbackManagerForToolRun] = None, ) -> Tuple[str, Dict]: """Answer a medical question, optionally based on an input image. Args: question (str): The medical question to answer. image_path (Optional[str]): The path to the medical image file (if applicable). run_manager (Optional[CallbackManagerForToolRun]): The callback manager for the tool run. Returns: Tuple[str, Dict]: A tuple containing the model's answer and any additional metadata. Raises: Exception: If there's an error processing the input or generating the answer. """ try: input_ids, image_tensor = self._process_input(question, image_path) # Ensure that inputs are on the same device as the model input_ids = input_ids.to(self.device) image_tensor = image_tensor.to(self.device, dtype=self.model.dtype) with torch.inference_mode(): output_ids = self.model.generate( input_ids, images=image_tensor, do_sample=False, temperature=0.2, max_new_tokens=500, use_cache=True, ) output = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip() metadata = { "question": question, "image_path": image_path, "analysis_status": "completed", } return output, metadata except Exception as e: return f"Error generating answer: {str(e)}", { "question": question, "image_path": image_path, "analysis_status": "failed", } async def _arun( self, question: str, image_path: Optional[str] = None, run_manager: Optional[AsyncCallbackManagerForToolRun] = None, ) -> Tuple[str, Dict]: """Asynchronously answer a medical question, optionally based on an input image. This method currently calls the synchronous version, as the model inference is not inherently asynchronous. For true asynchronous behavior, consider using a separate thread or process. Args: question (str): The medical question to answer. image_path (Optional[str]): The path to the medical image file (if applicable). run_manager (Optional[AsyncCallbackManagerForToolRun]): The async callback manager for the tool run. Returns: Tuple[str, Dict]: A tuple containing the model's answer and any additional metadata. Raises: Exception: If there's an error processing the input or generating the answer. """ return self._run(question, image_path)