File size: 6,183 Bytes
7535af8
4578ac5
42d4264
764d4a1
7535af8
42d4264
 
 
 
 
7535af8
 
 
 
 
3ed80b2
f74c03b
7535af8
 
42d4264
7535af8
d950576
 
 
d5b6595
 
d950576
 
 
 
 
 
 
 
 
 
 
 
4578ac5
 
 
 
 
 
 
 
 
42d4264
 
bbf91a6
42d4264
 
 
4578ac5
d5b6595
4578ac5
 
 
7535af8
d950576
f1abc3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14ff620
 
 
 
f1abc3d
 
7535af8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58aec17
 
7535af8
 
 
d950576
a83cca7
 
d950576
a98f9fe
d950576
d5b6595
 
 
 
 
d950576
 
a83cca7
 
 
d950576
 
a83cca7
d950576
 
58aec17
d5b6595
d950576
 
a83cca7
7535af8
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import gradio as gr
import torch
from transformers import AutoTokenizer, pipeline
from typing import Dict

# Custom models for zero-shot classification requiring trust_remote_code=True
CUSTOM_MODELS = [
    "mjwong/gte-multilingual-base-xnli-anli"
]

# Available models for zero-shot classification
AVAILABLE_MODELS = [
    "mjwong/multilingual-e5-large-instruct-xnli-anli",
    "mjwong/multilingual-e5-base-xnli-anli",
    "mjwong/multilingual-e5-large-xnli-anli",
    "mjwong/drama-base-xnli-anli",
    "mjwong/drama-large-xnli-anli",
    "mjwong/mcontriever-msmarco-xnli",
    "mjwong/mcontriever-xnli"
] + CUSTOM_MODELS

def classify_text(
        model_name: str, 
        text: str, 
        labels: str,
        multi_label: bool = False,
    ) -> Dict[str, float]:
    """
    Classifies the input text into one of the provided labels using a zero-shot classification model.
    
    Args:
        model_name: The name of the Hugging Face model to use.
        text: The input text to classify.
        labels: A comma-separated string of candidate labels.
    
    Returns:
        Dict[str, float]: A dictionary mapping each label to its classification score.
    """
    if not text.strip():
        return "Error: Please enter some text to classify."
    if not labels.strip():
        return "Error: Please enter some labels to classify the text."
    
    try:
        # Set device: 0 if GPU available, else -1 for CPU
        device = 0 if torch.cuda.is_available() else -1

        if model_name in CUSTOM_MODELS:
            tokenizer = AutoTokenizer.from_pretrained(model_name)
            classifier = pipeline("zero-shot-classification", model=model_name, device=device, tokenizer=tokenizer, trust_remote_code=True)
        else:
            classifier = pipeline("zero-shot-classification", model=model_name, device=device)

        labels_list = [label.strip() for label in labels.split(",")]
        result = classifier(text, candidate_labels=labels_list, multi_label=multi_label)
        return {label: score for label, score in zip(result["labels"], result["scores"])}
    except Exception as _:
        return "Error: An unexpected error occurred. Please try again later."

# Example Input with Mutually Exclusive Labels from News Articles
examples = [
    [
        "The government announced a new economic policy today aimed at reducing inflation and stabilizing the currency market.", 
        "economy, politics, finance, policy, inflation, government, currency"
    ],
    [
        "中国的科技公司在人工智能领域取得了重大突破,这可能会影响全球市场。", 
        "科技, 经济, 创新, 市场, 人工智能, 全球"
    ],
    [
        "นักวิจัยค้นพบวิธีใหม่ในการรักษาโรคมะเร็ง ซึ่งอาจช่วยชีวิตผู้ป่วยหลายล้านคนทั่วโลก", 
        "การแพทย์, วิทยาศาสตร์, นวัตกรรม, สุขภาพ, โรคมะเร็ง, การรักษา"
    ],
    [
        "La conférence des Nations Unies sur le climat a abouti à un nouvel accord pour réduire les émissions de carbone d'ici 2030.", 
        "environnement, climat, politique, énergie, carbone, écologie, ONU"
    ],
    [
        "सरकार ने आज एक नई आर्थिक नीति की घोषणा की, जिसका उद्देश्य मुद्रास्फीति को कम करना और मुद्रा बाजार को स्थिर करना है।", 
        "अर्थव्यवस्था, राजनीति, वित्त, नीति, मुद्रास्फीति, सरकार, मुद्रा"
    ]
]

# Define the Gradio interface
css = """
footer {display:none !important}
.output-markdown{display:none !important}
.gr-button-primary {
    z-index: 14;
    height: 43px;
    width: 130px;
    left: 0px;
    top: 0px;
    padding: 0px;
    cursor: pointer !important; 
    background: none rgb(17, 20, 45) !important;
    border: none !important;
    text-align: center !important;
    font-family: Poppins !important;
    font-size: 14px !important;
    font-weight: 500 !important;
    color: rgb(255, 255, 255) !important;
    line-height: 1 !important;
    border-radius: 12px !important;
    transition: box-shadow 200ms ease 0s, background 200ms ease 0s !important;
    box-shadow: none !important;
}
.classify-button {
    background: linear-gradient(90deg, yellow, orange) !important;
}
"""

# Initialize Gradio interface
with gr.Blocks(css=css) as iface:
    gr.Markdown("# Zero-Shot Text Classifier")
    gr.Markdown("Select a model, enter text, and a set of labels to classify the text using a zero-shot classification model.")
    gr.Markdown("More than 10 languages are officially supported, including: English, Arabic, Bulgarian, German, Greek, Spanish, French, Hindi, Russian, Swahili, Thai, Turkish, Urdu, Vietnam and Chinese.")

    with gr.Row():
        # Dropdown to select a model
        model_dropdown = gr.Dropdown(AVAILABLE_MODELS, label="Choose Model")
        # Checkbox for multi-label classification
        multi_label = gr.Checkbox(label="True", value=False, info="Check for multi-label classification, uncheck for single-label (multi-class).")

    # Input fields for text and labels
    with gr.Row():
        text_input = gr.Textbox(label="Enter Text", placeholder="Type or paste text here...")
        label_input = gr.Textbox(label="Enter Labels (comma-separated)", placeholder="e.g., sports, politics, technology")

    # Output display
    output_label = gr.Label(label="Classification Scores")

    # Classification button
    submit_button = gr.Button("Classify", elem_classes=["classify-button"])
    submit_button.click(fn=classify_text, inputs=[model_dropdown, text_input, label_input, multi_label], outputs=output_label)

    # Example input/output pairs
    gr.Examples(examples, inputs=[text_input, label_input])

# Launch the app
if __name__ == "__main__":
    iface.launch()