sitammeur commited on
Commit
df432d6
·
verified ·
1 Parent(s): b59a012

Update src/modernbert/classifier.py

Browse files
Files changed (1) hide show
  1. src/modernbert/classifier.py +65 -61
src/modernbert/classifier.py CHANGED
@@ -1,61 +1,65 @@
1
- # Necessary imports
2
- import sys
3
- from typing import Dict
4
- from src.logger import logging
5
- from src.exception import CustomExceptionHandling
6
- from transformers import pipeline
7
- import gradio as gr
8
-
9
-
10
- # Load the zero-shot classification model
11
- classifier = pipeline(
12
- "zero-shot-classification",
13
- model="MoritzLaurer/ModernBERT-large-zeroshot-v2.0",
14
- )
15
-
16
-
17
- def ZeroShotTextClassification(
18
- text_input: str, candidate_labels: str, multi_label: bool
19
- ) -> Dict[str, float]:
20
- """
21
- Performs zero-shot classification on the given text input.
22
-
23
- Args:
24
- - text_input: The input text to classify.
25
- - candidate_labels: A comma-separated string of candidate labels.
26
- - multi_label: A boolean indicating whether to allow the model to choose multiple classes.
27
-
28
- Returns:
29
- Dictionary containing label-score pairs.
30
- """
31
- try:
32
- # Check if the input and candidate labels are valid
33
- if not text_input or not candidate_labels:
34
- gr.Warning("Please provide valid input and candidate labels")
35
-
36
- # Split and clean the candidate labels
37
- labels = [label.strip() for label in candidate_labels.split(",")]
38
-
39
- # Log the classification attempt
40
- logging.info(f"Attempting classification with {len(labels)} labels")
41
-
42
- # Perform zero-shot classification
43
- hypothesis_template = "This text is about {}"
44
- prediction = classifier(
45
- text_input,
46
- labels,
47
- hypothesis_template=hypothesis_template,
48
- multi_label=multi_label,
49
- )
50
-
51
- # Return the classification results
52
- logging.info("Classification completed successfully")
53
- return {
54
- prediction["labels"][i]: prediction["scores"][i]
55
- for i in range(len(prediction["labels"]))
56
- }
57
-
58
- # Handle exceptions that may occur during the process
59
- except Exception as e:
60
- # Custom exception handling
61
- raise CustomExceptionHandling(e, sys) from e
 
 
 
 
 
1
+ # Necessary imports
2
+ import sys
3
+ from typing import Dict
4
+ import torch
5
+ from transformers import pipeline
6
+ import gradio as gr
7
+
8
+ # Local imports
9
+ from src.logger import logging
10
+ from src.exception import CustomExceptionHandling
11
+
12
+
13
+ # Load the zero-shot classification model
14
+ classifier = pipeline(
15
+ "zero-shot-classification",
16
+ model="MoritzLaurer/ModernBERT-large-zeroshot-v2.0",
17
+ torch_dtype=torch.bfloat16,
18
+ )
19
+
20
+
21
+ def ZeroShotTextClassification(
22
+ text_input: str, candidate_labels: str, multi_label: bool
23
+ ) -> Dict[str, float]:
24
+ """
25
+ Performs zero-shot classification on the given text input.
26
+
27
+ Args:
28
+ - text_input: The input text to classify.
29
+ - candidate_labels: A comma-separated string of candidate labels.
30
+ - multi_label: A boolean indicating whether to allow the model to choose multiple classes.
31
+
32
+ Returns:
33
+ Dictionary containing label-score pairs.
34
+ """
35
+ try:
36
+ # Check if the input and candidate labels are valid
37
+ if not text_input or not candidate_labels:
38
+ gr.Warning("Please provide valid input and candidate labels")
39
+
40
+ # Split and clean the candidate labels
41
+ labels = [label.strip() for label in candidate_labels.split(",")]
42
+
43
+ # Log the classification attempt
44
+ logging.info(f"Attempting classification with {len(labels)} labels")
45
+
46
+ # Perform zero-shot classification
47
+ hypothesis_template = "This text is about {}"
48
+ prediction = classifier(
49
+ text_input,
50
+ labels,
51
+ hypothesis_template=hypothesis_template,
52
+ multi_label=multi_label,
53
+ )
54
+
55
+ # Return the classification results
56
+ logging.info("Classification completed successfully")
57
+ return {
58
+ prediction["labels"][i]: prediction["scores"][i]
59
+ for i in range(len(prediction["labels"]))
60
+ }
61
+
62
+ # Handle exceptions that may occur during the process
63
+ except Exception as e:
64
+ # Custom exception handling
65
+ raise CustomExceptionHandling(e, sys) from e