Zen0 commited on
Commit
1358711
·
verified ·
1 Parent(s): daf1822

Update tasks/text.py

Browse files
Files changed (1) hide show
  1. tasks/text.py +60 -54
tasks/text.py CHANGED
@@ -1,10 +1,9 @@
 
1
  from fastapi import APIRouter
2
  from datetime import datetime
3
  from datasets import load_dataset
4
  from sklearn.metrics import accuracy_score
5
- from sklearn.feature_extraction.text import TfidfVectorizer
6
- from sklearn.linear_model import LogisticRegression
7
- from sklearn.pipeline import Pipeline
8
  import numpy as np
9
 
10
  from .utils.evaluation import TextEvaluationRequest
@@ -12,27 +11,9 @@ from .utils.emissions import tracker, clean_emissions_data, get_space_info
12
 
13
  router = APIRouter()
14
 
15
- DESCRIPTION = "Climate Disinformation Detection - TF-IDF + LogReg"
16
  ROUTE = "/text"
17
 
18
- def create_pipeline():
19
- """Create an efficient text classification pipeline"""
20
- return Pipeline([
21
- ('tfidf', TfidfVectorizer(
22
- max_features=10000, # Limit features for efficiency
23
- ngram_range=(1, 2), # Use unigrams and bigrams
24
- stop_words='english',
25
- min_df=2, # Remove very rare terms
26
- max_df=0.95 # Remove very common terms
27
- )),
28
- ('classifier', LogisticRegression(
29
- C=1.0,
30
- multi_class='multinomial',
31
- max_iter=200,
32
- n_jobs=-1 # Use all CPU cores
33
- ))
34
- ])
35
-
36
  @router.post(ROUTE, tags=["Text Task"], description=DESCRIPTION)
37
  async def evaluate_text(request: TextEvaluationRequest):
38
  """
@@ -53,48 +34,74 @@ async def evaluate_text(request: TextEvaluationRequest):
53
  "7_fossil_fuels_needed": 7
54
  }
55
 
 
 
 
 
 
 
56
  # Start tracking emissions
57
  tracker.start()
58
  tracker.start_task("inference")
59
 
60
  try:
61
- # Load and prepare the dataset
62
- dataset = load_dataset(request.dataset_name)
63
-
64
- # Convert string labels to integers
65
- dataset = dataset.map(lambda x: {"label": LABEL_MAPPING[x["label"]]})
66
-
67
- # Split dataset
68
- train_test = dataset["train"].train_test_split(
69
- test_size=request.test_size,
70
- seed=request.test_seed
71
- )
72
-
73
- train_dataset = train_test["train"]
74
- test_dataset = train_test["test"]
75
 
76
- # Create and train pipeline
77
- pipeline = create_pipeline()
78
-
79
- # Train the model
80
- pipeline.fit(
81
- train_dataset["quote"],
82
- train_dataset["label"]
83
  )
84
-
85
- # Make predictions
86
- predictions = pipeline.predict(test_dataset["quote"])
87
-
88
- # Get true labels
89
- true_labels = test_dataset["label"]
90
 
91
- # Stop tracking emissions
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  emissions_data = tracker.stop_task()
93
-
94
- # Calculate accuracy
95
  accuracy = accuracy_score(true_labels, predictions)
96
 
97
- # Prepare results dictionary
98
  results = {
99
  "username": username,
100
  "space_url": space_url,
@@ -115,6 +122,5 @@ async def evaluate_text(request: TextEvaluationRequest):
115
  return results
116
 
117
  except Exception as e:
118
- # Stop tracking in case of error
119
  tracker.stop_task()
120
  raise e
 
1
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
2
  from fastapi import APIRouter
3
  from datetime import datetime
4
  from datasets import load_dataset
5
  from sklearn.metrics import accuracy_score
6
+ import torch
 
 
7
  import numpy as np
8
 
9
  from .utils.evaluation import TextEvaluationRequest
 
11
 
12
  router = APIRouter()
13
 
14
+ DESCRIPTION = "Efficient Climate Disinformation Detection"
15
  ROUTE = "/text"
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  @router.post(ROUTE, tags=["Text Task"], description=DESCRIPTION)
18
  async def evaluate_text(request: TextEvaluationRequest):
19
  """
 
34
  "7_fossil_fuels_needed": 7
35
  }
36
 
37
+ # Load and prepare the dataset
38
+ dataset = load_dataset(request.dataset_name)
39
+ dataset = dataset.map(lambda x: {"label": LABEL_MAPPING[x["label"]]})
40
+ train_test = dataset["train"].train_test_split(test_size=request.test_size, seed=request.test_seed)
41
+ test_dataset = train_test["test"]
42
+
43
  # Start tracking emissions
44
  tracker.start()
45
  tracker.start_task("inference")
46
 
47
  try:
48
+ # Model configuration
49
+ model_name = "distilbert-base-uncased" # Lighter model than MobileBERT
50
+ BATCH_SIZE = 64 # Increased batch size
51
+ MAX_LENGTH = 128 # Reduced sequence length
 
 
 
 
 
 
 
 
 
 
52
 
53
+ # Initialize tokenizer and model
54
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
55
+ model = AutoModelForSequenceClassification.from_pretrained(
56
+ model_name,
57
+ num_labels=8,
58
+ problem_type="single_label_classification",
 
59
  )
 
 
 
 
 
 
60
 
61
+ # Enable mixed precision training if available
62
+ if torch.cuda.is_available():
63
+ model = model.half() # Convert to FP16
64
+
65
+ # Move model to device
66
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
67
+ model = model.to(device)
68
+ model.eval()
69
+
70
+ # Get test texts
71
+ test_texts = test_dataset["quote"]
72
+ predictions = []
73
+
74
+ # Process in efficient batches
75
+ for i in range(0, len(test_texts), BATCH_SIZE):
76
+ if torch.cuda.is_available():
77
+ torch.cuda.empty_cache()
78
+
79
+ batch_texts = test_texts[i:i + BATCH_SIZE]
80
+
81
+ # Efficient tokenization
82
+ inputs = tokenizer(
83
+ batch_texts,
84
+ padding=True,
85
+ truncation=True,
86
+ max_length=MAX_LENGTH,
87
+ return_tensors="pt"
88
+ )
89
+
90
+ # Move inputs to device efficiently
91
+ inputs = {k: v.to(device) for k, v in inputs.items()}
92
+
93
+ # Inference with optimizations
94
+ with torch.no_grad(), torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
95
+ outputs = model(**inputs)
96
+ batch_preds = torch.argmax(outputs.logits, dim=1)
97
+ predictions.extend(batch_preds.cpu().numpy())
98
+
99
+ # Get true labels and compute accuracy
100
+ true_labels = test_dataset['label']
101
  emissions_data = tracker.stop_task()
 
 
102
  accuracy = accuracy_score(true_labels, predictions)
103
 
104
+ # Prepare results
105
  results = {
106
  "username": username,
107
  "space_url": space_url,
 
122
  return results
123
 
124
  except Exception as e:
 
125
  tracker.stop_task()
126
  raise e