Spaces:
Build error
Build error
Refactor model loading in train.py to use a default model name parameter, enhancing flexibility. Adjust configuration for max sequence length and dtype for improved clarity and consistency.
Browse files
train.py
CHANGED
@@ -41,11 +41,10 @@ from transformers import (
|
|
41 |
from trl import SFTTrainer
|
42 |
|
43 |
# Configuration
|
44 |
-
|
45 |
-
dtype =
|
46 |
-
None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
|
47 |
-
)
|
48 |
load_in_4bit = True # Use 4bit quantization to reduce memory usage
|
|
|
49 |
validation_split = 0.1 # 10% of data for validation
|
50 |
|
51 |
|
@@ -89,12 +88,12 @@ def install_dependencies():
|
|
89 |
raise
|
90 |
|
91 |
|
92 |
-
def load_model() -> tuple[FastLanguageModel, AutoTokenizer]:
|
93 |
"""Load and configure the model."""
|
94 |
logger.info("Loading model and tokenizer...")
|
95 |
try:
|
96 |
model, tokenizer = FastLanguageModel.from_pretrained(
|
97 |
-
model_name=
|
98 |
max_seq_length=max_seq_length,
|
99 |
dtype=dtype,
|
100 |
load_in_4bit=load_in_4bit,
|
|
|
41 |
from trl import SFTTrainer
|
42 |
|
43 |
# Configuration
|
44 |
+
DEFAULT_MODEL_NAME = "unsloth/SmolLM2-135M-Instruct-bnb-4bit"
|
45 |
+
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
|
|
|
|
|
46 |
load_in_4bit = True # Use 4bit quantization to reduce memory usage
|
47 |
+
max_seq_length = 2048 # Auto supports RoPE Scaling internally
|
48 |
validation_split = 0.1 # 10% of data for validation
|
49 |
|
50 |
|
|
|
88 |
raise
|
89 |
|
90 |
|
91 |
+
def load_model(model_name: str = DEFAULT_MODEL_NAME) -> tuple[FastLanguageModel, AutoTokenizer]:
|
92 |
"""Load and configure the model."""
|
93 |
logger.info("Loading model and tokenizer...")
|
94 |
try:
|
95 |
model, tokenizer = FastLanguageModel.from_pretrained(
|
96 |
+
model_name=model_name,
|
97 |
max_seq_length=max_seq_length,
|
98 |
dtype=dtype,
|
99 |
load_in_4bit=load_in_4bit,
|