donsek commited on
Commit
9282ee1
·
verified ·
1 Parent(s): db67a1e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -114
app.py CHANGED
@@ -1,115 +1,98 @@
1
- import torch
2
- import gradio as gr
3
- import numpy as np
4
- import pickle
5
- import pandas as pd
6
- from sentence_transformers import SentenceTransformer
7
- from model import VotePredictor # <-- make sure this matches your model file
8
-
9
- # Load models
10
- main_model = VotePredictor(country_count=193)
11
- main_model.load_state_dict(torch.load("vote_predictor_epoch27.pt", map_location="cpu"))
12
- main_model.eval()
13
-
14
- problem_model = VotePredictor(country_count=46)
15
- problem_model.load_state_dict(torch.load("problem_country_model.pt", map_location="cpu"))
16
- problem_model.eval()
17
-
18
- # Load country encoder
19
- with open("country_encoder.pkl", "rb") as f:
20
- country_encoder = pickle.load(f)
21
-
22
- # Vectorizer
23
- vectorizer = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
24
-
25
- # Define problem countries (same as used during training)
26
- problem_countries = [
27
- 'SURINAME',
28
- 'TURKMENISTAN',
29
- 'MARSHALL ISLANDS',
30
- 'MYANMAR',
31
- 'GABON',
32
- 'CENTRAL AFRICAN REPUBLIC',
33
- 'ISRAEL',
34
- 'REPUBLIC OF THE CONGO',
35
- 'LIBERIA',
36
- 'SOMALIA',
37
- 'CANADA',
38
- "LAO PEOPLE'S DEMOCRATIC REPUBLIC",
39
- 'TUVALU',
40
- 'DEMOCRATIC REPUBLIC OF THE CONGO',
41
- 'MONTENEGRO',
42
- 'VANUATU',
43
- 'UNITED STATES',
44
- 'TÜRKİYE',
45
- 'SEYCHELLES',
46
- 'SERBIA',
47
- 'CABO VERDE',
48
- 'VENEZUELA (BOLIVARIAN REPUBLIC OF)',
49
- 'KIRIBATI',
50
- 'IRAN (ISLAMIC REPUBLIC OF)',
51
- 'SOUTH SUDAN',
52
- 'ALBANIA',
53
- 'CZECHIA',
54
- 'DOMINICA',
55
- 'SAO TOME AND PRINCIPE',
56
- 'ESWATINI',
57
- 'CHAD',
58
- 'EQUATORIAL GUINEA',
59
- 'GAMBIA',
60
- 'LIBYA',
61
- "CÔTE D'IVOIRE",
62
- 'SAINT CHRISTOPHER AND NEVIS',
63
- 'RWANDA',
64
- 'TONGA',
65
- 'NIGER',
66
- 'MICRONESIA (FEDERATED STATES OF)',
67
- 'SYRIAN ARAB REPUBLIC',
68
- 'NAURU',
69
- 'PALAU',
70
- 'NORTH MACEDONIA',
71
- 'NETHERLANDS',
72
- 'BOLIVIA (PLURINATIONAL STATE OF)'
73
- ]
74
-
75
- # Vote function
76
- def predict_votes(resolution_text):
77
- # Vectorize once
78
- vec = vectorizer.encode([resolution_text])
79
- x_tensor = torch.tensor(vec, dtype=torch.float32)
80
-
81
- countries = []
82
- votes = []
83
-
84
- for country in country_encoder.classes_:
85
- country_id = country_encoder.transform([country])[0]
86
- c_tensor = torch.tensor([country_id], dtype=torch.long)
87
-
88
- model = problem_model if country in problem_countries else main_model
89
-
90
- with torch.no_grad():
91
- logit = model(x_tensor, c_tensor).squeeze()
92
- prob = torch.sigmoid(logit).item()
93
- vote = " Yes" if prob > 0.5 else "❌ Not Yes"
94
-
95
- countries.append(country)
96
- votes.append(vote)
97
-
98
- df = pd.DataFrame({
99
- "Country": countries,
100
- "Vote": votes
101
- }).sort_values("Country")
102
-
103
- return df
104
-
105
- # Gradio UI
106
- iface = gr.Interface(
107
- fn=predict_votes,
108
- inputs=gr.Textbox(lines=15, label="Paste UN Resolution Text Here"),
109
- outputs=gr.Dataframe(label="Predicted Votes by Country"),
110
- title="UN Resolution Vote Predictor",
111
- description="This model predicts how each UN country will vote on a given resolution based on the text. Uses BERT embeddings and two models: one for normal countries, one for chaos monkeys.",
112
- live=False
113
- )
114
-
115
  iface.launch()
 
