File size: 6,183 Bytes
7535af8 4578ac5 42d4264 764d4a1 7535af8 42d4264 7535af8 3ed80b2 f74c03b 7535af8 42d4264 7535af8 d950576 d5b6595 d950576 4578ac5 42d4264 bbf91a6 42d4264 4578ac5 d5b6595 4578ac5 7535af8 d950576 f1abc3d 14ff620 f1abc3d 7535af8 58aec17 7535af8 d950576 a83cca7 d950576 a98f9fe d950576 d5b6595 d950576 a83cca7 d950576 a83cca7 d950576 58aec17 d5b6595 d950576 a83cca7 7535af8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import gradio as gr
import torch
from transformers import AutoTokenizer, pipeline
from typing import Dict
# Custom models for zero-shot classification requiring trust_remote_code=True
CUSTOM_MODELS = [
"mjwong/gte-multilingual-base-xnli-anli"
]
# Available models for zero-shot classification
AVAILABLE_MODELS = [
"mjwong/multilingual-e5-large-instruct-xnli-anli",
"mjwong/multilingual-e5-base-xnli-anli",
"mjwong/multilingual-e5-large-xnli-anli",
"mjwong/drama-base-xnli-anli",
"mjwong/drama-large-xnli-anli",
"mjwong/mcontriever-msmarco-xnli",
"mjwong/mcontriever-xnli"
] + CUSTOM_MODELS
def classify_text(
model_name: str,
text: str,
labels: str,
multi_label: bool = False,
) -> Dict[str, float]:
"""
Classifies the input text into one of the provided labels using a zero-shot classification model.
Args:
model_name: The name of the Hugging Face model to use.
text: The input text to classify.
labels: A comma-separated string of candidate labels.
Returns:
Dict[str, float]: A dictionary mapping each label to its classification score.
"""
if not text.strip():
return "Error: Please enter some text to classify."
if not labels.strip():
return "Error: Please enter some labels to classify the text."
try:
# Set device: 0 if GPU available, else -1 for CPU
device = 0 if torch.cuda.is_available() else -1
if model_name in CUSTOM_MODELS:
tokenizer = AutoTokenizer.from_pretrained(model_name)
classifier = pipeline("zero-shot-classification", model=model_name, device=device, tokenizer=tokenizer, trust_remote_code=True)
else:
classifier = pipeline("zero-shot-classification", model=model_name, device=device)
labels_list = [label.strip() for label in labels.split(",")]
result = classifier(text, candidate_labels=labels_list, multi_label=multi_label)
return {label: score for label, score in zip(result["labels"], result["scores"])}
except Exception as _:
return "Error: An unexpected error occurred. Please try again later."
# Example Input with Mutually Exclusive Labels from News Articles
examples = [
[
"The government announced a new economic policy today aimed at reducing inflation and stabilizing the currency market.",
"economy, politics, finance, policy, inflation, government, currency"
],
[
"中国的科技公司在人工智能领域取得了重大突破,这可能会影响全球市场。",
"科技, 经济, 创新, 市场, 人工智能, 全球"
],
[
"นักวิจัยค้นพบวิธีใหม่ในการรักษาโรคมะเร็ง ซึ่งอาจช่วยชีวิตผู้ป่วยหลายล้านคนทั่วโลก",
"การแพทย์, วิทยาศาสตร์, นวัตกรรม, สุขภาพ, โรคมะเร็ง, การรักษา"
],
[
"La conférence des Nations Unies sur le climat a abouti à un nouvel accord pour réduire les émissions de carbone d'ici 2030.",
"environnement, climat, politique, énergie, carbone, écologie, ONU"
],
[
"सरकार ने आज एक नई आर्थिक नीति की घोषणा की, जिसका उद्देश्य मुद्रास्फीति को कम करना और मुद्रा बाजार को स्थिर करना है।",
"अर्थव्यवस्था, राजनीति, वित्त, नीति, मुद्रास्फीति, सरकार, मुद्रा"
]
]
# Define the Gradio interface
css = """
footer {display:none !important}
.output-markdown{display:none !important}
.gr-button-primary {
z-index: 14;
height: 43px;
width: 130px;
left: 0px;
top: 0px;
padding: 0px;
cursor: pointer !important;
background: none rgb(17, 20, 45) !important;
border: none !important;
text-align: center !important;
font-family: Poppins !important;
font-size: 14px !important;
font-weight: 500 !important;
color: rgb(255, 255, 255) !important;
line-height: 1 !important;
border-radius: 12px !important;
transition: box-shadow 200ms ease 0s, background 200ms ease 0s !important;
box-shadow: none !important;
}
.classify-button {
background: linear-gradient(90deg, yellow, orange) !important;
}
"""
# Initialize Gradio interface
with gr.Blocks(css=css) as iface:
gr.Markdown("# Zero-Shot Text Classifier")
gr.Markdown("Select a model, enter text, and a set of labels to classify the text using a zero-shot classification model.")
gr.Markdown("More than 10 languages are officially supported, including: English, Arabic, Bulgarian, German, Greek, Spanish, French, Hindi, Russian, Swahili, Thai, Turkish, Urdu, Vietnam and Chinese.")
with gr.Row():
# Dropdown to select a model
model_dropdown = gr.Dropdown(AVAILABLE_MODELS, label="Choose Model")
# Checkbox for multi-label classification
multi_label = gr.Checkbox(label="True", value=False, info="Check for multi-label classification, uncheck for single-label (multi-class).")
# Input fields for text and labels
with gr.Row():
text_input = gr.Textbox(label="Enter Text", placeholder="Type or paste text here...")
label_input = gr.Textbox(label="Enter Labels (comma-separated)", placeholder="e.g., sports, politics, technology")
# Output display
output_label = gr.Label(label="Classification Scores")
# Classification button
submit_button = gr.Button("Classify", elem_classes=["classify-button"])
submit_button.click(fn=classify_text, inputs=[model_dropdown, text_input, label_input, multi_label], outputs=output_label)
# Example input/output pairs
gr.Examples(examples, inputs=[text_input, label_input])
# Launch the app
if __name__ == "__main__":
iface.launch()
|