import torch import torch.nn as nn class VotePredictor(nn.Module): def __init__(self, text_dim=384, country_count=193, country_emb_dim=32, hidden_dim=256): super(VotePredictor, self).__init__() self.country_embedding = nn.Embedding(country_count, country_emb_dim) self.model = nn.Sequential( nn.Linear(text_dim + country_emb_dim, hidden_dim), nn.ReLU(), nn.Dropout(0.3), nn.Linear(hidden_dim, 1) ) def forward(self, text_vecs, country_ids): country_vecs = self.country_embedding(country_ids) x = torch.cat([text_vecs, country_vecs], dim=1) return self.model(x)