matth commited on
Commit
02fe640
·
1 Parent(s): 133b0ea

Upload Flowformer

Browse files
Files changed (1) hide show
  1. model_flowformer.py +2 -2
model_flowformer.py CHANGED
@@ -96,7 +96,7 @@ class Flowformer(PreTrainedModel):
96
  self.dec = nn.Sequential(*dec_layers)
97
 
98
  def markers(self):
99
- return self._pretrained_markers
100
 
101
  def forward(self, tensor: torch.Tensor, labels: torch.Tensor=None, markers: list=None):
102
  r"""
@@ -110,7 +110,7 @@ class Flowformer(PreTrainedModel):
110
  """
111
  B, L, M = tensor.shape
112
  if markers is not None:
113
- assert len(markers) == M, "Number of markers in x and markers must be identical"
114
 
115
  zeros = torch.zeros((B, L, len(self.markers())), device=tensor.device)
116
  valid_markers = [m for m in markers if m in set(self.markers()).intersection(markers)]
 
96
  self.dec = nn.Sequential(*dec_layers)
97
 
98
  def markers(self):
99
+ return self._markers
100
 
101
  def forward(self, tensor: torch.Tensor, labels: torch.Tensor=None, markers: list=None):
102
  r"""
 
110
  """
111
  B, L, M = tensor.shape
112
  if markers is not None:
113
+ assert len(markers) == M, "last dimension of input must be equal to number of markers"
114
 
115
  zeros = torch.zeros((B, L, len(self.markers())), device=tensor.device)
116
  valid_markers = [m for m in markers if m in set(self.markers()).intersection(markers)]