File size: 4,028 Bytes
d5ba1b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import pandas as pd
import time
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoConfig

# Use a sentiment-specific model (replace with TinyBERT if fine-tuned)
MODEL = "tabularisai/multilingual-sentiment-analysis"  # Pre-trained for positive/negative sentiment

print("Loading model and tokenizer...")
start_load = time.time()

# Check for MPS (Metal) availability on M2 chip, fallback to CPU
device = "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using device: {device}")

# Load with optimizations (only once, removing redundancy)
tokenizer = AutoTokenizer.from_pretrained(MODEL, use_fast=True)
model = AutoModelForSequenceClassification.from_pretrained(MODEL).to(device)
config = AutoConfig.from_pretrained(MODEL)

load_time = time.time() - start_load
print(f"Model and tokenizer loaded in {load_time:.2f} seconds\n")

# Optimized preprocessing (unchanged from your code)
def preprocess(text):
    if not isinstance(text, str):
        text = str(text) if not pd.isna(text) else ""
    
    new_text = []
    for t in text.split(" "):
        t = '@user' if t.startswith('@') and len(t) > 1 else t
        t = 'http' if t.startswith('http') else t
        new_text.append(t)
    return " ".join(new_text)

# Batch prediction function (optimized for performance)
def predict_sentiment_batch(texts: list, batch_size: int = 16) -> list:
    if not isinstance(texts, list):
        raise TypeError(f"Expected list of texts, got {type(texts)}")
    
    # Validate and clean inputs
    valid_texts = [str(text) for text in texts if isinstance(text, str) and text.strip()]
    if not valid_texts:
        return []  # Return empty list if no valid texts
    
    print(f"Processing {len(valid_texts)} valid samples...")
    processed_texts = [preprocess(text) for text in valid_texts]
    
    predictions = []
    for i in range(0, len(processed_texts), batch_size):
        batch = processed_texts[i:i + batch_size]
        try:
            inputs = tokenizer(
                batch,
                padding=True,
                truncation=True,
                return_tensors="pt",
                max_length=64  # Reduced for speed on short texts like tweets
            ).to(device)
            
            with torch.no_grad():
                outputs = model(**inputs)
            
            batch_preds = outputs.logits.argmax(dim=1).cpu().numpy()
            predictions.extend([config.id2label[p] for p in batch_preds])
        except Exception as e:
            print(f"Error processing batch {i // batch_size}: {str(e)}")
            predictions.extend(["neutral"] * len(batch))  # Consider logging instead
        
    print(f"Predictions for {len(valid_texts)} samples generated in {time.time() - start_load:.2f} seconds")
    predictions = [prediction.lower().replace("very ", "") for prediction in predictions]

    print(predictions)
    
    return predictions

# # Example usage with your dataset (uncomment and adjust paths)
# test_data = pd.read_csv("/Users/caasidev/development/AI/last try/Whatssap-project/srcs/tweets.csv")
# print(f"Processing {len(test_data)} samples...")
# start_prediction = time.time()

# text_samples = test_data['text'].tolist()
# test_data['predicted_sentiment'] = predict_sentiment_batch(text_samples)

# prediction_time = time.time() - start_prediction
# time_per_sample = prediction_time / len(test_data)

# # Print runtime statistics
# print("\nRuntime Statistics:")
# print(f"- Model loading time: {load_time:.2f} seconds")
# print(f"- Total prediction time for {len(test_data)} samples: {prediction_time:.2f} seconds")
# print(f"- Average time per sample: {time_per_sample:.4f} seconds")
# print(f"- Estimated time for 1000 samples: {(time_per_sample * 1000):.2f} seconds")
# print(f"- Estimated time for 20000 samples: {(time_per_sample * 20000 / 60):.2f} minutes")

# # Print a sample of predictions
# print("\nPredicted Sentiments (first 5 samples):")
# print(test_data[['text', 'predicted_sentiment']].head())