voxmenthe commited on
Commit
e37651e
·
1 Parent(s): 105a9fa

Update device setting for inference

Browse files
Files changed (1) hide show
  1. inference.py +11 -1
inference.py CHANGED
@@ -16,6 +16,15 @@ class SentimentInference:
16
  model_yaml_cfg = config_data.get('model', {})
17
  inference_yaml_cfg = config_data.get('inference', {})
18
 
 
 
 
 
 
 
 
 
 
19
  model_hf_repo_id = model_yaml_cfg.get('name_or_path')
20
  tokenizer_hf_repo_id = model_yaml_cfg.get('tokenizer_name_or_path', model_hf_repo_id)
21
  local_model_weights_path = inference_yaml_cfg.get('model_path') # Path for local .pt file
@@ -127,12 +136,13 @@ class SentimentInference:
127
  print(f"[INFERENCE_LOG] HUB_LOAD_FALLBACK: Critical error during fallback load: {e_fallback}")
128
  raise e_fallback # Re-raise if fallback also fails catastrophically
129
 
 
130
  self.model.eval()
131
 
132
  def predict(self, text: str) -> Dict[str, Any]:
133
  inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=self.max_length, padding=True)
134
  with torch.no_grad():
135
- outputs = self.model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
136
  logits = outputs.get("logits") # Use .get for safety
137
  if logits is None:
138
  raise ValueError("Model output did not contain 'logits'. Check model's forward pass.")
 
16
  model_yaml_cfg = config_data.get('model', {})
17
  inference_yaml_cfg = config_data.get('inference', {})
18
 
19
+ # Determine device early
20
+ if torch.cuda.is_available():
21
+ self.device = torch.device("cuda")
22
+ elif torch.backends.mps.is_available(): # Check for MPS (Apple Silicon GPU)
23
+ self.device = torch.device("mps")
24
+ else:
25
+ self.device = torch.device("cpu")
26
+ print(f"[INFERENCE_LOG] Using device: {self.device}")
27
+
28
  model_hf_repo_id = model_yaml_cfg.get('name_or_path')
29
  tokenizer_hf_repo_id = model_yaml_cfg.get('tokenizer_name_or_path', model_hf_repo_id)
30
  local_model_weights_path = inference_yaml_cfg.get('model_path') # Path for local .pt file
 
136
  print(f"[INFERENCE_LOG] HUB_LOAD_FALLBACK: Critical error during fallback load: {e_fallback}")
137
  raise e_fallback # Re-raise if fallback also fails catastrophically
138
 
139
+ self.model.to(self.device) # Move model to the determined device
140
  self.model.eval()
141
 
142
  def predict(self, text: str) -> Dict[str, Any]:
143
  inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=self.max_length, padding=True)
144
  with torch.no_grad():
145
+ outputs = self.model(input_ids=inputs['input_ids'].to(self.device), attention_mask=inputs['attention_mask'].to(self.device))
146
  logits = outputs.get("logits") # Use .get for safety
147
  if logits is None:
148
  raise ValueError("Model output did not contain 'logits'. Check model's forward pass.")