fahmiaziz98 commited on
Commit
7f6c186
·
1 Parent(s): 4eb341e

py 3.9 + torch cpu

Browse files
Files changed (3) hide show
  1. Dockerfile +2 -2
  2. requirements.txt +2 -1
  3. train_bert.py +177 -0
Dockerfile CHANGED
@@ -1,4 +1,4 @@
1
- FROM python:3.11.11-slim AS build
2
 
3
  ENV PIP_DEFAULT_TIMEOUT=100 \
4
  PYTHONUNBUFFERED=1 \
@@ -20,7 +20,7 @@ WORKDIR /app
20
  COPY --chown=user requirements.txt .
21
  RUN pip install --no-cache-dir --user -r requirements.txt
22
 
23
- FROM python:3.11.11-slim
24
 
25
  RUN apt-get update && apt-get install -y \
26
  libjpeg-dev \
 
1
+ FROM python:3.9-slim AS build
2
 
3
  ENV PIP_DEFAULT_TIMEOUT=100 \
4
  PYTHONUNBUFFERED=1 \
 
20
  COPY --chown=user requirements.txt .
21
  RUN pip install --no-cache-dir --user -r requirements.txt
22
 
23
+ FROM python:3.9-slim
24
 
25
  RUN apt-get update && apt-get install -y \
26
  libjpeg-dev \
requirements.txt CHANGED
@@ -1,5 +1,6 @@
1
  boto3==1.37.27
2
- torch==2.6.0
 
3
  pillow==11.1.0
4
  transformers==4.50.3
5
  fastapi==0.115.12
 
1
  boto3==1.37.27
2
+ # torch==2.6.0 # Uncomment the line below if you want to use the GPU version of PyTorch
3
+ torch==2.6.0+cpu --index-url https://download.pytorch.org/whl/cpu
4
  pillow==11.1.0
5
  transformers==4.50.3
6
  fastapi==0.115.12
