toxic-comment-classifier / model_loader.py
JanviMl's picture
Create model_loader.py
4c95418 verified
raw
history blame
801 Bytes
# model_loader.py
from transformers import AutoModelForSequenceClassification, AutoTokenizer
def load_model_and_tokenizer():
"""
Load the fine-tuned XLM-RoBERTa model and tokenizer.
Returns the model and tokenizer for use in classification.
"""
try:
model_name = "your_username/xlm-roberta-toxic-classifier" # Replace with your model repo ID
# If the model is local: model_name = "./model"
model = AutoModelForSequenceClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
return model, tokenizer
except Exception as e:
raise Exception(f"Error loading model or tokenizer: {str(e)}")
# Load the model and tokenizer once at startup
model, tokenizer = load_model_and_tokenizer()