Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
class RandomPredictor(nn.Module): | |
def __init__(self, num_classes): | |
super().__init__() | |
self.num_classes = num_classes | |
def forward(self, inputs): | |
""" | |
Parameters | |
---------- | |
inputs: int or dict | |
batch_size or dict of input features | |
Returns | |
------- | |
probs: torch.tensor [batch_size x num_classes] | |
""" | |
batch_size = inputs if isinstance(inputs, int) else inputs["curr_node_id"].size(0) | |
ranom_index = torch.randint(self.num_classes, (batch_size, self.num_classes)) | |
probs = torch.zeros(batch_size, self.num_classes).to(torch.float) | |
probs.scatter_(-1, ranom_index, 1.0) | |
return probs | |
def get_inputs(self, tour, first_explained_step, node_feats): | |
return len(tour[first_explained_step:-1]) | |
class FixedClassPredictor(nn.Module): | |
def __init__(self, predicted_class, num_classes): | |
""" | |
Paramters | |
--------- | |
predicted_class: int | |
a class that this predictor always predicts | |
num_classes: int | |
number of classes | |
""" | |
super().__init__() | |
self.predicted_class = predicted_class | |
self.num_classes = num_classes | |
assert predicted_class < num_classes, f"predicted_class should be 0 - {num_classes}." | |
def forward(self, inputs): | |
""" | |
Parameters | |
---------- | |
inputs: int or dict | |
batch_size or dict of input features | |
Returns | |
------- | |
probs: torch.tensor [batch_size x num_classes] | |
""" | |
batch_size = inputs if isinstance(inputs, int) else inputs["curr_node_id"].size(0) | |
index = torch.full((batch_size, self.num_classes), self.predicted_class) | |
probs = torch.zeros(batch_size, self.num_classes).to(torch.float) | |
probs.scatter_(-1, index, 1.0) | |
return probs | |
def get_inputs(self, tour, first_explained_step, node_feats): | |
return len(tour[first_explained_step:-1]) |