File size: 3,700 Bytes
9282ee1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0de37c8
9282ee1
 
 
 
 
0de37c8
 
 
 
 
 
 
 
 
9282ee1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99a265f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import torch
import gradio as gr
import numpy as np
import pickle
import pandas as pd
from model import VotePredictor

from transformers import AutoTokenizer, AutoModel


# === Vectorizer wrapper (replaces sentence-transformers) ===
class BertVectorizer:
    def __init__(self, model_name="sentence-transformers/all-MiniLM-L6-v2"):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name)
        self.model.eval()

    def encode(self, text):
        with torch.no_grad():
            inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
            outputs = self.model(**inputs)
            cls_embedding = outputs.last_hidden_state[:, 0, :]
            return cls_embedding.squeeze().numpy()


# === Load Models ===
main_model = VotePredictor(country_count=193)
main_model.load_state_dict(torch.load("vote_predictor_epoch27.pt", map_location="cpu"))
main_model.eval()

problem_model = VotePredictor(country_count=46)
problem_model.load_state_dict(torch.load("problem_country_model.pt", map_location="cpu"))
problem_model.eval()

# === Load Encoder ===
with open("country_encoder.pkl", "rb") as f:
    country_encoder = pickle.load(f)

# === Initialize Vectorizer ===
vectorizer = BertVectorizer()

# === List of problem countries ===
problem_countries = [
    'SURINAME', 'TURKMENISTAN', 'MARSHALL ISLANDS', 'MYANMAR', 'GABON',
    'CENTRAL AFRICAN REPUBLIC', 'ISRAEL', 'REPUBLIC OF THE CONGO', 'LIBERIA',
    'SOMALIA', 'CANADA', "LAO PEOPLE'S DEMOCRATIC REPUBLIC", 'TUVALU',
    'DEMOCRATIC REPUBLIC OF THE CONGO', 'MONTENEGRO', 'VANUATU', 'UNITED STATES',
    'TÜRKİYE', 'SEYCHELLES', 'SERBIA', 'CABO VERDE',
    'VENEZUELA (BOLIVARIAN REPUBLIC OF)', 'KIRIBATI', 'IRAN (ISLAMIC REPUBLIC OF)',
    'SOUTH SUDAN', 'ALBANIA', 'CZECHIA', 'DOMINICA', 'SAO TOME AND PRINCIPE',
    'ESWATINI', 'CHAD', 'EQUATORIAL GUINEA', 'GAMBIA', 'LIBYA',
    "CÔTE D'IVOIRE", 'SAINT CHRISTOPHER AND NEVIS', 'RWANDA', 'TONGA', 'NIGER',
    'MICRONESIA (FEDERATED STATES OF)', 'SYRIAN ARAB REPUBLIC', 'NAURU',
    'PALAU', 'NORTH MACEDONIA', 'NETHERLANDS', 'BOLIVIA (PLURINATIONAL STATE OF)'
]


# === Prediction Function ===
def predict_votes(resolution_text):
    vec = vectorizer.encode(resolution_text)
    x_tensor = torch.tensor(vec, dtype=torch.float32).unsqueeze(0)

    countries = []
    votes = []

    for country in country_encoder.classes_:
        is_problem = country in problem_countries
        model = problem_model if is_problem else main_model

        if is_problem:
            problem_index = problem_countries.index(country)  # 0–45
            c_tensor = torch.tensor([problem_index], dtype=torch.long)
        else:
            country_id = country_encoder.transform([country])[0]  # 0–192
            c_tensor = torch.tensor([country_id], dtype=torch.long)

        with torch.no_grad():
            logit = model(x_tensor, c_tensor).squeeze()
            prob = torch.sigmoid(logit).item()
            vote = "✅ Yes" if prob > 0.5 else "❌ Not Yes"

        countries.append(country)
        votes.append(vote)

    df = pd.DataFrame({
        "Country": countries,
        "Vote": votes
    }).sort_values("Country")

    return df


# === Interface ===
iface = gr.Interface(
    fn=predict_votes,
    inputs=gr.Textbox(lines=15, label="Paste UN Resolution Text Here"),
    outputs=gr.Dataframe(label="Predicted Votes by Country"),
    title="UN Resolution Vote Predictor",
    description="Predicts how each UN country might vote on your custom resolution text. Two models: one for stable democracies, one for spicy outliers.",
    live=False
)

iface.launch()