Spaces:
Running
Running
import os | |
import torch | |
import torch.nn as nn | |
from transformers import AutoModelForSequenceClassification | |
# Define target columns | |
target_columns = ['susu', 'kacang', 'telur', 'makanan_laut', 'gandum'] | |
# Define model for multilabel classification | |
class MultilabelBertClassifier(nn.Module): | |
def __init__(self, model_name, num_labels): | |
super(MultilabelBertClassifier, self).__init__() | |
self.bert = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels) | |
# Replace the classification head with our own for multilabel | |
self.bert.classifier = nn.Linear(self.bert.config.hidden_size, num_labels) | |
def forward(self, input_ids, attention_mask): | |
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) | |
return outputs.logits | |
# Set device | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
print(f"Using device: {device}") | |
# Initialize model | |
model = MultilabelBertClassifier('indobenchmark/indobert-base-p1', len(target_columns)) | |
# Load the best model for evaluation | |
print("Loading model from best_alergen_model.pt...") | |
state_dict = torch.load('best_alergen_model.pt', map_location=device) | |
# If the model was trained with DataParallel, we need to remove the 'module.' prefix | |
new_state_dict = {} | |
for k, v in state_dict.items(): | |
name = k[7:] if k.startswith('module.') else k | |
new_state_dict[name] = v | |
model.load_state_dict(new_state_dict) | |
model.to(device) | |
# Create model directory | |
os.makedirs('model', exist_ok=True) | |
# Save model | |
print("Saving model to model/alergen_model.pt...") | |
torch.save({ | |
'model_state_dict': model.state_dict(), | |
'target_columns': target_columns, | |
}, 'model/alergen_model.pt') | |
print("Done!") |