|
import torch |
|
import torch.nn as nn |
|
|
|
from torch.utils.data import Dataset, DataLoader , random_split |
|
from datasets import load_dataset , concatenate_datasets |
|
from tokenizers import Tokenizer |
|
from tokenizers.models import BPE,WordLevel |
|
from tokenizers.trainers import BpeTrainer,WordLevelTrainer |
|
from tokenizers.pre_tokenizers import ByteLevel,Whitespace |
|
from tokenizers.processors import TemplateProcessing |
|
from tokenizers import decoders |
|
from torch.cuda.amp import autocast, GradScaler |
|
|
|
import time |
|
|
|
from torch.utils.tensorboard import SummaryWriter |
|
from itertools import islice |
|
|
|
from config import get_weights_file_path, get_config |
|
|
|
from tqdm import tqdm |
|
|
|
from pathlib import Path |
|
|
|
import warnings |
|
from dataset import BilingualDataset |
|
from model import build_gpt |
|
|
|
g = torch.Generator() |
|
g.manual_seed(23) |
|
|
|
def greedy_decode(model, text,mask, tokenizer, max_len, device): |
|
sos_idx = tokenizer.token_to_id('<s>') |
|
eos_idx = tokenizer.token_to_id('</s>') |
|
|
|
decoder_input = torch.empty(1,1).fill_(sos_idx).type_as(text).to(device) |
|
while True: |
|
if decoder_input.size(1) == max_len: |
|
break |
|
|
|
decoder_mask = causal_mask(decoder_input.size(1)).type_as(mask).to(device) |
|
|
|
out = model.decode(decoder_input, decoder_mask) |
|
|
|
prob = model.project(out[:,-1]) |
|
_, next_word = torch.max(prob, dim=1) |
|
|
|
|
|
decoder_input = torch.cat([decoder_input, torch.empty(1,1).type_as(text).fill_(next_word.item()).to(device)],dim=1) |
|
|
|
if next_word == eos_idx: |
|
break |
|
|
|
return decoder_input.squeeze(0) |
|
def generate_text( |
|
model, text, mask, tokenizer, max_len, device, |
|
temperature=0.7, top_k=50 |
|
): |
|
eos_idx = tokenizer.token_to_id('</s>') |
|
|
|
|
|
decoder_input = text.to(device) |
|
if decoder_input.dim() == 1: |
|
decoder_input = decoder_input.unsqueeze(0) |
|
|
|
|
|
|
|
prompt_text = tokenizer.decode(text.squeeze(0).tolist()) |
|
print(prompt_text, end="", flush=True) |
|
|
|
while len(decoder_input[0]) < max_len - 3: |
|
|
|
decoder_mask = causal_mask(decoder_input.size(1)).type_as(mask).to(device) |
|
|
|
|
|
out = model.decode(decoder_input, decoder_mask) |
|
logits = model.project(out[:, -1]) |
|
|
|
|
|
logits = logits / temperature |
|
top_k_logits, top_k_indices = torch.topk(logits, top_k) |
|
probs = torch.softmax(top_k_logits, dim=-1) |
|
next_token = torch.multinomial(probs, num_samples=1) |
|
next_token = top_k_indices.gather(-1, next_token) |
|
|
|
|
|
word = tokenizer.decode([next_token.item()]) |
|
print(word, end="", flush=True) |
|
|
|
|
|
decoder_input = torch.cat([decoder_input, next_token], dim=1) |
|
|
|
if next_token.item() == eos_idx: |
|
break |
|
|
|
print() |
|
return decoder_input.squeeze(0) |
|
|
|
|
|
def generate_text_(model, text,m, tokenizer, max_len, device, temperature=0.7, top_k=50): |
|
sos_idx = tokenizer.token_to_id('<s>') |
|
eos_idx = tokenizer.token_to_id('</s>') |
|
pad_idx = tokenizer.token_to_id('<pad>') |
|
|
|
|
|
input_tokens = [sos_idx] + tokenizer.encode(text).ids |
|
|
|
|
|
input_tokens = input_tokens[:max_len-1] |
|
|
|
|
|
decoder_input = torch.tensor(input_tokens, device=device).unsqueeze(0) |
|
|
|
|
|
for _ in range(max_len - len(input_tokens)): |
|
|
|
decoder_mask = causal_mask(decoder_input.size(1)).to(device) |
|
|
|
|
|
out = model.decode(decoder_input, decoder_mask) |
|
logits = model.project(out[:, -1]) |
|
|
|
|
|
logits = logits / temperature |
|
top_k_logits, top_k_indices = torch.topk(logits, top_k) |
|
probs = torch.softmax(top_k_logits, dim=-1) |
|
next_token = torch.multinomial(probs, num_samples=1) |
|
next_token = top_k_indices.gather(-1, next_token) |
|
|
|
|
|
word = tokenizer.decode([next_token.item()]) |
|
print(word, end="", flush=True) |
|
|
|
|
|
decoder_input = torch.cat([decoder_input, next_token.unsqueeze(0)], dim=1) |
|
|
|
if next_token.item() == eos_idx: |
|
break |
|
|
|
return decoder_input.squeeze(0) |
|
|
|
def run_validation(model,validation_ds, tokenizer, max_len, device, print_msg, global_state, writer, num_examples=2): |
|
model.eval() |
|
|
|
count = 0 |
|
pad_token = torch.tensor([tokenizer.token_to_id('<pad>')],dtype=torch.int64) |
|
sos_token = torch.tensor([tokenizer.token_to_id('<s>')],dtype=torch.int64) |
|
with torch.no_grad(): |
|
for batch in validation_ds: |
|
count += 1 |
|
input_tokens = batch['input'] |
|
|
|
|
|
print("TOKENIZED INPUT : ",input_tokens) |
|
input_tokens = input_tokens |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
input_tokens = torch.tensor(input_tokens) |
|
|
|
mask = causal_mask(input_tokens.size(0)) |
|
|
|
|
|
model_output = generate_text(model, input_tokens, mask, tokenizer, max_len, device) |
|
|
|
print_msg("Model Output Embedding : ") |
|
print_msg(str(model_output.tolist())) |
|
|
|
model_out_text = tokenizer.decode(model_output.detach().cpu().numpy()) |
|
|
|
|
|
|
|
print_msg(f'SOURCE : {input_tokens}') |
|
print_msg(f'PREDICTED : {model_out_text}') |
|
|
|
if count == num_examples: |
|
break |
|
|
|
def get_all_sentences(ds): |
|
for item in ds: |
|
yield item['text'] |
|
|
|
def get_or_build_tokenizer_(config,ds): |
|
tokenizer_path = Path(config['tokenizer_file']) |
|
if not Path.exists(tokenizer_path): |
|
tokenizer = Tokenizer(WordLevel(unk_token="<unk>")) |
|
tokenizer.pre_tokenizer = Whitespace() |
|
trainer = WordLevelTrainer(special_tokens=["<s>", "</s>", "<pad>", "<unk>", "<mask>","<user>","<ai>","<search_start>","<search_end>","<think>","</think>"],min_frequency=2) |
|
tokenizer.train_from_iterator(get_all_sentences(ds),trainer=trainer) |
|
tokenizer.save(str(tokenizer_path)) |
|
else: |
|
tokenizer = Tokenizer.from_file(str(tokenizer_path)) |
|
return tokenizer |
|
|
|
def get_or_build_tokenizer(config, ds): |
|
tokenizer_path = Path(config['tokenizer_file']) |
|
|
|
if not tokenizer_path.exists(): |
|
|
|
tokenizer = Tokenizer(BPE(unk_token="<unk>")) |
|
|
|
|
|
tokenizer.pre_tokenizer = ByteLevel(add_prefix_space=True) |
|
tokenizer.decoder = decoders.ByteLevel() |
|
|
|
|
|
tokenizer.post_processor = TemplateProcessing( |
|
single="<s> $A </s>", |
|
pair="<s> $A </s> <s> $B </s>", |
|
special_tokens=[ |
|
("<s>", 0), |
|
("</s>", 1), |
|
], |
|
) |
|
|
|
|
|
trainer = BpeTrainer( |
|
vocab_size = 30000, |
|
min_frequency=2, |
|
special_tokens=["<s>", "</s>", "<pad>", "<unk>", "<mask>","<user>","<ai>","<search_start>","<search_end>","<think>","</think>"] |
|
) |
|
|
|
|
|
tokenizer.train_from_iterator(get_all_sentences(ds), trainer=trainer) |
|
|
|
|
|
tokenizer.save(str(tokenizer_path)) |
|
|
|
else: |
|
tokenizer = Tokenizer.from_file(str(tokenizer_path)) |
|
|
|
return tokenizer |
|
|
|
def get_ds(config): |
|
|
|
ds_raw = load_dataset("json",data_files='./dataset/openwebtext_500k_docs.jsonl',split="train",streaming=True) |
|
ds_test = load_dataset("json",data_files='./dataset/openwebtext_test.jsonl',split="train",streaming=True) |
|
|
|
|
|
tokenizer = get_or_build_tokenizer(config,ds_raw) |
|
|
|
train_ds = BilingualDataset(ds_raw, tokenizer, config['seq_len']) |
|
val_ds = BilingualDataset(ds_test, tokenizer, config['seq_len']) |
|
train_dataloader = DataLoader(train_ds, num_workers=6,prefetch_factor=2,pin_memory=True,batch_size=config['batch_size']) |
|
val_dataloader = DataLoader(val_ds, batch_size=1) |
|
|
|
return train_dataloader, val_dataloader, tokenizer |
|
|
|
def get_model(config, vocab_size): |
|
|
|
model = build_gpt( vocab_size, config['seq_len'], config['d_model'], config['N'] , config['h'], config['d_ff'],config['dropout']) |
|
return model |
|
|
|
def validate_model(val_dataloader, model,device,loss_fn,vocab_size): |
|
total_loss = 0 |
|
model.eval() |
|
i = 0 |
|
with torch.no_grad(): |
|
for batch in val_dataloader: |
|
input_tokens = batch['input'].to(device,non_blocking=True) |
|
label = batch['label'].to(device,non_blocking=True) |
|
decoder_output = model.decode(input_tokens) |
|
project_output = model.project(decoder_output) |
|
total_loss += loss_fn( |
|
project_output.view(-1,vocab_size), |
|
label.view(-1) |
|
) |
|
i+=1 |
|
print(f"Validation loss : {total_loss/i:4f}") |
|
|
|
|
|
|
|
|
|
def train_model(config): |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
print(f"Using device : {device}") |
|
|
|
|
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
torch.backends.cudnn.allow_tf32 = True |
|
|
|
Path(config['model_folder']).mkdir(parents=True, exist_ok=True) |
|
|
|
train_dataloader , val_dataloader, tokenizer = get_ds(config) |
|
print(tokenizer.get_vocab_size()) |
|
model = get_model(config, tokenizer.get_vocab_size()).to(device) |
|
|
|
writer = SummaryWriter(config['experiment_name']) |
|
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'], eps=1e-9) |
|
scaler = GradScaler() |
|
|
|
initial_epoch = 0 |
|
global_step = 0 |
|
tqdm_state = {'n':0} |
|
|
|
model_filename = None |
|
if config['preload']: |
|
model_filename = get_weights_file_path(config, config['preload']) |
|
print(f"Preloading Model {model_filename}") |
|
state = torch.load(model_filename) |
|
model.load_state_dict(state['model_state_dict']) |
|
optimizer.load_state_dict(state['optimizer_state_dict']) |
|
initial_epoch = state['epoch'] if 'mid-' in model_filename else state['epoch'] + 1 |
|
global_step = state['global_step'] |
|
tqdm_state = state['tqdm_state'] if 'mid-' in model_filename else {'n':0} |
|
else: |
|
print("No Model to preload. Setting from scratch.") |
|
|
|
loss_fn = nn.CrossEntropyLoss( |
|
ignore_index=tokenizer.token_to_id('<pad>'), |
|
label_smoothing=0.05 |
|
).to(device) |
|
e = 0 |
|
|
|
try: |
|
|
|
for epoch in range(initial_epoch, config['num_epochs']): |
|
model.train() |
|
batch_iterator = tqdm(islice(train_dataloader,tqdm_state['n'],None), desc=f'Processing epoch {epoch:02d}',initial=tqdm_state['n'] ,total=140000) |
|
e = epoch |
|
if 'elapsed' in tqdm_state and "mid-" in model_filename : |
|
batch_iterator.start_t = time.time() - tqdm_state['elapsed'] |
|
|
|
for batch in batch_iterator: |
|
|
|
|
|
|
|
input_tokens = batch['input'].to(device,non_blocking=True) |
|
label = batch['label'].to(device,non_blocking=True) |
|
|
|
optimizer.zero_grad(set_to_none=True) |
|
|
|
|
|
with autocast(dtype=torch.float16): |
|
decoder_output = model.decode(input_tokens) |
|
project_output = model.project(decoder_output) |
|
|
|
loss = loss_fn( |
|
project_output.view(-1, tokenizer.get_vocab_size()), |
|
label.view(-1) |
|
) |
|
if global_step%10 ==0: |
|
batch_iterator.set_postfix({f"loss": f"{loss.item():6.3f}"}) |
|
writer.add_scalar("train loss", loss.item(), global_step) |
|
writer.flush() |
|
if global_step % 10000 == 0 and global_step != 0: |
|
validate_model(val_dataloader,model,device,loss_fn,tokenizer.get_vocab_size()) |
|
|
|
|
|
scaler.scale(loss).backward() |
|
scaler.step(optimizer) |
|
scaler.update() |
|
|
|
global_step += 1 |
|
tqdm_state = {'n': batch_iterator.n,'elapsed':batch_iterator.format_dict["elapsed"]} |
|
|
|
tqdm_state['n'] = 0 |
|
del tqdm_state['elapsed'] |
|
|
|
model_filename = get_weights_file_path(config, f'{epoch:02d}') |
|
torch.save({ |
|
'epoch': epoch, |
|
'model_state_dict': model.state_dict(), |
|
'optimizer_state_dict': optimizer.state_dict(), |
|
'global_step': global_step, |
|
'tqdm_state':tqdm_state |
|
}, model_filename) |
|
validate_model(validate_model,model,device,loss_fn,tokenizer.get_vocab_size()) |
|
except KeyboardInterrupt: |
|
print("You are stoping training : ... ") |
|
model_filename = get_weights_file_path(config, f'mid-{e:02d}{input("Checkpoint Name: ")}') |
|
torch.save({ |
|
'epoch': e, |
|
'model_state_dict': model.state_dict(), |
|
'optimizer_state_dict': optimizer.state_dict(), |
|
'global_step': global_step, |
|
'tqdm_state':tqdm_state |
|
}, model_filename) |
|
|
|
if __name__ == "__main__": |
|
warnings.filterwarnings('ignore') |
|
config = get_config("./openweb.config.json") |
|
train_model(config) |
|
|