donsek commited on
Commit
566cbba
·
verified ·
1 Parent(s): 837e763

Upload model.py

Browse files
Files changed (1) hide show
  1. model.py +18 -0
model.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class VotePredictor(nn.Module):
5
+ def __init__(self, text_dim=384, country_count=193, country_emb_dim=32, hidden_dim=256):
6
+ super(VotePredictor, self).__init__()
7
+ self.country_embedding = nn.Embedding(country_count, country_emb_dim)
8
+ self.model = nn.Sequential(
9
+ nn.Linear(text_dim + country_emb_dim, hidden_dim),
10
+ nn.ReLU(),
11
+ nn.Dropout(0.3),
12
+ nn.Linear(hidden_dim, 1)
13
+ )
14
+
15
+ def forward(self, text_vecs, country_ids):
16
+ country_vecs = self.country_embedding(country_ids)
17
+ x = torch.cat([text_vecs, country_vecs], dim=1)
18
+ return self.model(x)