Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
add full app and model initial test
Browse files- README.md +68 -11
- app.py +105 -4
- classifiers.py +141 -0
- config.yaml +12 -0
- inference.py +79 -0
- models.py +172 -0
- requirements.txt +18 -0
- src/config.yaml +46 -0
- train_utils.py +156 -0
- upload_to_hf.py +110 -0
README.md
CHANGED
@@ -1,14 +1,71 @@
|
|
1 |
---
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
---
|
13 |
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
language: en
|
3 |
+
tags:
|
4 |
+
- sentiment-analysis
|
5 |
+
- modernbert
|
6 |
+
- imdb
|
7 |
+
datasets:
|
8 |
+
- imdb
|
9 |
+
metrics:
|
10 |
+
- accuracy
|
11 |
+
- f1
|
12 |
---
|
13 |
|
14 |
+
# ModernBERT IMDb Sentiment Analysis Model
|
15 |
+
|
16 |
+
## Model Description
|
17 |
+
Fine-tuned ModernBERT model for sentiment analysis on IMDb movie reviews. Achieves 95.75% accuracy on the test set.
|
18 |
+
|
19 |
+
## Usage
|
20 |
+
```python
|
21 |
+
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
22 |
+
|
23 |
+
model = AutoModelForSequenceClassification.from_pretrained("voxmenthe/modernbert-imdb-sentiment")
|
24 |
+
tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base")
|
25 |
+
|
26 |
+
# Input processing
|
27 |
+
inputs = tokenizer("This movie was fantastic!", return_tensors="pt")
|
28 |
+
outputs = model(**inputs)
|
29 |
+
|
30 |
+
# Get the predicted class
|
31 |
+
predicted_class_id = outputs.logits.argmax().item()
|
32 |
+
|
33 |
+
# Convert class ID to label
|
34 |
+
predicted_label = model.config.id2label[predicted_class_id]
|
35 |
+
print(f"Predicted label: {predicted_label}")
|
36 |
+
```
|
37 |
+
|
38 |
+
## Model Card
|
39 |
+
|
40 |
+
### Model Details
|
41 |
+
- **Model Name**: ModernBERT IMDb Sentiment Analysis
|
42 |
+
- **Base Model**: answerdotai/ModernBERT-base
|
43 |
+
- **Task**: Sentiment Analysis
|
44 |
+
- **Dataset**: IMDb Movie Reviews
|
45 |
+
- **Training Epochs**: 5
|
46 |
+
|
47 |
+
### Model Performance
|
48 |
+
- **Test Accuracy**: 95.75%
|
49 |
+
- **Test F1 Score**: 95.75%
|
50 |
+
|
51 |
+
### Model Architecture
|
52 |
+
- **Base Model**: answerdotai/ModernBERT-base
|
53 |
+
- **Task-Specific Head**: ClassifierHead (from `classifiers.py`)
|
54 |
+
- **Number of Labels**: 2 (Positive, Negative)
|
55 |
+
|
56 |
+
### Model Inference
|
57 |
+
- **Input Format**: Text (single review)
|
58 |
+
- **Output Format**: Predicted sentiment label (Positive or Negative)
|
59 |
+
|
60 |
+
### Model Version
|
61 |
+
- **Version**: 1.0
|
62 |
+
- **Date**: 2025-05-07
|
63 |
+
|
64 |
+
### Model License
|
65 |
+
- **License**: MIT License
|
66 |
+
|
67 |
+
### Model Contact
|
68 |
+
- **Contact**: [email protected]
|
69 |
+
|
70 |
+
### Model Citation
|
71 |
+
- **Citation**: voxmenthe/modernbert-imdb-sentiment
|
app.py
CHANGED
@@ -1,7 +1,108 @@
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
|
4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
-
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
from inference import SentimentInference
|
3 |
+
import os
|
4 |
+
from datasets import load_dataset
|
5 |
+
import random
|
6 |
|
7 |
+
# --- Initialize Sentiment Model ---
|
8 |
+
CONFIG_PATH = os.path.join(os.path.dirname(__file__), "config.yaml")
|
9 |
+
if not os.path.exists(CONFIG_PATH):
|
10 |
+
CONFIG_PATH = "config.yaml"
|
11 |
+
if not os.path.exists(CONFIG_PATH):
|
12 |
+
raise FileNotFoundError(
|
13 |
+
f"Configuration file not found. Tried {os.path.join(os.path.dirname(__file__), 'config.yaml')} and {CONFIG_PATH}. "
|
14 |
+
f"Ensure 'config.yaml' exists and is accessible."
|
15 |
+
)
|
16 |
|
17 |
+
print(f"Loading model with config: {CONFIG_PATH}")
|
18 |
+
try:
|
19 |
+
sentiment_inferer = SentimentInference(config_path=CONFIG_PATH)
|
20 |
+
print("Sentiment model loaded successfully.")
|
21 |
+
except Exception as e:
|
22 |
+
print(f"Error loading sentiment model: {e}")
|
23 |
+
sentiment_inferer = None
|
24 |
+
|
25 |
+
# --- Load IMDB Dataset ---
|
26 |
+
print("Loading IMDB dataset for samples...")
|
27 |
+
try:
|
28 |
+
imdb_dataset = load_dataset("imdb", split="test")
|
29 |
+
print("IMDB dataset loaded successfully.")
|
30 |
+
except Exception as e:
|
31 |
+
print(f"Failed to load IMDB dataset: {e}. Sample loading will be disabled.")
|
32 |
+
imdb_dataset = None
|
33 |
+
|
34 |
+
def load_random_imdb_sample():
|
35 |
+
"""Loads a random sample text from the IMDB dataset."""
|
36 |
+
if imdb_dataset is None:
|
37 |
+
return "IMDB dataset not available. Cannot load sample.", None
|
38 |
+
random_index = random.randint(0, len(imdb_dataset) - 1)
|
39 |
+
sample = imdb_dataset[random_index]
|
40 |
+
return sample["text"], sample["label"]
|
41 |
+
|
42 |
+
def predict_sentiment(text_input, true_label_state):
|
43 |
+
"""Predicts sentiment for the given text_input."""
|
44 |
+
if sentiment_inferer is None:
|
45 |
+
return "Error: Sentiment model could not be loaded. Please check the logs.", true_label_state
|
46 |
+
|
47 |
+
if not text_input or not text_input.strip():
|
48 |
+
return "Please enter some text for analysis.", true_label_state
|
49 |
+
|
50 |
+
try:
|
51 |
+
prediction = sentiment_inferer.predict(text_input)
|
52 |
+
sentiment = prediction['sentiment']
|
53 |
+
|
54 |
+
# Convert numerical label to text if available
|
55 |
+
true_sentiment = None
|
56 |
+
if true_label_state is not None:
|
57 |
+
true_sentiment = "positive" if true_label_state == 1 else "negative"
|
58 |
+
|
59 |
+
result = f"Predicted Sentiment: {sentiment.capitalize()}"
|
60 |
+
if true_sentiment:
|
61 |
+
result += f"\nTrue IMDB Label: {true_sentiment.capitalize()}"
|
62 |
+
|
63 |
+
return result, None # Reset true label state after display
|
64 |
+
|
65 |
+
except Exception as e:
|
66 |
+
print(f"Error during prediction: {e}")
|
67 |
+
return f"Error during prediction: {str(e)}", true_label_state
|
68 |
+
|
69 |
+
# --- Gradio Interface ---
|
70 |
+
with gr.Blocks() as demo:
|
71 |
+
true_label = gr.State()
|
72 |
+
|
73 |
+
gr.Markdown("## IMDb Sentiment Analyzer")
|
74 |
+
gr.Markdown("Enter a movie review to classify its sentiment as Positive or Negative, or load a random sample from the IMDb dataset.")
|
75 |
+
|
76 |
+
with gr.Row():
|
77 |
+
input_textbox = gr.Textbox(lines=7, placeholder="Enter movie review here...", label="Movie Review", scale=3)
|
78 |
+
output_text = gr.Text(label="Analysis Result", scale=1)
|
79 |
+
|
80 |
+
with gr.Row():
|
81 |
+
submit_button = gr.Button("Analyze Sentiment")
|
82 |
+
load_sample_button = gr.Button("Load Random IMDB Sample")
|
83 |
+
|
84 |
+
gr.Examples(
|
85 |
+
examples=[
|
86 |
+
["This movie was absolutely fantastic! The acting was superb and the plot was gripping."],
|
87 |
+
["I was really disappointed with this film. It was boring and the story made no sense."],
|
88 |
+
["An average movie, had some good parts but overall quite forgettable."],
|
89 |
+
["Wow so I don't think I've ever seen a movie quite like that. The plot was... interesting, and the acting was, well, hmm."]
|
90 |
+
],
|
91 |
+
inputs=input_textbox
|
92 |
+
)
|
93 |
+
|
94 |
+
# Wire actions
|
95 |
+
submit_button.click(
|
96 |
+
fn=predict_sentiment,
|
97 |
+
inputs=[input_textbox, true_label],
|
98 |
+
outputs=[output_text, true_label]
|
99 |
+
)
|
100 |
+
load_sample_button.click(
|
101 |
+
fn=load_random_imdb_sample,
|
102 |
+
inputs=None,
|
103 |
+
outputs=[input_textbox, true_label]
|
104 |
+
)
|
105 |
+
|
106 |
+
if __name__ == '__main__':
|
107 |
+
print("Launching Gradio interface...")
|
108 |
+
demo.launch(share=False)
|
classifiers.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
import torch
|
3 |
+
|
4 |
+
|
5 |
+
class ClassifierHead(nn.Module):
|
6 |
+
"""Basically a fancy MLP: 3-layer classifier head with GELU, LayerNorm, and Skip Connections."""
|
7 |
+
def __init__(self, hidden_size, num_labels, dropout_prob):
|
8 |
+
super().__init__()
|
9 |
+
# Layer 1
|
10 |
+
self.dense1 = nn.Linear(hidden_size, hidden_size)
|
11 |
+
self.norm1 = nn.LayerNorm(hidden_size)
|
12 |
+
self.activation = nn.GELU()
|
13 |
+
self.dropout1 = nn.Dropout(dropout_prob)
|
14 |
+
|
15 |
+
# Layer 2
|
16 |
+
self.dense2 = nn.Linear(hidden_size, hidden_size)
|
17 |
+
self.norm2 = nn.LayerNorm(hidden_size)
|
18 |
+
self.dropout2 = nn.Dropout(dropout_prob)
|
19 |
+
|
20 |
+
# Output Layer
|
21 |
+
self.out_proj = nn.Linear(hidden_size, num_labels)
|
22 |
+
|
23 |
+
def forward(self, features):
|
24 |
+
# Layer 1
|
25 |
+
identity1 = features
|
26 |
+
x = self.norm1(features)
|
27 |
+
x = self.dense1(x)
|
28 |
+
x = self.activation(x)
|
29 |
+
x = self.dropout1(x)
|
30 |
+
x = x + identity1 # skip connection
|
31 |
+
|
32 |
+
# Layer 2
|
33 |
+
identity2 = x
|
34 |
+
x = self.norm2(x)
|
35 |
+
x = self.dense2(x)
|
36 |
+
x = self.activation(x)
|
37 |
+
x = self.dropout2(x)
|
38 |
+
x = x + identity2 # skip connection
|
39 |
+
|
40 |
+
# Output Layer
|
41 |
+
logits = self.out_proj(x)
|
42 |
+
return logits
|
43 |
+
|
44 |
+
|
45 |
+
class ConcatClassifierHead(nn.Module):
|
46 |
+
"""
|
47 |
+
An enhanced classifier head designed for concatenated CLS + Mean Pooling input.
|
48 |
+
Includes an initial projection layer before the standard enhanced block.
|
49 |
+
"""
|
50 |
+
def __init__(self, input_size, hidden_size, num_labels, dropout_prob):
|
51 |
+
super().__init__()
|
52 |
+
# Initial projection from concatenated size (2*hidden) down to hidden_size
|
53 |
+
self.initial_projection = nn.Linear(input_size, hidden_size)
|
54 |
+
self.initial_norm = nn.LayerNorm(hidden_size) # Norm after projection
|
55 |
+
self.initial_activation = nn.GELU()
|
56 |
+
self.initial_dropout = nn.Dropout(dropout_prob)
|
57 |
+
|
58 |
+
# Layer 1
|
59 |
+
self.dense1 = nn.Linear(hidden_size, hidden_size)
|
60 |
+
self.norm1 = nn.LayerNorm(hidden_size)
|
61 |
+
self.activation = nn.GELU()
|
62 |
+
self.dropout1 = nn.Dropout(dropout_prob)
|
63 |
+
|
64 |
+
# Layer 2
|
65 |
+
self.dense2 = nn.Linear(hidden_size, hidden_size)
|
66 |
+
self.norm2 = nn.LayerNorm(hidden_size)
|
67 |
+
self.dropout2 = nn.Dropout(dropout_prob)
|
68 |
+
|
69 |
+
# Output Layer
|
70 |
+
self.out_proj = nn.Linear(hidden_size, num_labels)
|
71 |
+
|
72 |
+
def forward(self, features):
|
73 |
+
# Initial Projection Step
|
74 |
+
x = self.initial_projection(features)
|
75 |
+
x = self.initial_norm(x)
|
76 |
+
x = self.initial_activation(x)
|
77 |
+
x = self.initial_dropout(x)
|
78 |
+
# x should now be of shape (batch_size, hidden_size)
|
79 |
+
|
80 |
+
# Layer 1 + Skip
|
81 |
+
identity1 = x # Skip connection starts after initial projection
|
82 |
+
x_res = self.norm1(x)
|
83 |
+
x_res = self.dense1(x_res)
|
84 |
+
x_res = self.activation(x_res)
|
85 |
+
x_res = self.dropout1(x_res)
|
86 |
+
x = x + x_res # skip connection
|
87 |
+
|
88 |
+
# Layer 2 + Skip
|
89 |
+
identity2 = x
|
90 |
+
x_res = self.norm2(x)
|
91 |
+
x_res = self.dense2(x_res)
|
92 |
+
x_res = self.activation(x_res)
|
93 |
+
x_res = self.dropout2(x_res)
|
94 |
+
x = x + x_res # skip connection
|
95 |
+
|
96 |
+
# Output Layer
|
97 |
+
logits = self.out_proj(x)
|
98 |
+
return logits
|
99 |
+
|
100 |
+
|
101 |
+
# ExpansionClassifierHead currently not used
|
102 |
+
class ExpansionClassifierHead(nn.Module):
|
103 |
+
"""
|
104 |
+
A classifier head using FFN-style expansion (input -> 4*hidden -> hidden -> labels).
|
105 |
+
Takes concatenated CLS + Mean Pooled features as input.
|
106 |
+
"""
|
107 |
+
def __init__(self, input_size, hidden_size, num_labels, dropout_prob):
|
108 |
+
super().__init__()
|
109 |
+
intermediate_size = hidden_size * 4 # FFN expansion factor
|
110 |
+
|
111 |
+
# Layer 1 (Expansion)
|
112 |
+
self.norm1 = nn.LayerNorm(input_size)
|
113 |
+
self.dense1 = nn.Linear(input_size, intermediate_size)
|
114 |
+
self.activation = nn.GELU()
|
115 |
+
self.dropout1 = nn.Dropout(dropout_prob)
|
116 |
+
|
117 |
+
# Layer 2 (Projection back down)
|
118 |
+
self.norm2 = nn.LayerNorm(intermediate_size)
|
119 |
+
self.dense2 = nn.Linear(intermediate_size, hidden_size)
|
120 |
+
# Activation and Dropout applied after projection
|
121 |
+
self.dropout2 = nn.Dropout(dropout_prob)
|
122 |
+
|
123 |
+
# Output Layer
|
124 |
+
self.out_proj = nn.Linear(hidden_size, num_labels)
|
125 |
+
|
126 |
+
def forward(self, features):
|
127 |
+
# Layer 1
|
128 |
+
x = self.norm1(features)
|
129 |
+
x = self.dense1(x)
|
130 |
+
x = self.activation(x)
|
131 |
+
x = self.dropout1(x)
|
132 |
+
|
133 |
+
# Layer 2
|
134 |
+
x = self.norm2(x)
|
135 |
+
x = self.dense2(x)
|
136 |
+
x = self.activation(x)
|
137 |
+
x = self.dropout2(x)
|
138 |
+
|
139 |
+
# Output Layer
|
140 |
+
logits = self.out_proj(x)
|
141 |
+
return logits
|
config.yaml
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
name: "voxmenthe/modernbert-imdb-sentiment"
|
3 |
+
output_dir: "checkpoints"
|
4 |
+
max_length: 880 # 256
|
5 |
+
dropout: 0.1
|
6 |
+
pooling_strategy: "mean" # Current default, change as needed
|
7 |
+
|
8 |
+
inference:
|
9 |
+
# Default path, can be overridden
|
10 |
+
model_path: "checkpoints/mean_epoch5_0.9575acc_0.9575f1.pt"
|
11 |
+
# Using the same max_length as training for consistency
|
12 |
+
max_length: 880 # 256
|
inference.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
3 |
+
from models import ModernBertForSentiment
|
4 |
+
from transformers import ModernBertConfig
|
5 |
+
from typing import Dict, Any
|
6 |
+
import yaml
|
7 |
+
import os
|
8 |
+
|
9 |
+
|
10 |
+
class SentimentInference:
|
11 |
+
def __init__(self, config_path: str = "config.yaml"):
|
12 |
+
"""Load configuration and initialize model and tokenizer."""
|
13 |
+
with open(config_path, 'r') as f:
|
14 |
+
config = yaml.safe_load(f)
|
15 |
+
|
16 |
+
model_cfg = config.get('model', {})
|
17 |
+
inference_cfg = config.get('inference', {})
|
18 |
+
|
19 |
+
# Path to the .pt model weights file
|
20 |
+
model_weights_path = inference_cfg.get('model_path',
|
21 |
+
os.path.join(model_cfg.get('output_dir', 'checkpoints'), 'best_model.pt'))
|
22 |
+
|
23 |
+
# Base model name from config (e.g., 'answerdotai/ModernBERT-base')
|
24 |
+
# This will be used for loading both tokenizer and base BERT config from Hugging Face Hub
|
25 |
+
base_model_name = model_cfg.get('name', 'answerdotai/ModernBERT-base')
|
26 |
+
|
27 |
+
self.max_length = inference_cfg.get('max_length', model_cfg.get('max_length', 256))
|
28 |
+
|
29 |
+
# Load tokenizer from the base model name (e.g., from Hugging Face Hub)
|
30 |
+
print(f"Loading tokenizer from: {base_model_name}")
|
31 |
+
self.tokenizer = AutoTokenizer.from_pretrained(base_model_name)
|
32 |
+
|
33 |
+
# Load base BERT config from the base model name
|
34 |
+
print(f"Loading ModernBertConfig from: {base_model_name}")
|
35 |
+
bert_config = ModernBertConfig.from_pretrained(base_model_name)
|
36 |
+
|
37 |
+
# --- Apply any necessary overrides from your config to the loaded bert_config ---
|
38 |
+
# For example, if your ModernBertForSentiment expects specific config values beyond the base BERT model.
|
39 |
+
# Your current ModernBertForSentiment takes the entire config object, which might implicitly carry these.
|
40 |
+
# However, explicitly setting them on bert_config loaded from HF is safer if they are architecturally relevant.
|
41 |
+
bert_config.classifier_dropout = model_cfg.get('dropout', bert_config.classifier_dropout) # Example
|
42 |
+
# Ensure num_labels is set if your inference model needs it (usually for HF pipeline, less so for manual predict)
|
43 |
+
# bert_config.num_labels = model_cfg.get('num_labels', 1) # Typically 1 for binary sentiment regression-style output
|
44 |
+
|
45 |
+
# It's also important that pooling_strategy and num_weighted_layers are set on the config object
|
46 |
+
# that ModernBertForSentiment receives, as it uses these to build its layers.
|
47 |
+
# These are usually fine-tuning specific, not part of the base HF config, so they should come from your model_cfg.
|
48 |
+
bert_config.pooling_strategy = model_cfg.get('pooling_strategy', 'cls')
|
49 |
+
bert_config.num_weighted_layers = model_cfg.get('num_weighted_layers', 4)
|
50 |
+
bert_config.loss_function = model_cfg.get('loss_function', {'name': 'SentimentWeightedLoss', 'params': {}}) # Needed by model init
|
51 |
+
# Ensure num_labels is explicitly set for the model's classifier head
|
52 |
+
bert_config.num_labels = 1 # For sentiment (positive/negative) often treated as 1 logit output
|
53 |
+
|
54 |
+
print("Instantiating ModernBertForSentiment model structure...")
|
55 |
+
self.model = ModernBertForSentiment(bert_config)
|
56 |
+
|
57 |
+
print(f"Loading model weights from local checkpoint: {model_weights_path}")
|
58 |
+
# Load the entire checkpoint dictionary first
|
59 |
+
checkpoint = torch.load(model_weights_path, map_location=torch.device('cpu'))
|
60 |
+
|
61 |
+
# Extract the model_state_dict from the checkpoint
|
62 |
+
# This handles the case where the checkpoint saves more than just the model weights (e.g., optimizer state, epoch)
|
63 |
+
if 'model_state_dict' in checkpoint:
|
64 |
+
model_state_to_load = checkpoint['model_state_dict']
|
65 |
+
else:
|
66 |
+
# If the checkpoint is just the state_dict itself (older format or different saving convention)
|
67 |
+
model_state_to_load = checkpoint
|
68 |
+
|
69 |
+
self.model.load_state_dict(model_state_to_load)
|
70 |
+
self.model.eval()
|
71 |
+
print("Model loaded successfully.")
|
72 |
+
|
73 |
+
def predict(self, text: str) -> Dict[str, Any]:
|
74 |
+
inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=self.max_length)
|
75 |
+
with torch.no_grad():
|
76 |
+
outputs = self.model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
|
77 |
+
logits = outputs["logits"]
|
78 |
+
prob = torch.sigmoid(logits).item()
|
79 |
+
return {"sentiment": "positive" if prob > 0.5 else "negative", "confidence": prob}
|
models.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import ModernBertModel, ModernBertPreTrainedModel
|
2 |
+
from transformers.modeling_outputs import SequenceClassifierOutput
|
3 |
+
from torch import nn
|
4 |
+
import torch
|
5 |
+
from train_utils import SentimentWeightedLoss, SentimentFocalLoss
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
from classifiers import ClassifierHead, ConcatClassifierHead
|
9 |
+
|
10 |
+
|
11 |
+
class ModernBertForSentiment(ModernBertPreTrainedModel):
|
12 |
+
"""ModernBERT encoder with a dynamically configurable classification head and pooling strategy."""
|
13 |
+
|
14 |
+
def __init__(self, config):
|
15 |
+
super().__init__(config)
|
16 |
+
self.num_labels = config.num_labels
|
17 |
+
self.bert = ModernBertModel(config) # Base BERT model, config may have output_hidden_states=True
|
18 |
+
|
19 |
+
# Store pooling strategy from config
|
20 |
+
self.pooling_strategy = getattr(config, 'pooling_strategy', 'mean')
|
21 |
+
self.num_weighted_layers = getattr(config, 'num_weighted_layers', 4)
|
22 |
+
|
23 |
+
if self.pooling_strategy in ['weighted_layer', 'cls_weighted_concat'] and not config.output_hidden_states:
|
24 |
+
# This check is more of an assertion; train.py should set output_hidden_states=True
|
25 |
+
raise ValueError(
|
26 |
+
"output_hidden_states must be True in BertConfig for weighted_layer pooling."
|
27 |
+
)
|
28 |
+
|
29 |
+
# Initialize weights for weighted layer pooling
|
30 |
+
if self.pooling_strategy in ['weighted_layer', 'cls_weighted_concat']:
|
31 |
+
# num_weighted_layers specifies how many *top* layers of BERT to use.
|
32 |
+
# If num_weighted_layers is e.g. 4, we use the last 4 layers.
|
33 |
+
self.layer_weights = nn.Parameter(torch.ones(self.num_weighted_layers) / self.num_weighted_layers)
|
34 |
+
|
35 |
+
# Determine classifier input size and choose head
|
36 |
+
classifier_input_size = config.hidden_size
|
37 |
+
if self.pooling_strategy in ['cls_mean_concat', 'cls_weighted_concat']:
|
38 |
+
classifier_input_size = config.hidden_size * 2
|
39 |
+
|
40 |
+
# Dropout for features fed into the classifier head
|
41 |
+
classifier_dropout_prob = (
|
42 |
+
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
|
43 |
+
)
|
44 |
+
self.features_dropout = nn.Dropout(classifier_dropout_prob)
|
45 |
+
|
46 |
+
# Select the appropriate classifier head based on input feature dimension
|
47 |
+
if classifier_input_size == config.hidden_size:
|
48 |
+
self.classifier = ClassifierHead(
|
49 |
+
hidden_size=config.hidden_size, # input_size for ClassifierHead is just hidden_size
|
50 |
+
num_labels=config.num_labels,
|
51 |
+
dropout_prob=classifier_dropout_prob
|
52 |
+
)
|
53 |
+
elif classifier_input_size == config.hidden_size * 2:
|
54 |
+
self.classifier = ConcatClassifierHead(
|
55 |
+
input_size=config.hidden_size * 2,
|
56 |
+
hidden_size=config.hidden_size, # Internal hidden size of the head
|
57 |
+
num_labels=config.num_labels,
|
58 |
+
dropout_prob=classifier_dropout_prob
|
59 |
+
)
|
60 |
+
else:
|
61 |
+
# This case should ideally not be reached with current strategies
|
62 |
+
raise ValueError(f"Unexpected classifier_input_size: {classifier_input_size}")
|
63 |
+
|
64 |
+
# Initialize loss function based on config
|
65 |
+
loss_config = getattr(config, 'loss_function', {'name': 'SentimentWeightedLoss', 'params': {}})
|
66 |
+
loss_name = loss_config.get('name', 'SentimentWeightedLoss')
|
67 |
+
loss_params = loss_config.get('params', {})
|
68 |
+
|
69 |
+
if loss_name == "SentimentWeightedLoss":
|
70 |
+
self.loss_fct = SentimentWeightedLoss() # SentimentWeightedLoss takes no arguments
|
71 |
+
elif loss_name == "SentimentFocalLoss":
|
72 |
+
# Ensure only relevant params are passed, or that loss_params is structured correctly for SentimentFocalLoss
|
73 |
+
# For SentimentFocalLoss, expected params are 'gamma_focal' and 'label_smoothing_epsilon'
|
74 |
+
self.loss_fct = SentimentFocalLoss(**loss_params)
|
75 |
+
else:
|
76 |
+
raise ValueError(f"Unsupported loss function: {loss_name}")
|
77 |
+
|
78 |
+
self.post_init() # Initialize weights and apply final processing
|
79 |
+
|
80 |
+
def _mean_pool(self, last_hidden_state, attention_mask):
|
81 |
+
if attention_mask is None:
|
82 |
+
attention_mask = torch.ones_like(last_hidden_state[:, :, 0]) # Assuming first dim of last hidden state is token ids
|
83 |
+
input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
|
84 |
+
sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, 1)
|
85 |
+
sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9)
|
86 |
+
return sum_embeddings / sum_mask
|
87 |
+
|
88 |
+
def _weighted_layer_pool(self, all_hidden_states):
|
89 |
+
# all_hidden_states includes embeddings + output of each layer.
|
90 |
+
# We want the outputs of the last num_weighted_layers.
|
91 |
+
# Example: 12 layers -> all_hidden_states have 13 items (embeddings + 12 layers)
|
92 |
+
# num_weighted_layers = 4 -> use layers 9, 10, 11, 12 (indices -4, -3, -2, -1)
|
93 |
+
layers_to_weigh = torch.stack(all_hidden_states[-self.num_weighted_layers:], dim=0)
|
94 |
+
# layers_to_weigh shape: (num_weighted_layers, batch_size, sequence_length, hidden_size)
|
95 |
+
|
96 |
+
# Normalize weights to sum to 1 (softmax or simple division)
|
97 |
+
normalized_weights = F.softmax(self.layer_weights, dim=-1)
|
98 |
+
|
99 |
+
# Weighted sum across layers
|
100 |
+
# Reshape weights for broadcasting: (num_weighted_layers, 1, 1, 1)
|
101 |
+
weighted_hidden_states = layers_to_weigh * normalized_weights.view(-1, 1, 1, 1)
|
102 |
+
weighted_sum_hidden_states = torch.sum(weighted_hidden_states, dim=0)
|
103 |
+
# weighted_sum_hidden_states shape: (batch_size, sequence_length, hidden_size)
|
104 |
+
|
105 |
+
# Pool the result (e.g., take [CLS] token of this weighted sum)
|
106 |
+
return weighted_sum_hidden_states[:, 0] # Return CLS token of the weighted sum
|
107 |
+
|
108 |
+
def forward(
|
109 |
+
self,
|
110 |
+
input_ids=None,
|
111 |
+
attention_mask=None,
|
112 |
+
labels=None,
|
113 |
+
lengths=None,
|
114 |
+
return_dict=None,
|
115 |
+
**kwargs
|
116 |
+
):
|
117 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
118 |
+
|
119 |
+
bert_outputs = self.bert(
|
120 |
+
input_ids,
|
121 |
+
attention_mask=attention_mask,
|
122 |
+
return_dict=return_dict,
|
123 |
+
output_hidden_states=self.config.output_hidden_states # Controlled by train.py
|
124 |
+
)
|
125 |
+
|
126 |
+
last_hidden_state = bert_outputs[0] # Or bert_outputs.last_hidden_state
|
127 |
+
pooled_features = None
|
128 |
+
|
129 |
+
if self.pooling_strategy == 'cls':
|
130 |
+
pooled_features = last_hidden_state[:, 0] # CLS token
|
131 |
+
elif self.pooling_strategy == 'mean':
|
132 |
+
pooled_features = self._mean_pool(last_hidden_state, attention_mask)
|
133 |
+
elif self.pooling_strategy == 'cls_mean_concat':
|
134 |
+
cls_output = last_hidden_state[:, 0]
|
135 |
+
mean_output = self._mean_pool(last_hidden_state, attention_mask)
|
136 |
+
pooled_features = torch.cat((cls_output, mean_output), dim=1)
|
137 |
+
elif self.pooling_strategy == 'weighted_layer':
|
138 |
+
if not self.config.output_hidden_states or bert_outputs.hidden_states is None:
|
139 |
+
raise ValueError("Weighted layer pooling requires output_hidden_states=True and hidden_states in BERT output.")
|
140 |
+
all_hidden_states = bert_outputs.hidden_states
|
141 |
+
pooled_features = self._weighted_layer_pool(all_hidden_states)
|
142 |
+
elif self.pooling_strategy == 'cls_weighted_concat':
|
143 |
+
if not self.config.output_hidden_states or bert_outputs.hidden_states is None:
|
144 |
+
raise ValueError("Weighted layer pooling requires output_hidden_states=True and hidden_states in BERT output.")
|
145 |
+
cls_output = last_hidden_state[:, 0]
|
146 |
+
all_hidden_states = bert_outputs.hidden_states
|
147 |
+
weighted_output = self._weighted_layer_pool(all_hidden_states)
|
148 |
+
pooled_features = torch.cat((cls_output, weighted_output), dim=1)
|
149 |
+
else:
|
150 |
+
raise ValueError(f"Unknown pooling_strategy: {self.pooling_strategy}")
|
151 |
+
|
152 |
+
pooled_features = self.features_dropout(pooled_features)
|
153 |
+
logits = self.classifier(pooled_features)
|
154 |
+
|
155 |
+
loss = None
|
156 |
+
if labels is not None:
|
157 |
+
if lengths is None:
|
158 |
+
raise ValueError("lengths must be provided when labels are specified for loss calculation.")
|
159 |
+
loss = self.loss_fct(logits.squeeze(-1), labels, lengths)
|
160 |
+
|
161 |
+
if not return_dict:
|
162 |
+
# Ensure 'outputs' from BERT is appropriately handled. If it's a tuple:
|
163 |
+
bert_model_outputs = bert_outputs[1:] if isinstance(bert_outputs, tuple) else (bert_outputs.hidden_states, bert_outputs.attentions)
|
164 |
+
output = (logits,) + bert_model_outputs
|
165 |
+
return ((loss,) + output) if loss is not None else output
|
166 |
+
|
167 |
+
return SequenceClassifierOutput(
|
168 |
+
loss=loss,
|
169 |
+
logits=logits,
|
170 |
+
hidden_states=bert_outputs.hidden_states,
|
171 |
+
attentions=bert_outputs.attentions,
|
172 |
+
)
|
requirements.txt
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
ipykernel
|
3 |
+
ipywidgets
|
4 |
+
tqdm
|
5 |
+
kagglehub
|
6 |
+
transformers>=4.51.3,<5.0.0
|
7 |
+
torch>=2.7.0,<2.8.0
|
8 |
+
datasets>=2.16.1,<2.17.0
|
9 |
+
markdown>=3.7.0,<4.0.0
|
10 |
+
matplotlib>=3.9.0,<4.0.0
|
11 |
+
notebook>=7.2.0,<8.0.0
|
12 |
+
numpy>=2.1.0,<3.0.0
|
13 |
+
pandas>=2.2.0,<3.0.0
|
14 |
+
python-json-logger>=2.0.7,<3.0.0
|
15 |
+
requests>=2.27.1,<3.0.0
|
16 |
+
scikit-learn>=1.5.0
|
17 |
+
seaborn>=0.13.0
|
18 |
+
weasyprint
|
src/config.yaml
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
name: "voxmenthe/modernbert-imdb-sentiment"
|
3 |
+
loss_function:
|
4 |
+
name: "SentimentWeightedLoss" # Options: "SentimentWeightedLoss", "SentimentFocalLoss"
|
5 |
+
# Parameters for the chosen loss function.
|
6 |
+
# For SentimentFocalLoss, common params are:
|
7 |
+
# gamma_focal: 1.0 # (e.g., 2.0 for standard, -2.0 for reversed, 0 for none)
|
8 |
+
# label_smoothing_epsilon: 0.05 # (e.g., 0.0 to 0.1)
|
9 |
+
# For SentimentWeightedLoss, params is empty:
|
10 |
+
params:
|
11 |
+
gamma_focal: 1.0
|
12 |
+
label_smoothing_epsilon: 0.05
|
13 |
+
output_dir: "checkpoints"
|
14 |
+
max_length: 880 # 256
|
15 |
+
dropout: 0.1
|
16 |
+
# --- Pooling Strategy --- #
|
17 |
+
# Options: "cls", "mean", "cls_mean_concat", "weighted_layer", "cls_weighted_concat"
|
18 |
+
# "cls" uses just the [CLS] token for classification
|
19 |
+
# "mean" uses mean pooling over final hidden states for classification
|
20 |
+
# "cls_mean_concat" uses both [CLS] and mean pooling over final hidden states for classification
|
21 |
+
# "weighted_layer" uses a weighted combination of the final hidden states from the top N layers for classification
|
22 |
+
# "cls_weighted_concat" uses a weighted combination of the final hidden states from the top N layers and the [CLS] token for classification
|
23 |
+
|
24 |
+
pooling_strategy: "mean" # Current default, change as needed
|
25 |
+
|
26 |
+
num_weighted_layers: 6 # Number of top BERT layers to use for 'weighted_layer' strategies (e.g., 1 to 12 for BERT-base)
|
27 |
+
|
28 |
+
data:
|
29 |
+
# No specific data paths needed as we use HF datasets at the moment
|
30 |
+
|
31 |
+
training:
|
32 |
+
epochs: 6
|
33 |
+
batch_size: 16
|
34 |
+
lr: 1e-5 # 1e-5 # 2.0e-5
|
35 |
+
weight_decay_rate: 0.02 # 0.01
|
36 |
+
resume_from_checkpoint: "" # "checkpoints/mean_epoch2_0.9361acc_0.9355f1.pt" # Path to checkpoint file, or empty to not resume
|
37 |
+
|
38 |
+
inference:
|
39 |
+
# Default path, can be overridden
|
40 |
+
model_path: "checkpoints/mean_epoch5_0.9575acc_0.9575f1.pt"
|
41 |
+
# Using the same max_length as training for consistency
|
42 |
+
max_length: 880 # 256
|
43 |
+
|
44 |
+
|
45 |
+
# "answerdotai/ModernBERT-base"
|
46 |
+
# "answerdotai/ModernBERT-large"
|
train_utils.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from torch import nn
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
|
7 |
+
class SentimentWeightedLoss(nn.Module):
|
8 |
+
"""BCEWithLogits + dynamic weighting.
|
9 |
+
|
10 |
+
We weight each sample by:
|
11 |
+
• length_weight: sqrt(num_tokens) / sqrt(max_tokens)
|
12 |
+
• confidence_weight: |sigmoid(logits) - 0.5| (higher confidence ⇒ larger weight)
|
13 |
+
|
14 |
+
The two weights are combined multiplicatively then normalized.
|
15 |
+
"""
|
16 |
+
|
17 |
+
def __init__(self):
|
18 |
+
super().__init__()
|
19 |
+
# Initialize BCE loss without reduction, since we're applying per-sample weights
|
20 |
+
self.bce = nn.BCEWithLogitsLoss(reduction="none")
|
21 |
+
self.min_len_weight_sqrt = 0.1 # Minimum length weight
|
22 |
+
|
23 |
+
def forward(self, logits, targets, lengths):
|
24 |
+
base_loss = self.bce(logits.view(-1), targets.float()) # shape [B]
|
25 |
+
|
26 |
+
prob = torch.sigmoid(logits.view(-1))
|
27 |
+
confidence_weight = (prob - 0.5).abs() * 2 # ∈ [0,1]
|
28 |
+
|
29 |
+
if lengths.numel() == 0:
|
30 |
+
# Handle empty batch: return 0.0 loss or mean of base_loss if it's also empty (becomes nan then)
|
31 |
+
# If base_loss on empty input is empty tensor, mean is nan. So return 0.0 is safer.
|
32 |
+
return torch.tensor(0.0, device=logits.device, requires_grad=logits.requires_grad)
|
33 |
+
|
34 |
+
length_weight = torch.sqrt(lengths.float()) / math.sqrt(lengths.max().item())
|
35 |
+
length_weight = length_weight.clamp(self.min_len_weight_sqrt, 1.0) # Clamp to avoid extreme weights
|
36 |
+
|
37 |
+
weights = confidence_weight * length_weight
|
38 |
+
weights = weights / (weights.mean() + 1e-8) # normalize so E[w]=1
|
39 |
+
return (base_loss * weights).mean()
|
40 |
+
|
41 |
+
|
42 |
+
|
43 |
+
|
44 |
+
class SentimentFocalLoss(nn.Module):
|
45 |
+
"""
|
46 |
+
This loss function incorporates:
|
47 |
+
1. Base BCEWithLogitsLoss.
|
48 |
+
2. Label Smoothing.
|
49 |
+
3. Focal Loss modulation to focus more on hard examples (can be reversed to focus on easy examples).
|
50 |
+
4. Sample weighting based on review length.
|
51 |
+
5. Sample weighting based on prediction confidence.
|
52 |
+
|
53 |
+
The final loss for each sample is calculated roughly as:
|
54 |
+
Loss_sample = FocalModulator(pt, gamma) * BCE(logits, smoothed_targets) * NormalizedExternalWeight
|
55 |
+
NormalizedExternalWeight = (ConfidenceWeight * LengthWeight) / Mean(ConfidenceWeight * LengthWeight)
|
56 |
+
"""
|
57 |
+
|
58 |
+
def __init__(self, gamma_focal: float = 0.1, label_smoothing_epsilon: float = 0.05):
|
59 |
+
"""
|
60 |
+
Args:
|
61 |
+
gamma_focal (float): Gamma parameter for Focal Loss.
|
62 |
+
- If gamma_focal > 0 (e.g., 2.0), applies standard Focal Loss,
|
63 |
+
down-weighting easy examples (focus on hard examples).
|
64 |
+
- If gamma_focal < 0 (e.g., -2.0), applies a reversed Focal Loss,
|
65 |
+
down-weighting hard examples (focus on easy examples by up-weighting pt).
|
66 |
+
- If gamma_focal = 0, no Focal Loss modulation is applied.
|
67 |
+
label_smoothing_epsilon (float): Epsilon for label smoothing. (0.0 <= epsilon < 1.0)
|
68 |
+
- If 0.0, no label smoothing is applied. Converts hard labels (0, 1)
|
69 |
+
to soft labels (epsilon, 1-epsilon).
|
70 |
+
"""
|
71 |
+
super().__init__()
|
72 |
+
if not (0.0 <= label_smoothing_epsilon < 1.0):
|
73 |
+
raise ValueError("label_smoothing_epsilon must be between 0.0 and <1.0.")
|
74 |
+
|
75 |
+
self.gamma_focal = gamma_focal
|
76 |
+
self.label_smoothing_epsilon = label_smoothing_epsilon
|
77 |
+
# Initialize BCE loss without reduction, since we're applying per-sample weights
|
78 |
+
self.bce_loss_no_reduction = nn.BCEWithLogitsLoss(reduction="none")
|
79 |
+
|
80 |
+
def forward(self, logits: torch.Tensor, targets: torch.Tensor, lengths: torch.Tensor) -> torch.Tensor:
|
81 |
+
"""
|
82 |
+
Computes the custom loss.
|
83 |
+
|
84 |
+
Args:
|
85 |
+
logits (torch.Tensor): Raw logits from the model. Expected shape [B] or [B, 1].
|
86 |
+
targets (torch.Tensor): Ground truth labels (0 or 1). Expected shape [B] or [B, 1].
|
87 |
+
lengths (torch.Tensor): Number of tokens in each review. Expected shape [B].
|
88 |
+
|
89 |
+
Returns:
|
90 |
+
torch.Tensor: The computed scalar loss.
|
91 |
+
"""
|
92 |
+
B = logits.size(0)
|
93 |
+
if B == 0: # Handle empty batch case
|
94 |
+
return torch.tensor(0.0, device=logits.device, requires_grad=True)
|
95 |
+
|
96 |
+
logits_flat = logits.view(-1)
|
97 |
+
original_targets_flat = targets.view(-1).float() # Ensure targets are float
|
98 |
+
|
99 |
+
# 1. Label Smoothing
|
100 |
+
if self.label_smoothing_epsilon > 0:
|
101 |
+
# Smooth 1 to (1 - epsilon), and 0 to epsilon
|
102 |
+
targets_for_bce = original_targets_flat * (1.0 - self.label_smoothing_epsilon) + \
|
103 |
+
(1.0 - original_targets_flat) * self.label_smoothing_epsilon
|
104 |
+
else:
|
105 |
+
targets_for_bce = original_targets_flat
|
106 |
+
|
107 |
+
# 2. Calculate Base BCE loss terms (using potentially smoothed targets)
|
108 |
+
base_bce_loss_terms = self.bce_loss_no_reduction(logits_flat, targets_for_bce)
|
109 |
+
|
110 |
+
# 3. Focal Loss Modulation Component
|
111 |
+
# For the focal modulator, 'pt' is the probability assigned by the model to the *original* ground truth class.
|
112 |
+
probs = torch.sigmoid(logits_flat)
|
113 |
+
# pt: probability of the original true class
|
114 |
+
pt = torch.where(original_targets_flat.bool(), probs, 1.0 - probs)
|
115 |
+
|
116 |
+
focal_modulator = torch.ones_like(pt) # Default to 1 (no modulation if gamma_focal is 0)
|
117 |
+
if self.gamma_focal > 0: # Standard Focal Loss: (1-pt)^gamma. Focus on hard examples (pt is small).
|
118 |
+
focal_modulator = (1.0 - pt + 1e-8).pow(self.gamma_focal) # Epsilon for stability if pt is 1
|
119 |
+
elif self.gamma_focal < 0: # Reversed Focal: (pt)^|gamma|. Focus on easy examples (pt is large).
|
120 |
+
focal_modulator = (pt + 1e-8).pow(abs(self.gamma_focal)) # Epsilon for stability if pt is 0
|
121 |
+
|
122 |
+
modulated_loss_terms = focal_modulator * base_bce_loss_terms
|
123 |
+
|
124 |
+
# 4. Confidence Weighting (based on how far probability is from 0.5)
|
125 |
+
# Uses the same `probs` calculated for focal `pt`.
|
126 |
+
confidence_w = (probs - 0.5).abs() * 2.0 # Scales to range [0, 1]
|
127 |
+
|
128 |
+
# 5. Length Weighting (longer reviews potentially weighted more)
|
129 |
+
lengths_flat = lengths.view(-1).float()
|
130 |
+
max_len_in_batch = lengths_flat.max().item()
|
131 |
+
|
132 |
+
if max_len_in_batch == 0: # Edge case: if all reviews in batch have 0 length
|
133 |
+
length_w = torch.ones_like(lengths_flat)
|
134 |
+
else:
|
135 |
+
# Normalize by sqrt of max length in the current batch. Add epsilon for stability.
|
136 |
+
length_w = torch.sqrt(lengths_flat) / (math.sqrt(max_len_in_batch) + 1e-8)
|
137 |
+
length_w = torch.clamp(length_w, 0.0, 1.0) # Ensure weights are capped at 1
|
138 |
+
|
139 |
+
# 6. Combine External Weights (Confidence and Length)
|
140 |
+
# These weights are applied ON TOP of the focal-modulated loss terms.
|
141 |
+
external_weights = confidence_w * length_w
|
142 |
+
|
143 |
+
# Normalize these combined external_weights so their mean is approximately 1.
|
144 |
+
# This prevents the weighting scheme from drastically changing the overall loss magnitude.
|
145 |
+
if external_weights.sum() > 1e-8: # Avoid division by zero if all weights are zero
|
146 |
+
normalized_external_weights = external_weights / (external_weights.mean() + 1e-8)
|
147 |
+
else: # If all external weights are zero, use ones to not nullify the loss.
|
148 |
+
normalized_external_weights = torch.ones_like(external_weights)
|
149 |
+
|
150 |
+
# 7. Apply Normalized External Weights to the (Focal) Modulated Loss Terms
|
151 |
+
final_loss_terms_per_sample = modulated_loss_terms * normalized_external_weights
|
152 |
+
|
153 |
+
# 8. Final Reduction: Mean of the per-sample losses
|
154 |
+
loss = final_loss_terms_per_sample.mean()
|
155 |
+
|
156 |
+
return loss
|
upload_to_hf.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from huggingface_hub import HfApi, upload_folder, create_repo
|
2 |
+
from transformers import AutoTokenizer, AutoConfig
|
3 |
+
import os
|
4 |
+
import shutil
|
5 |
+
import tempfile
|
6 |
+
|
7 |
+
# --- Configuration ---
|
8 |
+
HUGGING_FACE_USERNAME = "voxmenthe" # Your Hugging Face username
|
9 |
+
MODEL_NAME_ON_HF = "modernbert-imdb-sentiment" # The name of the model on Hugging Face
|
10 |
+
REPO_ID = f"{HUGGING_FACE_USERNAME}/{MODEL_NAME_ON_HF}"
|
11 |
+
|
12 |
+
# Original base model from which the tokenizer and initial config were derived
|
13 |
+
ORIGINAL_BASE_MODEL_NAME = "answerdotai/ModernBERT-base"
|
14 |
+
|
15 |
+
# Local path to your fine-tuned model checkpoint
|
16 |
+
LOCAL_MODEL_CHECKPOINT_DIR = "checkpoints"
|
17 |
+
FINE_TUNED_MODEL_FILENAME = "mean_epoch5_0.9575acc_0.9575f1.pt" # Your best checkpoint
|
18 |
+
# If your fine-tuned model is just a .pt file, ensure you also have a config.json for ModernBertForSentiment
|
19 |
+
# For simplicity, we'll re-save the config from the fine-tuned model structure if possible, or from original base.
|
20 |
+
|
21 |
+
# Files from your project to include (e.g., custom model code, inference script)
|
22 |
+
# The user has moved these to the root directory.
|
23 |
+
PROJECT_FILES_TO_UPLOAD = [
|
24 |
+
"config.yaml",
|
25 |
+
"inference.py",
|
26 |
+
"models.py",
|
27 |
+
"train_utils.py",
|
28 |
+
"classifiers.py",
|
29 |
+
"README.md"
|
30 |
+
]
|
31 |
+
|
32 |
+
def upload_model_and_tokenizer():
|
33 |
+
api = HfApi()
|
34 |
+
|
35 |
+
# Create the repository on Hugging Face Hub (if it doesn't exist)
|
36 |
+
print(f"Creating repository {REPO_ID} on Hugging Face Hub...")
|
37 |
+
create_repo(repo_id=REPO_ID, repo_type="model", exist_ok=True)
|
38 |
+
|
39 |
+
# Create a temporary directory to gather all files for upload
|
40 |
+
with tempfile.TemporaryDirectory() as tmp_upload_dir:
|
41 |
+
print(f"Created temporary directory for upload: {tmp_upload_dir}")
|
42 |
+
|
43 |
+
# 1. Save tokenizer files from the ORIGINAL_BASE_MODEL_NAME
|
44 |
+
print(f"Saving tokenizer from {ORIGINAL_BASE_MODEL_NAME} to {tmp_upload_dir}...")
|
45 |
+
try:
|
46 |
+
tokenizer = AutoTokenizer.from_pretrained(ORIGINAL_BASE_MODEL_NAME)
|
47 |
+
tokenizer.save_pretrained(tmp_upload_dir)
|
48 |
+
print("Tokenizer files saved.")
|
49 |
+
except Exception as e:
|
50 |
+
print(f"Error saving tokenizer from {ORIGINAL_BASE_MODEL_NAME}: {e}")
|
51 |
+
print("Please ensure this model name is correct and accessible.")
|
52 |
+
return
|
53 |
+
|
54 |
+
# 2. Save base model config.json (architecture) from ORIGINAL_BASE_MODEL_NAME
|
55 |
+
# This is crucial for AutoModelForSequenceClassification.from_pretrained(REPO_ID) to work.
|
56 |
+
print(f"Saving model config.json from {ORIGINAL_BASE_MODEL_NAME} to {tmp_upload_dir}...")
|
57 |
+
try:
|
58 |
+
config = AutoConfig.from_pretrained(ORIGINAL_BASE_MODEL_NAME)
|
59 |
+
# If your fine-tuned ModernBertForSentiment has specific architectural changes in its config
|
60 |
+
# that are NOT automatically handled by loading the state_dict (e.g. num_labels if not standard),
|
61 |
+
# you might need to update 'config' here before saving.
|
62 |
+
# For now, we assume the base config is sufficient or your model's state_dict handles it.
|
63 |
+
config.save_pretrained(tmp_upload_dir)
|
64 |
+
print("Model config.json saved.")
|
65 |
+
except Exception as e:
|
66 |
+
print(f"Error saving config.json from {ORIGINAL_BASE_MODEL_NAME}: {e}")
|
67 |
+
return
|
68 |
+
|
69 |
+
# 3. Copy fine-tuned model checkpoint to temporary directory
|
70 |
+
# The fine-tuned weights should be named 'pytorch_model.bin' or 'model.safetensors' for HF to auto-load.
|
71 |
+
# Or, your config.json in the repo should point to the custom name.
|
72 |
+
# For simplicity, we'll rename it to HF standard name of pytorch_model.bin.
|
73 |
+
local_checkpoint_path = os.path.join(LOCAL_MODEL_CHECKPOINT_DIR, FINE_TUNED_MODEL_FILENAME)
|
74 |
+
if os.path.exists(local_checkpoint_path):
|
75 |
+
hf_model_path = os.path.join(tmp_upload_dir, "pytorch_model.bin")
|
76 |
+
shutil.copyfile(local_checkpoint_path, hf_model_path)
|
77 |
+
print(f"Copied fine-tuned model {FINE_TUNED_MODEL_FILENAME} to {hf_model_path}.")
|
78 |
+
else:
|
79 |
+
print(f"Error: Fine-tuned model checkpoint {local_checkpoint_path} not found.")
|
80 |
+
return
|
81 |
+
|
82 |
+
# 4. Copy other project files
|
83 |
+
for project_file in PROJECT_FILES_TO_UPLOAD:
|
84 |
+
local_project_file_path = project_file # Files are now at the root
|
85 |
+
if os.path.exists(local_project_file_path):
|
86 |
+
shutil.copy(local_project_file_path, os.path.join(tmp_upload_dir, os.path.basename(project_file)))
|
87 |
+
print(f"Copied project file {project_file} to {tmp_upload_dir}.")
|
88 |
+
else:
|
89 |
+
print(f"Warning: Project file {project_file} not found at {local_project_file_path}.")
|
90 |
+
|
91 |
+
# 5. Upload the contents of the temporary directory
|
92 |
+
print(f"Uploading all files from {tmp_upload_dir} to {REPO_ID}...")
|
93 |
+
try:
|
94 |
+
upload_folder(
|
95 |
+
folder_path=tmp_upload_dir,
|
96 |
+
repo_id=REPO_ID,
|
97 |
+
repo_type="model",
|
98 |
+
commit_message=f"Upload fine-tuned model, tokenizer, and supporting files for {MODEL_NAME_ON_HF}"
|
99 |
+
)
|
100 |
+
print("All files uploaded successfully!")
|
101 |
+
except Exception as e:
|
102 |
+
print(f"Error uploading folder to Hugging Face Hub: {e}")
|
103 |
+
|
104 |
+
if __name__ == "__main__":
|
105 |
+
# Make sure you are logged in to Hugging Face CLI:
|
106 |
+
# Run `huggingface-cli login` or `huggingface-cli login --token YOUR_HF_WRITE_TOKEN` in your terminal first.
|
107 |
+
print("Starting upload process...")
|
108 |
+
print(f"Target Hugging Face Repo ID: {REPO_ID}")
|
109 |
+
print("Ensure you have run 'huggingface-cli login' with a write token.")
|
110 |
+
upload_model_and_tokenizer()
|