|
import pandas as pd |
|
import time |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoConfig |
|
|
|
|
|
MODEL = "tabularisai/multilingual-sentiment-analysis" |
|
|
|
print("Loading model and tokenizer...") |
|
start_load = time.time() |
|
|
|
|
|
device = "mps" if torch.backends.mps.is_available() else "cpu" |
|
print(f"Using device: {device}") |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL, use_fast=True) |
|
model = AutoModelForSequenceClassification.from_pretrained(MODEL).to(device) |
|
config = AutoConfig.from_pretrained(MODEL) |
|
|
|
load_time = time.time() - start_load |
|
print(f"Model and tokenizer loaded in {load_time:.2f} seconds\n") |
|
|
|
|
|
def preprocess(text): |
|
if not isinstance(text, str): |
|
text = str(text) if not pd.isna(text) else "" |
|
|
|
new_text = [] |
|
for t in text.split(" "): |
|
t = '@user' if t.startswith('@') and len(t) > 1 else t |
|
t = 'http' if t.startswith('http') else t |
|
new_text.append(t) |
|
return " ".join(new_text) |
|
|
|
|
|
def predict_sentiment_batch(texts: list, batch_size: int = 16) -> list: |
|
if not isinstance(texts, list): |
|
raise TypeError(f"Expected list of texts, got {type(texts)}") |
|
|
|
|
|
valid_texts = [str(text) for text in texts if isinstance(text, str) and text.strip()] |
|
if not valid_texts: |
|
return [] |
|
|
|
print(f"Processing {len(valid_texts)} valid samples...") |
|
processed_texts = [preprocess(text) for text in valid_texts] |
|
|
|
predictions = [] |
|
for i in range(0, len(processed_texts), batch_size): |
|
batch = processed_texts[i:i + batch_size] |
|
try: |
|
inputs = tokenizer( |
|
batch, |
|
padding=True, |
|
truncation=True, |
|
return_tensors="pt", |
|
max_length=64 |
|
).to(device) |
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
|
|
batch_preds = outputs.logits.argmax(dim=1).cpu().numpy() |
|
predictions.extend([config.id2label[p] for p in batch_preds]) |
|
except Exception as e: |
|
print(f"Error processing batch {i // batch_size}: {str(e)}") |
|
predictions.extend(["neutral"] * len(batch)) |
|
|
|
print(f"Predictions for {len(valid_texts)} samples generated in {time.time() - start_load:.2f} seconds") |
|
predictions = [prediction.lower().replace("very ", "") for prediction in predictions] |
|
|
|
print(predictions) |
|
|
|
return predictions |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|