train_bert.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from datetime import datetime
3
+ import pandas as pd
4
+ import numpy as np
5
+ import torch
6
+ from torch.nn import CrossEntropyLoss
7
+ from torch.utils.data import Dataset, DataLoader
8
+ from transformers import (
9
+ BertConfig,
10
+ BertForSequenceClassification,
11
+ BertTokenizer,
12
+ Trainer,
13
+ TrainingArguments,
14
+ EarlyStoppingCallback,
15
+ )
16
+ from sklearn.model_selection import train_test_split
17
+ from sklearn.metrics import (
18
+ accuracy_score,
19
+ f1_score,
20
+ precision_score,
21
+ recall_score,
22
+ confusion_matrix,
23
+ )
24
+ from sklearn.utils.class_weight import compute_class_weight
25
+
26
+ # Setup
27
+ tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
28
+ config = BertConfig.from_pretrained("bert-base-uncased", num_labels=2)
29
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
+ last_confusion_matrix = None
31
+
32
+ class WeightedBertForSequenceClassification(BertForSequenceClassification):
33
+ def __init__(self, config, class_weights):
34
+ super().__init__(config)
35
+ self.class_weights = class_weights
36
+
37
+ def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
38
+ outputs = super().forward(input_ids=input_ids, attention_mask=attention_mask, labels=None, **kwargs)
39
+ logits = outputs.logits
40
+ loss = None
41
+ if labels is not None:
42
+ loss_fct = CrossEntropyLoss(weight=self.class_weights)
43
+ loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
44
+ return {"loss": loss, "logits": logits}
45
+
46
+ class SMSClassificationDataset(Dataset):
47
+ def __init__(self, encodings, labels):
48
+ self.encodings = encodings
49
+ self.labels = torch.tensor(labels, dtype=torch.long)
50
+
51
+ def __len__(self):
52
+ return len(self.labels)
53
+
54
+ def __getitem__(self, idx):
55
+ item = {key: val[idx] for key, val in self.encodings.items()}
56
+ item["labels"] = self.labels[idx]
57
+ return item
58
+
59
+ def compute_metrics(eval_pred):
60
+ logits, labels = eval_pred
61
+ predictions = torch.argmax(torch.tensor(logits), dim=1)
62
+
63
+ acc = accuracy_score(labels, predictions)
64
+ precision = precision_score(labels, predictions, average="weighted", zero_division=0)
65
+ recall = recall_score(labels, predictions, average="weighted")
66
+ f1 = f1_score(labels, predictions, average='weighted')
67
+ cm = confusion_matrix(labels, predictions)
68
+
69
+ last_confusion_matrix = cm
70
+
71
+ return {
72
+ 'accuracy': acc,
73
+ 'precision': precision,
74
+ 'recall': recall,
75
+ 'f1': f1
76
+ }
77
+
78
+ def train():
79
+ # Load and preprocess data
80
+ df = pd.read_csv('data/spam.csv', encoding='iso-8859-1')[['label', 'text']]
81
+ df['label'] = df['label'].map({'spam': 1, 'ham': 0})
82
+
83
+ # Split into train (70%), validation (15%), test (15%)
84
+ train_texts, temp_texts, train_labels, temp_labels = train_test_split(
85
+ df['text'], df['label'], test_size=0.30, random_state=42, stratify=df['label']
86
+ )
87
+ val_texts, test_texts, val_labels, test_labels = train_test_split(
88
+ temp_texts, temp_labels, test_size=0.5, random_state=42, stratify=temp_labels
89
+ )
90
+
91
+ # Compute class weights from training labels
92
+ class_weights = compute_class_weight(
93
+ class_weight='balanced',
94
+ classes=np.unique(train_labels),
95
+ y=train_labels
96
+ )
97
+ class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)
98
+
99
+ # Silence excessive logging
100
+ for logger in logging.root.manager.loggerDict:
101
+ if "transformers" in logger.lower():
102
+ logging.getLogger(logger).setLevel(logging.ERROR)
103
+
104
+ # Initialize model
105
+ model = WeightedBertForSequenceClassification(config, class_weights=class_weights)
106
+ model.load_state_dict(BertForSequenceClassification.from_pretrained(
107
+ "bert-base-uncased", num_labels=2, use_safetensors=True, return_dict=False, attn_implementation="sdpa"
108
+ ).state_dict(), strict=False)
109
+ model.to(device)
110
+
111
+ # Tokenize
112
+ train_encodings = tokenizer(train_texts.tolist(), truncation=True, padding=True, return_tensors="pt")
113
+ val_encodings = tokenizer(val_texts.tolist(), truncation=True, padding=True, return_tensors="pt")
114
+ test_encodings = tokenizer(test_texts.tolist(), truncation=True, padding=True, return_tensors="pt")
115
+
116
+ # Datasets
117
+ train_dataset = SMSClassificationDataset(train_encodings, train_labels.tolist())
118
+ val_dataset = SMSClassificationDataset(val_encodings, val_labels.tolist())
119
+ test_dataset = SMSClassificationDataset(test_encodings, test_labels.tolist())
120
+
121
+ # Training setup
122
+ training_args = TrainingArguments(
123
+ output_dir='./models/pretrained',
124
+ num_train_epochs=5,
125
+ per_device_train_batch_size=8,
126
+ per_device_eval_batch_size=16,
127
+ warmup_steps=500,
128
+ weight_decay=0.01,
129
+ logging_dir='./logs',
130
+ logging_steps=10,
131
+ eval_strategy="epoch",
132
+ report_to="none",
133
+ save_total_limit=1,
134
+ load_best_model_at_end=True,
135
+ save_strategy="epoch",
136
+ )
137
+
138
+ # Trainer
139
+ trainer = Trainer(
140
+ model=model,
141
+ args=training_args,
142
+ train_dataset=train_dataset,
143
+ eval_dataset=val_dataset,
144
+ compute_metrics=compute_metrics,
145
+ callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
146
+ )
147
+
148
+ # Train
149
+ trainer.train()
150
+
151
+ # Save logs
152
+ logs = trainer.state.log_history
153
+ timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
154
+ pd.DataFrame(logs).to_csv(f"logs/training_logs_{timestamp}.csv", index=False)
155
+
156
+ # Save model and tokenizer
157
+ tokenizer.save_pretrained('./models/pretrained')
158
+ model.save_pretrained('./models/pretrained')
159
+
160
+ # Final test set evaluation
161
+ print("\nEvaluating on FINAL TEST SET:")
162
+ final_test_metrics = trainer.evaluate(eval_dataset=test_dataset)
163
+ print("Final Test Set Metrics:", final_test_metrics)
164
+
165
+ timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
166
+ log_filename = f"logs/final_test_results_{timestamp}.txt"
167
+
168
+ with open(log_filename, "w") as f:
169
+ f.write("FINAL TEST SET METRICS\n")
170
+ for key, value in final_test_metrics.items():
171
+ f.write(f"{key}: {value}\n")
172
+
173
+ f.write("\nCONFUSION MATRIX\n")
174
+ f.write(str(last_confusion_matrix))
175
+
176
+ if __name__ == "__main__":
177
+ train()