mjwong commited on
Commit
d5b6595
·
verified ·
1 Parent(s): 4578ac5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -5
app.py CHANGED
@@ -15,7 +15,8 @@ AVAILABLE_MODELS = [
15
  def classify_text(
16
  model_name: str,
17
  text: str,
18
- labels: str
 
19
  ) -> Dict[str, float]:
20
  """
21
  Classifies the input text into one of the provided labels using a zero-shot classification model.
@@ -39,7 +40,7 @@ def classify_text(
39
 
40
  classifier = pipeline("zero-shot-classification", model=model_name, device=device)
41
  labels_list = [label.strip() for label in labels.split(",")]
42
- result = classifier(text, candidate_labels=labels_list)
43
  return {label: score for label, score in zip(result["labels"], result["scores"])}
44
  except Exception as _:
45
  return "Error: An unexpected error occurred. Please try again later."
@@ -102,8 +103,11 @@ with gr.Blocks(css=css) as iface:
102
  gr.Markdown("# Zero-Shot Text Classifier")
103
  gr.Markdown("Select a model, enter text, and a set of labels to classify the text using a zero-shot classification model.")
104
 
105
- # Dropdown to select a model
106
- model_dropdown = gr.Dropdown(AVAILABLE_MODELS, label="Choose Model")
 
 
 
107
 
108
  # Input fields for text and labels
109
  with gr.Row():
@@ -115,7 +119,7 @@ with gr.Blocks(css=css) as iface:
115
 
116
  # Classification button
117
  submit_button = gr.Button("Classify", elem_classes=["classify-button"])
118
- submit_button.click(fn=classify_text, inputs=[model_dropdown, text_input, label_input], outputs=output_label)
119
 
120
  # Example input/output pairs
121
  gr.Examples(examples, inputs=[text_input, label_input])
 
15
  def classify_text(
16
  model_name: str,
17
  text: str,
18
+ labels: str,
19
+ multi_label: bool = False,
20
  ) -> Dict[str, float]:
21
  """
22
  Classifies the input text into one of the provided labels using a zero-shot classification model.
 
40
 
41
  classifier = pipeline("zero-shot-classification", model=model_name, device=device)
42
  labels_list = [label.strip() for label in labels.split(",")]
43
+ result = classifier(text, candidate_labels=labels_list, multi_label=multi_label)
44
  return {label: score for label, score in zip(result["labels"], result["scores"])}
45
  except Exception as _:
46
  return "Error: An unexpected error occurred. Please try again later."
 
103
  gr.Markdown("# Zero-Shot Text Classifier")
104
  gr.Markdown("Select a model, enter text, and a set of labels to classify the text using a zero-shot classification model.")
105
 
106
+ with gr.Row():
107
+ # Dropdown to select a model
108
+ model_dropdown = gr.Dropdown(AVAILABLE_MODELS, label="Choose Model")
109
+ # Checkbox for multi-label classification
110
+ multi_label = gr.Checkbox(label="True", value=False, info="Check for multi-label classification, uncheck for single-label (multi-class).")
111
 
112
  # Input fields for text and labels
113
  with gr.Row():
 
119
 
120
  # Classification button
121
  submit_button = gr.Button("Classify", elem_classes=["classify-button"])
122
+ submit_button.click(fn=classify_text, inputs=[model_dropdown, text_input, label_input, multi_label], outputs=output_label)
123
 
124
  # Example input/output pairs
125
  gr.Examples(examples, inputs=[text_input, label_input])