jenbenarye commited on
Commit
0a375ac
·
1 Parent(s): d151abe

added languages selection to kto training

Browse files
Files changed (1) hide show
  1. ml/kto_lora.py +26 -8
ml/kto_lora.py CHANGED
@@ -4,9 +4,12 @@ from dataclasses import dataclass
4
  from accelerate import PartialState
5
  from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser
6
  from trl import KTOConfig, KTOTrainer, ModelConfig, get_peft_config, maybe_unpair_preference_dataset, setup_chat_format
7
- from kto_dataset_processor import process_feel_dataset
8
  from datetime import datetime
9
  import wandb
 
 
 
10
 
11
  # PEFT library: attach and load adapters
12
  from peft import get_peft_model, PeftModel
@@ -15,15 +18,28 @@ from peft import get_peft_model, PeftModel
15
  # CONFIGURATION
16
  ####################################
17
 
 
18
  @dataclass
19
  class ScriptArguments:
20
  """
21
  Configuration for the script.
22
  """
23
- process_dataset_func: callable = process_feel_dataset # Function to process dataset
24
- checkpoint_path: str = None # Checkpoint path if needed
25
- push_to_hub: bool = False # Whether to push the adapter to the HF Hub after training
26
- language: str = "en" # Language identifier (e.g., "en", "fr", etc.)
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
  @dataclass
29
  class ModelArguments(ModelConfig):
@@ -121,7 +137,7 @@ def main():
121
  # Data Preparation and Training
122
  # -----------------------------
123
  print("Processing dataset...")
124
- dataset = script_args.process_dataset_func()
125
  print("Dataset processed.")
126
 
127
  print("Initializing trainer...")
@@ -173,8 +189,10 @@ def main():
173
  print(f"Adapter for language '{script_args.language}' saved to: {adapter_dir}")
174
 
175
  if script_args.push_to_hub:
176
- print("Pushing adapter to Hugging Face Hub...")
177
- model.push_to_hub(repo_id=f"your_hf_org/{script_args.language}-adapter")
 
 
178
 
179
  print("Process completed.")
180
 
 
4
  from accelerate import PartialState
5
  from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser
6
  from trl import KTOConfig, KTOTrainer, ModelConfig, get_peft_config, maybe_unpair_preference_dataset, setup_chat_format
7
+ from kto_dataset_processor import process_feel_dataset, SupportedLanguages
8
  from datetime import datetime
9
  import wandb
10
+ from enum import Enum
11
+ from typing import Optional
12
+
13
 
14
  # PEFT library: attach and load adapters
15
  from peft import get_peft_model, PeftModel
 
18
  # CONFIGURATION
19
  ####################################
20
 
21
+
22
  @dataclass
23
  class ScriptArguments:
24
  """
25
  Configuration for the script.
26
  """
27
+ process_dataset_func: callable = process_feel_dataset
28
+ checkpoint_path: str = None
29
+ push_to_hub: bool = True
30
+ language: str = "English" # Default to English
31
+
32
+ def __post_init__(self):
33
+ """Validate the language after initialization"""
34
+ try:
35
+ # This will raise ValueError if language is not in the enum
36
+ SupportedLanguages(self.language)
37
+ except ValueError:
38
+ supported_langs = "\n- ".join([lang.value for lang in SupportedLanguages])
39
+ raise ValueError(
40
+ f"Invalid language: '{self.language}'\n"
41
+ f"Supported languages are:\n- {supported_langs}"
42
+ )
43
 
44
  @dataclass
45
  class ModelArguments(ModelConfig):
 
137
  # Data Preparation and Training
138
  # -----------------------------
139
  print("Processing dataset...")
140
+ dataset = script_args.process_dataset_func(script_args.language)
141
  print("Dataset processed.")
142
 
143
  print("Initializing trainer...")
 
189
  print(f"Adapter for language '{script_args.language}' saved to: {adapter_dir}")
190
 
191
  if script_args.push_to_hub:
192
+ # Using a consistent naming pattern that links to the FEEL project
193
+ repo_id = f"feel-fl/kto-lora-adapter-{script_args.language}"
194
+ print(f"Pushing adapter to Hugging Face Hub at {repo_id}...")
195
+ model.push_to_hub(repo_id=repo_id)
196
 
197
  print("Process completed.")
198