route-explainer / models /classifiers /meaningless_models.py
daisuke.kikuta
first commit
719d0db
import torch
import torch.nn as nn
class RandomPredictor(nn.Module):
def __init__(self, num_classes):
super().__init__()
self.num_classes = num_classes
def forward(self, inputs):
"""
Parameters
----------
inputs: int or dict
batch_size or dict of input features
Returns
-------
probs: torch.tensor [batch_size x num_classes]
"""
batch_size = inputs if isinstance(inputs, int) else inputs["curr_node_id"].size(0)
ranom_index = torch.randint(self.num_classes, (batch_size, self.num_classes))
probs = torch.zeros(batch_size, self.num_classes).to(torch.float)
probs.scatter_(-1, ranom_index, 1.0)
return probs
def get_inputs(self, tour, first_explained_step, node_feats):
return len(tour[first_explained_step:-1])
class FixedClassPredictor(nn.Module):
def __init__(self, predicted_class, num_classes):
"""
Paramters
---------
predicted_class: int
a class that this predictor always predicts
num_classes: int
number of classes
"""
super().__init__()
self.predicted_class = predicted_class
self.num_classes = num_classes
assert predicted_class < num_classes, f"predicted_class should be 0 - {num_classes}."
def forward(self, inputs):
"""
Parameters
----------
inputs: int or dict
batch_size or dict of input features
Returns
-------
probs: torch.tensor [batch_size x num_classes]
"""
batch_size = inputs if isinstance(inputs, int) else inputs["curr_node_id"].size(0)
index = torch.full((batch_size, self.num_classes), self.predicted_class)
probs = torch.zeros(batch_size, self.num_classes).to(torch.float)
probs.scatter_(-1, index, 1.0)
return probs
def get_inputs(self, tour, first_explained_step, node_feats):
return len(tour[first_explained_step:-1])