sitammeur commited on
Commit
5088775
·
verified ·
1 Parent(s): 099ab82

Update classytext/classifier/predict.py

Browse files
Files changed (1) hide show
  1. classytext/classifier/predict.py +61 -56
classytext/classifier/predict.py CHANGED
@@ -1,56 +1,61 @@
1
- # Necessary imports
2
- import sys
3
- from typing import Dict
4
- from classytext.logger import logging
5
- from classytext.exception import CustomExceptionHandling
6
- from transformers import pipeline
7
-
8
-
9
- # Load the zero-shot classification model
10
- classifier = pipeline(
11
- "zero-shot-classification",
12
- model="MoritzLaurer/ModernBERT-large-zeroshot-v2.0",
13
- )
14
-
15
-
16
- def ZeroShotTextClassification(
17
- text_input: str, candidate_labels: str, multi_label: bool
18
- ) -> Dict[str, float]:
19
- """
20
- Performs zero-shot classification on the given text input.
21
-
22
- Args:
23
- - text_input: The input text to classify.
24
- - candidate_labels: A comma-separated string of candidate labels.
25
- - multi_label: A boolean indicating whether to allow the model to choose multiple classes.
26
-
27
- Returns:
28
- Dictionary containing label-score pairs.
29
- """
30
- try:
31
- # Split and clean the candidate labels
32
- labels = [label.strip() for label in candidate_labels.split(",")]
33
-
34
- # Log the classification attempt
35
- logging.info(f"Attempting classification with {len(labels)} labels")
36
-
37
- # Perform zero-shot classification
38
- hypothesis_template = "This text is about {}"
39
- prediction = classifier(
40
- text_input,
41
- labels,
42
- hypothesis_template=hypothesis_template,
43
- multi_label=multi_label,
44
- )
45
-
46
- # Return the classification results
47
- logging.info("Classification completed successfully")
48
- return {
49
- prediction["labels"][i]: prediction["scores"][i]
50
- for i in range(len(prediction["labels"]))
51
- }
52
-
53
- # Handle exceptions that may occur during the process
54
- except Exception as e:
55
- # Custom exception handling
56
- raise CustomExceptionHandling(e, sys) from e
 
 
 
 
 
 
1
+ # Necessary imports
2
+ import sys
3
+ from typing import Dict
4
+ from classytext.logger import logging
5
+ from classytext.exception import CustomExceptionHandling
6
+ from transformers import pipeline
7
+ import gradio as gr
8
+
9
+
10
+ # Load the zero-shot classification model
11
+ classifier = pipeline(
12
+ "zero-shot-classification",
13
+ model="MoritzLaurer/ModernBERT-large-zeroshot-v2.0",
14
+ )
15
+
16
+
17
+ def ZeroShotTextClassification(
18
+ text_input: str, candidate_labels: str, multi_label: bool
19
+ ) -> Dict[str, float]:
20
+ """
21
+ Performs zero-shot classification on the given text input.
22
+
23
+ Args:
24
+ - text_input: The input text to classify.
25
+ - candidate_labels: A comma-separated string of candidate labels.
26
+ - multi_label: A boolean indicating whether to allow the model to choose multiple classes.
27
+
28
+ Returns:
29
+ Dictionary containing label-score pairs.
30
+ """
31
+ try:
32
+ # Check if the input and candidate labels are valid
33
+ if not text_input or not candidate_labels:
34
+ gr.Warning("Please provide valid input and candidate labels")
35
+
36
+ # Split and clean the candidate labels
37
+ labels = [label.strip() for label in candidate_labels.split(",")]
38
+
39
+ # Log the classification attempt
40
+ logging.info(f"Attempting classification with {len(labels)} labels")
41
+
42
+ # Perform zero-shot classification
43
+ hypothesis_template = "This text is about {}"
44
+ prediction = classifier(
45
+ text_input,
46
+ labels,
47
+ hypothesis_template=hypothesis_template,
48
+ multi_label=multi_label,
49
+ )
50
+
51
+ # Return the classification results
52
+ logging.info("Classification completed successfully")
53
+ return {
54
+ prediction["labels"][i]: prediction["scores"][i]
55
+ for i in range(len(prediction["labels"]))
56
+ }
57
+
58
+ # Handle exceptions that may occur during the process
59
+ except Exception as e:
60
+ # Custom exception handling
61
+ raise CustomExceptionHandling(e, sys) from e