Dhruv-Ty commited on
Commit
0ffa584
·
1 Parent(s): 4000a69

gpu to cpu

Browse files
medrax/tools/classification.py CHANGED
@@ -47,14 +47,19 @@ class ChestXRayClassifierTool(BaseTool):
47
  )
48
  args_schema: Type[BaseModel] = ChestXRayInput
49
  model: xrv.models.DenseNet = None
50
- device: Optional[str] = "cuda"
51
  transform: torchvision.transforms.Compose = None
52
 
53
- def __init__(self, model_name: str = "densenet121-res224-all", device: Optional[str] = "cuda"):
54
  super().__init__()
 
 
 
55
  self.model = xrv.models.DenseNet(weights=model_name)
56
  self.model.eval()
57
- self.device = torch.device(device) if device else "cuda"
 
 
58
  self.model = self.model.to(self.device)
59
  self.transform = torchvision.transforms.Compose([xrv.datasets.XRayCenterCrop()])
60
 
 
47
  )
48
  args_schema: Type[BaseModel] = ChestXRayInput
49
  model: xrv.models.DenseNet = None
50
+ device: Optional[torch.device] = torch.device("cpu") # Default to CPU
51
  transform: torchvision.transforms.Compose = None
52
 
53
+ def __init__(self, model_name: str = "densenet121-res224-all", device: Optional[str] = None):
54
  super().__init__()
55
+
56
+ # If device is not specified, use CUDA if available, else fallback to CPU
57
+ device = device or ("cuda" if torch.cuda.is_available() else "cpu")
58
  self.model = xrv.models.DenseNet(weights=model_name)
59
  self.model.eval()
60
+
61
+ # Assign device based on the passed or auto-detected option
62
+ self.device = torch.device(device)
63
  self.model = self.model.to(self.device)
64
  self.transform = torchvision.transforms.Compose([xrv.datasets.XRayCenterCrop()])
65
 
medrax/tools/generation.py CHANGED
@@ -61,7 +61,10 @@ class ChestXRayGeneratorTool(BaseTool):
61
  """Initialize the chest X-ray generator tool."""
62
  super().__init__()
63
 
64
- self.device = torch.device(device) if device else "cuda"
 
 
 
65
  self.model = StableDiffusionPipeline.from_pretrained(model_path, cache_dir=cache_dir)
66
  self.model = self.model.to(torch.float32).to(self.device)
67
 
@@ -121,7 +124,7 @@ class ChestXRayGeneratorTool(BaseTool):
121
 
122
  except Exception as e:
123
  return (
124
- {"error": str(e)},
125
  {
126
  "prompt": prompt,
127
  "analysis_status": "failed",
 
61
  """Initialize the chest X-ray generator tool."""
62
  super().__init__()
63
 
64
+ # Automatically detect device (cuda if available, else cpu)
65
+ device = device or ("cuda" if torch.cuda.is_available() else "cpu")
66
+ self.device = torch.device(device)
67
+
68
  self.model = StableDiffusionPipeline.from_pretrained(model_path, cache_dir=cache_dir)
69
  self.model = self.model.to(torch.float32).to(self.device)
70
 
 
124
 
125
  except Exception as e:
126
  return (
127
+ {"error": str(e)} ,
128
  {
129
  "prompt": prompt,
130
  "analysis_status": "failed",
medrax/tools/grounding.py CHANGED
@@ -50,7 +50,7 @@ class XRayPhraseGroundingTool(BaseTool):
50
 
51
  model: Any = None
52
  processor: Any = None
53
- device: str = "cuda"
54
  temp_dir: Path = None
55
 
56
  def __init__(
@@ -64,7 +64,10 @@ class XRayPhraseGroundingTool(BaseTool):
64
  ):
65
  """Initialize the XRay Phrase Grounding Tool."""
66
  super().__init__()
67
- self.device = torch.device(device) if device else "cuda"
 
 
 
68
 
69
  # Setup quantization config
70
  if load_in_4bit:
@@ -93,7 +96,6 @@ class XRayPhraseGroundingTool(BaseTool):
93
  model_path, cache_dir=cache_dir, trust_remote_code=True
94
  )
95
 
96
-
97
  self.model = self.model.eval()
98
 
99
  self.temp_dir = Path(temp_dir if temp_dir else tempfile.mkdtemp())
 
50
 
51
  model: Any = None
52
  processor: Any = None
53
+ device: torch.device = None
54
  temp_dir: Path = None
55
 
56
  def __init__(
 
64
  ):
65
  """Initialize the XRay Phrase Grounding Tool."""
66
  super().__init__()
67
+
68
+ # Automatically detect device (cuda if available, else cpu)
69
+ device = device or ("cuda" if torch.cuda.is_available() else "cpu")
70
+ self.device = torch.device(device)
71
 
72
  # Setup quantization config
73
  if load_in_4bit:
 
96
  model_path, cache_dir=cache_dir, trust_remote_code=True
97
  )
98
 
 
99
  self.model = self.model.eval()
100
 
101
  self.temp_dir = Path(temp_dir if temp_dir else tempfile.mkdtemp())
medrax/tools/llava_med.py CHANGED
@@ -11,7 +11,6 @@ from langchain_core.tools import BaseTool
11
 
12
  from PIL import Image
13
 
14
-
15
  from medrax.llava.conversation import conv_templates
16
  from medrax.llava.model.builder import load_pretrained_model
17
  from medrax.llava.mm_utils import tokenizer_image_token, process_images
@@ -65,6 +64,11 @@ class LlavaMedTool(BaseTool):
65
  **kwargs,
66
  ):
67
  super().__init__()
 
 
 
 
 
68
  self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
69
  model_path=model_path,
70
  model_base=None,
@@ -77,6 +81,9 @@ class LlavaMedTool(BaseTool):
77
  device=device,
78
  **kwargs,
79
  )
 
 
 
