Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from transformers.activations import ACT2FN | |
class MaskedConv1d(nn.Conv1d): | |
"""A masked 1-dimensional convolution layer. | |
Takes the same arguments as torch.nn.Conv1D, except that the padding is set automatically. | |
Shape: | |
Input: (N, L, in_channels) | |
input_mask: (N, L, 1), optional | |
Output: (N, L, out_channels) | |
""" | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
kernel_size: int, | |
stride: int = 1, | |
dilation: int = 1, | |
groups: int = 1, | |
bias: bool = True, | |
): | |
""" | |
:param in_channels: input channels | |
:param out_channels: output channels | |
:param kernel_size: the kernel width | |
:param stride: filter shift | |
:param dilation: dilation factor | |
:param groups: perform depth-wise convolutions | |
:param bias: adds learnable bias to output | |
""" | |
padding = dilation * (kernel_size - 1) // 2 | |
super().__init__( | |
in_channels, | |
out_channels, | |
kernel_size, | |
stride=stride, | |
dilation=dilation, | |
groups=groups, | |
bias=bias, | |
padding=padding, | |
) | |
def forward(self, x, input_mask=None): | |
if input_mask is not None: | |
x = x * input_mask | |
return super().forward(x.transpose(1, 2)).transpose(1, 2) | |
class Attention1dPooling(nn.Module): | |
def __init__(self, hidden_size): | |
super().__init__() | |
self.layer = MaskedConv1d(hidden_size, 1, 1) | |
def forward(self, x, input_mask=None): | |
batch_szie = x.shape[0] | |
attn = self.layer(x) | |
attn = attn.view(batch_szie, -1) | |
if input_mask is not None: | |
attn = attn.masked_fill_( | |
~input_mask.view(batch_szie, -1).bool(), float("-inf") | |
) | |
attn = F.softmax(attn, dim=-1).view(batch_szie, -1, 1) | |
out = (attn * x).sum(dim=1) | |
return out | |
class Attention1dPoolingProjection(nn.Module): | |
def __init__(self, hidden_size, num_labels, dropout=0.25) -> None: | |
super(Attention1dPoolingProjection, self).__init__() | |
self.linear = nn.Linear(hidden_size, hidden_size) | |
self.dropout = nn.Dropout(dropout) | |
self.relu = nn.ReLU() | |
self.final = nn.Linear(hidden_size, num_labels) | |
def forward(self, x): | |
x = self.linear(x) | |
x = self.dropout(x) | |
x = self.relu(x) | |
x = self.final(x) | |
return x | |
class Attention1dPoolingHead(nn.Module): | |
"""Outputs of the model with the attention1d""" | |
def __init__( | |
self, hidden_size: int, num_labels: int, dropout: float = 0.25 | |
): # [batch x sequence(751) x embedding (1280)] --> [batch x embedding] --> [batch x 1] | |
super(Attention1dPoolingHead, self).__init__() | |
self.attention1d = Attention1dPooling(hidden_size) | |
self.attention1d_projection = Attention1dPoolingProjection(hidden_size, num_labels, dropout) | |
def forward(self, x, input_mask=None): | |
x = self.attention1d(x, input_mask=input_mask.unsqueeze(-1)) | |
x = self.attention1d_projection(x) | |
return x | |
class MeanPooling(nn.Module): | |
"""Mean Pooling for sentence-level classification tasks.""" | |
def __init__(self): | |
super().__init__() | |
def forward(self, features, input_mask=None): | |
if input_mask is not None: | |
# Applying input_mask to zero out masked values | |
masked_features = features * input_mask.unsqueeze(2) | |
sum_features = torch.sum(masked_features, dim=1) | |
mean_pooled_features = sum_features / input_mask.sum(dim=1, keepdim=True) | |
else: | |
mean_pooled_features = torch.mean(features, dim=1) | |
return mean_pooled_features | |
class MeanPoolingProjection(nn.Module): | |
"""Mean Pooling with a projection layer for sentence-level classification tasks.""" | |
def __init__(self, hidden_size, num_labels, dropout=0.25): | |
super().__init__() | |
self.dense = nn.Linear(hidden_size, hidden_size) | |
self.dropout = nn.Dropout(dropout) | |
self.out_proj = nn.Linear(hidden_size, num_labels) | |
def forward(self, mean_pooled_features): | |
x = self.dropout(mean_pooled_features) | |
x = self.dense(x) | |
x = ACT2FN['gelu'](x) | |
x = self.dropout(x) | |
x = self.out_proj(x) | |
return x | |
class MeanPoolingHead(nn.Module): | |
"""Mean Pooling Head for sentence-level classification tasks.""" | |
def __init__(self, hidden_size, num_labels, dropout=0.25): | |
super().__init__() | |
self.mean_pooling = MeanPooling() | |
self.mean_pooling_projection = MeanPoolingProjection(hidden_size, num_labels, dropout) | |
def forward(self, features, input_mask=None): | |
mean_pooling_features = self.mean_pooling(features, input_mask=input_mask) | |
x = self.mean_pooling_projection(mean_pooling_features) | |
return x | |
class LightAttentionPoolingHead(nn.Module): | |
def __init__(self, hidden_size=1280, num_labels=11, dropout=0.25, kernel_size=9, conv_dropout: float = 0.25): | |
super(LightAttentionPoolingHead, self).__init__() | |
self.feature_convolution = nn.Conv1d(hidden_size, hidden_size, kernel_size, stride=1, | |
padding=kernel_size // 2) | |
self.attention_convolution = nn.Conv1d(hidden_size, hidden_size, kernel_size, stride=1, | |
padding=kernel_size // 2) | |
self.softmax = nn.Softmax(dim=-1) | |
self.dropout = nn.Dropout(conv_dropout) | |
self.linear = nn.Sequential( | |
nn.Linear(2 * hidden_size, 32), | |
nn.Dropout(dropout), | |
nn.ReLU(), | |
nn.BatchNorm1d(32) | |
) | |
self.output = nn.Linear(32, num_labels) | |
def forward(self, x: torch.Tensor, mask, **kwargs) -> torch.Tensor: | |
""" | |
Args: | |
x: [batch_size, sequence_length, hidden_size] embedding tensor that should be classified | |
mask: [batch_size, sequence_length] mask corresponding to the zero padding used for the shorter sequecnes in the batch. All values corresponding to padding are False and the rest is True. | |
Returns: | |
classification: [batch_size,num_labels] tensor with logits | |
""" | |
x = x.permute(0, 2, 1) # [batch_size, hidden_size, sequence_length] | |
o = self.feature_convolution(x) # [batch_size, hidden_size, sequence_length] | |
o = self.dropout(o) # [batch_gsize, hidden_size, sequence_length] | |
attention = self.attention_convolution(x) # [batch_size, hidden_size, sequence_length] | |
# mask out the padding to which we do not want to pay any attention (we have the padding because the sequences have different lenghts). | |
# This padding is added by the dataloader when using the padded_permuted_collate function in utils/general.py | |
attention = attention.masked_fill(mask[:, None, :] == False, -1e9) | |
# code used for extracting embeddings for UMAP visualizations | |
# extraction = torch.sum(x * self.softmax(attention), dim=-1) | |
# extraction = self.id0(extraction) | |
o1 = torch.sum(o * self.softmax(attention), dim=-1) # [batchsize, hidden_size] | |
o2, _ = torch.max(o, dim=-1) # [batchsize, hidden_size] | |
o = torch.cat([o1, o2], dim=-1) # [batchsize, 2*hidden_size] | |
o = self.linear(o) # [batchsize, 32] | |
return self.output(o) # [batchsize, num_labels] |