donsek commited on
Commit
f3f6110
·
verified ·
1 Parent(s): 566cbba

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -115
app.py DELETED
@@ -1,115 +0,0 @@
1
- import torch
2
- import gradio as gr
3
- import numpy as np
4
- import pickle
5
- import pandas as pd
6
- from sentence_transformers import SentenceTransformer
7
- from model import VotePredictor # <-- make sure this matches your model file
8
-
9
- # Load models
10
- main_model = VotePredictor(country_count=193)
11
- main_model.load_state_dict(torch.load("vote_predictor_epoch27.pt", map_location="cpu"))
12
- main_model.eval()
13
-
14
- problem_model = VotePredictor(country_count=193)
15
- problem_model.load_state_dict(torch.load("problem_country_model.pt", map_location="cpu"))
16
- problem_model.eval()
17
-
18
- # Load country encoder
19
- with open("country_encoder.pkl", "rb") as f:
20
- country_encoder = pickle.load(f)
21
-
22
- # Vectorizer
23
- vectorizer = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
24
-
25
- # Define problem countries (same as used during training)
26
- problem_countries = [
27
- 'SURINAME',
28
- 'TURKMENISTAN',
29
- 'MARSHALL ISLANDS',
30
- 'MYANMAR',
31
- 'GABON',
32
- 'CENTRAL AFRICAN REPUBLIC',
33
- 'ISRAEL',
34
- 'REPUBLIC OF THE CONGO',
35
- 'LIBERIA',
36
- 'SOMALIA',
37
- 'CANADA',
38
- "LAO PEOPLE'S DEMOCRATIC REPUBLIC",
39
- 'TUVALU',
40
- 'DEMOCRATIC REPUBLIC OF THE CONGO',
41
- 'MONTENEGRO',
42
- 'VANUATU',
43
- 'UNITED STATES',
44
- 'TÜRKİYE',
45
- 'SEYCHELLES',
46
- 'SERBIA',
47
- 'CABO VERDE',
48
- 'VENEZUELA (BOLIVARIAN REPUBLIC OF)',
49
- 'KIRIBATI',
50
- 'IRAN (ISLAMIC REPUBLIC OF)',
51
- 'SOUTH SUDAN',
52
- 'ALBANIA',
53
- 'CZECHIA',
54
- 'DOMINICA',
55
- 'SAO TOME AND PRINCIPE',
56
- 'ESWATINI',
57
- 'CHAD',
58
- 'EQUATORIAL GUINEA',
59
- 'GAMBIA',
60
- 'LIBYA',
61
- "CÔTE D'IVOIRE",
62
- 'SAINT CHRISTOPHER AND NEVIS',
63
- 'RWANDA',
64
- 'TONGA',
65
- 'NIGER',
66
- 'MICRONESIA (FEDERATED STATES OF)',
67
- 'SYRIAN ARAB REPUBLIC',
68
- 'NAURU',
69
- 'PALAU',
70
- 'NORTH MACEDONIA',
71
- 'NETHERLANDS',
72
- 'BOLIVIA (PLURINATIONAL STATE OF)'
73
- ]
74
-
75
- # Vote function
76
- def predict_votes(resolution_text):
77
- # Vectorize once
78
- vec = vectorizer.encode([resolution_text])
79
- x_tensor = torch.tensor(vec, dtype=torch.float32)
80
-
81
- countries = []
82
- votes = []
83
-
84
- for country in country_encoder.classes_:
85
- country_id = country_encoder.transform([country])[0]
86
- c_tensor = torch.tensor([country_id], dtype=torch.long)
87
-
88
- model = problem_model if country in problem_countries else main_model
89
-
90
- with torch.no_grad():
91
- logit = model(x_tensor, c_tensor).squeeze()
92
- prob = torch.sigmoid(logit).item()
93
- vote = "✅ Yes" if prob > 0.5 else "❌ Not Yes"
94
-
95
- countries.append(country)
96
- votes.append(vote)
97
-
98
- df = pd.DataFrame({
99
- "Country": countries,
100
- "Vote": votes
101
- }).sort_values("Country")
102
-
103
- return df
104
-
105
- # Gradio UI
106
- iface = gr.Interface(
107
- fn=predict_votes,
108
- inputs=gr.Textbox(lines=15, label="Paste UN Resolution Text Here"),
109
- outputs=gr.Dataframe(label="Predicted Votes by Country"),
110
- title="UN Resolution Vote Predictor",
111
- description="This model predicts how each UN country will vote on a given resolution based on the text. Uses BERT embeddings and two models: one for normal countries, one for chaos monkeys.",
112
- live=False
113
- )
114
-
115
- iface.launch()