voxmenthe commited on
Commit
ff78fc6
·
1 Parent(s): 6529956

new test inference

Browse files
Files changed (1) hide show
  1. inference.py +60 -46
inference.py CHANGED
@@ -29,78 +29,92 @@ class SentimentInference:
29
  if not tokenizer_hf_repo_id and not model_hf_repo_id:
30
  raise ValueError("Either model.tokenizer_name_or_path or model.name_or_path (as fallback for tokenizer) must be specified in config.yaml")
31
  effective_tokenizer_repo_id = tokenizer_hf_repo_id or model_hf_repo_id
32
- print(f"Loading tokenizer from: {effective_tokenizer_repo_id}")
33
  self.tokenizer = AutoTokenizer.from_pretrained(effective_tokenizer_repo_id)
34
 
35
  # --- Model Loading --- #
36
  # Determine if we are loading from a local .pt file or from Hugging Face Hub
37
  load_from_local_pt = False
38
  if local_model_weights_path and os.path.isfile(local_model_weights_path):
39
- print(f"Found local model weights path: {local_model_weights_path}")
40
  print(f"--- Debug: Found local model weights path: {local_model_weights_path} ---") # Add this
41
  load_from_local_pt = True
42
  elif not model_hf_repo_id:
43
  raise ValueError("No local model_path found and model.name_or_path (for Hub) is not specified in config.yaml")
44
 
 
45
  print(f"--- Debug: load_from_local_pt is: {load_from_local_pt} ---") # Add this
46
 
47
  if load_from_local_pt:
48
- print("Attempting to load model from local .pt checkpoint...")
49
  print("--- Debug: Entering LOCAL .pt loading path ---") # Add this
50
  # Base BERT config must still be loaded, usually from a Hub ID (e.g., original base model)
51
  # This base_model_for_config_id is crucial for building the correct ModernBertForSentiment structure.
52
- base_model_for_config_id = model_yaml_cfg.get('base_model_for_config', model_hf_repo_id or tokenizer_hf_repo_id)
53
- print(f"--- Debug: base_model_for_config_id (for local .pt): {base_model_for_config_id} ---") # Add this
54
  if not base_model_for_config_id:
55
- raise ValueError("For local .pt loading, model.base_model_for_config must be specified in config.yaml (e.g., 'answerdotai/ModernBERT-base') to build the model structure.")
56
 
57
- print(f"Loading ModernBertConfig for structure from: {base_model_for_config_id}")
58
- bert_config = ModernBertConfig.from_pretrained(base_model_for_config_id)
59
-
60
- # Augment config with parameters from model_yaml_cfg
61
- bert_config.pooling_strategy = model_yaml_cfg.get('pooling_strategy', 'mean')
62
- bert_config.num_weighted_layers = model_yaml_cfg.get('num_weighted_layers', 4)
63
- bert_config.classifier_dropout = model_yaml_cfg.get('dropout')
64
- bert_config.num_labels = model_yaml_cfg.get('num_labels', 1)
65
- # bert_config.loss_function = model_yaml_cfg.get('loss_function') # If needed by __init__
66
 
67
- print("Instantiating ModernBertForSentiment model structure...")
68
- self.model = ModernBertForSentiment(bert_config)
69
 
70
- print(f"Loading model weights from local checkpoint: {local_model_weights_path}")
71
  checkpoint = torch.load(local_model_weights_path, map_location=torch.device('cpu'))
72
- if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
73
- model_state_to_load = checkpoint['model_state_dict']
74
- else:
75
- model_state_to_load = checkpoint # Assume it's the state_dict itself
76
- self.model.load_state_dict(model_state_to_load)
77
- print(f"Model loaded successfully from local checkpoint: {local_model_weights_path}.")
78
 
79
- else: # Load from Hugging Face Hub
80
- print(f"Attempting to load model from Hugging Face Hub: {model_hf_repo_id}...")
81
- print(f"--- Debug: Entering HUGGING FACE HUB loading path ---") # Add this
82
- print(f"--- Debug: model_hf_repo_id (for Hub loading): {model_hf_repo_id} ---") # Add this
83
- if not model_hf_repo_id:
84
- raise ValueError("model.name_or_path must be specified in config.yaml for Hub loading.")
85
 
