File size: 9,591 Bytes
472f1d2
105a9fa
472f1d2
 
6529956
 
472f1d2
 
 
6529956
 
472f1d2
b976908
6529956
472f1d2
b976908
 
472f1d2
e37651e
 
 
 
 
 
 
 
 
b976908
 
6529956
 
 
 
472f1d2
b976908
472f1d2
6529956
 
 
 
ff78fc6
6529956
 
 
 
 
 
ff78fc6
6529956
 
 
 
 
ff78fc6
6529956
b976908
6529956
ff78fc6
6529956
 
 
ff78fc6
6529956
ff78fc6
6529956
ff78fc6
 
 
 
 
 
 
 
 
6529956
ff78fc6
 
6529956
ff78fc6
6529956
ff78fc6
 
 
 
6529956
ff78fc6
 
 
6529956
ff78fc6
 
 
 
 
6529956
105a9fa
ff78fc6
 
105a9fa
ff78fc6
105a9fa
ff78fc6
 
105a9fa
 
 
 
 
 
 
 
ff78fc6
 
105a9fa
ff78fc6
105a9fa
ff78fc6
105a9fa
ff78fc6
105a9fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6529956
e37651e
472f1d2
 
 
b976908
472f1d2
e37651e
b976908
 
 
472f1d2
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoConfig, ModernBertConfig
from typing import Dict, Any
import yaml
import os
from models import ModernBertForSentiment

