File size: 9,407 Bytes
f77e8bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
import os
import argparse
import json
from datetime import datetime
from typing import Dict, List, Any

try:
    import datasets
    from transformers import AutoTokenizer, TrainingArguments
    from peft import get_peft_model, LoraConfig, prepare_model_for_kbit_training
    from trl import SFTTrainer
    import torch
except ImportError:
    print("Installing required packages...")
    import subprocess
    subprocess.check_call(["pip", "install", 
                           "transformers>=4.36.0", 
                           "peft>=0.7.0", 
                           "datasets>=2.14.0",
                           "accelerate>=0.25.0",
                           "trl>=0.7.1",
                           "bitsandbytes>=0.40.0",
                           "torch>=2.0.0"])
    import datasets
    from transformers import AutoTokenizer, TrainingArguments
    from peft import get_peft_model, LoraConfig, prepare_model_for_kbit_training
    from trl import SFTTrainer
    import torch

def load_model_and_tokenizer(model_name_or_path: str, 

                            adapter_path: str = None, 

                            quantize: bool = True,

                            token: str = None):
    """

    Load the model and tokenizer, with optional adapter and quantization.

    

    This will load the model in 4-bit quantization by default (which is needed

    for such a large model) and can optionally load an existing adapter.

    """
    from transformers import BitsAndBytesConfig, AutoModelForCausalLM
    
    print(f"Loading model: {model_name_or_path}")
    
    # Configure for quantization
    quantization_config = BitsAndBytesConfig(
        load_in_4bit=quantize,
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_use_double_quant=True
    ) if quantize else None
    
    # Load the model
    model = AutoModelForCausalLM.from_pretrained(
        model_name_or_path,
        quantization_config=quantization_config,
        device_map="auto",
        token=token
    )
    
    # Load adapter if provided
    if adapter_path:
        print(f"Loading adapter from {adapter_path}")
        from peft import PeftModel
        model = PeftModel.from_pretrained(model, adapter_path)
    
    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, token=token)
    
    # Ensure we have a pad token
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    
    return model, tokenizer

def prepare_dataset(data_path: str):
    """Load and prepare datasets from JSON files."""
    # Load datasets
    if os.path.isdir(data_path):
        train_path = os.path.join(data_path, "train.json")
        val_path = os.path.join(data_path, "validation.json")
        
        if not (os.path.exists(train_path) and os.path.exists(val_path)):
            raise ValueError(f"Training data files not found in {data_path}")
    else:
        raise ValueError(f"Data path {data_path} is not a directory")
    
    # Load JSON files
    with open(train_path, 'r', encoding='utf-8') as f:
        train_data = json.load(f)
    
    with open(val_path, 'r', encoding='utf-8') as f:
        val_data = json.load(f)
    
    # Convert to datasets
    train_dataset = datasets.Dataset.from_list(train_data)
    eval_dataset = datasets.Dataset.from_list(val_data)
    
    print(f"Loaded {len(train_dataset)} training examples and {len(eval_dataset)} validation examples")
    return train_dataset, eval_dataset