86
- print(f"Loading base ModernBertConfig from: {model_hf_repo_id}")
87
- loaded_config = ModernBertConfig.from_pretrained(model_hf_repo_id)
 
 
 
88
 
89
- # Augment loaded_config
90
- loaded_config.pooling_strategy = model_yaml_cfg.get('pooling_strategy', 'mean')
91
- loaded_config.num_weighted_layers = model_yaml_cfg.get('num_weighted_layers', 6) # Default to 6 now
92
- loaded_config.classifier_dropout = model_yaml_cfg.get('dropout')
93
- loaded_config.num_labels = model_yaml_cfg.get('num_labels', 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
- print(f"Instantiating and loading model weights for {model_hf_repo_id}...")
96
- self.model = AutoModelForSequenceClassification.from_pretrained(
97
- model_hf_repo_id,
98
- config=loaded_config,
99
- trust_remote_code=True,
100
- force_download=True # <--- TEMPORARY - remove when everything is working
101
- )
102
- print(f"Model {model_hf_repo_id} loaded successfully from Hugging Face Hub.")
103
-
104
  self.model.eval()
105
 
106
  def predict(self, text: str) -> Dict[str, Any]:
 
29
  if not tokenizer_hf_repo_id and not model_hf_repo_id:
30
  raise ValueError("Either model.tokenizer_name_or_path or model.name_or_path (as fallback for tokenizer) must be specified in config.yaml")
31
  effective_tokenizer_repo_id = tokenizer_hf_repo_id or model_hf_repo_id
32
+ print(f"[INFERENCE_LOG] Loading tokenizer from: {effective_tokenizer_repo_id}") # Logging
33
  self.tokenizer = AutoTokenizer.from_pretrained(effective_tokenizer_repo_id)
34
 
35
  # --- Model Loading --- #
36
  # Determine if we are loading from a local .pt file or from Hugging Face Hub
37
  load_from_local_pt = False
38
  if local_model_weights_path and os.path.isfile(local_model_weights_path):
39
+ print(f"[INFERENCE_LOG] Found local model weights path: {local_model_weights_path}") # Logging
40
  print(f"--- Debug: Found local model weights path: {local_model_weights_path} ---") # Add this
41
  load_from_local_pt = True
42
  elif not model_hf_repo_id:
43
  raise ValueError("No local model_path found and model.name_or_path (for Hub) is not specified in config.yaml")
44
 
45
+ print(f"[INFERENCE_LOG] load_from_local_pt: {load_from_local_pt}") # Logging
46
  print(f"--- Debug: load_from_local_pt is: {load_from_local_pt} ---") # Add this
47
 
48
  if load_from_local_pt:
49
+ print("[INFERENCE_LOG] Attempting to load model from LOCAL .pt checkpoint...") # Logging
50
  print("--- Debug: Entering LOCAL .pt loading path ---") # Add this
51
  # Base BERT config must still be loaded, usually from a Hub ID (e.g., original base model)
52
  # This base_model_for_config_id is crucial for building the correct ModernBertForSentiment structure.
53
+ base_model_for_config_id = model_yaml_cfg.get('base_model_for_config', model_yaml_cfg.get('name_or_path'))
 
54
  if not base_model_for_config_id:
55
+ raise ValueError("model.base_model_for_config or model.name_or_path must be specified in config.yaml when loading local .pt for ModernBertForSentiment structure.")
56
 
57
+ print(f"[INFERENCE_LOG] LOCAL_PT_LOAD: base_model_for_config_id: {base_model_for_config_id}") # Logging
58
+
59
+ model_config = ModernBertConfig.from_pretrained(
60
+ base_model_for_config_id,
61
+ num_labels=model_yaml_cfg.get('num_labels', 1), # from config.yaml via model_yaml_cfg
62
+ pooling_strategy=model_yaml_cfg.get('pooling_strategy', 'mean'), # from config.yaml via model_yaml_cfg
63
+ num_weighted_layers=model_yaml_cfg.get('num_weighted_layers', 4) # from config.yaml via model_yaml_cfg
64
+ )
65
+ print(f"[INFERENCE_LOG] LOCAL_PT_LOAD: Loaded ModernBertConfig: {model_config.to_diff_dict()}") # Logging
66
 
67
+ print(f"[INFERENCE_LOG] LOCAL_PT_LOAD: Initializing ModernBertForSentiment with this config.") # Logging
68
+ self.model = ModernBertForSentiment(config=model_config)
69
 
70
+ print(f"[INFERENCE_LOG] LOCAL_PT_LOAD: Loading weights from checkpoint: {local_model_weights_path}") # Logging
71
  checkpoint = torch.load(local_model_weights_path, map_location=torch.device('cpu'))
72
+
73
+ state_dict_to_load = checkpoint.get('model_state_dict', checkpoint.get('state_dict', checkpoint))
74
+ if not isinstance(state_dict_to_load, dict):
75
+ raise TypeError(f"Loaded checkpoint from {local_model_weights_path} is not a dict or does not contain 'model_state_dict' or 'state_dict'.")
 
 
76
 
77
+ # Log first few keys for debugging
78
+ first_few_keys = list(state_dict_to_load.keys())[:5]
79
+ print(f"[INFERENCE_LOG] LOCAL_PT_LOAD: First few keys from checkpoint state_dict: {first_few_keys}") # Logging
 
 
 
80
 
81
+ self.model.load_state_dict(state_dict_to_load)
82
+ print(f"[INFERENCE_LOG] LOCAL_PT_LOAD: Weights loaded successfully into ModernBertForSentiment from {local_model_weights_path}.") # Logging
83
+ else:
84
+ # Load from Hugging Face Hub
85
+ print(f"[INFERENCE_LOG] Attempting to load model from HUGGING_FACE_HUB: {model_hf_repo_id}") # Logging
86
 
87
+ # Here, we use the config that's packaged with the model on the Hub by default.
88
+ # We just add/override num_labels, pooling_strategy, num_weighted_layers if they are in our local config.yaml
89
+ # as these might be specific to our fine-tuning and not in the Hub's default config.json.
90
+ hub_config_overrides = {
91
+ "num_labels": model_yaml_cfg.get('num_labels', 1),
92
+ "pooling_strategy": model_yaml_cfg.get('pooling_strategy', 'mean'),
93
+ "num_weighted_layers": model_yaml_cfg.get('num_weighted_layers', 6) # Default to 6 now
94
+ }
95
+ print(f"[INFERENCE_LOG] HUB_LOAD: Overrides for Hub config: {hub_config_overrides}") # Logging
96
+
97
+ try:
98
+ # Using ModernBertForSentiment.from_pretrained directly.
99
+ # This assumes the config.json on the Hub for 'model_hf_repo_id' is compatible
100
+ # or that from_pretrained can correctly initialize ModernBertForSentiment with it.
101
+ self.model = ModernBertForSentiment.from_pretrained(
102
+ model_hf_repo_id,
103
+ **hub_config_overrides
104
+ )
105
+ print(f"[INFERENCE_LOG] HUB_LOAD: Model ModernBertForSentiment loaded successfully from {model_hf_repo_id}.") # Logging
106
+ except Exception as e:
107
+ print(f"[INFERENCE_LOG] HUB_LOAD: Error loading ModernBertForSentiment from {model_hf_repo_id}: {e}") # Logging
108
+ print(f"[INFERENCE_LOG] HUB_LOAD: Falling back to AutoModelForSequenceClassification for {model_hf_repo_id}.") # Logging
109
+ # Fallback: Try with AutoModelForSequenceClassification if ModernBertForSentiment fails
110
+ # This might happen if the Hub model isn't strictly saved as a ModernBertForSentiment type
111
+ # or if its config.json doesn't have _custom_class set, etc.
112
+ self.model = AutoModelForSequenceClassification.from_pretrained(
113
+ model_hf_repo_id,
114
+ **hub_config_overrides
115
+ )
116
+ print(f"[INFERENCE_LOG] HUB_LOAD: AutoModelForSequenceClassification loaded for {model_hf_repo_id}.") # Logging
117
 
 
 
 
 
 
 
 
 
 
118
  self.model.eval()
119
 
120
  def predict(self, text: str) -> Dict[str, Any]: