import torch import gradio as gr import numpy as np import pickle import pandas as pd from model import VotePredictor from transformers import AutoTokenizer, AutoModel # === Vectorizer wrapper (replaces sentence-transformers) === class BertVectorizer: def __init__(self, model_name="sentence-transformers/all-MiniLM-L6-v2"): self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.model = AutoModel.from_pretrained(model_name) self.model.eval() def encode(self, text): with torch.no_grad(): inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512) outputs = self.model(**inputs) cls_embedding = outputs.last_hidden_state[:, 0, :] return cls_embedding.squeeze().numpy() # === Load Models === main_model = VotePredictor(country_count=193) main_model.load_state_dict(torch.load("vote_predictor_epoch27.pt", map_location="cpu")) main_model.eval() problem_model = VotePredictor(country_count=46) problem_model.load_state_dict(torch.load("problem_country_model.pt", map_location="cpu")) problem_model.eval() # === Load Encoder === with open("country_encoder.pkl", "rb") as f: country_encoder = pickle.load(f) # === Initialize Vectorizer === vectorizer = BertVectorizer() # === List of problem countries === problem_countries = [ 'SURINAME', 'TURKMENISTAN', 'MARSHALL ISLANDS', 'MYANMAR', 'GABON', 'CENTRAL AFRICAN REPUBLIC', 'ISRAEL', 'REPUBLIC OF THE CONGO', 'LIBERIA', 'SOMALIA', 'CANADA', "LAO PEOPLE'S DEMOCRATIC REPUBLIC", 'TUVALU', 'DEMOCRATIC REPUBLIC OF THE CONGO', 'MONTENEGRO', 'VANUATU', 'UNITED STATES', 'TÜRKİYE', 'SEYCHELLES', 'SERBIA', 'CABO VERDE', 'VENEZUELA (BOLIVARIAN REPUBLIC OF)', 'KIRIBATI', 'IRAN (ISLAMIC REPUBLIC OF)', 'SOUTH SUDAN', 'ALBANIA', 'CZECHIA', 'DOMINICA', 'SAO TOME AND PRINCIPE', 'ESWATINI', 'CHAD', 'EQUATORIAL GUINEA', 'GAMBIA', 'LIBYA', "CÔTE D'IVOIRE", 'SAINT CHRISTOPHER AND NEVIS', 'RWANDA', 'TONGA', 'NIGER', 'MICRONESIA (FEDERATED STATES OF)', 'SYRIAN ARAB REPUBLIC', 'NAURU', 'PALAU', 'NORTH MACEDONIA', 'NETHERLANDS', 'BOLIVIA (PLURINATIONAL STATE OF)' ] # === Prediction Function === def predict_votes(resolution_text): vec = vectorizer.encode(resolution_text) x_tensor = torch.tensor(vec, dtype=torch.float32).unsqueeze(0) countries = [] votes = [] for country in country_encoder.classes_: is_problem = country in problem_countries model = problem_model if is_problem else main_model if is_problem: problem_index = problem_countries.index(country) # 0–45 c_tensor = torch.tensor([problem_index], dtype=torch.long) else: country_id = country_encoder.transform([country])[0] # 0–192 c_tensor = torch.tensor([country_id], dtype=torch.long) with torch.no_grad(): logit = model(x_tensor, c_tensor).squeeze() prob = torch.sigmoid(logit).item() vote = "✅ Yes" if prob > 0.5 else "❌ Not Yes" countries.append(country) votes.append(vote) df = pd.DataFrame({ "Country": countries, "Vote": votes }).sort_values("Country") return df # === Interface === iface = gr.Interface( fn=predict_votes, inputs=gr.Textbox(lines=15, label="Paste UN Resolution Text Here"), outputs=gr.Dataframe(label="Predicted Votes by Country"), title="UN Resolution Vote Predictor", description="Predicts how each UN country might vote on your custom resolution text. Two models: one for stable democracies, one for spicy outliers.", live=False ) iface.launch()