Zen0 commited on
Commit
f63546e
·
verified ·
1 Parent(s): 1358711

Update tasks/text.py

Browse files
Files changed (1) hide show
  1. tasks/text.py +38 -14
tasks/text.py CHANGED
@@ -1,5 +1,6 @@
1
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
2
- from fastapi import APIRouter
 
3
  from datetime import datetime
4
  from datasets import load_dataset
5
  from sklearn.metrics import accuracy_score
@@ -9,12 +10,23 @@ import numpy as np
9
  from .utils.evaluation import TextEvaluationRequest
10
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
11
 
 
 
12
  router = APIRouter()
13
 
 
 
 
 
 
 
 
 
 
14
  DESCRIPTION = "Efficient Climate Disinformation Detection"
15
  ROUTE = "/text"
16
 
17
- @router.post(ROUTE, tags=["Text Task"], description=DESCRIPTION)
18
  async def evaluate_text(request: TextEvaluationRequest):
19
  """
20
  Evaluate text classification for climate disinformation detection.
@@ -46,21 +58,21 @@ async def evaluate_text(request: TextEvaluationRequest):
46
 
47
  try:
48
  # Model configuration
49
- model_name = "distilbert-base-uncased" # Lighter model than MobileBERT
50
- BATCH_SIZE = 64 # Increased batch size
51
- MAX_LENGTH = 128 # Reduced sequence length
52
 
53
  # Initialize tokenizer and model
54
  tokenizer = AutoTokenizer.from_pretrained(model_name)
55
  model = AutoModelForSequenceClassification.from_pretrained(
56
  model_name,
57
  num_labels=8,
58
- problem_type="single_label_classification",
59
  )
60
 
61
- # Enable mixed precision training if available
62
  if torch.cuda.is_available():
63
- model = model.half() # Convert to FP16
64
 
65
  # Move model to device
66
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -71,14 +83,14 @@ async def evaluate_text(request: TextEvaluationRequest):
71
  test_texts = test_dataset["quote"]
72
  predictions = []
73
 
74
- # Process in efficient batches
75
  for i in range(0, len(test_texts), BATCH_SIZE):
76
  if torch.cuda.is_available():
77
  torch.cuda.empty_cache()
78
 
79
  batch_texts = test_texts[i:i + BATCH_SIZE]
80
 
81
- # Efficient tokenization
82
  inputs = tokenizer(
83
  batch_texts,
84
  padding=True,
@@ -87,18 +99,22 @@ async def evaluate_text(request: TextEvaluationRequest):
87
  return_tensors="pt"
88
  )
89
 
90
- # Move inputs to device efficiently
91
  inputs = {k: v.to(device) for k, v in inputs.items()}
92
 
93
- # Inference with optimizations
94
  with torch.no_grad(), torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
95
  outputs = model(**inputs)
96
  batch_preds = torch.argmax(outputs.logits, dim=1)
97
  predictions.extend(batch_preds.cpu().numpy())
98
 
99
- # Get true labels and compute accuracy
100
  true_labels = test_dataset['label']
 
 
101
  emissions_data = tracker.stop_task()
 
 
102
  accuracy = accuracy_score(true_labels, predictions)
103
 
104
  # Prepare results
@@ -123,4 +139,12 @@ async def evaluate_text(request: TextEvaluationRequest):
123
 
124
  except Exception as e:
125
  tracker.stop_task()
126
- raise e
 
 
 
 
 
 
 
 
 
1
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
2
+ from fastapi import FastAPI, APIRouter
3
+ from fastapi.middleware.cors import CORSMiddleware
4
  from datetime import datetime
5
  from datasets import load_dataset
6
  from sklearn.metrics import accuracy_score
 
10
  from .utils.evaluation import TextEvaluationRequest
11
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
12
 
13
+ # Initialize FastAPI app and router
14
+ app = FastAPI()
15
  router = APIRouter()
16
 
17
+ # Add CORS middleware
18
+ app.add_middleware(
19
+ CORSMiddleware,
20
+ allow_origins=["*"],
21
+ allow_credentials=True,
22
+ allow_methods=["*"],
23
+ allow_headers=["*"],
24
+ )
25
+
26
  DESCRIPTION = "Efficient Climate Disinformation Detection"
27
  ROUTE = "/text"
28
 
29
+ @router.post("/text", tags=["Text Task"], description=DESCRIPTION)
30
  async def evaluate_text(request: TextEvaluationRequest):
31
  """
32
  Evaluate text classification for climate disinformation detection.
 
58
 
59
  try:
60
  # Model configuration
61
+ model_name = "distilbert-base-uncased"
62
+ BATCH_SIZE = 64
63
+ MAX_LENGTH = 128
64
 
65
  # Initialize tokenizer and model
66
  tokenizer = AutoTokenizer.from_pretrained(model_name)
67
  model = AutoModelForSequenceClassification.from_pretrained(
68
  model_name,
69
  num_labels=8,
70
+ problem_type="single_label_classification"
71
  )
72
 
73
+ # Enable mixed precision if available
74
  if torch.cuda.is_available():
75
+ model = model.half()
76
 
77
  # Move model to device
78
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
83
  test_texts = test_dataset["quote"]
84
  predictions = []
85
 
86
+ # Process in batches
87
  for i in range(0, len(test_texts), BATCH_SIZE):
88
  if torch.cuda.is_available():
89
  torch.cuda.empty_cache()
90
 
91
  batch_texts = test_texts[i:i + BATCH_SIZE]
92
 
93
+ # Tokenize batch
94
  inputs = tokenizer(
95
  batch_texts,
96
  padding=True,
 
99
  return_tensors="pt"
100
  )
101
 
102
+ # Move inputs to device
103
  inputs = {k: v.to(device) for k, v in inputs.items()}
104
 
105
+ # Run inference
106
  with torch.no_grad(), torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
107
  outputs = model(**inputs)
108
  batch_preds = torch.argmax(outputs.logits, dim=1)
109
  predictions.extend(batch_preds.cpu().numpy())
110
 
111
+ # Get true labels
112
  true_labels = test_dataset['label']
113
+
114
+ # Stop tracking emissions
115
  emissions_data = tracker.stop_task()
116
+
117
+ # Calculate accuracy
118
  accuracy = accuracy_score(true_labels, predictions)
119
 
120
  # Prepare results
 
139
 
140
  except Exception as e:
141
  tracker.stop_task()
142
+ raise e
143
+
144
+ # Include the router
145
+ app.include_router(router)
146
+
147
+ # Add a health check endpoint
148
+ @app.get("/health")
149
+ async def health_check():
150
+ return {"status": "healthy"}