File size: 1,270 Bytes
6fe8523
 
 
 
 
 
47e3f8f
6fe8523
47e3f8f
6fe8523
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch

# Load from local checkpoint
# or whatever your checkpoint number is
model_id = "checkpoint-2391"
tokenizer = AutoTokenizer.from_pretrained(
    'huawei-noah/TinyBERT_General_4L_312D')  # Original tokenizer
model = AutoModelForSequenceClassification.from_pretrained(model_id)


def predict(text):
    # Tokenize and predict
    inputs = tokenizer(text,
                       truncation=True,
                       padding=True,
                       max_length=64,
                       return_tensors="pt")

    with torch.no_grad():
        outputs = model(**inputs)
        probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
        prediction = probs.argmax(-1).item()
        confidence = probs[0][prediction].item()
        return probs

    label_map = {0: 'Left', 1: 'Right', 2: 'Centrist'}
    return f"{label_map[prediction]} (Confidence: {confidence:.2%})"


# Create the interface
demo = gr.Interface(
    fn=predict,
    inputs=gr.Textbox(lines=4, placeholder="Enter text to analyze..."),
    outputs="text",
    title="Political Text Classifier",
    description="Classify political text as Left, Right, or Centrist"
)

demo.launch()