File size: 1,356 Bytes
70a6fd7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
import torch.nn as nn
import torch
class TransformerBlock(nn.Module):
def __init__(self, d_model, n_heads, ff_dim):
super().__init__()
self.attention = nn.MultiheadAttention(d_model, n_heads, batch_first=True)
self.ff = nn.Sequential(
nn.Linear(d_model, ff_dim),
nn.ReLU(),
nn.Linear(ff_dim, d_model),
)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
def forward(self, x):
attn_output, _ = self.attention(x, x, x)
x = self.norm1(x + attn_output)
x = self.norm2(x + self.ff(x))
return x
class TransformerModel(nn.Module):
def __init__(self, vocab_size, d_model, n_heads, n_layers, max_len):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_embedding = nn.Parameter(torch.randn(1, max_len, d_model))
self.transformer_blocks = nn.ModuleList([
TransformerBlock(d_model, n_heads, ff_dim=4*d_model)
for _ in range(n_layers)
])
self.output = nn.Linear(d_model, vocab_size)
def forward(self, x):
x = self.embedding(x) + self.pos_embedding[:, :x.size(1), :]
for block in self.transformer_blocks:
x = block(x)
return self.output(x)
|