Spaces:
Sleeping
Sleeping
Update device setting for inference
Browse files- inference.py +11 -1
inference.py
CHANGED
@@ -16,6 +16,15 @@ class SentimentInference:
|
|
16 |
model_yaml_cfg = config_data.get('model', {})
|
17 |
inference_yaml_cfg = config_data.get('inference', {})
|
18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
model_hf_repo_id = model_yaml_cfg.get('name_or_path')
|
20 |
tokenizer_hf_repo_id = model_yaml_cfg.get('tokenizer_name_or_path', model_hf_repo_id)
|
21 |
local_model_weights_path = inference_yaml_cfg.get('model_path') # Path for local .pt file
|
@@ -127,12 +136,13 @@ class SentimentInference:
|
|
127 |
print(f"[INFERENCE_LOG] HUB_LOAD_FALLBACK: Critical error during fallback load: {e_fallback}")
|
128 |
raise e_fallback # Re-raise if fallback also fails catastrophically
|
129 |
|
|
|
130 |
self.model.eval()
|
131 |
|
132 |
def predict(self, text: str) -> Dict[str, Any]:
|
133 |
inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=self.max_length, padding=True)
|
134 |
with torch.no_grad():
|
135 |
-
outputs = self.model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
|
136 |
logits = outputs.get("logits") # Use .get for safety
|
137 |
if logits is None:
|
138 |
raise ValueError("Model output did not contain 'logits'. Check model's forward pass.")
|
|
|
16 |
model_yaml_cfg = config_data.get('model', {})
|
17 |
inference_yaml_cfg = config_data.get('inference', {})
|
18 |
|
19 |
+
# Determine device early
|
20 |
+
if torch.cuda.is_available():
|
21 |
+
self.device = torch.device("cuda")
|
22 |
+
elif torch.backends.mps.is_available(): # Check for MPS (Apple Silicon GPU)
|
23 |
+
self.device = torch.device("mps")
|
24 |
+
else:
|
25 |
+
self.device = torch.device("cpu")
|
26 |
+
print(f"[INFERENCE_LOG] Using device: {self.device}")
|
27 |
+
|
28 |
model_hf_repo_id = model_yaml_cfg.get('name_or_path')
|
29 |
tokenizer_hf_repo_id = model_yaml_cfg.get('tokenizer_name_or_path', model_hf_repo_id)
|
30 |
local_model_weights_path = inference_yaml_cfg.get('model_path') # Path for local .pt file
|
|
|
136 |
print(f"[INFERENCE_LOG] HUB_LOAD_FALLBACK: Critical error during fallback load: {e_fallback}")
|
137 |
raise e_fallback # Re-raise if fallback also fails catastrophically
|
138 |
|
139 |
+
self.model.to(self.device) # Move model to the determined device
|
140 |
self.model.eval()
|
141 |
|
142 |
def predict(self, text: str) -> Dict[str, Any]:
|
143 |
inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=self.max_length, padding=True)
|
144 |
with torch.no_grad():
|
145 |
+
outputs = self.model(input_ids=inputs['input_ids'].to(self.device), attention_mask=inputs['attention_mask'].to(self.device))
|
146 |
logits = outputs.get("logits") # Use .get for safety
|
147 |
if logits is None:
|
148 |
raise ValueError("Model output did not contain 'logits'. Check model's forward pass.")
|