mjwong commited on
Commit
7535af8
·
verified ·
1 Parent(s): 6c3db23

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -0
app.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import pipeline
3
+
4
+ # Available models for zero-shot classification
5
+ AVAILABLE_MODELS = [
6
+ "mjwong/multilingual-e5-large-instruct-xnli-anli",
7
+ "mjwong/multilingual-e5-base-xnli-anli",
8
+ "mjwong/multilingual-e5-large-xnli-anli",
9
+ "mjwong/mcontriever-msmarco-xnli",
10
+ "mjwong/mcontriever-xnli"
11
+ ]
12
+
13
+ def classify_text(model_name, text, labels):
14
+ classifier = pipeline("zero-shot-classification", model=model_name)
15
+ labels_list = [label.strip() for label in labels.split(",")]
16
+ result = classifier(text, candidate_labels=labels_list)
17
+ return {label: score for label, score in zip(result["labels"], result["scores"])}
18
+
19
+ # Example Input
20
+ examples = [["One day I will see the world", "travel, live, die, future"]]
21
+
22
+ # Define the Gradio interface
23
+ css = """
24
+ footer {display:none !important}
25
+ .output-markdown{display:none !important}
26
+ .gr-button-primary {
27
+ z-index: 14;
28
+ height: 43px;
29
+ width: 130px;
30
+ left: 0px;
31
+ top: 0px;
32
+ padding: 0px;
33
+ cursor: pointer !important;
34
+ background: none rgb(17, 20, 45) !important;
35
+ border: none !important;
36
+ text-align: center !important;
37
+ font-family: Poppins !important;
38
+ font-size: 14px !important;
39
+ font-weight: 500 !important;
40
+ color: rgb(255, 255, 255) !important;
41
+ line-height: 1 !important;
42
+ border-radius: 12px !important;
43
+ transition: box-shadow 200ms ease 0s, background 200ms ease 0s !important;
44
+ box-shadow: none !important;
45
+ }
46
+ .gr-button-primary:hover{
47
+ z-index: 14;
48
+ height: 43px;
49
+ width: 130px;
50
+ left: 0px;
51
+ top: 0px;
52
+ padding: 0px;
53
+ cursor: pointer !important;
54
+ background: none rgb(66, 133, 244) !important;
55
+ border: none !important;
56
+ text-align: center !important;
57
+ font-family: Poppins !important;
58
+ font-size: 14px !important;
59
+ font-weight: 500 !important;
60
+ color: rgb(255, 255, 255) !important;
61
+ line-height: 1 !important;
62
+ border-radius: 12px !important;
63
+ transition: box-shadow 200ms ease 0s, background 200ms ease 0s !important;
64
+ box-shadow: rgb(0 0 0 / 23%) 0px 1px 7px 0px !important;
65
+ }
66
+ """
67
+
68
+ iface = gr.Interface(
69
+ fn=classify_text,
70
+ inputs=[
71
+ gr.Dropdown(AVAILABLE_MODELS, label="Choose Model"),
72
+ gr.Textbox(label="Enter Text", placeholder="Type or paste text here..."),
73
+ gr.Textbox(label="Enter Labels (comma-separated)", placeholder="e.g., sports, politics, technology")
74
+ ],
75
+ outputs=gr.Label(label="Classification Scores"),
76
+ title="Zero-Shot Text Classifier",
77
+ description="Select a model, enter text, and a set of labels to classify it using a zero-shot classification model.",
78
+ examples=examples,
79
+ css=css
80
+ )
81
+
82
+ # Launch the app
83
+ if __name__ == "__main__":
84
+ iface.launch()