news_verification / src /application /text /ai_classification.py
pmkhanh7890's picture
refactor code + fix bug of label after grouping url
00b1038
raw
history blame
2.39 kB
from typing import (
Dict,
List,
Tuple,
)
import torch
from transformers import (
AutoModelForSequenceClassification,
AutoTokenizer,
)
from src.application.config import AI_TEXT_CLASSIFICATION_MODEL
def load_model_and_tokenizer(
model_path: str = AI_TEXT_CLASSIFICATION_MODEL,
) -> Tuple[AutoTokenizer, AutoModelForSequenceClassification]:
"""
Loads the trained model and tokenizer from the specified path.
Args:
model_path: path of directory containing the saved model and tokenizer.
Returns:
A tuple containing the loaded tokenizer and model.
"""
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSequenceClassification.from_pretrained(model_path)
model.eval()
return tokenizer, model
def predict(
texts: List[str],
model: AutoModelForSequenceClassification,
tokenizer: AutoTokenizer,
) -> List[Dict[str, str]]:
"""
Classify on input texts into gpt-4o or gpt-4o-mini.
Args:
texts: A list of input text strings to be classified.
model: The loaded model for sequence classification.
tokenizer: The loaded tokenizer.
Returns:
A list of dictionaries, where each dictionary contains the input text,
the predicted label, and the confidence score.
"""
label_map = {0: "GPT-4o", 1: "GPT-4o mini"}
inputs = tokenizer(
texts,
padding="max_length",
truncation=True,
return_tensors="pt",
)
with torch.no_grad():
outputs = model(**inputs)
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
confidence, predictions = torch.max(probabilities, dim=-1)
results = []
for text, pred, conf in zip(
texts,
predictions.tolist(),
confidence.tolist(),
):
results.append(
{"input": text, "prediction": label_map[pred], "confidence": conf},
)
return results
if __name__ == "__main__":
text = """The resignation brings a long political chapter to an end.
Trudeau has been in office since 2015, when he brought the Liberals back
to power from the political wilderness.
"""
tokenizer, model = load_model_and_tokenizer("ductuan024/gpts-detector")
predictions = predict(text, model, tokenizer)
print(predictions[0]["prediction"])
print(predictions[0]["confidence"])