jayebaku commited on
Commit
bc1afba
·
verified ·
1 Parent(s): 2ff1f7b

Update classifier.py

Browse files

updated function to include threshold

Files changed (1) hide show
  1. classifier.py +8 -3
classifier.py CHANGED
@@ -2,7 +2,7 @@ import spaces
2
  from transformers import pipeline
3
 
4
  #@spaces.GPU(duration=60)
5
- def classify(tweet, event_model, hftoken):
6
 
7
  # event type prediction
8
  event_predictor = pipeline(task="text-classification", model=event_model,
@@ -13,7 +13,12 @@ def classify(tweet, event_model, hftoken):
13
  prediction = event_predictor(tweet, **tokenizer_kwargs)[0]
14
 
15
  results["text"] = tweet
16
- results["event"] = prediction["label"]
17
- results["score"] = prediction["score"]
 
 
 
 
 
18
 
19
  return results
 
2
  from transformers import pipeline
3
 
4
  #@spaces.GPU(duration=60)
5
+ def classify(tweet, event_model, hftoken, threshold):
6
 
7
  # event type prediction
8
  event_predictor = pipeline(task="text-classification", model=event_model,
 
13
  prediction = event_predictor(tweet, **tokenizer_kwargs)[0]
14
 
15
  results["text"] = tweet
16
+
17
+ if prediction["label"] != "none" and round(prediction["score"], 2) <= threshold:
18
+ results["event"] = "none"
19
+ results["score"] = prediction["score"]
20
+ else:
21
+ results["event"] = prediction["label"]
22
+ results["score"] = prediction["score"]
23
 
24
  return results