allergen_detector_bert / save_model.py
rdsarjito
8 commit
f391e9e
import os
import torch
import torch.nn as nn
from transformers import AutoModelForSequenceClassification
# Define target columns
target_columns = ['susu', 'kacang', 'telur', 'makanan_laut', 'gandum']
# Define model for multilabel classification
class MultilabelBertClassifier(nn.Module):
def __init__(self, model_name, num_labels):
super(MultilabelBertClassifier, self).__init__()
self.bert = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels)
# Replace the classification head with our own for multilabel
self.bert.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)
def forward(self, input_ids, attention_mask):
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
return outputs.logits
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Initialize model
model = MultilabelBertClassifier('indobenchmark/indobert-base-p1', len(target_columns))
# Load the best model for evaluation
print("Loading model from best_alergen_model.pt...")
state_dict = torch.load('best_alergen_model.pt', map_location=device)
# If the model was trained with DataParallel, we need to remove the 'module.' prefix
new_state_dict = {}
for k, v in state_dict.items():
name = k[7:] if k.startswith('module.') else k
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
model.to(device)
# Create model directory
os.makedirs('model', exist_ok=True)
# Save model
print("Saving model to model/alergen_model.pt...")
torch.save({
'model_state_dict': model.state_dict(),
'target_columns': target_columns,
}, 'model/alergen_model.pt')
print("Done!")