import torch import torch.nn as nn from torch.utils.data import Dataset, DataLoader , random_split from datasets import load_dataset , concatenate_datasets from tokenizers import Tokenizer from tokenizers.models import BPE,WordLevel from tokenizers.trainers import BpeTrainer,WordLevelTrainer from tokenizers.pre_tokenizers import ByteLevel,Whitespace from tokenizers.processors import TemplateProcessing from tokenizers import decoders from torch.cuda.amp import autocast, GradScaler import time from torch.utils.tensorboard import SummaryWriter from itertools import islice from config import get_weights_file_path, get_config from tqdm import tqdm from pathlib import Path import warnings from fine_tune_dataset import BilingualDataset from model import build_gpt g = torch.Generator() g.manual_seed(23) def greedy_decode(model, text,mask, tokenizer, max_len, device): sos_idx = tokenizer.token_to_id('') eos_idx = tokenizer.token_to_id('') decoder_input = torch.empty(1,1).fill_(sos_idx).type_as(text).to(device) while True: if decoder_input.size(1) == max_len: break decoder_mask = causal_mask(decoder_input.size(1)).type_as(mask).to(device) out = model.decode(decoder_input, decoder_mask) prob = model.project(out[:,-1]) _, next_word = torch.max(prob, dim=1) decoder_input = torch.cat([decoder_input, torch.empty(1,1).type_as(text).fill_(next_word.item()).to(device)],dim=1) if next_word == eos_idx: break return decoder_input.squeeze(0) def generate_text( model, text, mask, tokenizer, max_len, device, temperature=0.7, top_k=50 ): eos_idx = tokenizer.token_to_id('') # Start with the input text as initial decoder input decoder_input = text.to(device) if decoder_input.dim() == 1: decoder_input = decoder_input.unsqueeze(0) # Print the initial prompt prompt_text = tokenizer.decode(text.squeeze(0).tolist()) print(prompt_text, end="", flush=True) while len(decoder_input[0]) < max_len - 3: # Apply causal mask based on current decoder_input length decoder_mask = causal_mask(decoder_input.size(1)).type_as(mask).to(device) # Get model output out = model.decode(decoder_input, decoder_mask) logits = model.project(out[:, -1]) # Get logits for last token # Sampling: temperature + top-k logits = logits / temperature top_k_logits, top_k_indices = torch.topk(logits, top_k) probs = torch.softmax(top_k_logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) next_token = top_k_indices.gather(-1, next_token) # Decode and print token word = tokenizer.decode([next_token.item()]) print(word, end="", flush=True) # Append next token decoder_input = torch.cat([decoder_input, next_token], dim=1) if next_token.item() == eos_idx: break print() return decoder_input.squeeze(0) def generate_text_(model, text,m, tokenizer, max_len, device, temperature=0.7, top_k=50): sos_idx = tokenizer.token_to_id('') eos_idx = tokenizer.token_to_id('') pad_idx = tokenizer.token_to_id('') # Encode input and add at beginning input_tokens = [sos_idx] + tokenizer.encode(text).ids # Truncate if too long input_tokens = input_tokens[:max_len-1] # Leave room for # Convert to tensor decoder_input = torch.tensor(input_tokens, device=device).unsqueeze(0) # Generate until max_len for _ in range(max_len - len(input_tokens)): # Create causal mask for what we've generated so far decoder_mask = causal_mask(decoder_input.size(1)).to(device) # Get model output out = model.decode(decoder_input, decoder_mask) logits = model.project(out[:, -1]) # Apply sampling logits = logits / temperature top_k_logits, top_k_indices = torch.topk(logits, top_k) probs = torch.softmax(top_k_logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) next_token = top_k_indices.gather(-1, next_token) # Print the generated word word = tokenizer.decode([next_token.item()]) print(word, end="", flush=True) # Append to input decoder_input = torch.cat([decoder_input, next_token.unsqueeze(0)], dim=1) if next_token.item() == eos_idx: break return decoder_input.squeeze(0) def run_validation(model,validation_ds, tokenizer, max_len, device, print_msg, global_state, writer, num_examples=2): model.eval() count = 0 pad_token = torch.tensor([tokenizer.token_to_id('')],dtype=torch.int64) sos_token = torch.tensor([tokenizer.token_to_id('')],dtype=torch.int64) with torch.no_grad(): for batch in validation_ds: count += 1 input_tokens = batch['input'] # print("TEXT INPUT : ",text) # input_tokens = tokenizer.encode(text).ids[:-1] print("TOKENIZED INPUT : ",input_tokens) input_tokens = input_tokens # if len(input_tokens) < config['seq_len'] : # input_tokens+=[pad_token] * ((config['seq_len'] ) - len(input_tokens)) # if len(input_tokens) > config['seq_len'] : # input_tokens = input_tokens[:config['seq_len']] input_tokens = torch.tensor(input_tokens) # (input_tokens != pad_token).unsqueeze(0).int() & mask = causal_mask(input_tokens.size(0)) # text = batch['input'].to(device) # mask = batch['input_mask'].to(device) model_output = generate_text(model, input_tokens, mask, tokenizer, max_len, device) # model_output = greed0y_decode(model, text, mask,tokenizer, max_len, device) print_msg("Model Output Embedding : ") print_msg(str(model_output.tolist())) model_out_text = tokenizer.decode(model_output.detach().cpu().numpy()) # text = tokenizer.decode(input_tokens[0].tolist(),skip_special_tokens=True) #print print_msg(f'SOURCE : {input_tokens}') print_msg(f'PREDICTED : {model_out_text}') if count == num_examples: break def get_all_sentences(ds): for item in ds: yield item['text'] def get_or_build_tokenizer_(config,ds): tokenizer_path = Path(config['tokenizer_file']) if not Path.exists(tokenizer_path): tokenizer = Tokenizer(WordLevel(unk_token="")) tokenizer.pre_tokenizer = Whitespace() trainer = WordLevelTrainer(special_tokens=["", "", "", "", "","","","","","",""],min_frequency=2) tokenizer.train_from_iterator(get_all_sentences(ds),trainer=trainer) tokenizer.save(str(tokenizer_path)) else: tokenizer = Tokenizer.from_file(str(tokenizer_path)) return tokenizer def get_or_build_tokenizer(config, ds): tokenizer_path = Path(config['tokenizer_file']) if not tokenizer_path.exists(): # Define tokenizer with BPE model tokenizer = Tokenizer(BPE(unk_token="")) # ByteLevel pre-tokenizer and decoder tokenizer.pre_tokenizer = ByteLevel(add_prefix_space=True) tokenizer.decoder = decoders.ByteLevel() # Optional: Add post-processing for special tokens tokenizer.post_processor = TemplateProcessing( single=" $A ", pair=" $A $B ", special_tokens=[ ("", 0), ("", 1), ], ) # Trainer trainer = BpeTrainer( vocab_size = 30000, min_frequency=2, special_tokens=["", "", "", "", "","","","","","",""] ) # Train from dataset tokenizer.train_from_iterator(get_all_sentences(ds), trainer=trainer) # Save as single .json file tokenizer.save(str(tokenizer_path)) else: tokenizer = Tokenizer.from_file(str(tokenizer_path)) return tokenizer def get_ds(config): # ds_raw = load_dataset("json",data_files={'train':config['train'],'test':config['test']}) ds_raw = load_dataset("tatsu-lab/alpaca",split="train[:50000]") ds_test = load_dataset("tatsu-lab/alpaca",split="train[-2002:]") # ds_raw = ds_raw[:1] # ds_raw = load_dataset("stas/openwebtext-10k") tokenizer = get_or_build_tokenizer(config,ds_raw) # tokenizer = get_or_build_tokenizer(config,ds_raw) train_ds = BilingualDataset(ds_raw, tokenizer, config['seq_len']) val_ds = BilingualDataset(ds_test, tokenizer, config['seq_len']) train_dataloader = DataLoader(train_ds, num_workers=6,prefetch_factor=2,pin_memory=True,batch_size=config['batch_size']) val_dataloader = DataLoader(val_ds, batch_size=1) return train_dataloader, val_dataloader, tokenizer def get_model(config, vocab_size): # model = build_transformer(vocab_src_len,vocab_tgt_len,config['seq_len'],config['seq_len'],config['d_model'], config['N'] , config['h'], config['d_ff']) model = build_gpt( vocab_size, config['seq_len'], config['d_model'], config['N'] , config['h'], config['d_ff'],config['dropout']) return model def validate_model(val_dataloader, model,device,loss_fn,vocab_size): total_loss = 0 model.eval() i = 0 with torch.no_grad(): for batch in val_dataloader: input_tokens = batch['input'].to(device,non_blocking=True) label = batch['label'].to(device,non_blocking=True) decoder_output = model.decode(input_tokens) project_output = model.project(decoder_output) total_loss += loss_fn( project_output.view(-1,vocab_size), label.view(-1) ) i+=1 print(f"Validation loss : {total_loss/i:4f}") def train_model(config): #Define the device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device : {device}") # Enable TF32 (optional, speeds up matmul) torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True Path(config['model_folder']).mkdir(parents=True, exist_ok=True) train_dataloader , val_dataloader, tokenizer = get_ds(config) print(tokenizer.get_vocab_size()) model = get_model(config, tokenizer.get_vocab_size()).to(device) # TensorBoard writer = SummaryWriter(config['experiment_name']) optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'], eps=1e-9) scaler = GradScaler() # <- added scaler for mixed precision initial_epoch = 0 global_step = 0 tqdm_state = {'n':0} model_filename = None if config['preload']: model_filename = get_weights_file_path(config, config['preload']) print(f"Preloading Model {model_filename}") state = torch.load(model_filename) model.load_state_dict(state['model_state_dict']) optimizer.load_state_dict(state['optimizer_state_dict']) initial_epoch = state['epoch'] if 'mid-' in model_filename else state['epoch'] + 1 global_step = state['global_step'] tqdm_state = state['tqdm_state'] if 'mid-' in model_filename else {'n':0} else: print("No Model to preload. Setting from scratch.") loss_fn = nn.CrossEntropyLoss( ignore_index=tokenizer.token_to_id(''), label_smoothing=0.05 ).to(device) e = 0 try: for epoch in range(initial_epoch, config['num_epochs']): model.train() batch_iterator = tqdm(islice(train_dataloader,tqdm_state['n'],None), desc=f'Processing epoch {epoch:02d}',initial=tqdm_state['n'] ,total=len(train_dataloader))#total=217013) e = epoch if 'elapsed' in tqdm_state and "mid-" in model_filename : batch_iterator.start_t = time.time() - tqdm_state['elapsed'] # total_len = len(batch_iterator) for batch in batch_iterator: # print(len(batch_iterator)) # torch.cuda.empty_cache() input_tokens = batch['input'].to(device,non_blocking=True) label = batch['label'].to(device,non_blocking=True) optimizer.zero_grad(set_to_none=True) # 🔥 Mixed precision forward pass with autocast(dtype=torch.float16): decoder_output = model.decode(input_tokens) project_output = model.project(decoder_output) # (B, Seq_len, tgt_vocab_size) loss = loss_fn( project_output.view(-1, tokenizer.get_vocab_size()), label.view(-1) ) if global_step%10 ==0: batch_iterator.set_postfix({f"loss": f"{loss.item():6.3f}"}) writer.add_scalar("train loss", loss.item(), global_step) writer.flush() if global_step % 10000 == 0 and global_step != 0: validate_model(val_dataloader,model,device,loss_fn,tokenizer.get_vocab_size()) # 🔥 Mixed precision backward pass scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() global_step += 1 tqdm_state = {'n': batch_iterator.n,'elapsed':batch_iterator.format_dict["elapsed"]} # if() tqdm_state['n'] = 0 del tqdm_state['elapsed'] model_filename = get_weights_file_path(config, f'{epoch:02d}') torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'global_step': global_step, 'tqdm_state':tqdm_state }, model_filename) validate_model(train_dataloader,model,device,loss_fn,tokenizer.get_vocab_size()) except KeyboardInterrupt: print("You are stoping training : ... ") model_filename = get_weights_file_path(config, f'mid-{e:02d}{input("Checkpoint Name: ")}') torch.save({ 'epoch': e, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'global_step': global_step, 'tqdm_state':tqdm_state }, model_filename) if __name__ == "__main__": warnings.filterwarnings('ignore') config = get_config("./openweb.config.json") train_model(config)