donsek commited on
Commit
99a265f
·
verified ·
1 Parent(s): f3f6110

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +115 -0
  2. problem_country_encoder.pkl +3 -0
app.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ import numpy as np
4
+ import pickle
5
+ import pandas as pd
6
+ from sentence_transformers import SentenceTransformer
7
+ from model import VotePredictor # <-- make sure this matches your model file
8
+
9
+ # Load models
10
+ main_model = VotePredictor(country_count=193)
11
+ main_model.load_state_dict(torch.load("vote_predictor_epoch27.pt", map_location="cpu"))
12
+ main_model.eval()
13
+
14
+ problem_model = VotePredictor(country_count=46)
15
+ problem_model.load_state_dict(torch.load("problem_country_model.pt", map_location="cpu"))
16
+ problem_model.eval()
17
+
18
+ # Load country encoder
19
+ with open("country_encoder.pkl", "rb") as f:
20
+ country_encoder = pickle.load(f)
21
+
22
+ # Vectorizer
23
+ vectorizer = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
24
+
25
+ # Define problem countries (same as used during training)
26
+ problem_countries = [
27
+ 'SURINAME',
28
+ 'TURKMENISTAN',
29
+ 'MARSHALL ISLANDS',
30
+ 'MYANMAR',
31
+ 'GABON',
32
+ 'CENTRAL AFRICAN REPUBLIC',
33
+ 'ISRAEL',
34
+ 'REPUBLIC OF THE CONGO',
35
+ 'LIBERIA',
36
+ 'SOMALIA',
37
+ 'CANADA',
38
+ "LAO PEOPLE'S DEMOCRATIC REPUBLIC",
39
+ 'TUVALU',
40
+ 'DEMOCRATIC REPUBLIC OF THE CONGO',
41
+ 'MONTENEGRO',
42
+ 'VANUATU',
43
+ 'UNITED STATES',
44
+ 'TÜRKİYE',
45
+ 'SEYCHELLES',
46
+ 'SERBIA',
47
+ 'CABO VERDE',
48
+ 'VENEZUELA (BOLIVARIAN REPUBLIC OF)',
49
+ 'KIRIBATI',
50
+ 'IRAN (ISLAMIC REPUBLIC OF)',
51
+ 'SOUTH SUDAN',
52
+ 'ALBANIA',
53
+ 'CZECHIA',
54
+ 'DOMINICA',
55
+ 'SAO TOME AND PRINCIPE',
56
+ 'ESWATINI',
57
+ 'CHAD',
58
+ 'EQUATORIAL GUINEA',
59
+ 'GAMBIA',
60
+ 'LIBYA',
61
+ "CÔTE D'IVOIRE",
62
+ 'SAINT CHRISTOPHER AND NEVIS',
63
+ 'RWANDA',
64
+ 'TONGA',
65
+ 'NIGER',
66
+ 'MICRONESIA (FEDERATED STATES OF)',
67
+ 'SYRIAN ARAB REPUBLIC',
68
+ 'NAURU',
69
+ 'PALAU',
70
+ 'NORTH MACEDONIA',
71
+ 'NETHERLANDS',
72
+ 'BOLIVIA (PLURINATIONAL STATE OF)'
73
+ ]
74
+
75
+ # Vote function
76
+ def predict_votes(resolution_text):
77
+ # Vectorize once
78
+ vec = vectorizer.encode([resolution_text])
79
+ x_tensor = torch.tensor(vec, dtype=torch.float32)
80
+
81
+ countries = []
82
+ votes = []
83
+
84
+ for country in country_encoder.classes_:
85
+ country_id = country_encoder.transform([country])[0]
86
+ c_tensor = torch.tensor([country_id], dtype=torch.long)
87
+
88
+ model = problem_model if country in problem_countries else main_model
89
+
90
+ with torch.no_grad():
91
+ logit = model(x_tensor, c_tensor).squeeze()
92
+ prob = torch.sigmoid(logit).item()
93
+ vote = "✅ Yes" if prob > 0.5 else "❌ Not Yes"
94
+
95
+ countries.append(country)
96
+ votes.append(vote)
97
+
98
+ df = pd.DataFrame({
99
+ "Country": countries,
100
+ "Vote": votes
101
+ }).sort_values("Country")
102
+
103
+ return df
104
+
105
+ # Gradio UI
106
+ iface = gr.Interface(
107
+ fn=predict_votes,
108
+ inputs=gr.Textbox(lines=15, label="Paste UN Resolution Text Here"),
109
+ outputs=gr.Dataframe(label="Predicted Votes by Country"),
110
+ title="UN Resolution Vote Predictor",
111
+ description="This model predicts how each UN country will vote on a given resolution based on the text. Uses BERT embeddings and two models: one for normal countries, one for chaos monkeys.",
112
+ live=False
113
+ )
114
+
115
+ iface.launch()
problem_country_encoder.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:805fcaa77f0176fa8a3745700ddafde4df90278ed1bcfacb83353423dba4bb75
3
+ size 1038