Spaces:
Running
on
Zero
Running
on
Zero
jenbenarye
commited on
Commit
·
0a375ac
1
Parent(s):
d151abe
added languages selection to kto training
Browse files- ml/kto_lora.py +26 -8
ml/kto_lora.py
CHANGED
@@ -4,9 +4,12 @@ from dataclasses import dataclass
|
|
4 |
from accelerate import PartialState
|
5 |
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser
|
6 |
from trl import KTOConfig, KTOTrainer, ModelConfig, get_peft_config, maybe_unpair_preference_dataset, setup_chat_format
|
7 |
-
from kto_dataset_processor import process_feel_dataset
|
8 |
from datetime import datetime
|
9 |
import wandb
|
|
|
|
|
|
|
10 |
|
11 |
# PEFT library: attach and load adapters
|
12 |
from peft import get_peft_model, PeftModel
|
@@ -15,15 +18,28 @@ from peft import get_peft_model, PeftModel
|
|
15 |
# CONFIGURATION
|
16 |
####################################
|
17 |
|
|
|
18 |
@dataclass
|
19 |
class ScriptArguments:
|
20 |
"""
|
21 |
Configuration for the script.
|
22 |
"""
|
23 |
-
process_dataset_func: callable = process_feel_dataset
|
24 |
-
checkpoint_path: str = None
|
25 |
-
push_to_hub: bool =
|
26 |
-
language: str = "
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
|
28 |
@dataclass
|
29 |
class ModelArguments(ModelConfig):
|
@@ -121,7 +137,7 @@ def main():
|
|
121 |
# Data Preparation and Training
|
122 |
# -----------------------------
|
123 |
print("Processing dataset...")
|
124 |
-
dataset = script_args.process_dataset_func()
|
125 |
print("Dataset processed.")
|
126 |
|
127 |
print("Initializing trainer...")
|
@@ -173,8 +189,10 @@ def main():
|
|
173 |
print(f"Adapter for language '{script_args.language}' saved to: {adapter_dir}")
|
174 |
|
175 |
if script_args.push_to_hub:
|
176 |
-
|
177 |
-
|
|
|
|
|
178 |
|
179 |
print("Process completed.")
|
180 |
|
|
|
4 |
from accelerate import PartialState
|
5 |
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser
|
6 |
from trl import KTOConfig, KTOTrainer, ModelConfig, get_peft_config, maybe_unpair_preference_dataset, setup_chat_format
|
7 |
+
from kto_dataset_processor import process_feel_dataset, SupportedLanguages
|
8 |
from datetime import datetime
|
9 |
import wandb
|
10 |
+
from enum import Enum
|
11 |
+
from typing import Optional
|
12 |
+
|
13 |
|
14 |
# PEFT library: attach and load adapters
|
15 |
from peft import get_peft_model, PeftModel
|
|
|
18 |
# CONFIGURATION
|
19 |
####################################
|
20 |
|
21 |
+
|
22 |
@dataclass
|
23 |
class ScriptArguments:
|
24 |
"""
|
25 |
Configuration for the script.
|
26 |
"""
|
27 |
+
process_dataset_func: callable = process_feel_dataset
|
28 |
+
checkpoint_path: str = None
|
29 |
+
push_to_hub: bool = True
|
30 |
+
language: str = "English" # Default to English
|
31 |
+
|
32 |
+
def __post_init__(self):
|
33 |
+
"""Validate the language after initialization"""
|
34 |
+
try:
|
35 |
+
# This will raise ValueError if language is not in the enum
|
36 |
+
SupportedLanguages(self.language)
|
37 |
+
except ValueError:
|
38 |
+
supported_langs = "\n- ".join([lang.value for lang in SupportedLanguages])
|
39 |
+
raise ValueError(
|
40 |
+
f"Invalid language: '{self.language}'\n"
|
41 |
+
f"Supported languages are:\n- {supported_langs}"
|
42 |
+
)
|
43 |
|
44 |
@dataclass
|
45 |
class ModelArguments(ModelConfig):
|
|
|
137 |
# Data Preparation and Training
|
138 |
# -----------------------------
|
139 |
print("Processing dataset...")
|
140 |
+
dataset = script_args.process_dataset_func(script_args.language)
|
141 |
print("Dataset processed.")
|
142 |
|
143 |
print("Initializing trainer...")
|
|
|
189 |
print(f"Adapter for language '{script_args.language}' saved to: {adapter_dir}")
|
190 |
|
191 |
if script_args.push_to_hub:
|
192 |
+
# Using a consistent naming pattern that links to the FEEL project
|
193 |
+
repo_id = f"feel-fl/kto-lora-adapter-{script_args.language}"
|
194 |
+
print(f"Pushing adapter to Hugging Face Hub at {repo_id}...")
|
195 |
+
model.push_to_hub(repo_id=repo_id)
|
196 |
|
197 |
print("Process completed.")
|
198 |
|