1
+ import torch
2
+ import gradio as gr
3
+ import numpy as np
4
+ import pickle
5
+ import pandas as pd
6
+ from model import VotePredictor
7
+
8
+ from transformers import AutoTokenizer, AutoModel
9
+
10
+
11
+ # === Vectorizer wrapper (replaces sentence-transformers) ===
12
+ class BertVectorizer:
13
+ def __init__(self, model_name="sentence-transformers/all-MiniLM-L6-v2"):
14
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
15
+ self.model = AutoModel.from_pretrained(model_name)
16
+ self.model.eval()
17
+
18
+ def encode(self, text):
19
+ with torch.no_grad():
20
+ inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
21
+ outputs = self.model(**inputs)
22
+ cls_embedding = outputs.last_hidden_state[:, 0, :]
23
+ return cls_embedding.squeeze().numpy()
24
+
25
+
26
+ # === Load Models ===
27
+ main_model = VotePredictor(country_count=193)
28
+ main_model.load_state_dict(torch.load("vote_predictor_epoch27.pt", map_location="cpu"))
29
+ main_model.eval()
30
+
31
+ problem_model = VotePredictor(country_count=46)
32
+ problem_model.load_state_dict(torch.load("problem_country_model.pt", map_location="cpu"))
33
+ problem_model.eval()
34
+
35
+ # === Load Encoder ===
36
+ with open("country_encoder.pkl", "rb") as f:
37
+ country_encoder = pickle.load(f)
38
+
39
+ # === Initialize Vectorizer ===
40
+ vectorizer = BertVectorizer()
41
+
42
+ # === List of problem countries ===
43
+ problem_countries = [
44
+ 'SURINAME', 'TURKMENISTAN', 'MARSHALL ISLANDS', 'MYANMAR', 'GABON',
45
+ 'CENTRAL AFRICAN REPUBLIC', 'ISRAEL', 'REPUBLIC OF THE CONGO', 'LIBERIA',
46
+ 'SOMALIA', 'CANADA', "LAO PEOPLE'S DEMOCRATIC REPUBLIC", 'TUVALU',
47
+ 'DEMOCRATIC REPUBLIC OF THE CONGO', 'MONTENEGRO', 'VANUATU', 'UNITED STATES',
48
+ 'TÜRKİYE', 'SEYCHELLES', 'SERBIA', 'CABO VERDE',
49
+ 'VENEZUELA (BOLIVARIAN REPUBLIC OF)', 'KIRIBATI', 'IRAN (ISLAMIC REPUBLIC OF)',
50
+ 'SOUTH SUDAN', 'ALBANIA', 'CZECHIA', 'DOMINICA', 'SAO TOME AND PRINCIPE',
51
+ 'ESWATINI', 'CHAD', 'EQUATORIAL GUINEA', 'GAMBIA', 'LIBYA',
52
+ "CÔTE D'IVOIRE", 'SAINT CHRISTOPHER AND NEVIS', 'RWANDA', 'TONGA', 'NIGER',
53
+ 'MICRONESIA (FEDERATED STATES OF)', 'SYRIAN ARAB REPUBLIC', 'NAURU',
54
+ 'PALAU', 'NORTH MACEDONIA', 'NETHERLANDS', 'BOLIVIA (PLURINATIONAL STATE OF)'
55
+ ]
56
+
57
+
58
+ # === Prediction Function ===
59
+ def predict_votes(resolution_text):
60
+ vec = vectorizer.encode(resolution_text)
61
+ x_tensor = torch.tensor(vec, dtype=torch.float32).unsqueeze(0) # batchify
62
+
63
+ countries = []
64
+ votes = []
65
+
66
+ for country in country_encoder.classes_:
67
+ country_id = country_encoder.transform([country])[0]
68
+ c_tensor = torch.tensor([country_id], dtype=torch.long)
69
+
70
+ model = problem_model if country in problem_countries else main_model
71
+
72
+ with torch.no_grad():
73
+ logit = model(x_tensor, c_tensor).squeeze()
74
+ prob = torch.sigmoid(logit).item()
75
+ vote = "✅ Yes" if prob > 0.5 else "❌ Not Yes"
76
+
77
+ countries.append(country)
78
+ votes.append(vote)
79
+
80
+ df = pd.DataFrame({
81
+ "Country": countries,
82
+ "Vote": votes
83
+ }).sort_values("Country")
84
+
85
+ return df
86
+
87
+
88
+ # === Interface ===
89
+ iface = gr.Interface(
90
+ fn=predict_votes,
91
+ inputs=gr.Textbox(lines=15, label="Paste UN Resolution Text Here"),
92
+ outputs=gr.Dataframe(label="Predicted Votes by Country"),
93
+ title="UN Resolution Vote Predictor",
94
+ description="Predicts how each UN country might vote on your custom resolution text. Two models: one for stable democracies, one for spicy outliers.",
95
+ live=False
96
+ )
97
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  iface.launch()