80
  self.model.eval()
81
 
82
  def _process_input(
@@ -101,14 +108,14 @@ class LlavaMedTool(BaseTool):
101
  input_ids = (
102
  tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
103
  .unsqueeze(0)
104
- .cuda()
105
  )
106
 
107
  image_tensor = None
108
  if image_path:
109
  image = Image.open(image_path)
110
  image_tensor = process_images([image], self.image_processor, self.model.config)[0]
111
- image_tensor = image_tensor.unsqueeze(0).half().cuda()
112
 
113
  return input_ids, image_tensor
114
 
@@ -133,8 +140,10 @@ class LlavaMedTool(BaseTool):
133
  """
134
  try:
135
  input_ids, image_tensor = self._process_input(question, image_path)
136
- input_ids = input_ids.to(device=self.model.device)
137
- image_tensor = image_tensor.to(device=self.model.device, dtype=self.model.dtype)
 
 
138
 
139
  with torch.inference_mode():
140
  output_ids = self.model.generate(
 
11
 
12
  from PIL import Image
13
 
 
14
  from medrax.llava.conversation import conv_templates
15
  from medrax.llava.model.builder import load_pretrained_model
16
  from medrax.llava.mm_utils import tokenizer_image_token, process_images
 
64
  **kwargs,
65
  ):
66
  super().__init__()
67
+
68
+ # Set the device (cuda or cpu)
69
+ self.device = torch.device(device) if device else torch.device("cuda")
70
+
71
+ # Load the model and tokenizer
72
  self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
73
  model_path=model_path,
74
  model_base=None,
 
81
  device=device,
82
  **kwargs,
83
  )
84
+
85
+ # Move the model to the desired device
86
+ self.model.to(self.device)
87
  self.model.eval()
88
 
89
  def _process_input(
 
108
  input_ids = (
109
  tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
110
  .unsqueeze(0)
111
+ .to(self.device) # Move to the correct device
112
  )
113
 
114
  image_tensor = None
115
  if image_path:
116
  image = Image.open(image_path)
117
  image_tensor = process_images([image], self.image_processor, self.model.config)[0]
118
+ image_tensor = image_tensor.unsqueeze(0).to(self.device, dtype=self.model.dtype) # Move to device
119
 
120
  return input_ids, image_tensor
121
 
 
140
  """
141
  try:
142
  input_ids, image_tensor = self._process_input(question, image_path)
143
+
144
+ # Ensure that inputs are on the same device as the model
145
+ input_ids = input_ids.to(self.device)
146
+ image_tensor = image_tensor.to(self.device, dtype=self.model.dtype)
147
 
148
  with torch.inference_mode():
149
  output_ids = self.model.generate(
medrax/tools/report_generation.py CHANGED
@@ -47,7 +47,7 @@ class ChestXRayReportGeneratorTool(BaseTool):
47
  "to a chest X-ray image file. Output is a structured report with both detailed "
48
  "observations and key clinical conclusions."
49
  )
50
- device: Optional[str] = "cuda"
51
  args_schema: Type[BaseModel] = ChestXRayInput
52
  findings_model: VisionEncoderDecoderModel = None
53
  impression_model: VisionEncoderDecoderModel = None
@@ -57,10 +57,10 @@ class ChestXRayReportGeneratorTool(BaseTool):
57
  impression_processor: ViTImageProcessor = None
58
  generation_args: Dict[str, Any] = None
59
 
60
- def __init__(self, cache_dir: str = "/model-weights", device: Optional[str] = "cuda"):
61
  """Initialize the ChestXRayReportGeneratorTool with both findings and impression models."""
62
  super().__init__()
63
- self.device = torch.device(device) if device else "cuda"
64
 
65
  # Initialize findings model
66
  self.findings_model = VisionEncoderDecoderModel.from_pretrained(
@@ -84,7 +84,7 @@ class ChestXRayReportGeneratorTool(BaseTool):
84
  "IAMJB/chexpert-mimic-cxr-impression-baseline", cache_dir=cache_dir
85
  )
86
 
87
- # Move models to device
88
  self.findings_model = self.findings_model.to(self.device)
89
  self.impression_model = self.impression_model.to(self.device)
90
 
 
47
  "to a chest X-ray image file. Output is a structured report with both detailed "
48
  "observations and key clinical conclusions."
49
  )
50
+ device: Optional[str] = "cpu" # Change the device to "cpu"
51
  args_schema: Type[BaseModel] = ChestXRayInput
52
  findings_model: VisionEncoderDecoderModel = None
53
  impression_model: VisionEncoderDecoderModel = None
 
57
  impression_processor: ViTImageProcessor = None
58
  generation_args: Dict[str, Any] = None
59
 
60
+ def __init__(self, cache_dir: str = "/model-weights", device: Optional[str] = "cpu"):
61
  """Initialize the ChestXRayReportGeneratorTool with both findings and impression models."""
62
  super().__init__()
63
+ self.device = torch.device(device) if device else torch.device("cpu") # Ensure CPU is used
64
 
65
  # Initialize findings model
66
  self.findings_model = VisionEncoderDecoderModel.from_pretrained(
 
84
  "IAMJB/chexpert-mimic-cxr-impression-baseline", cache_dir=cache_dir
85
  )
86
 
87
+ # Move models to device (CPU)
88
  self.findings_model = self.findings_model.to(self.device)
89
  self.impression_model = self.impression_model.to(self.device)
90