Spaces:
Running
Running
File size: 2,054 Bytes
719d0db |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
import torch
import torch.nn as nn
class RandomPredictor(nn.Module):
def __init__(self, num_classes):
super().__init__()
self.num_classes = num_classes
def forward(self, inputs):
"""
Parameters
----------
inputs: int or dict
batch_size or dict of input features
Returns
-------
probs: torch.tensor [batch_size x num_classes]
"""
batch_size = inputs if isinstance(inputs, int) else inputs["curr_node_id"].size(0)
ranom_index = torch.randint(self.num_classes, (batch_size, self.num_classes))
probs = torch.zeros(batch_size, self.num_classes).to(torch.float)
probs.scatter_(-1, ranom_index, 1.0)
return probs
def get_inputs(self, tour, first_explained_step, node_feats):
return len(tour[first_explained_step:-1])
class FixedClassPredictor(nn.Module):
def __init__(self, predicted_class, num_classes):
"""
Paramters
---------
predicted_class: int
a class that this predictor always predicts
num_classes: int
number of classes
"""
super().__init__()
self.predicted_class = predicted_class
self.num_classes = num_classes
assert predicted_class < num_classes, f"predicted_class should be 0 - {num_classes}."
def forward(self, inputs):
"""
Parameters
----------
inputs: int or dict
batch_size or dict of input features
Returns
-------
probs: torch.tensor [batch_size x num_classes]
"""
batch_size = inputs if isinstance(inputs, int) else inputs["curr_node_id"].size(0)
index = torch.full((batch_size, self.num_classes), self.predicted_class)
probs = torch.zeros(batch_size, self.num_classes).to(torch.float)
probs.scatter_(-1, index, 1.0)
return probs
def get_inputs(self, tour, first_explained_step, node_feats):
return len(tour[first_explained_step:-1]) |