File size: 3,700 Bytes
9282ee1 0de37c8 9282ee1 0de37c8 9282ee1 99a265f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
import torch
import gradio as gr
import numpy as np
import pickle
import pandas as pd
from model import VotePredictor
from transformers import AutoTokenizer, AutoModel
# === Vectorizer wrapper (replaces sentence-transformers) ===
class BertVectorizer:
def __init__(self, model_name="sentence-transformers/all-MiniLM-L6-v2"):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(model_name)
self.model.eval()
def encode(self, text):
with torch.no_grad():
inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
outputs = self.model(**inputs)
cls_embedding = outputs.last_hidden_state[:, 0, :]
return cls_embedding.squeeze().numpy()
# === Load Models ===
main_model = VotePredictor(country_count=193)
main_model.load_state_dict(torch.load("vote_predictor_epoch27.pt", map_location="cpu"))
main_model.eval()
problem_model = VotePredictor(country_count=46)
problem_model.load_state_dict(torch.load("problem_country_model.pt", map_location="cpu"))
problem_model.eval()
# === Load Encoder ===
with open("country_encoder.pkl", "rb") as f:
country_encoder = pickle.load(f)
# === Initialize Vectorizer ===
vectorizer = BertVectorizer()
# === List of problem countries ===
problem_countries = [
'SURINAME', 'TURKMENISTAN', 'MARSHALL ISLANDS', 'MYANMAR', 'GABON',
'CENTRAL AFRICAN REPUBLIC', 'ISRAEL', 'REPUBLIC OF THE CONGO', 'LIBERIA',
'SOMALIA', 'CANADA', "LAO PEOPLE'S DEMOCRATIC REPUBLIC", 'TUVALU',
'DEMOCRATIC REPUBLIC OF THE CONGO', 'MONTENEGRO', 'VANUATU', 'UNITED STATES',
'TÜRKİYE', 'SEYCHELLES', 'SERBIA', 'CABO VERDE',
'VENEZUELA (BOLIVARIAN REPUBLIC OF)', 'KIRIBATI', 'IRAN (ISLAMIC REPUBLIC OF)',
'SOUTH SUDAN', 'ALBANIA', 'CZECHIA', 'DOMINICA', 'SAO TOME AND PRINCIPE',
'ESWATINI', 'CHAD', 'EQUATORIAL GUINEA', 'GAMBIA', 'LIBYA',
"CÔTE D'IVOIRE", 'SAINT CHRISTOPHER AND NEVIS', 'RWANDA', 'TONGA', 'NIGER',
'MICRONESIA (FEDERATED STATES OF)', 'SYRIAN ARAB REPUBLIC', 'NAURU',
'PALAU', 'NORTH MACEDONIA', 'NETHERLANDS', 'BOLIVIA (PLURINATIONAL STATE OF)'
]
# === Prediction Function ===
def predict_votes(resolution_text):
vec = vectorizer.encode(resolution_text)
x_tensor = torch.tensor(vec, dtype=torch.float32).unsqueeze(0)
countries = []
votes = []
for country in country_encoder.classes_:
is_problem = country in problem_countries
model = problem_model if is_problem else main_model
if is_problem:
problem_index = problem_countries.index(country) # 0–45
c_tensor = torch.tensor([problem_index], dtype=torch.long)
else:
country_id = country_encoder.transform([country])[0] # 0–192
c_tensor = torch.tensor([country_id], dtype=torch.long)
with torch.no_grad():
logit = model(x_tensor, c_tensor).squeeze()
prob = torch.sigmoid(logit).item()
vote = "✅ Yes" if prob > 0.5 else "❌ Not Yes"
countries.append(country)
votes.append(vote)
df = pd.DataFrame({
"Country": countries,
"Vote": votes
}).sort_values("Country")
return df
# === Interface ===
iface = gr.Interface(
fn=predict_votes,
inputs=gr.Textbox(lines=15, label="Paste UN Resolution Text Here"),
outputs=gr.Dataframe(label="Predicted Votes by Country"),
title="UN Resolution Vote Predictor",
description="Predicts how each UN country might vote on your custom resolution text. Two models: one for stable democracies, one for spicy outliers.",
live=False
)
iface.launch() |