Upload Flowformer
Browse files- model_flowformer.py +2 -2
model_flowformer.py
CHANGED
@@ -96,7 +96,7 @@ class Flowformer(PreTrainedModel):
|
|
96 |
self.dec = nn.Sequential(*dec_layers)
|
97 |
|
98 |
def markers(self):
|
99 |
-
return self.
|
100 |
|
101 |
def forward(self, tensor: torch.Tensor, labels: torch.Tensor=None, markers: list=None):
|
102 |
r"""
|
@@ -110,7 +110,7 @@ class Flowformer(PreTrainedModel):
|
|
110 |
"""
|
111 |
B, L, M = tensor.shape
|
112 |
if markers is not None:
|
113 |
-
assert len(markers) == M, "
|
114 |
|
115 |
zeros = torch.zeros((B, L, len(self.markers())), device=tensor.device)
|
116 |
valid_markers = [m for m in markers if m in set(self.markers()).intersection(markers)]
|
|
|
96 |
self.dec = nn.Sequential(*dec_layers)
|
97 |
|
98 |
def markers(self):
|
99 |
+
return self._markers
|
100 |
|
101 |
def forward(self, tensor: torch.Tensor, labels: torch.Tensor=None, markers: list=None):
|
102 |
r"""
|
|
|
110 |
"""
|
111 |
B, L, M = tensor.shape
|
112 |
if markers is not None:
|
113 |
+
assert len(markers) == M, "last dimension of input must be equal to number of markers"
|
114 |
|
115 |
zeros = torch.zeros((B, L, len(self.markers())), device=tensor.device)
|
116 |
valid_markers = [m for m in markers if m in set(self.markers()).intersection(markers)]
|