import logging from datetime import datetime import re from collections import Counter import pandas as pd import numpy as np import torch from torch.nn import CrossEntropyLoss from torch.utils.data import Dataset, DataLoader from transformers import ( BertConfig, BertForSequenceClassification, BertTokenizer, Trainer, TrainingArguments, EarlyStoppingCallback, ) from sklearn.model_selection import train_test_split from sklearn.metrics import ( accuracy_score, f1_score, precision_score, recall_score, confusion_matrix, ) from sklearn.utils.class_weight import compute_class_weight tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') config = BertConfig.from_pretrained("bert-base-uncased", num_labels=2) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") class WeightedBertForSequenceClassification(BertForSequenceClassification): def __init__(self, config, class_weights): super().__init__(config) self.class_weights = class_weights def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs): outputs = super().forward(input_ids=input_ids, attention_mask=attention_mask, labels=None, **kwargs) logits = outputs.logits loss = None if labels is not None: loss_fct = CrossEntropyLoss(weight=self.class_weights) loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) return {"loss": loss, "logits": logits} class SMSClassificationDataset(Dataset): def __init__(self, encodings, labels): self.encodings = encodings self.labels = torch.tensor(labels, dtype=torch.long) def __len__(self): return len(self.labels) def __getitem__(self, idx): item = {key: val[idx] for key, val in self.encodings.items()} item["labels"] = self.labels[idx] return item def compute_metrics(eval_pred): logits, labels = eval_pred predictions = torch.argmax(torch.tensor(logits), dim=1) acc = accuracy_score(labels, predictions) precision = precision_score(labels, predictions, average="weighted", zero_division=0) recall = recall_score(labels, predictions, average="weighted") f1 = f1_score(labels, predictions, average='weighted') cm = confusion_matrix(labels, predictions) print("Confusion Matrix:\n", cm) return { 'accuracy': acc, 'precision': precision, 'recall': recall, 'f1': f1 } def train(): df = pd.read_csv('data/spam.csv', encoding='iso-8859-1')[['label', 'text']] label_mapping = {'spam': 1, 'ham': 0} df['label'] = df['label'].map(label_mapping) train_texts, val_texts, train_labels, val_labels = train_test_split( df['text'].tolist(), df['label'].tolist(), test_size=0.25, random_state=42) class_weights = compute_class_weight( class_weight='balanced', classes=np.unique(train_labels), y=train_labels ) class_weights = torch.tensor(class_weights, dtype=torch.float).to(device) model = WeightedBertForSequenceClassification(config, class_weights=class_weights) loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict] for logger in loggers: if "transformers" in logger.name.lower(): logger.setLevel(logging.ERROR) model.load_state_dict(BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2, use_safetensors=True, return_dict=False, attn_implementation="sdpa").state_dict(), strict=False) model.to(device) train_encodings = tokenizer(train_texts, truncation=True, padding=True, return_tensors="pt") val_encodings = tokenizer(val_texts, truncation=True, padding=True, return_tensors="pt") train_dataset = SMSClassificationDataset(train_encodings, train_labels) val_dataset = SMSClassificationDataset(val_encodings, val_labels) training_args = TrainingArguments( output_dir='./models/pretrained', num_train_epochs=5, per_device_train_batch_size=8, per_device_eval_batch_size=16, warmup_steps=500, weight_decay=0.01, logging_dir='./logs', logging_steps=10, eval_strategy="epoch", report_to="none", save_total_limit=1, load_best_model_at_end=True, save_strategy="epoch", ) trainer = Trainer( model=model, args=training_args, train_dataset=train_dataset, eval_dataset=val_dataset, compute_metrics=compute_metrics, callbacks=[EarlyStoppingCallback(early_stopping_patience=3)] ) trainer.train() logs = trainer.state.log_history df_logs = pd.DataFrame(logs) timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') df_logs.to_csv(f"logs/training_logs_{timestamp}.csv", index=False) tokenizer.save_pretrained('./models/pretrained') model.save_pretrained('./models/pretrained') if __name__ == "__main__": train()