VanshK04 commited on
Commit
f29e855
·
verified ·
1 Parent(s): a24ea0c

Update tasks/text.py

Browse files
Files changed (1) hide show
  1. tasks/text.py +59 -42
tasks/text.py CHANGED
@@ -1,34 +1,70 @@
1
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
2
- import torch
3
  from fastapi import APIRouter
4
  from datetime import datetime
5
- from datasets import Dataset
6
  from sklearn.metrics import accuracy_score
7
- from sklearn.model_selection import train_test_split
8
- from torch.utils.data import DataLoader, Dataset
9
- import pandas as pd
10
- from sklearn.preprocessing import LabelEncoder
11
 
12
  from .utils.evaluation import TextEvaluationRequest
13
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
14
 
15
  router = APIRouter()
16
 
17
- ROUTE = "/text" # Define the route
18
- DESCRIPTION = "Evaluate text classification for climate disinformation detection" # Define the description
19
 
20
- # Load the fine-tuned BERT model and tokenizer
21
- model_dir = "./" # Path to the fine-tuned BERT model directory
22
- tokenizer = AutoTokenizer.from_pretrained(model_dir)
23
- model = AutoModelForSequenceClassification.from_pretrained(model_dir)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  # Assign device
26
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
- model.to(device)
28
- model.eval() # Set the model to evaluation mode
29
 
30
- # Dataset class
31
- class TextDataset(Dataset):
32
  def __init__(self, texts, labels, tokenizer, max_len=128):
33
  self.texts = texts
34
  self.labels = labels
@@ -54,7 +90,7 @@ class TextDataset(Dataset):
54
  'labels': torch.tensor(label, dtype=torch.long)
55
  }
56
 
57
- @router.post(ROUTE, tags=["Text Task"],
58
  description=DESCRIPTION)
59
  async def evaluate_text(request: TextEvaluationRequest):
60
  """
@@ -82,30 +118,11 @@ async def evaluate_text(request: TextEvaluationRequest):
82
  val_dataset = TextDataset(val_texts, val_labels, tokenizer)
83
  val_loader = DataLoader(val_dataset, batch_size=32)
84
 
85
- # Start tracking emissions
86
- tracker.start()
87
- tracker.start_task("inference")
88
-
89
- #--------------------------------------------------------------------------------------------
90
- # Fine-tuned BERT model inference
91
- #--------------------------------------------------------------------------------------------
92
- predictions = []
93
- true_labels = val_labels.tolist()
94
-
95
- with torch.no_grad():
96
- for batch in val_loader:
97
- input_ids = batch["input_ids"].to(device)
98
- attention_mask = batch["attention_mask"].to(device)
99
-
100
- outputs = model(input_ids=input_ids, attention_mask=attention_mask)
101
- logits = outputs.logits
102
- batch_predictions = torch.argmax(logits, dim=1).cpu().tolist()
103
- predictions.extend(batch_predictions)
104
-
105
- #--------------------------------------------------------------------------------------------
106
- # Fine-tuned BERT model inference stops here
107
  #--------------------------------------------------------------------------------------------
 
 
108
 
 
109
  # Stop tracking emissions
110
  emissions_data = tracker.stop_task()
111
 
@@ -130,4 +147,4 @@ async def evaluate_text(request: TextEvaluationRequest):
130
  }
131
  }
132
 
133
- return results
 
 
 
1
  from fastapi import APIRouter
2
  from datetime import datetime
3
+ from datasets import load_dataset
4
  from sklearn.metrics import accuracy_score
5
+ import random
 
 
 
6
 
7
  from .utils.evaluation import TextEvaluationRequest
8
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
9
 
10
  router = APIRouter()
11
 
12
+ DESCRIPTION = "Evaluate text classification for climate disinformation detection"
13
+ ROUTE = "/text"
14
 
15
+ @router.post(ROUTE, tags=["Text Task"],
16
+ description=DESCRIPTION)
17
+ async def evaluate_text(request: TextEvaluationRequest):
18
+ """
19
+ Evaluate text classification for climate disinformation detection.
20
+
21
+ Current Model: Random Baseline
22
+ - Makes random predictions from the label space (0-7)
23
+ - Used as a baseline for comparison
24
+ """
25
+ # Get space info
26
+ username, space_url = get_space_info()
27
+
28
+ # Define the label mapping
29
+ LABEL_MAPPING = {
30
+ "0_not_relevant": 0,
31
+ "1_not_happening": 1,
32
+ "2_not_human": 2,
33
+ "3_not_bad": 3,
34
+ "4_solutions_harmful_unnecessary": 4,
35
+ "5_science_unreliable": 5,
36
+ "6_proponents_biased": 6,
37
+ "7_fossil_fuels_needed": 7
38
+ }
39
+
40
+ # Load and prepare the dataset
41
+ dataset = load_dataset(request.dataset_name)
42
+
43
+ # Convert string labels to integers
44
+ dataset = dataset.map(lambda x: {"label": LABEL_MAPPING[x["label"]]})
45
+
46
+ # Split dataset
47
+ train_test = dataset["train"].train_test_split(test_size=request.test_size, seed=request.test_seed)
48
+ test_dataset = train_test["test"]
49
+
50
+ # Start tracking emissions
51
+ tracker.start()
52
+ tracker.start_task("inference")
53
+
54
+ #--------------------------------------------------------------------------------------------
55
+ # YOUR MODEL INFERENCE CODE HERE
56
+ # Update the code below to replace the random baseline by your model inference within the inference pass where the energy consumption and emissions are tracked.
57
+ #--------------------------------------------------------------------------------------------
58
+ model_dir = "./" # Path to the fine-tuned BERT model directory
59
+ tokenizer = AutoTokenizer.from_pretrained(model_dir)
60
+ model = AutoModelForSequenceClassification.from_pretrained(model_dir)
61
 
62
  # Assign device
63
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
64
+ model.to(device)
65
+ model.eval() # Set the model to evaluation mode
66
 
67
+ class TextDataset(Dataset):
 
68
  def __init__(self, texts, labels, tokenizer, max_len=128):
69
  self.texts = texts
70
  self.labels = labels
 
90
  'labels': torch.tensor(label, dtype=torch.long)
91
  }
92
 
93
+ @router.post(ROUTE, tags=["Text Task"],
94
  description=DESCRIPTION)
95
  async def evaluate_text(request: TextEvaluationRequest):
96
  """
 
118
  val_dataset = TextDataset(val_texts, val_labels, tokenizer)
119
  val_loader = DataLoader(val_dataset, batch_size=32)
120
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  #--------------------------------------------------------------------------------------------
122
+ # YOUR MODEL INFERENCE STOPS HERE
123
+ #--------------------------------------------------------------------------------------------
124
 
125
+
126
  # Stop tracking emissions
127
  emissions_data = tracker.stop_task()
128
 
 
147
  }
148
  }
149
 
150
+ return results