Zen0 commited on
Commit
aee4009
·
verified ·
1 Parent(s): d778205

Update tasks/text.py

Browse files
Files changed (1) hide show
  1. tasks/text.py +28 -17
tasks/text.py CHANGED
@@ -3,13 +3,12 @@ from fastapi import APIRouter
3
  from datetime import datetime
4
  from datasets import load_dataset
5
  from sklearn.metrics import accuracy_score
 
 
6
 
7
  from .utils.evaluation import TextEvaluationRequest
8
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
9
 
10
- import numpy as np
11
- import torch
12
-
13
  router = APIRouter()
14
 
15
  DESCRIPTION = "FrugalDisinfoHunter Model"
@@ -51,46 +50,58 @@ async def evaluate_text(request: TextEvaluationRequest):
51
 
52
  try:
53
  # Model configuration
54
- model_name = "Zen0/FrugalDisinfoHunter" # Model path
55
- tokenizer_name = "google/mobilebert-uncased" # Base MobileBERT tokenizer
56
- BATCH_SIZE = 32 # Batch size for efficient processing
57
- MAX_LENGTH = 128 # Maximum sequence length
58
 
59
- # Initialize model and tokenizer
 
60
  model = AutoModelForSequenceClassification.from_pretrained(
61
  model_name,
62
  num_labels=8,
63
- output_hidden_states=True,
64
  problem_type="single_label_classification"
65
  )
66
- tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
 
 
 
 
 
 
 
 
67
 
68
  # Move model to appropriate device
69
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
70
  model = model.to(device)
71
- model.eval() # Set model to evaluation mode
72
 
73
- # Get test texts
74
  test_texts = test_dataset["quote"]
75
  predictions = []
76
 
77
  # Process in batches
78
  for i in range(0, len(test_texts), BATCH_SIZE):
 
 
 
 
79
  batch_texts = test_texts[i:i + BATCH_SIZE]
80
 
81
- # Tokenize batch
82
  inputs = tokenizer(
83
  batch_texts,
84
  padding=True,
85
  truncation=True,
86
- return_tensors="pt",
87
- max_length=MAX_LENGTH
88
  )
89
 
90
  # Move inputs to device
91
- inputs = {key: val.to(device) for key, val in inputs.items()}
92
 
93
- # Run inference
94
  with torch.no_grad():
95
  outputs = model(**inputs)
96
  batch_preds = torch.argmax(outputs.logits, dim=1)
 
3
  from datetime import datetime
4
  from datasets import load_dataset
5
  from sklearn.metrics import accuracy_score
6
+ import torch
7
+ import numpy as np
8
 
9
  from .utils.evaluation import TextEvaluationRequest
10
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
11
 
 
 
 
12
  router = APIRouter()
13
 
14
  DESCRIPTION = "FrugalDisinfoHunter Model"
 
50
 
51
  try:
52
  # Model configuration
53
+ model_name = "google/mobilebert-uncased" # Base model
54
+ local_weights = "model/model.pt" # Path to our trained weights
55
+ BATCH_SIZE = 32
56
+ MAX_LENGTH = 256 # Increased from 128
57
 
58
+ # Initialize tokenizer and model
59
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
60
  model = AutoModelForSequenceClassification.from_pretrained(
61
  model_name,
62
  num_labels=8,
 
63
  problem_type="single_label_classification"
64
  )
65
+
66
+ # Load our trained weights
67
+ try:
68
+ state_dict = torch.load(local_weights, map_location='cpu')
69
+ model.load_state_dict(state_dict)
70
+ except Exception as e:
71
+ print(f"Error loading weights: {e}")
72
+ # Continue with base model if weights fail to load
73
+ pass
74
 
75
  # Move model to appropriate device
76
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
77
  model = model.to(device)
78
+ model.eval() # Set to evaluation mode
79
 
80
+ # Get test texts and process in batches
81
  test_texts = test_dataset["quote"]
82
  predictions = []
83
 
84
  # Process in batches
85
  for i in range(0, len(test_texts), BATCH_SIZE):
86
+ # Clear CUDA cache if using GPU
87
+ if torch.cuda.is_available():
88
+ torch.cuda.empty_cache()
89
+
90
  batch_texts = test_texts[i:i + BATCH_SIZE]
91
 
92
+ # Tokenize with padding and attention masks
93
  inputs = tokenizer(
94
  batch_texts,
95
  padding=True,
96
  truncation=True,
97
+ max_length=MAX_LENGTH,
98
+ return_tensors="pt"
99
  )
100
 
101
  # Move inputs to device
102
+ inputs = {k: v.to(device) for k, v in inputs.items()}
103
 
104
+ # Run inference with no gradient computation
105
  with torch.no_grad():
106
  outputs = model(**inputs)
107
  batch_preds = torch.argmax(outputs.logits, dim=1)