toxic-comment-classifier / model_loader.py
JanviMl's picture
Update model_loader.py
583e12e verified
raw
history blame
824 Bytes
from transformers import AutoModelForSequenceClassification, AutoTokenizer
def load_model_and_tokenizer():
"""
Load the fine-tuned XLM-RoBERTa model and tokenizer.
Returns the model and tokenizer for use in classification.
"""
try:
model_name = "JanviMl/xlm-roberta-toxic-classifier-capstone" # Replace with your model repo ID
# If the model is local: model_name = "./model"
model = AutoModelForSequenceClassification.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) # Use slow tokenizer
return model, tokenizer
except Exception as e:
raise Exception(f"Error loading model or tokenizer: {str(e)}")
# Load the model and tokenizer once at startup
model, tokenizer = load_model_and_tokenizer()