|
import torch.nn as nn
|
|
import torch
|
|
|
|
class TransformerBlock(nn.Module):
|
|
def __init__(self, d_model, n_heads, ff_dim):
|
|
super().__init__()
|
|
self.attention = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
|
|
self.ff = nn.Sequential(
|
|
nn.Linear(d_model, ff_dim),
|
|
nn.ReLU(),
|
|
nn.Linear(ff_dim, d_model),
|
|
)
|
|
self.norm1 = nn.LayerNorm(d_model)
|
|
self.norm2 = nn.LayerNorm(d_model)
|
|
|
|
def forward(self, x):
|
|
attn_output, _ = self.attention(x, x, x)
|
|
x = self.norm1(x + attn_output)
|
|
x = self.norm2(x + self.ff(x))
|
|
return x
|
|
|
|
class TransformerModel(nn.Module):
|
|
def __init__(self, vocab_size, d_model, n_heads, n_layers, max_len):
|
|
super().__init__()
|
|
self.embedding = nn.Embedding(vocab_size, d_model)
|
|
self.pos_embedding = nn.Parameter(torch.randn(1, max_len, d_model))
|
|
self.transformer_blocks = nn.ModuleList([
|
|
TransformerBlock(d_model, n_heads, ff_dim=4*d_model)
|
|
for _ in range(n_layers)
|
|
])
|
|
self.output = nn.Linear(d_model, vocab_size)
|
|
|
|
def forward(self, x):
|
|
x = self.embedding(x) + self.pos_embedding[:, :x.size(1), :]
|
|
for block in self.transformer_blocks:
|
|
x = block(x)
|
|
return self.output(x)
|
|
|