mjschock commited on
Commit
aecd650
·
unverified ·
1 Parent(s): 7d4f8c8

Refactor model loading in train.py to use a default model name parameter, enhancing flexibility. Adjust configuration for max sequence length and dtype for improved clarity and consistency.

Browse files
Files changed (1) hide show
  1. train.py +5 -6
train.py CHANGED
@@ -41,11 +41,10 @@ from transformers import (
41
  from trl import SFTTrainer
42
 
43
  # Configuration
44
- max_seq_length = 2048 # Auto supports RoPE Scaling internally
45
- dtype = (
46
- None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
47
- )
48
  load_in_4bit = True # Use 4bit quantization to reduce memory usage
 
49
  validation_split = 0.1 # 10% of data for validation
50
 
51
 
@@ -89,12 +88,12 @@ def install_dependencies():
89
  raise
90
 
91
 
92
- def load_model() -> tuple[FastLanguageModel, AutoTokenizer]:
93
  """Load and configure the model."""
94
  logger.info("Loading model and tokenizer...")
95
  try:
96
  model, tokenizer = FastLanguageModel.from_pretrained(
97
- model_name="unsloth/SmolLM2-135M-Instruct-bnb-4bit",
98
  max_seq_length=max_seq_length,
99
  dtype=dtype,
100
  load_in_4bit=load_in_4bit,
 
41
  from trl import SFTTrainer
42
 
43
  # Configuration
44
+ DEFAULT_MODEL_NAME = "unsloth/SmolLM2-135M-Instruct-bnb-4bit"
45
+ dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
 
 
46
  load_in_4bit = True # Use 4bit quantization to reduce memory usage
47
+ max_seq_length = 2048 # Auto supports RoPE Scaling internally
48
  validation_split = 0.1 # 10% of data for validation
49
 
50
 
 
88
  raise
89
 
90
 
91
+ def load_model(model_name: str = DEFAULT_MODEL_NAME) -> tuple[FastLanguageModel, AutoTokenizer]:
92
  """Load and configure the model."""
93
  logger.info("Loading model and tokenizer...")
94
  try:
95
  model, tokenizer = FastLanguageModel.from_pretrained(
96
+ model_name=model_name,
97
  max_seq_length=max_seq_length,
98
  dtype=dtype,
99
  load_in_4bit=load_in_4bit,