Spaces:
Sleeping
Sleeping
File size: 5,231 Bytes
5caedb4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import os
from dataclasses import dataclass, field
from typing import Any
import llm_studio.src.datasets.text_dpo_modeling_ds
from llm_studio.app_utils.config import default_cfg
from llm_studio.python_configs.base import DefaultConfigProblemBase
from llm_studio.python_configs.text_causal_language_modeling_config import (
ConfigNLPAugmentation,
ConfigNLPCausalLMArchitecture,
ConfigNLPCausalLMDataset,
ConfigNLPCausalLMEnvironment,
ConfigNLPCausalLMLogging,
ConfigNLPCausalLMPrediction,
ConfigNLPCausalLMTokenizer,
ConfigNLPCausalLMTraining,
)
from llm_studio.src import possible_values
from llm_studio.src.losses import text_dpo_modeling_losses
from llm_studio.src.models import text_dpo_modeling_model
from llm_studio.src.nesting import Dependency
from llm_studio.src.plots import text_dpo_modeling_plots
from llm_studio.src.utils.modeling_utils import generate_experiment_name
@dataclass
class ConfigDPODataset(ConfigNLPCausalLMDataset):
dataset_class: Any = llm_studio.src.datasets.text_dpo_modeling_ds.CustomDataset
# Always have full chat history.
# Chosen/Rejected prompt are only at the end of a conversation.
limit_chained_samples: bool = True
rejected_prompt_column: str = "None"
answer_column: str = "chosen_response"
rejected_answer_column: str = "rejected_response"
def __post_init__(self):
super().__post_init__()
self._possible_values["rejected_prompt_column"] = possible_values.Columns(
prefer_with=lambda column: column
in (
"rejected_input",
"rejected_prompt",
"rejected_instruction",
"rejected_question",
),
add_none=True,
)
self._possible_values["rejected_answer_column"] = possible_values.Columns(
prefer_with=lambda column: column
in (
"rejected_answer",
"rejected_response",
"rejected",
)
)
self._visibility["limit_chained_samples"] = -1
self._visibility["mask_prompt_labels"] = -1
self._order.insert("rejected_prompt_column", after="prompt_column")
self._order.insert("rejected_answer_column", after="answer_column")
@dataclass
class ConfigDPOTraining(ConfigNLPCausalLMTraining):
learning_rate: float = 1e-4 # relatively high as we use LORA
beta: float = 0.2
simpo_gamma: float = 1.0
gradient_clip: float = 10.0
loss_class: Any = text_dpo_modeling_losses.Losses
loss_function: str = "DPOLoss"
optimizer: str = "AdamW"
# Needs to be enabled as we need logits from original model, see forward pass
lora: bool = True
def __post_init__(self):
super().__post_init__()
self._possible_values["beta"] = possible_values.Number(0.05, 1.0, 0.05)
self._possible_values["simpo_gamma"] = possible_values.Number(0.05, 2.0, 0.05)
self._grid_search_values["loss_function"] = None
self._grid_search_values["beta"] = (0.1, 0.15, 0.20, 0.25, 0.4, 0.5)
self._grid_search_values["simpo_gamma"] = (0.5, 0.75, 1, 1.25, 1.5, 1.75, 2)
self._grid_search_iscustom["beta"] = True
self._grid_search_iscustom["simpo_gamma"] = True
self._nesting.add(
["simpo_gamma"],
[Dependency(key="loss_function", value="SimPOLoss", is_set=True)],
)
self._order.insert("beta", after="learning_rate")
self._order.insert("simpo_gamma", after="beta")
@dataclass
class ConfigDPOArchitecture(ConfigNLPCausalLMArchitecture):
model_class: Any = text_dpo_modeling_model.Model
@dataclass
class ConfigDPOPLogging(ConfigNLPCausalLMLogging):
plots_class: Any = text_dpo_modeling_plots.Plots
@dataclass
class ConfigProblemBase(DefaultConfigProblemBase):
output_directory: str = f"output/{os.path.basename(__file__).split('.')[0]}"
experiment_name: str = field(default_factory=generate_experiment_name)
llm_backbone: str = (
"h2oai/h2o-danube3-500m-chat"
if "h2oai/h2o-danube3-500m-chat" in default_cfg.default_causal_language_models
else default_cfg.default_causal_language_models[0]
)
dataset: ConfigDPODataset = field(default_factory=ConfigDPODataset)
tokenizer: ConfigNLPCausalLMTokenizer = field(
default_factory=ConfigNLPCausalLMTokenizer
)
architecture: ConfigDPOArchitecture = field(default_factory=ConfigDPOArchitecture)
training: ConfigDPOTraining = field(default_factory=ConfigDPOTraining)
augmentation: ConfigNLPAugmentation = field(default_factory=ConfigNLPAugmentation)
prediction: ConfigNLPCausalLMPrediction = field(
default_factory=ConfigNLPCausalLMPrediction
)
environment: ConfigNLPCausalLMEnvironment = field(
default_factory=ConfigNLPCausalLMEnvironment
)
logging: ConfigDPOPLogging = field(default_factory=ConfigDPOPLogging)
def __post_init__(self):
super().__post_init__()
self._visibility["output_directory"] = -1
self._possible_values["llm_backbone"] = possible_values.String(
values=default_cfg.default_causal_language_models,
allow_custom=True,
)
|