Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
new test inference
Browse files- inference.py +60 -46
inference.py
CHANGED
@@ -29,78 +29,92 @@ class SentimentInference:
|
|
29 |
if not tokenizer_hf_repo_id and not model_hf_repo_id:
|
30 |
raise ValueError("Either model.tokenizer_name_or_path or model.name_or_path (as fallback for tokenizer) must be specified in config.yaml")
|
31 |
effective_tokenizer_repo_id = tokenizer_hf_repo_id or model_hf_repo_id
|
32 |
-
print(f"Loading tokenizer from: {effective_tokenizer_repo_id}")
|
33 |
self.tokenizer = AutoTokenizer.from_pretrained(effective_tokenizer_repo_id)
|
34 |
|
35 |
# --- Model Loading --- #
|
36 |
# Determine if we are loading from a local .pt file or from Hugging Face Hub
|
37 |
load_from_local_pt = False
|
38 |
if local_model_weights_path and os.path.isfile(local_model_weights_path):
|
39 |
-
print(f"Found local model weights path: {local_model_weights_path}")
|
40 |
print(f"--- Debug: Found local model weights path: {local_model_weights_path} ---") # Add this
|
41 |
load_from_local_pt = True
|
42 |
elif not model_hf_repo_id:
|
43 |
raise ValueError("No local model_path found and model.name_or_path (for Hub) is not specified in config.yaml")
|
44 |
|
|
|
45 |
print(f"--- Debug: load_from_local_pt is: {load_from_local_pt} ---") # Add this
|
46 |
|
47 |
if load_from_local_pt:
|
48 |
-
print("Attempting to load model from
|
49 |
print("--- Debug: Entering LOCAL .pt loading path ---") # Add this
|
50 |
# Base BERT config must still be loaded, usually from a Hub ID (e.g., original base model)
|
51 |
# This base_model_for_config_id is crucial for building the correct ModernBertForSentiment structure.
|
52 |
-
base_model_for_config_id = model_yaml_cfg.get('base_model_for_config',
|
53 |
-
print(f"--- Debug: base_model_for_config_id (for local .pt): {base_model_for_config_id} ---") # Add this
|
54 |
if not base_model_for_config_id:
|
55 |
-
raise ValueError("
|
56 |
|
57 |
-
print(f"
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
|
67 |
-
print("
|
68 |
-
self.model = ModernBertForSentiment(
|
69 |
|
70 |
-
print(f"Loading
|
71 |
checkpoint = torch.load(local_model_weights_path, map_location=torch.device('cpu'))
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
self.model.load_state_dict(model_state_to_load)
|
77 |
-
print(f"Model loaded successfully from local checkpoint: {local_model_weights_path}.")
|
78 |
|
79 |
-
|
80 |
-
|
81 |
-
print(f"
|
82 |
-
print(f"--- Debug: model_hf_repo_id (for Hub loading): {model_hf_repo_id} ---") # Add this
|
83 |
-
if not model_hf_repo_id:
|
84 |
-
raise ValueError("model.name_or_path must be specified in config.yaml for Hub loading.")
|
85 |
|
86 |
-
|
87 |
-
|
|
|
|
|
|
|
88 |
|
89 |
-
#
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
|
95 |
-
print(f"Instantiating and loading model weights for {model_hf_repo_id}...")
|
96 |
-
self.model = AutoModelForSequenceClassification.from_pretrained(
|
97 |
-
model_hf_repo_id,
|
98 |
-
config=loaded_config,
|
99 |
-
trust_remote_code=True,
|
100 |
-
force_download=True # <--- TEMPORARY - remove when everything is working
|
101 |
-
)
|
102 |
-
print(f"Model {model_hf_repo_id} loaded successfully from Hugging Face Hub.")
|
103 |
-
|
104 |
self.model.eval()
|
105 |
|
106 |
def predict(self, text: str) -> Dict[str, Any]:
|
|
|
29 |
if not tokenizer_hf_repo_id and not model_hf_repo_id:
|
30 |
raise ValueError("Either model.tokenizer_name_or_path or model.name_or_path (as fallback for tokenizer) must be specified in config.yaml")
|
31 |
effective_tokenizer_repo_id = tokenizer_hf_repo_id or model_hf_repo_id
|
32 |
+
print(f"[INFERENCE_LOG] Loading tokenizer from: {effective_tokenizer_repo_id}") # Logging
|
33 |
self.tokenizer = AutoTokenizer.from_pretrained(effective_tokenizer_repo_id)
|
34 |
|
35 |
# --- Model Loading --- #
|
36 |
# Determine if we are loading from a local .pt file or from Hugging Face Hub
|
37 |
load_from_local_pt = False
|
38 |
if local_model_weights_path and os.path.isfile(local_model_weights_path):
|
39 |
+
print(f"[INFERENCE_LOG] Found local model weights path: {local_model_weights_path}") # Logging
|
40 |
print(f"--- Debug: Found local model weights path: {local_model_weights_path} ---") # Add this
|
41 |
load_from_local_pt = True
|
42 |
elif not model_hf_repo_id:
|
43 |
raise ValueError("No local model_path found and model.name_or_path (for Hub) is not specified in config.yaml")
|
44 |
|
45 |
+
print(f"[INFERENCE_LOG] load_from_local_pt: {load_from_local_pt}") # Logging
|
46 |
print(f"--- Debug: load_from_local_pt is: {load_from_local_pt} ---") # Add this
|
47 |
|
48 |
if load_from_local_pt:
|
49 |
+
print("[INFERENCE_LOG] Attempting to load model from LOCAL .pt checkpoint...") # Logging
|
50 |
print("--- Debug: Entering LOCAL .pt loading path ---") # Add this
|
51 |
# Base BERT config must still be loaded, usually from a Hub ID (e.g., original base model)
|
52 |
# This base_model_for_config_id is crucial for building the correct ModernBertForSentiment structure.
|
53 |
+
base_model_for_config_id = model_yaml_cfg.get('base_model_for_config', model_yaml_cfg.get('name_or_path'))
|
|
|
54 |
if not base_model_for_config_id:
|
55 |
+
raise ValueError("model.base_model_for_config or model.name_or_path must be specified in config.yaml when loading local .pt for ModernBertForSentiment structure.")
|
56 |
|
57 |
+
print(f"[INFERENCE_LOG] LOCAL_PT_LOAD: base_model_for_config_id: {base_model_for_config_id}") # Logging
|
58 |
+
|
59 |
+
model_config = ModernBertConfig.from_pretrained(
|
60 |
+
base_model_for_config_id,
|
61 |
+
num_labels=model_yaml_cfg.get('num_labels', 1), # from config.yaml via model_yaml_cfg
|
62 |
+
pooling_strategy=model_yaml_cfg.get('pooling_strategy', 'mean'), # from config.yaml via model_yaml_cfg
|
63 |
+
num_weighted_layers=model_yaml_cfg.get('num_weighted_layers', 4) # from config.yaml via model_yaml_cfg
|
64 |
+
)
|
65 |
+
print(f"[INFERENCE_LOG] LOCAL_PT_LOAD: Loaded ModernBertConfig: {model_config.to_diff_dict()}") # Logging
|
66 |
|
67 |
+
print(f"[INFERENCE_LOG] LOCAL_PT_LOAD: Initializing ModernBertForSentiment with this config.") # Logging
|
68 |
+
self.model = ModernBertForSentiment(config=model_config)
|
69 |
|
70 |
+
print(f"[INFERENCE_LOG] LOCAL_PT_LOAD: Loading weights from checkpoint: {local_model_weights_path}") # Logging
|
71 |
checkpoint = torch.load(local_model_weights_path, map_location=torch.device('cpu'))
|
72 |
+
|
73 |
+
state_dict_to_load = checkpoint.get('model_state_dict', checkpoint.get('state_dict', checkpoint))
|
74 |
+
if not isinstance(state_dict_to_load, dict):
|
75 |
+
raise TypeError(f"Loaded checkpoint from {local_model_weights_path} is not a dict or does not contain 'model_state_dict' or 'state_dict'.")
|
|
|
|
|
76 |
|
77 |
+
# Log first few keys for debugging
|
78 |
+
first_few_keys = list(state_dict_to_load.keys())[:5]
|
79 |
+
print(f"[INFERENCE_LOG] LOCAL_PT_LOAD: First few keys from checkpoint state_dict: {first_few_keys}") # Logging
|
|
|
|
|
|
|
80 |
|
81 |
+
self.model.load_state_dict(state_dict_to_load)
|
82 |
+
print(f"[INFERENCE_LOG] LOCAL_PT_LOAD: Weights loaded successfully into ModernBertForSentiment from {local_model_weights_path}.") # Logging
|
83 |
+
else:
|
84 |
+
# Load from Hugging Face Hub
|
85 |
+
print(f"[INFERENCE_LOG] Attempting to load model from HUGGING_FACE_HUB: {model_hf_repo_id}") # Logging
|
86 |
|
87 |
+
# Here, we use the config that's packaged with the model on the Hub by default.
|
88 |
+
# We just add/override num_labels, pooling_strategy, num_weighted_layers if they are in our local config.yaml
|
89 |
+
# as these might be specific to our fine-tuning and not in the Hub's default config.json.
|
90 |
+
hub_config_overrides = {
|
91 |
+
"num_labels": model_yaml_cfg.get('num_labels', 1),
|
92 |
+
"pooling_strategy": model_yaml_cfg.get('pooling_strategy', 'mean'),
|
93 |
+
"num_weighted_layers": model_yaml_cfg.get('num_weighted_layers', 6) # Default to 6 now
|
94 |
+
}
|
95 |
+
print(f"[INFERENCE_LOG] HUB_LOAD: Overrides for Hub config: {hub_config_overrides}") # Logging
|
96 |
+
|
97 |
+
try:
|
98 |
+
# Using ModernBertForSentiment.from_pretrained directly.
|
99 |
+
# This assumes the config.json on the Hub for 'model_hf_repo_id' is compatible
|
100 |
+
# or that from_pretrained can correctly initialize ModernBertForSentiment with it.
|
101 |
+
self.model = ModernBertForSentiment.from_pretrained(
|
102 |
+
model_hf_repo_id,
|
103 |
+
**hub_config_overrides
|
104 |
+
)
|
105 |
+
print(f"[INFERENCE_LOG] HUB_LOAD: Model ModernBertForSentiment loaded successfully from {model_hf_repo_id}.") # Logging
|
106 |
+
except Exception as e:
|
107 |
+
print(f"[INFERENCE_LOG] HUB_LOAD: Error loading ModernBertForSentiment from {model_hf_repo_id}: {e}") # Logging
|
108 |
+
print(f"[INFERENCE_LOG] HUB_LOAD: Falling back to AutoModelForSequenceClassification for {model_hf_repo_id}.") # Logging
|
109 |
+
# Fallback: Try with AutoModelForSequenceClassification if ModernBertForSentiment fails
|
110 |
+
# This might happen if the Hub model isn't strictly saved as a ModernBertForSentiment type
|
111 |
+
# or if its config.json doesn't have _custom_class set, etc.
|
112 |
+
self.model = AutoModelForSequenceClassification.from_pretrained(
|
113 |
+
model_hf_repo_id,
|
114 |
+
**hub_config_overrides
|
115 |
+
)
|
116 |
+
print(f"[INFERENCE_LOG] HUB_LOAD: AutoModelForSequenceClassification loaded for {model_hf_repo_id}.") # Logging
|
117 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
self.model.eval()
|
119 |
|
120 |
def predict(self, text: str) -> Dict[str, Any]:
|