class SentimentInference:
    def __init__(self, config_path: str = "config.yaml"):
        """Load configuration and initialize model and tokenizer from local checkpoint or Hugging Face Hub."""
        print(f"--- Debug: SentimentInference __init__ received config_path: {config_path} ---") # Add this
        with open(config_path, 'r') as f:
            config_data = yaml.safe_load(f)
        print(f"--- Debug: SentimentInference loaded config_data: {config_data} ---") # Add this
        
        model_yaml_cfg = config_data.get('model', {})
        inference_yaml_cfg = config_data.get('inference', {})
        
        # Determine device early
        if torch.cuda.is_available():
            self.device = torch.device("cuda")
        elif torch.backends.mps.is_available(): # Check for MPS (Apple Silicon GPU)
            self.device = torch.device("mps")
        else:
            self.device = torch.device("cpu")
        print(f"[INFERENCE_LOG] Using device: {self.device}")
        
        model_hf_repo_id = model_yaml_cfg.get('name_or_path')
        tokenizer_hf_repo_id = model_yaml_cfg.get('tokenizer_name_or_path', model_hf_repo_id)
        local_model_weights_path = inference_yaml_cfg.get('model_path') # Path for local .pt file

        print(f"--- Debug: model_hf_repo_id: {model_hf_repo_id} ---") # Add this
        print(f"--- Debug: local_model_weights_path: {local_model_weights_path} ---") # Add this

        self.max_length = inference_yaml_cfg.get('max_length', model_yaml_cfg.get('max_length', 512))

        # --- Tokenizer Loading (always from Hub for now, or could be made conditional) ---
        if not tokenizer_hf_repo_id and not model_hf_repo_id:
            raise ValueError("Either model.tokenizer_name_or_path or model.name_or_path (as fallback for tokenizer) must be specified in config.yaml")
        effective_tokenizer_repo_id = tokenizer_hf_repo_id or model_hf_repo_id
        print(f"[INFERENCE_LOG] Loading tokenizer from: {effective_tokenizer_repo_id}") # Logging
        self.tokenizer = AutoTokenizer.from_pretrained(effective_tokenizer_repo_id)

        # --- Model Loading --- #
        # Determine if we are loading from a local .pt file or from Hugging Face Hub
        load_from_local_pt = False
        if local_model_weights_path and os.path.isfile(local_model_weights_path):
            print(f"[INFERENCE_LOG] Found local model weights path: {local_model_weights_path}") # Logging
            print(f"--- Debug: Found local model weights path: {local_model_weights_path} ---") # Add this
            load_from_local_pt = True
        elif not model_hf_repo_id:
            raise ValueError("No local model_path found and model.name_or_path (for Hub) is not specified in config.yaml")

        print(f"[INFERENCE_LOG] load_from_local_pt: {load_from_local_pt}") # Logging
        print(f"--- Debug: load_from_local_pt is: {load_from_local_pt} ---") # Add this

        if load_from_local_pt:
            print("[INFERENCE_LOG] Attempting to load model from LOCAL .pt checkpoint...") # Logging
            print("--- Debug: Entering LOCAL .pt loading path ---") # Add this
            # Base BERT config must still be loaded, usually from a Hub ID (e.g., original base model)
            # This base_model_for_config_id is crucial for building the correct ModernBertForSentiment structure.
            base_model_for_config_id = model_yaml_cfg.get('base_model_for_config', model_yaml_cfg.get('name_or_path'))
            if not base_model_for_config_id:
                 raise ValueError("model.base_model_for_config or model.name_or_path must be specified in config.yaml when loading local .pt for ModernBertForSentiment structure.")
            
            print(f"[INFERENCE_LOG] LOCAL_PT_LOAD: base_model_for_config_id: {base_model_for_config_id}") # Logging

            model_config = ModernBertConfig.from_pretrained(
                base_model_for_config_id, 
                num_labels=model_yaml_cfg.get('num_labels', 1), # from config.yaml via model_yaml_cfg
                pooling_strategy=model_yaml_cfg.get('pooling_strategy', 'mean'), # from config.yaml via model_yaml_cfg
                num_weighted_layers=model_yaml_cfg.get('num_weighted_layers', 4) # from config.yaml via model_yaml_cfg
            )
            print(f"[INFERENCE_LOG] LOCAL_PT_LOAD: Loaded ModernBertConfig: {model_config.to_diff_dict()}") # Logging

            print(f"[INFERENCE_LOG] LOCAL_PT_LOAD: Initializing ModernBertForSentiment with this config.") # Logging
            self.model = ModernBertForSentiment(config=model_config)
            
            print(f"[INFERENCE_LOG] LOCAL_PT_LOAD: Loading weights from checkpoint: {local_model_weights_path}") # Logging
            checkpoint = torch.load(local_model_weights_path, map_location=torch.device('cpu'))
            
            state_dict_to_load = checkpoint.get('model_state_dict', checkpoint.get('state_dict', checkpoint))
            if not isinstance(state_dict_to_load, dict):
                raise TypeError(f"Loaded checkpoint from {local_model_weights_path} is not a dict or does not contain 'model_state_dict' or 'state_dict'.")

            # Log first few keys for debugging
            first_few_keys = list(state_dict_to_load.keys())[:5]
            print(f"[INFERENCE_LOG] LOCAL_PT_LOAD: First few keys from checkpoint state_dict: {first_few_keys}") # Logging

            self.model.load_state_dict(state_dict_to_load)
            print(f"[INFERENCE_LOG] LOCAL_PT_LOAD: Weights loaded successfully into ModernBertForSentiment from {local_model_weights_path}.") # Logging
        else:
            # Load from Hugging Face Hub
            print(f"[INFERENCE_LOG] Attempting to load model from HUGGING_FACE_HUB: {model_hf_repo_id}") # Logging
            
            hub_config_params = {
                "num_labels": model_yaml_cfg.get('num_labels', 1),
                "pooling_strategy": model_yaml_cfg.get('pooling_strategy', 'mean'),
                "num_weighted_layers": model_yaml_cfg.get('num_weighted_layers', 6)
            }
            print(f"[INFERENCE_LOG] HUB_LOAD: Parameters to update Hub config: {hub_config_params}") # Logging

            try:
                # Step 1: Load config from Hub, allowing for our custom ModernBertConfig
                config = ModernBertConfig.from_pretrained(model_hf_repo_id)
                # Step 2: Update the loaded config with our specific parameters
                for key, value in hub_config_params.items():
                    setattr(config, key, value)
                print(f"[INFERENCE_LOG] HUB_LOAD: Updated config: {config.to_diff_dict()}")

                # Step 3: Load model with the updated config
                self.model = ModernBertForSentiment.from_pretrained(
                    model_hf_repo_id,
                    config=config
                )
                print(f"[INFERENCE_LOG] HUB_LOAD: Model ModernBertForSentiment loaded successfully from {model_hf_repo_id} with updated config.") # Logging
            except Exception as e:
                print(f"[INFERENCE_LOG] HUB_LOAD: Error loading ModernBertForSentiment from {model_hf_repo_id} with explicit config: {e}") # Logging
                print(f"[INFERENCE_LOG] HUB_LOAD: Falling back to AutoModelForSequenceClassification for {model_hf_repo_id}.") # Logging
                
                # Fallback: Try with AutoModelForSequenceClassification
                # Load its config (could be BertConfig or ModernBertConfig if auto-detected)
                # AutoConfig should ideally resolve to ModernBertConfig if architectures field is set in Hub's config.json
                try:
                    config_fallback = AutoConfig.from_pretrained(model_hf_repo_id)
                    for key, value in hub_config_params.items():
                        setattr(config_fallback, key, value)
                    print(f"[INFERENCE_LOG] HUB_LOAD_FALLBACK: Updated fallback config: {config_fallback.to_diff_dict()}")

                    self.model = AutoModelForSequenceClassification.from_pretrained(
                        model_hf_repo_id,
                        config=config_fallback
                    )
                    print(f"[INFERENCE_LOG] HUB_LOAD_FALLBACK: AutoModelForSequenceClassification loaded for {model_hf_repo_id} with updated config.") # Logging
                except Exception as e_fallback:
                    print(f"[INFERENCE_LOG] HUB_LOAD_FALLBACK: Critical error during fallback load: {e_fallback}")
                    raise e_fallback # Re-raise if fallback also fails catastrophically

        self.model.to(self.device) # Move model to the determined device
        self.model.eval()
        
    def predict(self, text: str) -> Dict[str, Any]:
        inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=self.max_length, padding=True)
        with torch.no_grad():
            outputs = self.model(input_ids=inputs['input_ids'].to(self.device), attention_mask=inputs['attention_mask'].to(self.device))
        logits = outputs.get("logits") # Use .get for safety
        if logits is None:
            raise ValueError("Model output did not contain 'logits'. Check model's forward pass.")
        prob = torch.sigmoid(logits).item()
        return {"sentiment": "positive" if prob > 0.5 else "negative", "confidence": prob}