|
import torch |
|
import gradio as gr |
|
import numpy as np |
|
import pickle |
|
import pandas as pd |
|
from model import VotePredictor |
|
|
|
from transformers import AutoTokenizer, AutoModel |
|
|
|
|
|
|
|
class BertVectorizer: |
|
def __init__(self, model_name="sentence-transformers/all-MiniLM-L6-v2"): |
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
self.model = AutoModel.from_pretrained(model_name) |
|
self.model.eval() |
|
|
|
def encode(self, text): |
|
with torch.no_grad(): |
|
inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512) |
|
outputs = self.model(**inputs) |
|
cls_embedding = outputs.last_hidden_state[:, 0, :] |
|
return cls_embedding.squeeze().numpy() |
|
|
|
|
|
|
|
main_model = VotePredictor(country_count=193) |
|
main_model.load_state_dict(torch.load("vote_predictor_epoch27.pt", map_location="cpu")) |
|
main_model.eval() |
|
|
|
problem_model = VotePredictor(country_count=46) |
|
problem_model.load_state_dict(torch.load("problem_country_model.pt", map_location="cpu")) |
|
problem_model.eval() |
|
|
|
|
|
with open("country_encoder.pkl", "rb") as f: |
|
country_encoder = pickle.load(f) |
|
|
|
|
|
vectorizer = BertVectorizer() |
|
|
|
|
|
problem_countries = [ |
|
'SURINAME', 'TURKMENISTAN', 'MARSHALL ISLANDS', 'MYANMAR', 'GABON', |
|
'CENTRAL AFRICAN REPUBLIC', 'ISRAEL', 'REPUBLIC OF THE CONGO', 'LIBERIA', |
|
'SOMALIA', 'CANADA', "LAO PEOPLE'S DEMOCRATIC REPUBLIC", 'TUVALU', |
|
'DEMOCRATIC REPUBLIC OF THE CONGO', 'MONTENEGRO', 'VANUATU', 'UNITED STATES', |
|
'TÜRKİYE', 'SEYCHELLES', 'SERBIA', 'CABO VERDE', |
|
'VENEZUELA (BOLIVARIAN REPUBLIC OF)', 'KIRIBATI', 'IRAN (ISLAMIC REPUBLIC OF)', |
|
'SOUTH SUDAN', 'ALBANIA', 'CZECHIA', 'DOMINICA', 'SAO TOME AND PRINCIPE', |
|
'ESWATINI', 'CHAD', 'EQUATORIAL GUINEA', 'GAMBIA', 'LIBYA', |
|
"CÔTE D'IVOIRE", 'SAINT CHRISTOPHER AND NEVIS', 'RWANDA', 'TONGA', 'NIGER', |
|
'MICRONESIA (FEDERATED STATES OF)', 'SYRIAN ARAB REPUBLIC', 'NAURU', |
|
'PALAU', 'NORTH MACEDONIA', 'NETHERLANDS', 'BOLIVIA (PLURINATIONAL STATE OF)' |
|
] |
|
|
|
|
|
|
|
def predict_votes(resolution_text): |
|
vec = vectorizer.encode(resolution_text) |
|
x_tensor = torch.tensor(vec, dtype=torch.float32).unsqueeze(0) |
|
|
|
countries = [] |
|
votes = [] |
|
|
|
for country in country_encoder.classes_: |
|
is_problem = country in problem_countries |
|
model = problem_model if is_problem else main_model |
|
|
|
if is_problem: |
|
problem_index = problem_countries.index(country) |
|
c_tensor = torch.tensor([problem_index], dtype=torch.long) |
|
else: |
|
country_id = country_encoder.transform([country])[0] |
|
c_tensor = torch.tensor([country_id], dtype=torch.long) |
|
|
|
with torch.no_grad(): |
|
logit = model(x_tensor, c_tensor).squeeze() |
|
prob = torch.sigmoid(logit).item() |
|
vote = "✅ Yes" if prob > 0.5 else "❌ Not Yes" |
|
|
|
countries.append(country) |
|
votes.append(vote) |
|
|
|
df = pd.DataFrame({ |
|
"Country": countries, |
|
"Vote": votes |
|
}).sort_values("Country") |
|
|
|
return df |
|
|
|
|
|
|
|
iface = gr.Interface( |
|
fn=predict_votes, |
|
inputs=gr.Textbox(lines=15, label="Paste UN Resolution Text Here"), |
|
outputs=gr.Dataframe(label="Predicted Votes by Country"), |
|
title="UN Resolution Vote Predictor", |
|
description="Predicts how each UN country might vote on your custom resolution text. Two models: one for stable democracies, one for spicy outliers.", |
|
live=False |
|
) |
|
|
|
iface.launch() |