Update app.py
Browse files
app.py
CHANGED
@@ -15,7 +15,8 @@ AVAILABLE_MODELS = [
|
|
15 |
def classify_text(
|
16 |
model_name: str,
|
17 |
text: str,
|
18 |
-
labels: str
|
|
|
19 |
) -> Dict[str, float]:
|
20 |
"""
|
21 |
Classifies the input text into one of the provided labels using a zero-shot classification model.
|
@@ -39,7 +40,7 @@ def classify_text(
|
|
39 |
|
40 |
classifier = pipeline("zero-shot-classification", model=model_name, device=device)
|
41 |
labels_list = [label.strip() for label in labels.split(",")]
|
42 |
-
result = classifier(text, candidate_labels=labels_list)
|
43 |
return {label: score for label, score in zip(result["labels"], result["scores"])}
|
44 |
except Exception as _:
|
45 |
return "Error: An unexpected error occurred. Please try again later."
|
@@ -102,8 +103,11 @@ with gr.Blocks(css=css) as iface:
|
|
102 |
gr.Markdown("# Zero-Shot Text Classifier")
|
103 |
gr.Markdown("Select a model, enter text, and a set of labels to classify the text using a zero-shot classification model.")
|
104 |
|
105 |
-
|
106 |
-
|
|
|
|
|
|
|
107 |
|
108 |
# Input fields for text and labels
|
109 |
with gr.Row():
|
@@ -115,7 +119,7 @@ with gr.Blocks(css=css) as iface:
|
|
115 |
|
116 |
# Classification button
|
117 |
submit_button = gr.Button("Classify", elem_classes=["classify-button"])
|
118 |
-
submit_button.click(fn=classify_text, inputs=[model_dropdown, text_input, label_input], outputs=output_label)
|
119 |
|
120 |
# Example input/output pairs
|
121 |
gr.Examples(examples, inputs=[text_input, label_input])
|
|
|
15 |
def classify_text(
|
16 |
model_name: str,
|
17 |
text: str,
|
18 |
+
labels: str,
|
19 |
+
multi_label: bool = False,
|
20 |
) -> Dict[str, float]:
|
21 |
"""
|
22 |
Classifies the input text into one of the provided labels using a zero-shot classification model.
|
|
|
40 |
|
41 |
classifier = pipeline("zero-shot-classification", model=model_name, device=device)
|
42 |
labels_list = [label.strip() for label in labels.split(",")]
|
43 |
+
result = classifier(text, candidate_labels=labels_list, multi_label=multi_label)
|
44 |
return {label: score for label, score in zip(result["labels"], result["scores"])}
|
45 |
except Exception as _:
|
46 |
return "Error: An unexpected error occurred. Please try again later."
|
|
|
103 |
gr.Markdown("# Zero-Shot Text Classifier")
|
104 |
gr.Markdown("Select a model, enter text, and a set of labels to classify the text using a zero-shot classification model.")
|
105 |
|
106 |
+
with gr.Row():
|
107 |
+
# Dropdown to select a model
|
108 |
+
model_dropdown = gr.Dropdown(AVAILABLE_MODELS, label="Choose Model")
|
109 |
+
# Checkbox for multi-label classification
|
110 |
+
multi_label = gr.Checkbox(label="True", value=False, info="Check for multi-label classification, uncheck for single-label (multi-class).")
|
111 |
|
112 |
# Input fields for text and labels
|
113 |
with gr.Row():
|
|
|
119 |
|
120 |
# Classification button
|
121 |
submit_button = gr.Button("Classify", elem_classes=["classify-button"])
|
122 |
+
submit_button.click(fn=classify_text, inputs=[model_dropdown, text_input, label_input, multi_label], outputs=output_label)
|
123 |
|
124 |
# Example input/output pairs
|
125 |
gr.Examples(examples, inputs=[text_input, label_input])
|