import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader , random_split
from datasets import load_dataset , concatenate_datasets
from tokenizers import Tokenizer
from tokenizers.models import BPE,WordLevel
from tokenizers.trainers import BpeTrainer,WordLevelTrainer
from tokenizers.pre_tokenizers import ByteLevel,Whitespace
from tokenizers.processors import TemplateProcessing
from tokenizers import decoders
from torch.cuda.amp import autocast, GradScaler
import time
from torch.utils.tensorboard import SummaryWriter
from itertools import islice
from config import get_weights_file_path, get_config
from tqdm import tqdm
from pathlib import Path
import warnings
from fine_tune_dataset import BilingualDataset
from model import build_gpt
g = torch.Generator()
g.manual_seed(23)
def greedy_decode(model, text,mask, tokenizer, max_len, device):
sos_idx = tokenizer.token_to_id('')
eos_idx = tokenizer.token_to_id('')
decoder_input = torch.empty(1,1).fill_(sos_idx).type_as(text).to(device)
while True:
if decoder_input.size(1) == max_len:
break
decoder_mask = causal_mask(decoder_input.size(1)).type_as(mask).to(device)
out = model.decode(decoder_input, decoder_mask)
prob = model.project(out[:,-1])
_, next_word = torch.max(prob, dim=1)
decoder_input = torch.cat([decoder_input, torch.empty(1,1).type_as(text).fill_(next_word.item()).to(device)],dim=1)
if next_word == eos_idx:
break
return decoder_input.squeeze(0)
def generate_text(
model, text, mask, tokenizer, max_len, device,
temperature=0.7, top_k=50
):
eos_idx = tokenizer.token_to_id('')
# Start with the input text as initial decoder input
decoder_input = text.to(device)
if decoder_input.dim() == 1:
decoder_input = decoder_input.unsqueeze(0)
# Print the initial prompt
prompt_text = tokenizer.decode(text.squeeze(0).tolist())
print(prompt_text, end="", flush=True)
while len(decoder_input[0]) < max_len - 3:
# Apply causal mask based on current decoder_input length
decoder_mask = causal_mask(decoder_input.size(1)).type_as(mask).to(device)
# Get model output
out = model.decode(decoder_input, decoder_mask)
logits = model.project(out[:, -1]) # Get logits for last token
# Sampling: temperature + top-k
logits = logits / temperature
top_k_logits, top_k_indices = torch.topk(logits, top_k)
probs = torch.softmax(top_k_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
next_token = top_k_indices.gather(-1, next_token)
# Decode and print token
word = tokenizer.decode([next_token.item()])
print(word, end="", flush=True)
# Append next token
decoder_input = torch.cat([decoder_input, next_token], dim=1)
if next_token.item() == eos_idx:
break
print()
return decoder_input.squeeze(0)
def generate_text_(model, text,m, tokenizer, max_len, device, temperature=0.7, top_k=50):
sos_idx = tokenizer.token_to_id('')
eos_idx = tokenizer.token_to_id('')
pad_idx = tokenizer.token_to_id('')
# Encode input and add at beginning
input_tokens = [sos_idx] + tokenizer.encode(text).ids
# Truncate if too long
input_tokens = input_tokens[:max_len-1] # Leave room for
# Convert to tensor
decoder_input = torch.tensor(input_tokens, device=device).unsqueeze(0)
# Generate until max_len
for _ in range(max_len - len(input_tokens)):
# Create causal mask for what we've generated so far
decoder_mask = causal_mask(decoder_input.size(1)).to(device)
# Get model output
out = model.decode(decoder_input, decoder_mask)
logits = model.project(out[:, -1])
# Apply sampling
logits = logits / temperature
top_k_logits, top_k_indices = torch.topk(logits, top_k)
probs = torch.softmax(top_k_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
next_token = top_k_indices.gather(-1, next_token)
# Print the generated word
word = tokenizer.decode([next_token.item()])
print(word, end="", flush=True)
# Append to input
decoder_input = torch.cat([decoder_input, next_token.unsqueeze(0)], dim=1)
if next_token.item() == eos_idx:
break
return decoder_input.squeeze(0)
def run_validation(model,validation_ds, tokenizer, max_len, device, print_msg, global_state, writer, num_examples=2):
model.eval()
count = 0
pad_token = torch.tensor([tokenizer.token_to_id('')],dtype=torch.int64)
sos_token = torch.tensor([tokenizer.token_to_id('')],dtype=torch.int64)
with torch.no_grad():
for batch in validation_ds:
count += 1
input_tokens = batch['input']
# print("TEXT INPUT : ",text)
# input_tokens = tokenizer.encode(text).ids[:-1]
print("TOKENIZED INPUT : ",input_tokens)
input_tokens = input_tokens
# if len(input_tokens) < config['seq_len'] :
# input_tokens+=[pad_token] * ((config['seq_len'] ) - len(input_tokens))
# if len(input_tokens) > config['seq_len'] :
# input_tokens = input_tokens[:config['seq_len']]
input_tokens = torch.tensor(input_tokens)
# (input_tokens != pad_token).unsqueeze(0).int() &
mask = causal_mask(input_tokens.size(0))
# text = batch['input'].to(device)
# mask = batch['input_mask'].to(device)
model_output = generate_text(model, input_tokens, mask, tokenizer, max_len, device)
# model_output = greed0y_decode(model, text, mask,tokenizer, max_len, device)
print_msg("Model Output Embedding : ")
print_msg(str(model_output.tolist()))
model_out_text = tokenizer.decode(model_output.detach().cpu().numpy())
# text = tokenizer.decode(input_tokens[0].tolist(),skip_special_tokens=True)
#print
print_msg(f'SOURCE : {input_tokens}')
print_msg(f'PREDICTED : {model_out_text}')
if count == num_examples:
break
def get_all_sentences(ds):
for item in ds:
yield item['text']
def get_or_build_tokenizer_(config,ds):
tokenizer_path = Path(config['tokenizer_file'])
if not Path.exists(tokenizer_path):
tokenizer = Tokenizer(WordLevel(unk_token=""))
tokenizer.pre_tokenizer = Whitespace()
trainer = WordLevelTrainer(special_tokens=["", "", "", "", "","","","","","",""],min_frequency=2)
tokenizer.train_from_iterator(get_all_sentences(ds),trainer=trainer)
tokenizer.save(str(tokenizer_path))
else:
tokenizer = Tokenizer.from_file(str(tokenizer_path))
return tokenizer
def get_or_build_tokenizer(config, ds):
tokenizer_path = Path(config['tokenizer_file'])
if not tokenizer_path.exists():
# Define tokenizer with BPE model
tokenizer = Tokenizer(BPE(unk_token=""))
# ByteLevel pre-tokenizer and decoder
tokenizer.pre_tokenizer = ByteLevel(add_prefix_space=True)
tokenizer.decoder = decoders.ByteLevel()
# Optional: Add post-processing for special tokens
tokenizer.post_processor = TemplateProcessing(
single=" $A ",
pair=" $A $B ",
special_tokens=[
("", 0),
("", 1),
],
)
# Trainer
trainer = BpeTrainer(
vocab_size = 30000,
min_frequency=2,
special_tokens=["", "", "", "", "","","","","","",""]
)
# Train from dataset
tokenizer.train_from_iterator(get_all_sentences(ds), trainer=trainer)
# Save as single .json file
tokenizer.save(str(tokenizer_path))
else:
tokenizer = Tokenizer.from_file(str(tokenizer_path))
return tokenizer
def get_ds(config):
# ds_raw = load_dataset("json",data_files={'train':config['train'],'test':config['test']})
ds_raw = load_dataset("tatsu-lab/alpaca",split="train[:50000]")
ds_test = load_dataset("tatsu-lab/alpaca",split="train[-2002:]")
# ds_raw = ds_raw[:1]
# ds_raw = load_dataset("stas/openwebtext-10k")
tokenizer = get_or_build_tokenizer(config,ds_raw)
# tokenizer = get_or_build_tokenizer(config,ds_raw)
train_ds = BilingualDataset(ds_raw, tokenizer, config['seq_len'])
val_ds = BilingualDataset(ds_test, tokenizer, config['seq_len'])
train_dataloader = DataLoader(train_ds, num_workers=6,prefetch_factor=2,pin_memory=True,batch_size=config['batch_size'])
val_dataloader = DataLoader(val_ds, batch_size=1)
return train_dataloader, val_dataloader, tokenizer
def get_model(config, vocab_size):
# model = build_transformer(vocab_src_len,vocab_tgt_len,config['seq_len'],config['seq_len'],config['d_model'], config['N'] , config['h'], config['d_ff'])
model = build_gpt( vocab_size, config['seq_len'], config['d_model'], config['N'] , config['h'], config['d_ff'],config['dropout'])
return model
def validate_model(val_dataloader, model,device,loss_fn,vocab_size):
total_loss = 0
model.eval()
i = 0
with torch.no_grad():
for batch in val_dataloader:
input_tokens = batch['input'].to(device,non_blocking=True)
label = batch['label'].to(device,non_blocking=True)
decoder_output = model.decode(input_tokens)
project_output = model.project(decoder_output)
total_loss += loss_fn(
project_output.view(-1,vocab_size),
label.view(-1)
)
i+=1
print(f"Validation loss : {total_loss/i:4f}")
def train_model(config):
#Define the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device : {device}")
# Enable TF32 (optional, speeds up matmul)
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
Path(config['model_folder']).mkdir(parents=True, exist_ok=True)
train_dataloader , val_dataloader, tokenizer = get_ds(config)
print(tokenizer.get_vocab_size())
model = get_model(config, tokenizer.get_vocab_size()).to(device)
# TensorBoard
writer = SummaryWriter(config['experiment_name'])
optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'], eps=1e-9)
scaler = GradScaler() # <- added scaler for mixed precision
initial_epoch = 0
global_step = 0
tqdm_state = {'n':0}
model_filename = None
if config['preload']:
model_filename = get_weights_file_path(config, config['preload'])
print(f"Preloading Model {model_filename}")
state = torch.load(model_filename)
model.load_state_dict(state['model_state_dict'])
optimizer.load_state_dict(state['optimizer_state_dict'])
initial_epoch = state['epoch'] if 'mid-' in model_filename else state['epoch'] + 1
global_step = state['global_step']
tqdm_state = state['tqdm_state'] if 'mid-' in model_filename else {'n':0}
else:
print("No Model to preload. Setting from scratch.")
loss_fn = nn.CrossEntropyLoss(
ignore_index=tokenizer.token_to_id(''),
label_smoothing=0.05
).to(device)
e = 0
try:
for epoch in range(initial_epoch, config['num_epochs']):
model.train()
batch_iterator = tqdm(islice(train_dataloader,tqdm_state['n'],None), desc=f'Processing epoch {epoch:02d}',initial=tqdm_state['n'] ,total=len(train_dataloader))#total=217013)
e = epoch
if 'elapsed' in tqdm_state and "mid-" in model_filename :
batch_iterator.start_t = time.time() - tqdm_state['elapsed']
# total_len = len(batch_iterator)
for batch in batch_iterator:
# print(len(batch_iterator))
# torch.cuda.empty_cache()
input_tokens = batch['input'].to(device,non_blocking=True)
label = batch['label'].to(device,non_blocking=True)
optimizer.zero_grad(set_to_none=True)
# 🔥 Mixed precision forward pass
with autocast(dtype=torch.float16):
decoder_output = model.decode(input_tokens)
project_output = model.project(decoder_output) # (B, Seq_len, tgt_vocab_size)
loss = loss_fn(
project_output.view(-1, tokenizer.get_vocab_size()),
label.view(-1)
)
if global_step%10 ==0:
batch_iterator.set_postfix({f"loss": f"{loss.item():6.3f}"})
writer.add_scalar("train loss", loss.item(), global_step)
writer.flush()
if global_step % 10000 == 0 and global_step != 0:
validate_model(val_dataloader,model,device,loss_fn,tokenizer.get_vocab_size())
# 🔥 Mixed precision backward pass
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
global_step += 1
tqdm_state = {'n': batch_iterator.n,'elapsed':batch_iterator.format_dict["elapsed"]}
# if()
tqdm_state['n'] = 0
del tqdm_state['elapsed']
model_filename = get_weights_file_path(config, f'{epoch:02d}')
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'global_step': global_step,
'tqdm_state':tqdm_state
}, model_filename)
validate_model(train_dataloader,model,device,loss_fn,tokenizer.get_vocab_size())
except KeyboardInterrupt:
print("You are stoping training : ... ")
model_filename = get_weights_file_path(config, f'mid-{e:02d}{input("Checkpoint Name: ")}')
torch.save({
'epoch': e,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'global_step': global_step,
'tqdm_state':tqdm_state
}, model_filename)
if __name__ == "__main__":
warnings.filterwarnings('ignore')
config = get_config("./openweb.config.json")
train_model(config)