|
import torch
|
|
import torch.nn as nn
|
|
from multi_head_attention import MultiHeadAttention
|
|
from feedforward import FeedForward
|
|
|
|
class TransformerBlock(nn.Module):
|
|
def __init__(self, d_model, n_heads, ff_dim):
|
|
super(TransformerBlock, self).__init__()
|
|
|
|
self.attention = MultiHeadAttention(d_model, n_heads)
|
|
self.ffn = FeedForward(d_model, ff_dim)
|
|
self.norm1 = nn.LayerNorm(d_model)
|
|
self.norm2 = nn.LayerNorm(d_model)
|
|
self.dropout1 = nn.Dropout(0.1)
|
|
self.dropout2 = nn.Dropout(0.1)
|
|
|
|
def forward(self, x, mask=None):
|
|
|
|
attn_out = self.attention(x, x, x, mask)
|
|
x = self.norm1(x + self.dropout1(attn_out))
|
|
|
|
|
|
ff_out = self.ffn(x)
|
|
x = self.norm2(x + self.dropout2(ff_out))
|
|
|
|
return x
|
|
|