|
from deepseek_v3 import DeepSeekV3Model |
|
import torch |
|
import yaml |
|
from transformers import AutoTokenizer |
|
|
|
from torch.utils.data import DataLoader |
|
import numpy as np |
|
from datasets import load_dataset |
|
import logging |
|
import math |
|
|
|
from utils import upload_file_to_s3 |
|
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') |
|
file_handler = logging.FileHandler('training.log') |
|
file_handler.setFormatter(formatter) |
|
logger.addHandler(file_handler) |
|
logger.setLevel(logging.INFO) |
|
|
|
def encode_text(examples, tokenizer, seq_length): |
|
"""Tokenize and prepare text examples for training.""" |
|
tokens = tokenizer( |
|
examples["text"], |
|
truncation=True, |
|
padding="max_length", |
|
max_length=seq_length + 1, |
|
return_tensors="pt", |
|
) |
|
|
|
input_ids = tokens["input_ids"].squeeze(0).clone().detach() |
|
input_ids = torch.clamp(input_ids, min=0, max=tokenizer.vocab_size - 1) |
|
labels = input_ids.clone().detach() |
|
labels = labels[1:].to(torch.int64) |
|
input_ids = input_ids[:-1].to(torch.int64) |
|
|
|
return {"input_ids": input_ids, "labels": labels} |
|
|
|
def load_cosmopedia_dataset(batch_size=8, seq_length=1024, tokenizer=None): |
|
""" |
|
Returns a torch dataloader for the cosmopedia dataset |
|
""" |
|
|
|
import os |
|
os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
logger.info("tokenizer parallelism set to false") |
|
try: |
|
|
|
from datasets import config |
|
config.HF_DATASETS_TIMEOUT = 300 |
|
config.MAX_RETRIES = 10 |
|
logger.info("dataset loading config set") |
|
train_dataset = load_dataset( |
|
"HuggingFaceTB/smollm-corpus", |
|
name="cosmopedia-v2", |
|
split="train", |
|
streaming=True, |
|
) |
|
logger.info("dataset loaded") |
|
|
|
|
|
from functools import partial |
|
encode_fn = partial(encode_text, tokenizer=tokenizer, seq_length=seq_length) |
|
|
|
train_dataset = train_dataset.map( |
|
encode_fn, |
|
remove_columns=["text"], |
|
batched=False |
|
) |
|
train_dataset = train_dataset.with_format("torch") |
|
|
|
train_dataloader = DataLoader( |
|
train_dataset, |
|
batch_size=batch_size, |
|
num_workers=2, |
|
pin_memory=True, |
|
prefetch_factor=4, |
|
persistent_workers=True |
|
) |
|
return train_dataloader |
|
except Exception as e: |
|
logger.error(f"Error loading dataset: {str(e)}") |
|
return None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate(model, idx, max_new_tokens, context_length, temperature=1.0, top_k=None, eos_token=None, device=None): |
|
logger.info(f"Generating on device {device}") |
|
model = model.to(device) |
|
idx = idx.to(device) |
|
model.eval() |
|
for _ in range(max_new_tokens): |
|
idx_cond = idx[:, -context_length:] |
|
with torch.no_grad(): |
|
logits, _ = model(idx_cond) |
|
logits = logits.view(idx_cond.shape[0], -1, model.config['vocab_size']) |
|
|
|
|
|
logits = logits[:, -1, :] |
|
|
|
if top_k is not None: |
|
|
|
top_logits, top_pos = torch.topk(logits, top_k) |
|
min_logit = top_logits[:, -1].unsqueeze(-1) |
|
logits = torch.where(logits < min_logit, |
|
torch.tensor(float('-inf')).to(logits.device), |
|
logits) |
|
|
|
|
|
if temperature > 0.0: |
|
logits /= temperature |
|
probs = torch.softmax(logits, dim=-1) |
|
idx_next = torch.multinomial(probs, num_samples=1) |
|
else: |
|
idx_next = torch.argmax(logits, dim=-1, keepdim=True) |
|
|
|
if idx_next.item() == eos_token: |
|
break |
|
|
|
idx = torch.cat((idx, idx_next), dim=1) |
|
model.train() |
|
return idx |
|
|
|
def sync_device(device): |
|
if device.startswith('cuda'): |
|
torch.cuda.synchronize() |
|
elif device == 'cpu': |
|
torch.cpu.synchronize() if hasattr(torch.cpu, 'synchronize') else None |
|
elif device.startswith('mps'): |
|
torch.mps.synchronize() |
|
|
|
def print_gpu_memory(step_name=""): |
|
""" |
|
Print GPU memory statistics with a specified step name |
|
""" |
|
if torch.cuda.is_available(): |
|
logger.info(f"\nGPU Memory Stats {step_name}:") |
|
logger.info(f"GPU Memory allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB") |
|
logger.info(f"GPU Memory reserved: {torch.cuda.memory_reserved() / 1024**2:.2f} MB") |
|
logger.info(f"Max GPU Memory allocated: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB") |
|
|
|
|
|
def get_lr_lambda(current_step, warmup_steps, max_steps, max_lr): |
|
""" |
|
Modified learning rate scheduler with: |
|
1. Linear warmup for first 3000 steps |
|
2. Cosine decay from 3000 to 60000 steps |
|
3. Minimum learning rate of 1.5e-5 (5% of max_lr) |
|
""" |
|
min_lr = max_lr * 0.05 |
|
|
|
if current_step < warmup_steps: |
|
|
|
return float(current_step) / float(max(1, warmup_steps)) |
|
else: |
|
|
|
progress = float(current_step - warmup_steps) / float(max(1, max_steps - warmup_steps)) |
|
return min_lr + 0.5 * (max_lr - min_lr) * (1.0 + math.cos(math.pi * progress)) |
|
|
|
|
|
def train_model(config, model, train_loader, test_loader, optimizer, device, num_epochs, eval_freq, eval_iter, start_context="Jack Gisburn rather a cheap genius- ", tokenizer=None): |
|
total_loss = 0 |
|
tokens_seen, global_step = 0, -1 |
|
|
|
|
|
actual_batch_size = config['tokens']['micro_batch_size'] |
|
effective_batch_size_multiplier = 1 |
|
target_batch_size = effective_batch_size_multiplier * config['tokens']['micro_batch_size'] |
|
gradient_accumulation_steps = target_batch_size // actual_batch_size |
|
|
|
|
|
max_lr = 3e-4 |
|
warmup_steps = 3000 |
|
max_steps = 60000 |
|
min_lr = max_lr * 0.05 |
|
|
|
|
|
lr_lambda = lambda step: get_lr_lambda(step, warmup_steps, max_steps, max_lr) |
|
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) |
|
|
|
logger.info(f"Training with learning rate schedule:") |
|
logger.info(f"Max LR: {max_lr}") |
|
logger.info(f"Warmup Steps: {warmup_steps}") |
|
logger.info(f"Max Steps: {max_steps}") |
|
logger.info(f"Min LR: {max_lr * 0.05}") |
|
logger.info(f"Gradient Accumulation Steps: {gradient_accumulation_steps}") |
|
logger.info(f"Effective Batch Size: {actual_batch_size * gradient_accumulation_steps}") |
|
|
|
print_gpu_memory("at start of training") |
|
|
|
|
|
torch.cuda.empty_cache() |
|
torch.backends.cudnn.benchmark = True |
|
for epoch in range(num_epochs): |
|
model.train() |
|
optimizer.zero_grad() |
|
|
|
for batch_idx, batch in enumerate(train_loader): |
|
input_batch = batch['input_ids'].to(device) |
|
target_batch = batch['labels'].to(device) |
|
|
|
|
|
with torch.autocast(device_type=device, dtype=torch.bfloat16): |
|
logits, original_loss = model(input_batch, target_batch) |
|
|
|
|
|
scaled_loss = original_loss / gradient_accumulation_steps |
|
scaled_loss.backward() |
|
|
|
|
|
total_loss += original_loss.item() |
|
tokens_seen += input_batch.numel() |
|
|
|
|
|
total_batches = batch_idx + 1 |
|
avg_loss = total_loss / total_batches |
|
if batch_idx % 25 == 0: |
|
logger.info(f"Batch {batch_idx + 1}, Running Avg Loss: {avg_loss:.5f}") |
|
|
|
if (batch_idx + 1) % gradient_accumulation_steps == 0: |
|
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) |
|
|
|
optimizer.step() |
|
scheduler.step() |
|
optimizer.zero_grad() |
|
global_step += 1 |
|
|
|
|
|
if global_step % eval_freq == 0 and global_step > 0: |
|
|
|
current_lr = scheduler.get_last_lr()[0] |
|
optimizer_lr = optimizer.param_groups[0]['lr'] |
|
|
|
print_gpu_memory(f"at step {global_step}") |
|
logger.info(f"learning rate: {current_lr:.8f}") |
|
logger.info(f"Ep {epoch+1} (Step {global_step:06d}): " |
|
f"Avg loss {avg_loss:.3f} | {tokens_seen} tokens seen") |
|
logger.info(f"optimizer lr: {optimizer_lr:.8f}") |
|
logger.info(f"scheduler lr: {current_lr:.8f}") |
|
|
|
|
|
start_context_list = ["In today's ever-evolving world, technology has become an integral part of our lives","Once upon a time, there was a friendly agency called Gaudette Insurance Agency, Inc. They help","A couple of years ago, I was working as an extra on the set of a low-budget British film.","Introduction: The Art of Crafting Vegan Sandwich Delights Sandwiches occupy a unique space in","Meet Chris, a superhero of supplies! Just like how Batman protects Gotham City","Identity formation is a complex and multifaceted process that involves the development of", "With the development of science and technology, computer has become more and more ","Just as there are many variants and forms of electronic malware and Internet-based ","Correctly identifying what is causing a problem is the most important step in pest control.","Lobster, California spiny The California Spiny Lobster fishery is a small but locally ","Bees are vital for pollination. You can buy leafcutter bee houses to attract ","Feeling Alone Together: Exploring Alienation and Isolation in Literature", "Imagine if someone got their hands on dangerous weapons","Once upon a time, in a colorful town called Popville, ","he bell above the door jangled as Sarah walked into her family's hardware store"] |
|
|
|
random_prompt = np.random.choice(start_context_list) |
|
logger.info(f"Selected prompt: {random_prompt}") |
|
logger.info(f"+++"*30) |
|
encoded_text = tokenizer.encode(random_prompt, return_tensors="pt") |
|
random_topk = np.random.randint(1, 10) |
|
logger.info(f"random_topk: {random_topk}") |
|
random_temperature = np.random.uniform(0.7, 0.9) |
|
logger.info(f"random_temperature: {random_temperature}") |
|
logger.info(f"global step {global_step} , batch_idx {batch_idx} => generating text") |
|
generated_text = generate(model, |
|
idx=encoded_text, |
|
max_new_tokens=256, |
|
context_length=256, |
|
temperature=random_temperature, |
|
top_k=random_topk, |
|
eos_token=tokenizer.eos_token_id, |
|
device=device) |
|
logger.info(f"+++"*30) |
|
logger.info(tokenizer.decode(generated_text.squeeze(0))) |
|
logger.info(f"+++"*30) |
|
|
|
|
|
model_file_name = f"model_{global_step}_steps_avg_loss_{avg_loss:.5f}_optimizer_lr_{optimizer_lr:.8f}.pth" |
|
torch.save({ |
|
'step': global_step, |
|
'model_state_dict': model.state_dict(), |
|
'optimizer_state_dict': optimizer.state_dict(), |
|
'scheduler_state_dict': scheduler.state_dict(), |
|
'loss': avg_loss, |
|
}, model_file_name) |
|
|
|
s3_path = upload_file_to_s3(model_file_name, config['model']['model_config']['s3_bucket'], |
|
config['model']['model_config']['s3_checkpoint_folder']) |
|
logger.info(f"Model saved to S3: {s3_path}") |
|
|
|
log_path = upload_file_to_s3(config['model']['model_config']['s3_log_file_name'], config['model']['model_config']['s3_bucket'], |
|
config['model']['model_config']['s3_log_folder']) |
|
logger.info(f"Log saved to S3: {log_path}") |
|
|
|
if batch_idx % 100 == 0: |
|
logger.info(f"Batch {batch_idx} finished") |
|
logger.info(f"+++"*30) |
|
|
|
logger.info("Training complete") |
|
|
|
if __name__ == "__main__": |
|
config = yaml.load(open("config_smollm2_135M.yaml", "r"), Loader=yaml.FullLoader) |
|
logger.info(config) |
|
|
|
|
|
torch.set_float32_matmul_precision('high') |
|
torch.backends.cudnn.benchmark = True |
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
|
|
|
|
torch.cuda.empty_cache() |
|
import os |
|
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:64' |
|
|
|
model = DeepSeekV3Model(config['model']) |
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
|
|
|
|
|
|
model.to(device) |
|
|
|
logger.info(model) |
|
logger.info("++"*30) |
|
total_params = sum(p.numel() for p in model.parameters()) |
|
logger.info(f"Total parameters: {total_params}") |
|
|
|
optimizer = torch.optim.AdamW( |
|
model.parameters(), |
|
lr=3e-4, |
|
weight_decay=0.15, |
|
betas=(0.9, 0.95) |
|
) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/cosmo2-tokenizer") |
|
tokenizer.pad_token = tokenizer.eos_token |
|
vocab_size = tokenizer.vocab_size |
|
|
|
|
|
train_loader = load_cosmopedia_dataset( |
|
batch_size=8, |
|
seq_length=512, |
|
tokenizer=tokenizer |
|
) |
|
|
|
import time |
|
t1 = time.time() |
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
|
|
import os |
|
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:512' |
|
|
|
train_model( |
|
config, |
|
model, |
|
train_loader, |
|
train_loader, |
|
optimizer=optimizer, |
|
device=device, |
|
num_epochs=1, |
|
eval_freq=2500, |
|
eval_iter=2500, |
|
start_context="Once Upon a Time far far away in a galaxy", |
|
tokenizer=tokenizer |
|
) |
|
t2 = time.time() |
|
logger.info(f"Time taken for training: {t2 - t1:.2f} seconds") |