Jen Ben Arye commited on
Commit
78757b7
·
1 Parent(s): 09e9f82

changed base model and added lora adapters

Browse files
Files changed (1) hide show
  1. ml/kto_lora.py +185 -0
ml/kto_lora.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from dataclasses import dataclass
4
+ from accelerate import PartialState
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser
6
+ from trl import KTOConfig, KTOTrainer, ModelConfig, get_peft_config, maybe_unpair_preference_dataset, setup_chat_format
7
+ from kto_dataset_processor import process_feel_dataset
8
+ from datetime import datetime
9
+ import wandb
10
+
11
+ # PEFT library: attach and load adapters
12
+ from peft import get_peft_model, PeftModel
13
+
14
+ ####################################
15
+ # CONFIGURATION
16
+ ####################################
17
+
18
+ @dataclass
19
+ class ScriptArguments:
20
+ """
21
+ Configuration for the script.
22
+ """
23
+ process_dataset_func: callable = process_feel_dataset # Function to process dataset
24
+ checkpoint_path: str = None # Checkpoint path if needed
25
+ push_to_hub: bool = False # Whether to push the adapter to the HF Hub after training
26
+ language: str = "en" # Language identifier (e.g., "en", "fr", etc.)
27
+
28
+ @dataclass
29
+ class ModelArguments(ModelConfig):
30
+ """
31
+ Configuration for the model.
32
+ """
33
+ model_name: str = "CohereForAI/aya-expanse-8b"
34
+ use_peft: bool = True
35
+ lora_target_modules: str = "all-linear"
36
+ lora_r: int = 16
37
+ lora_alpha: int = 16
38
+ trust_remote_code: bool = True
39
+
40
+ @dataclass
41
+ class TrainingArguments(KTOConfig):
42
+ """
43
+ Configuration for the KTO trainer.
44
+ """
45
+ output_dir: str = f"kto_{ModelArguments.model_name}_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"
46
+ num_train_epochs: int = 1
47
+ per_device_train_batch_size: int = 4
48
+ learning_rate: float = 5e-7
49
+ lr_scheduler_type: str = "cosine"
50
+ gradient_accumulation_steps: int = 1
51
+ logging_steps: int = 10
52
+ eval_steps: int = 500
53
+ warmup_ratio: float = 0.1
54
+ bf16: bool = True
55
+ logging_first_step: bool = True
56
+
57
+ # Initialize configurations
58
+ script_args = ScriptArguments()
59
+ training_args = TrainingArguments()
60
+ model_args = ModelArguments()
61
+
62
+ ####################################
63
+ # HELPER FUNCTIONS
64
+ ####################################
65
+
66
+ def load_model_and_tokenizer(model_args):
67
+ """
68
+ Load the base model and tokenizer from the Hugging Face Hub.
69
+ """
70
+ model = AutoModelForCausalLM.from_pretrained(
71
+ model_args.model_name,
72
+ trust_remote_code=model_args.trust_remote_code,
73
+ torch_dtype=torch.float16,
74
+ device_map="auto"
75
+ )
76
+ tokenizer = AutoTokenizer.from_pretrained(
77
+ model_args.model_name,
78
+ trust_remote_code=model_args.trust_remote_code
79
+ )
80
+
81
+ # Set pad token if it is missing
82
+ if tokenizer.pad_token is None:
83
+ tokenizer.pad_token = tokenizer.eos_token
84
+
85
+ # Setup chat format if not available on the tokenizer
86
+ if not getattr(tokenizer, "chat_template", None):
87
+ model, tokenizer = setup_chat_format(model, tokenizer)
88
+
89
+ return model, tokenizer
90
+
91
+ ####################################
92
+ # MAIN LOGIC
93
+ ####################################
94
+
95
+ def main():
96
+ # Initialize wandb for logging
97
+ wandb.init(project="kto")
98
+
99
+ print("Loading base model and tokenizer...")
100
+ model, tokenizer = load_model_and_tokenizer(model_args)
101
+ ref_model, _ = load_model_and_tokenizer(model_args)
102
+ print("Models and tokenizer loaded.")
103
+
104
+ # -----------------------------
105
+ # Adapter Loading or Initialization
106
+ # -----------------------------
107
+ # Configure the PEFT / LoRA adapter settings
108
+ peft_config = get_peft_config(model_args)
109
+ adapter_dir = os.path.join("adapters", script_args.language)
110
+
111
+ if os.path.isdir(adapter_dir):
112
+ # If an adapter for this language already exists, load it into the base model.
113
+ model = PeftModel.from_pretrained(model, adapter_dir)
114
+ print(f"Loaded existing adapter for language '{script_args.language}' from {adapter_dir}.")
115
+ else:
116
+ # Otherwise, initialize a new LoRA adapter.
117
+ model = get_peft_model(model, peft_config)
118
+ print(f"No adapter found for language '{script_args.language}'. Initialized new adapter.")
119
+
120
+ # -----------------------------
121
+ # Data Preparation and Training
122
+ # -----------------------------
123
+ print("Processing dataset...")
124
+ dataset = script_args.process_dataset_func()
125
+ print("Dataset processed.")
126
+
127
+ print("Initializing trainer...")
128
+ trainer = KTOTrainer(
129
+ model=model,
130
+ ref_model=ref_model,
131
+ args=training_args,
132
+ train_dataset=dataset["train"],
133
+ eval_dataset=dataset["test"],
134
+ processing_class=tokenizer,
135
+ peft_config=peft_config,
136
+ )
137
+
138
+ # Training
139
+ print("Starting training...")
140
+ trainer.train()
141
+ print("Training completed.")
142
+
143
+ # Evaluation
144
+ print("Evaluating model...")
145
+ metrics = trainer.evaluate()
146
+ print(f"Metrics: {metrics}")
147
+ trainer.log_metrics("eval", metrics)
148
+ trainer.save_metrics("eval", metrics)
149
+
150
+ # Log metrics to wandb
151
+ wandb.log({
152
+ "epoch": metrics.get("epoch"),
153
+ "grad_norm": metrics.get("grad_norm"),
154
+ "kl": metrics.get("kl"),
155
+ "learning_rate": metrics.get("learning_rate"),
156
+ "logits/chosen": metrics.get("logits/chosen"),
157
+ "logits/rejected": metrics.get("logits/rejected"),
158
+ "logps/chosen": metrics.get("logps/chosen"),
159
+ "logps/rejected": metrics.get("logps/rejected"),
160
+ "loss": metrics.get("loss"),
161
+ "rewards/chosen": metrics.get("rewards/chosen"),
162
+ "rewards/margins": metrics.get("rewards/margins"),
163
+ "rewards/rejected": metrics.get("rewards/rejected"),
164
+ "step": metrics.get("step")
165
+ })
166
+
167
+ # -----------------------------
168
+ # Adapter Saving
169
+ # -----------------------------
170
+ print("Saving adapter...")
171
+ os.makedirs(adapter_dir, exist_ok=True)
172
+ model.save_pretrained(adapter_dir)
173
+ print(f"Adapter for language '{script_args.language}' saved to: {adapter_dir}")
174
+
175
+ if script_args.push_to_hub:
176
+ print("Pushing adapter to Hugging Face Hub...")
177
+ model.push_to_hub(repo_id=f"your_hf_org/{script_args.language}-adapter")
178
+
179
+ print("Process completed.")
180
+
181
+ # Finish wandb run
182
+ wandb.finish()
183
+
184
+ if __name__ == "__main__":
185
+ main()