def finetune(

    model_name: str,

    dataset_path: str,

    output_dir: str,

    hub_model_id: str = None,

    hf_token: str = None,

    use_peft: bool = True,

    num_train_epochs: int = 3,

    learning_rate: float = 2e-5,

    bf16: bool = True,

    quantize: bool = True,

    max_seq_length: int = 2048,

    gradient_accumulation_steps: int = 2

):
    """Fine-tune the model with PEFT on the provided dataset."""
    # Set up output directory
    if not output_dir:
        output_dir = f"llama3-finetuned-{datetime.now().strftime('%Y-%m-%d-%H-%M')}"
    os.makedirs(output_dir, exist_ok=True)
    
    # Load datasets
    train_dataset, eval_dataset = prepare_dataset(dataset_path)
    
    # Load base model
    model, tokenizer = load_model_and_tokenizer(
        model_name, 
        quantize=quantize,
        token=hf_token
    )
    
    # Set up PEFT configuration if using PEFT
    if use_peft:
        print("Setting up PEFT (Parameter-Efficient Fine-Tuning)")
        
        # Prepare model for k-bit training if quantized
        if quantize:
            model = prepare_model_for_kbit_training(model)
        
        # Set up LoRA configuration
        peft_config = LoraConfig(
            r=16,  # Rank dimension
            lora_alpha=32,  # Scale parameter
            lora_dropout=0.05,
            bias="none",
            task_type="CAUSAL_LM",
            target_modules=[
                "q_proj", 
                "k_proj", 
                "v_proj", 
                "o_proj", 
                "gate_proj", 
                "up_proj", 
                "down_proj"
            ]
        )
    else:
        peft_config = None
    
    # Training arguments
    training_args = TrainingArguments(
        output_dir=output_dir,
        num_train_epochs=num_train_epochs,
        per_device_train_batch_size=1,  # Adjust based on GPU memory
        gradient_accumulation_steps=gradient_accumulation_steps,
        learning_rate=learning_rate,
        weight_decay=0.01,
        max_grad_norm=0.3,
        logging_steps=10,
        optim="paged_adamw_32bit",
        lr_scheduler_type="cosine",
        warmup_ratio=0.03,
        evaluation_strategy="steps",
        eval_steps=0.1,  # Evaluate every 10% of training
        save_strategy="steps",
        save_steps=0.1,  # Save every 10% of training
        save_total_limit=3,
        bf16=bf16,  # Use bfloat16 precision if available
        push_to_hub=bool(hub_model_id),
        hub_model_id=hub_model_id,
        hub_token=hf_token,
    )
    
    # Initialize the SFT trainer
    trainer = SFTTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        peft_config=peft_config,
        tokenizer=tokenizer,
        max_seq_length=max_seq_length,
    )
    
    # Train the model
    print("Starting training...")
    trainer.train()
    
    # Save the fine-tuned model
    print(f"Saving model to {output_dir}")
    trainer.save_model()
    
    # Push to hub if specified
    if hub_model_id and hf_token:
        print(f"Pushing model to Hugging Face Hub: {hub_model_id}")
        trainer.push_to_hub()
    
    return output_dir

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Fine-tune Llama 3.3 with your data")
    parser.add_argument("--model_name", type=str, default="nvidia/Llama-3_3-Nemotron-Super-49B-v1", 
                        help="Base model to fine-tune")
    parser.add_argument("--dataset_path", type=str, required=True, 
                        help="Path to the directory containing train.json and validation.json")
    parser.add_argument("--output_dir", type=str, default=None, 
                        help="Directory to save the fine-tuned model")
    parser.add_argument("--hub_model_id", type=str, default=None, 
                        help="Hugging Face Hub model ID to push the model to")
    parser.add_argument("--hf_token", type=str, default=None, 
                        help="Hugging Face token for accessing gated models and pushing to hub")
    parser.add_argument("--no_peft", action='store_true', 
                        help="Disable PEFT/LoRA (not recommended for large models)")
    parser.add_argument("--no_quantize", action='store_true', 
                        help="Disable quantization (requires much more VRAM)")
    parser.add_argument("--no_bf16", action='store_true', 
                        help="Disable bf16 precision")
    parser.add_argument("--epochs", type=int, default=3, 
                        help="Number of training epochs")
    parser.add_argument("--learning_rate", type=float, default=2e-5, 
                        help="Learning rate")
    parser.add_argument("--max_seq_length", type=int, default=2048, 
                        help="Maximum sequence length for training")
    parser.add_argument("--gradient_accumulation_steps", type=int, default=2, 
                        help="Gradient accumulation steps")
    
    args = parser.parse_args()
    
    # Get token from environment if not provided
    hf_token = args.hf_token or os.environ.get("HF_TOKEN")
    
    finetune(
        model_name=args.model_name,
        dataset_path=args.dataset_path,
        output_dir=args.output_dir,
        hub_model_id=args.hub_model_id,
        hf_token=hf_token,
        use_peft=not args.no_peft,
        num_train_epochs=args.epochs,
        learning_rate=args.learning_rate,
        bf16=not args.no_bf16,
        quantize=not args.no_quantize,
        max_seq_length=args.max_seq_length,
        gradient_accumulation_steps=args.gradient_accumulation_steps
    )