Adapters
OpAI1.1 / transformer_block.py
Osher
Upload 14 files
70a6fd7 verified
raw
history blame contribute delete
878 Bytes
import torch
import torch.nn as nn
from multi_head_attention import MultiHeadAttention # Add this import
from feedforward import FeedForward
class TransformerBlock(nn.Module):
def __init__(self, d_model, n_heads, ff_dim):
super(TransformerBlock, self).__init__()
self.attention = MultiHeadAttention(d_model, n_heads)
self.ffn = FeedForward(d_model, ff_dim)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(0.1)
self.dropout2 = nn.Dropout(0.1)
def forward(self, x, mask=None):
# Multi-head attention
attn_out = self.attention(x, x, x, mask)
x = self.norm1(x + self.dropout1(attn_out))
# Feedforward network
ff_out = self.ffn(x)
x = self.norm2(x + self.dropout2(ff_out))
return x