pfe_site / FrModel.py
YsnHdn's picture
Adding the result side legend
834b97c
raw
history blame contribute delete
7.83 kB
import torch
import torch.nn as nn
import numpy as np
from transformers import AutoTokenizer
import pickle
fr_tokenizer = AutoTokenizer.from_pretrained("camembert-base")
# FRENCH Tokenizer
fr_tokenizer.add_special_tokens({'additional_special_tokens': ['<MULT>']})
fr_mult_token_id = fr_tokenizer.convert_tokens_to_ids('<MULT>')
fr_cls_token_id = fr_tokenizer.cls_token_id
fr_sep_token_id = fr_tokenizer.sep_token_id
fr_pad_token_id = fr_tokenizer.pad_token_id
## Voice Part functions
maxlen = 255 # maximum of length
batch_size = 32
max_pred = 5 # max tokens of prediction
n_layers = 6 # number of Encoder of Encoder Layer
n_heads = 12 # number of heads in Multi-Head Attention
d_model = 768 # Embedding Size
d_ff = 768 * 4 # 4*d_model, FeedForward dimension
d_k = d_v = 64 # dimension of K(=Q), V
n_segments = 2
fr_vocab_size = fr_tokenizer.vocab_size +2
def get_attn_pad_mask(seq_q, seq_k):
batch_size, len_q = seq_q.size()
batch_size, len_k = seq_k.size()
# eq(zero) is PAD token
pad_attn_mask = seq_k.data.eq(1).unsqueeze(1) # batch_size x 1 x len_k(=len_q), one is masking
return pad_attn_mask.expand(batch_size, len_q, len_k) # batch_size x len_q x len_k
class Embedding(nn.Module):
def __init__(self):
super(Embedding, self).__init__()
self.tok_embed = nn.Embedding(fr_vocab_size, d_model) # token embedding
self.pos_embed = nn.Embedding(maxlen, d_model) # position embedding
self.seg_embed = nn.Embedding(n_segments, d_model) # segment(token type) embedding
self.norm = nn.LayerNorm(d_model)
def forward(self, x, seg):
seq_len = x.size(1)
pos = torch.arange(seq_len, dtype=torch.long, device=x.device)
pos = pos.unsqueeze(0).expand_as(x) # (seq_len,) -> (batch_size, seq_len)
embedding = self.tok_embed(x)
embedding += self.pos_embed(pos)
embedding += self.seg_embed(seg)
return self.norm(embedding)
class ScaledDotProductAttention(nn.Module):
def __init__(self):
super(ScaledDotProductAttention, self).__init__()
def forward(self, Q, K, V, attn_mask):
scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k) # scores : [batch_size x n_heads x len_q(=len_k) x len_k(=len_q)]
scores.masked_fill_(attn_mask, -1e9) # Fills elements of self tensor with value where mask is one.
attn = nn.Softmax(dim=-1)(scores)
context = torch.matmul(attn, V)
return scores , context, attn
class MultiHeadAttention(nn.Module):
def __init__(self):
super(MultiHeadAttention, self).__init__()
self.W_Q = nn.Linear(d_model, d_k * n_heads)
self.W_K = nn.Linear(d_model, d_k * n_heads)
self.W_V = nn.Linear(d_model, d_v * n_heads)
self.fc = nn.Linear(n_heads * d_v, d_model)
self.norm = nn.LayerNorm(d_model)
def forward(self, Q, K, V, attn_mask):
# q: [batch_size x len_q x d_model], k: [batch_size x len_k x d_model], v: [batch_size x len_k x d_model]
residual, batch_size = Q, Q.size(0)
device = Q.device
Q, K, V = Q.to(device), K.to(device), V.to(device)
# (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W)
q_s = self.W_Q(Q).view(batch_size, -1, n_heads, d_k).transpose(1,2) # q_s: [batch_size x n_heads x len_q x d_k]
k_s = self.W_K(K).view(batch_size, -1, n_heads, d_k).transpose(1,2) # k_s: [batch_size x n_heads x len_k x d_k]
v_s = self.W_V(V).view(batch_size, -1, n_heads, d_v).transpose(1,2) # v_s: [batch_size x n_heads x len_k x d_v]
attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1) # attn_mask : [batch_size x n_heads x len_q x len_k]
# context: [batch_size x n_heads x len_q x d_v], attn: [batch_size x n_heads x len_q(=len_k) x len_k(=len_q)]
scores ,context, attn = ScaledDotProductAttention()(q_s, k_s, v_s, attn_mask)
context = context.transpose(1, 2).contiguous().view(batch_size, -1, n_heads * d_v) # context: [batch_size x len_q x n_heads * d_v]
output = self.fc(context)
return self.norm(output + residual), attn # output: [batch_size x len_q x d_model]
class PoswiseFeedForwardNet(nn.Module):
def __init__(self):
super(PoswiseFeedForwardNet, self).__init__()
self.fc1 = nn.Linear(d_model, d_ff)
self.fc2 = nn.Linear(d_ff, d_model)
self.gelu = nn.GELU()
def forward(self, x):
# (batch_size, len_seq, d_model) -> (batch_size, len_seq, d_ff) -> (batch_size, len_seq, d_model)
return self.fc2(self.gelu(self.fc1(x)))
class EncoderLayer(nn.Module):
def __init__(self):
super(EncoderLayer, self).__init__()
self.enc_self_attn = MultiHeadAttention()
self.pos_ffn = PoswiseFeedForwardNet()
def forward(self, enc_inputs, enc_self_attn_mask):
enc_outputs, attn = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask.to(enc_inputs.device)) # enc_inputs to same Q,K,V
enc_outputs = self.pos_ffn(enc_outputs) # enc_outputs: [batch_size x len_q x d_model]
return enc_outputs, attn
class FR_BERT(nn.Module):
def __init__(self):
super(FR_BERT, self).__init__()
self.embedding = Embedding()
self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)])
self.fc = nn.Linear(d_model, d_model)
self.activ1 = nn.Tanh()
self.linear = nn.Linear(d_model, d_model)
self.activ2 = nn.GELU()
self.norm = nn.LayerNorm(d_model)
self.classifier = nn.Linear(d_model, 2)
# decoder is shared with embedding layer
embed_weight = self.embedding.tok_embed.weight
n_vocab, n_dim = embed_weight.size()
self.decoder = nn.Linear(n_dim, n_vocab, bias=False)
self.decoder.weight = embed_weight
self.decoder_bias = nn.Parameter(torch.zeros(n_vocab))
self.mclassifier = nn.Linear(d_model, 17)
def forward(self, input_ids, segment_ids, masked_pos):
output = self.embedding(input_ids, segment_ids)
enc_self_attn_mask = get_attn_pad_mask(input_ids, input_ids).to(output.device)
for layer in self.layers:
output, enc_self_attn = layer(output, enc_self_attn_mask)
# output : [batch_size, len, d_model], attn : [batch_size, n_heads, d_mode, d_model]
# it will be decided by first token(CLS)
h_pooled = self.activ1(self.fc(output[:, 0])) # [batch_size, d_model]
logits_clsf = self.classifier(h_pooled) # [batch_size, 2]
masked_pos = masked_pos[:, :, None].expand(-1, -1, output.size(-1)) # [batch_size, max_pred, d_model]
# get masked position from final output of transformer.
h_masked = torch.gather(output, 1, masked_pos) # masking position [batch_size, max_pred, d_model]
h_masked = self.norm(self.activ2(self.linear(h_masked)))
logits_lm = self.decoder(h_masked) + self.decoder_bias # [batch_size, max_pred, n_vocab]
h_mult_sent1 = self.activ1(self.fc(output[:, 1]))
logits_mclsf1 = self.mclassifier(h_mult_sent1)
mult2_token_id = fr_mult_token_id # Assuming mult_token_id is defined globally
mult2_positions = (input_ids == mult2_token_id).nonzero(as_tuple=False) # Find positions of [MULT2] tokens
# Ensure there are exactly two [MULT] tokens in each input sequence
assert mult2_positions.size(0) == 2 * input_ids.size(0)
mult2_positions = mult2_positions[1::2][:, 1]
# Gather the hidden states corresponding to the second [MULT] token
h_mult_sent2 = output[torch.arange(output.size(0)), mult2_positions]
logits_mclsf2 = self.mclassifier(h_mult_sent2)
logits_mclsf2 = self.mclassifier(h_mult_sent2)
return logits_lm, logits_clsf , logits_mclsf1 , logits_mclsf2