George-API commited on
Commit
467f05c
·
verified ·
1 Parent(s): 87d150b

Upload run_cloud_training.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. run_cloud_training.py +353 -751
run_cloud_training.py CHANGED
@@ -1,751 +1,353 @@
1
- #!/usr/bin/env python
2
- # -*- coding: utf-8 -*-
3
-
4
- """
5
- Fine-tuning script for DeepSeek-R1-Distill-Qwen-14B-unsloth-bnb-4bit using unsloth
6
- RESEARCH TRAINING PHASE ONLY - No output generation
7
- WORKS WITH PRE-TOKENIZED DATASET - No re-tokenization
8
- OPTIMIZED FOR L40S GPU (48GB VRAM)
9
- SUPPORTS ENVIRONMENTS WITHOUT MPI
10
- """
11
-
12
- # Set critical environment variables before any imports
13
- import os
14
- # Configure PyTorch memory allocator for better memory management with L40S GPU
15
- os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True,max_split_size_mb:256"
16
- os.environ["XFORMERS_DISABLED"] = "1"
17
- os.environ["TRANSFORMERS_NO_FLASH_ATTENTION"] = "1"
18
- # L40S-specific CUDA optimization
19
- os.environ["CUDA_AUTO_BOOST"] = "1"
20
-
21
- # Completely disable DeepSpeed for Hugging Face Spaces to avoid compatibility issues
22
- os.environ["DISABLE_DEEPSPEED"] = "1"
23
-
24
- import json
25
- import logging
26
- import argparse
27
- import numpy as np
28
- from dotenv import load_dotenv
29
- import torch
30
- import sys
31
- from datasets import load_dataset
32
- import transformers
33
- from transformers import AutoTokenizer, TrainingArguments, Trainer, AutoModelForCausalLM, AutoConfig
34
- from transformers.data.data_collator import DataCollatorMixin
35
- from peft import LoraConfig
36
- from unsloth import FastLanguageModel
37
-
38
- # Configure logging first (before any potential errors with imports)
39
- logging.basicConfig(
40
- level=logging.INFO,
41
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
42
- handlers=[
43
- logging.StreamHandler(),
44
- logging.FileHandler("training.log")
45
- ]
46
- )
47
- logger = logging.getLogger(__name__)
48
-
49
- # Set up environment variables
50
- os.environ["MASTER_ADDR"] = "localhost"
51
- os.environ["MASTER_PORT"] = "9994"
52
- os.environ["RANK"] = "0"
53
- os.environ["LOCAL_RANK"] = "0"
54
- os.environ["WORLD_SIZE"] = "1"
55
-
56
- # DeepSpeed is disabled for Hugging Face Spaces due to compatibility issues
57
- logger.info("DeepSpeed is disabled for Hugging Face Spaces to avoid compatibility issues")
58
- deepspeed_available = False
59
-
60
- # Disable all attention optimizations that might cause issues
61
- os.environ["TRANSFORMERS_NO_FLASH_ATTENTION"] = "1"
62
- os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
63
- os.environ["XFORMERS_DISABLED"] = "1"
64
-
65
- # Completely disable xformers by removing it from sys.modules if it's loaded
66
- if 'xformers' in sys.modules:
67
- del sys.modules['xformers']
68
- if 'xformers.ops' in sys.modules:
69
- del sys.modules['xformers.ops']
70
-
71
- # Patch Python's import system to prevent xformers from being imported
72
- class XFormersBlocker:
73
- def __init__(self, original_importer):
74
- self.original_importer = original_importer
75
-
76
- def find_spec(self, fullname, path, target=None):
77
- if 'xformers' in fullname:
78
- # Block xformers imports
79
- return None
80
- # Use the original importer for everything else
81
- return self.original_importer.find_spec(fullname, path, target)
82
-
83
- # Add our import blocker to sys.meta_path
84
- sys.meta_path.insert(0, XFormersBlocker(sys.meta_path[0]))
85
-
86
- # Make sure torch is installed and available before proceeding
87
- try:
88
- logger.info("Importing torch...")
89
- import torch
90
- logger.info(f"PyTorch version: {torch.__version__}")
91
- logger.info(f"CUDA available: {torch.cuda.is_available()}")
92
- if torch.cuda.is_available():
93
- logger.info(f"CUDA version: {torch.version.cuda}")
94
- logger.info(f"GPU: {torch.cuda.get_device_name(0)}")
95
- except ImportError:
96
- logger.error("PyTorch not found. Installing torch first...")
97
- try:
98
- import subprocess
99
- import sys
100
- subprocess.check_call([sys.executable, "-m", "pip", "install", "torch"])
101
- logger.info("PyTorch installed successfully. Importing...")
102
- import torch
103
- logger.info(f"PyTorch version: {torch.__version__}")
104
- except Exception as e:
105
- logger.error(f"Failed to install PyTorch: {e}")
106
- logger.error("Cannot proceed without PyTorch. Exiting.")
107
- raise
108
-
109
- # Now try to install flash-attention (for systems that support it)
110
- try:
111
- import subprocess
112
- import sys
113
-
114
- # Make sure torch is installed before attempting flash-attn
115
- try:
116
- logger.info("Ensuring PyTorch is installed before flash-attention...")
117
- subprocess.check_call([sys.executable, "-m", "pip", "install", "torch", "--quiet"])
118
- logger.info("PyTorch installation verified")
119
- except Exception as torch_error:
120
- logger.warning(f"PyTorch installation check failed: {torch_error}")
121
- logger.info("Will continue with flash-attention installation anyway")
122
-
123
- logger.info("Attempting to install flash-attention...")
124
-
125
- # Try multiple installation approaches for flash-attention
126
- try:
127
- # First try with pip install
128
- logger.info("Trying standard pip install for flash-attn")
129
- subprocess.check_call([sys.executable, "-m", "pip", "install", "flash-attn"])
130
- except Exception as pip_error:
131
- logger.warning(f"Standard installation failed: {pip_error}")
132
- logger.info("Trying alternative installation approach...")
133
-
134
- # Try the PIP_EXTRA_INDEX_URL approach
135
- env = os.environ.copy()
136
- if "PIP_EXTRA_INDEX_URL" not in env:
137
- env["PIP_EXTRA_INDEX_URL"] = "https://download.pytorch.org/whl/cu118"
138
-
139
- subprocess.check_call(
140
- [sys.executable, "-m", "pip", "install", "flash-attn"],
141
- env=env
142
- )
143
-
144
- logger.info("Successfully installed flash-attention")
145
- except Exception as e:
146
- logger.warning(f"Failed to install flash-attention: {e}")
147
- logger.info("Continuing without flash-attention")
148
-
149
- # Check if flash attention was successfully installed
150
- flash_attention_available = False
151
- try:
152
- import flash_attn
153
- flash_attention_available = True
154
- logger.info(f"Flash Attention will be used (version: {flash_attn.__version__})")
155
- # We'll handle flash attention configuration during model loading
156
- except ImportError:
157
- logger.info("Flash Attention not available, will use standard attention mechanism")
158
-
159
- # Check if tensorboard is available
160
- try:
161
- import tensorboard
162
- TENSORBOARD_AVAILABLE = True
163
- except ImportError:
164
- TENSORBOARD_AVAILABLE = False
165
- print("Tensorboard not available. Will skip tensorboard logging.")
166
-
167
- # Default dataset path - use the correct path with username
168
- DEFAULT_DATASET = "George-API/phi4-cognitive-dataset"
169
-
170
- def load_config(config_path):
171
- """Load the transformers config from JSON file"""
172
- logger.info(f"Loading config from {config_path}")
173
- with open(config_path, 'r') as f:
174
- config = json.load(f)
175
- return config
176
-
177
- def load_and_prepare_dataset(dataset_name, config):
178
- """
179
- Load and prepare the dataset for fine-tuning.
180
- Sort entries by prompt_number as required.
181
- Handles both pre-tokenized and string content.
182
- """
183
- # Use the default dataset path if no specific path is provided
184
- if dataset_name == "phi4-cognitive-dataset":
185
- dataset_name = DEFAULT_DATASET
186
-
187
- logger.info(f"Loading dataset: {dataset_name}")
188
-
189
- try:
190
- # Load dataset
191
- dataset = load_dataset(dataset_name)
192
-
193
- # Extract the split we want to use (usually 'train')
194
- if 'train' in dataset:
195
- dataset = dataset['train']
196
-
197
- # Get the dataset config
198
- dataset_config = config.get("dataset_config", {})
199
- sort_field = dataset_config.get("sort_by_field", "prompt_number")
200
-
201
- # Always sort in ascending order by prompt_number
202
- logger.info(f"Sorting dataset by {sort_field} in ascending order")
203
- dataset = dataset.sort(sort_field)
204
-
205
- # Verify sorting
206
- if len(dataset) > 1:
207
- first_prompt = dataset[0].get(sort_field, None)
208
- last_prompt = dataset[-1].get(sort_field, None)
209
- logger.info(f"Dataset sorted: first {sort_field}={first_prompt}, last {sort_field}={last_prompt}")
210
-
211
- # Additional verification of a few samples
212
- sample_indices = [0, len(dataset)//2, len(dataset)-1]
213
- sample_prompts = [dataset[i].get(sort_field, None) for i in sample_indices]
214
- logger.info(f"Sample prompt numbers: {sample_prompts}")
215
-
216
- # Verify order is ascending
217
- if not all(sample_prompts[i] <= sample_prompts[i+1] for i in range(len(sample_prompts)-1)):
218
- logger.warning("Dataset may not be properly sorted! Please check the ordering.")
219
-
220
- # Print dataset structure for debugging
221
- logger.info(f"Dataset loaded with {len(dataset)} entries")
222
- logger.info(f"Dataset columns: {dataset.column_names}")
223
-
224
- # Print a sample entry to understand structure
225
- if len(dataset) > 0:
226
- sample = dataset[0]
227
- logger.info(f"Sample entry structure: {list(sample.keys())}")
228
-
229
- # Check if dataset is pre-tokenized or contains string content
230
- is_pre_tokenized = False
231
-
232
- if 'input_ids' in sample and isinstance(sample['input_ids'], list) and all(isinstance(x, int) for x in sample['input_ids']):
233
- logger.info("Dataset appears to be pre-tokenized with input_ids field")
234
- is_pre_tokenized = True
235
- elif 'conversations' in sample:
236
- logger.info(f"Sample conversations structure: {sample['conversations'][:1]}")
237
-
238
- # Check if conversations contain pre-tokenized data
239
- if isinstance(sample['conversations'], list) and len(sample['conversations']) > 0:
240
- conv = sample['conversations'][0]
241
- if isinstance(conv, dict) and 'input_ids' in conv and isinstance(conv['input_ids'], list):
242
- logger.info("Dataset appears to be pre-tokenized in conversations.input_ids")
243
- is_pre_tokenized = True
244
- elif isinstance(conv, dict) and 'content' in conv:
245
- content = conv['content']
246
- if isinstance(content, list) and all(isinstance(x, int) for x in content):
247
- logger.info("Dataset appears to be pre-tokenized in conversations.content")
248
- is_pre_tokenized = True
249
- else:
250
- logger.info("Dataset appears to contain string content that will need tokenization")
251
-
252
- if is_pre_tokenized:
253
- logger.info("Using pre-tokenized dataset - tokenizer will only be used as fallback")
254
- else:
255
- logger.info("Dataset contains string content - tokenizer will be used")
256
-
257
- return dataset
258
-
259
- except Exception as e:
260
- logger.error(f"Error loading dataset: {str(e)}")
261
- logger.info("Available datasets in the Hub:")
262
- # Print a more helpful error message
263
- print(f"Failed to load dataset: {dataset_name}")
264
- print(f"Make sure the dataset exists and is accessible.")
265
- print(f"If it's a private dataset, ensure your HF_TOKEN has access to it.")
266
- raise
267
-
268
- def tokenize_string(text, tokenizer):
269
- """Tokenize a string using the provided tokenizer"""
270
- if not text:
271
- return []
272
-
273
- # Tokenize the text
274
- tokens = tokenizer.encode(text, add_special_tokens=False)
275
- return tokens
276
-
277
- # Data collator for pre-tokenized dataset
278
- class PreTokenizedCollator(DataCollatorMixin):
279
- """
280
- Data collator that can handle both pre-tokenized datasets and string content.
281
- Will tokenize strings if necessary, but logs warnings.
282
- """
283
- def __init__(self, pad_token_id=0, tokenizer=None):
284
- self.pad_token_id = pad_token_id
285
- self.tokenizer = tokenizer # Keep a reference to the tokenizer for fallback tokenization
286
-
287
- def __call__(self, features):
288
- # Print a sample feature to understand structure
289
- if len(features) > 0:
290
- logger.info(f"Sample feature keys: {list(features[0].keys())}")
291
-
292
- # Extract input_ids from conversations if needed
293
- processed_features = []
294
- for feature in features:
295
- # If input_ids is directly available, use it without tokenization
296
- if 'input_ids' in feature and isinstance(feature['input_ids'], list):
297
- # Already tokenized, no processing needed
298
- processed_features.append(feature)
299
- continue
300
-
301
- # If input_ids is not directly available, try to extract from conversations
302
- if 'input_ids' not in feature and 'conversations' in feature:
303
- # Extract from conversations based on your dataset structure
304
- conversations = feature['conversations']
305
-
306
- # Debug the conversations structure (only for first batch)
307
- if len(processed_features) == 0:
308
- logger.info(f"Conversations type: {type(conversations)}")
309
- if isinstance(conversations, list) and len(conversations) > 0:
310
- logger.info(f"First conversation type: {type(conversations[0])}")
311
-
312
- # Try different approaches to extract input_ids
313
- if isinstance(conversations, list) and len(conversations) > 0:
314
- # Case 1: If conversations is a list of dicts with 'input_ids' field (pre-tokenized)
315
- if isinstance(conversations[0], dict) and 'input_ids' in conversations[0]:
316
- feature['input_ids'] = conversations[0]['input_ids']
317
-
318
- # Case 2: If conversations itself contains the input_ids (pre-tokenized)
319
- elif all(isinstance(x, int) for x in conversations):
320
- feature['input_ids'] = conversations
321
-
322
- # Case 3: If conversations is a list of dicts with 'content' field
323
- elif isinstance(conversations[0], dict) and 'content' in conversations[0]:
324
- content = conversations[0]['content']
325
-
326
- # If content is already a list of integers, use it directly
327
- if isinstance(content, list) and all(isinstance(x, int) for x in content):
328
- feature['input_ids'] = content
329
- # If content is a string, tokenize it with a warning
330
- elif isinstance(content, str) and self.tokenizer:
331
- logger.warning("Found string content in dataset. Tokenizing as fallback.")
332
- feature['input_ids'] = self.tokenizer.encode(content, add_special_tokens=False)
333
- else:
334
- logger.warning(f"Unexpected content format: {type(content)}")
335
- continue
336
-
337
- # Case 4: If conversations is a list of strings
338
- elif all(isinstance(x, str) for x in conversations) and self.tokenizer:
339
- # Join all strings and tokenize
340
- logger.warning("Found string conversations in dataset. Tokenizing as fallback.")
341
- full_text = " ".join(conversations)
342
- feature['input_ids'] = self.tokenizer.encode(full_text, add_special_tokens=False)
343
-
344
- # Ensure input_ids is a list of integers
345
- if 'input_ids' in feature:
346
- # If input_ids is a string, tokenize it
347
- if isinstance(feature['input_ids'], str) and self.tokenizer:
348
- logger.warning("Found string input_ids in dataset. Tokenizing as fallback.")
349
- feature['input_ids'] = self.tokenizer.encode(feature['input_ids'], add_special_tokens=False)
350
- # If input_ids is not a list, convert it
351
- elif not isinstance(feature['input_ids'], list):
352
- try:
353
- feature['input_ids'] = list(feature['input_ids'])
354
- except:
355
- logger.error(f"Could not convert input_ids to list: {type(feature['input_ids'])}")
356
- continue
357
- else:
358
- logger.warning("No input_ids found in this example. Skipping.")
359
- continue
360
-
361
- processed_features.append(feature)
362
-
363
- # If we still don't have input_ids, log an error
364
- if len(processed_features) == 0:
365
- logger.error("No valid examples found in batch. Check dataset format.")
366
- raise ValueError("No valid examples found. Please check dataset structure.")
367
-
368
- if 'input_ids' not in processed_features[0]:
369
- logger.error(f"Could not find input_ids in features. Available keys: {list(processed_features[0].keys())}")
370
- if 'conversations' in processed_features[0]:
371
- logger.error(f"Conversations structure: {processed_features[0]['conversations'][:1]}")
372
- raise ValueError("Could not find input_ids in dataset. Please check dataset structure.")
373
-
374
- # Determine max length in this batch
375
- batch_max_len = max(len(x["input_ids"]) for x in processed_features)
376
-
377
- # Initialize batch tensors
378
- batch = {
379
- "input_ids": torch.ones((len(processed_features), batch_max_len), dtype=torch.long) * self.pad_token_id,
380
- "attention_mask": torch.zeros((len(processed_features), batch_max_len), dtype=torch.long),
381
- "labels": torch.ones((len(processed_features), batch_max_len), dtype=torch.long) * -100 # -100 is ignored in loss
382
- }
383
-
384
- # Fill batch tensors
385
- for i, feature in enumerate(processed_features):
386
- input_ids = feature["input_ids"]
387
- seq_len = len(input_ids)
388
-
389
- # Convert to tensor if it's a list
390
- if isinstance(input_ids, list):
391
- input_ids = torch.tensor(input_ids, dtype=torch.long)
392
-
393
- # Copy data to batch tensors
394
- batch["input_ids"][i, :seq_len] = input_ids
395
- batch["attention_mask"][i, :seq_len] = 1
396
-
397
- # If there are labels, use them, otherwise use input_ids
398
- if "labels" in feature:
399
- labels = feature["labels"]
400
- if isinstance(labels, list):
401
- labels = torch.tensor(labels, dtype=torch.long)
402
- batch["labels"][i, :len(labels)] = labels
403
- else:
404
- batch["labels"][i, :seq_len] = input_ids
405
-
406
- return batch
407
-
408
- def create_training_marker(output_dir):
409
- """Create a marker file to indicate training is active"""
410
- # Create in current directory for app.py to find
411
- with open("TRAINING_ACTIVE", "w") as f:
412
- f.write(f"Training active in {output_dir}")
413
-
414
- # Also create in output directory
415
- os.makedirs(output_dir, exist_ok=True)
416
- with open(os.path.join(output_dir, "RESEARCH_TRAINING_ONLY"), "w") as f:
417
- f.write("This model is for research training only. No interactive outputs.")
418
-
419
- def remove_training_marker():
420
- """Remove the training marker file"""
421
- if os.path.exists("TRAINING_ACTIVE"):
422
- os.remove("TRAINING_ACTIVE")
423
- logger.info("Removed training active marker")
424
-
425
- def load_model_safely(model_name, max_seq_length, dtype=None, use_flash_attention=False, use_deepspeed=False):
426
- """
427
- Load the model directly with HuggingFace, bypassing Unsloth optimizations
428
- to avoid memory-efficient attention issues
429
- """
430
- logger.info(f"Loading model: {model_name}")
431
-
432
- # Create BitsAndBytesConfig for 4-bit quantization
433
- from transformers import BitsAndBytesConfig
434
- bnb_config = BitsAndBytesConfig(
435
- load_in_4bit=True,
436
- bnb_4bit_compute_dtype=torch.float16,
437
- bnb_4bit_quant_type="nf4",
438
- bnb_4bit_use_double_quant=True
439
- )
440
-
441
- # Force eager implementation to avoid BMGHK format issues
442
- attn_implementation = "eager"
443
- logger.info(f"Forcing eager attention implementation to avoid BMGHK format issues")
444
-
445
- # Skip Unsloth and use standard HuggingFace loading
446
- logger.info("Bypassing Unsloth optimizations to avoid memory-efficient attention issues")
447
-
448
- # Check available GPUs
449
- gpu_count = torch.cuda.device_count()
450
- logger.info(f"Found {gpu_count} GPU(s) available")
451
-
452
- # Load with standard HuggingFace
453
- config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
454
-
455
- # Set attention implementation in config
456
- config.attn_implementation = attn_implementation
457
-
458
- # Disable any custom attention mechanisms
459
- if hasattr(config, "use_flash_attention"):
460
- config.use_flash_attention = False
461
- if hasattr(config, "use_memory_efficient_attention"):
462
- config.use_memory_efficient_attention = False
463
-
464
- tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
465
-
466
- # Set device mapping based on whether DeepSpeed is used
467
- # When using DeepSpeed, we should use 'cpu' or 'meta' for initial loading
468
- # to avoid OOM issues, as DeepSpeed will handle the device placement
469
- if use_deepspeed:
470
- logger.info("Using DeepSpeed - loading model initially on CPU to avoid OOM issues")
471
- device_map = "cpu" # Load on CPU first, DeepSpeed will handle distribution
472
- else:
473
- # Always use auto device mapping for cloud hardware when not using DeepSpeed
474
- device_map = "auto"
475
-
476
- logger.info(f"Using device_map={device_map} for initial model loading")
477
-
478
- # Load the model
479
- model = AutoModelForCausalLM.from_pretrained(
480
- model_name,
481
- config=config,
482
- device_map=device_map,
483
- torch_dtype=dtype or torch.float16,
484
- quantization_config=bnb_config,
485
- trust_remote_code=True,
486
- attn_implementation=attn_implementation
487
- )
488
-
489
- logger.info("Model loaded successfully with standard HF loading")
490
-
491
- # If using DeepSpeed, ensure model is properly prepared
492
- if use_deepspeed:
493
- logger.info("Model loaded on CPU - DeepSpeed will handle device placement during training")
494
-
495
- return model, tokenizer
496
-
497
- def train(config_path, dataset_name, output_dir):
498
- """Main training function - RESEARCH TRAINING PHASE ONLY"""
499
- # Load environment variables
500
- load_dotenv()
501
- config = load_config(config_path)
502
-
503
- # Set CUDA launch blocking for better error reporting
504
- os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
505
-
506
- # Try to unload xformers if it's loaded
507
- if 'xformers' in sys.modules:
508
- logger.info("Removing xformers from sys.modules")
509
- del sys.modules['xformers']
510
-
511
- # Patch torch.nn.functional to avoid memory_efficient_attention
512
- try:
513
- import torch.nn.functional as F
514
- if hasattr(F, 'scaled_dot_product_attention'):
515
- logger.info("Patching torch.nn.functional.scaled_dot_product_attention")
516
- original_sdpa = F.scaled_dot_product_attention
517
-
518
- def safe_sdpa(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None):
519
- # Force disable memory efficient attention
520
- logger.info("Using safe scaled_dot_product_attention (no xformers)")
521
- return original_sdpa(query, key, value, attn_mask, dropout_p, is_causal, scale)
522
-
523
- F.scaled_dot_product_attention = safe_sdpa
524
- except Exception as e:
525
- logger.warning(f"Failed to patch scaled_dot_product_attention: {e}")
526
-
527
- # Extract configs
528
- model_config = config.get("model_config", {})
529
- training_config = config.get("training_config", {})
530
- hardware_config = config.get("hardware_config", {})
531
- lora_config = config.get("lora_config", {})
532
- dataset_config = config.get("dataset_config", {})
533
-
534
- # Set the output directory
535
- output_dir = output_dir or training_config.get("output_dir", "fine_tuned_model")
536
- os.makedirs(output_dir, exist_ok=True)
537
-
538
- # Create training marker
539
- create_training_marker(output_dir)
540
-
541
- try:
542
- # Print configuration summary
543
- logger.info("RESEARCH TRAINING PHASE ACTIVE - No output generation")
544
- logger.info("Configuration Summary:")
545
- model_name = model_config.get("model_name_or_path")
546
- logger.info(f"Model: {model_name}")
547
- logger.info(f"Dataset: {dataset_name if dataset_name != 'phi4-cognitive-dataset' else DEFAULT_DATASET}")
548
- logger.info(f"Output directory: {output_dir}")
549
- logger.info("IMPORTANT: Using already 4-bit quantized model - not re-quantizing")
550
-
551
- # Check GPU availability
552
- gpu_count = torch.cuda.device_count()
553
- logger.info(f"Found {gpu_count} GPU(s) available")
554
- for i in range(gpu_count):
555
- logger.info(f"GPU {i}: {torch.cuda.get_device_name(i)}")
556
-
557
- # Load and prepare the dataset
558
- dataset = load_and_prepare_dataset(dataset_name, config)
559
-
560
- # Initialize tokenizer (just for model initialization, not for tokenizing data)
561
- logger.info("Loading tokenizer (for model initialization only, not for tokenizing data)")
562
- tokenizer = AutoTokenizer.from_pretrained(
563
- model_name,
564
- trust_remote_code=True
565
- )
566
- tokenizer.pad_token = tokenizer.eos_token
567
-
568
- # Initialize model
569
- logger.info("Initializing model (preserving 4-bit quantization)")
570
-
571
- # Use full sequence length of 2048 as required for pre-tokenized dataset
572
- max_seq_length = training_config.get("max_seq_length", 2048)
573
- logger.info(f"Using sequence length: {max_seq_length} as required for pre-tokenized dataset")
574
-
575
- # Create LoRA config directly
576
- logger.info("Creating LoRA configuration")
577
- lora_config_obj = LoraConfig(
578
- r=lora_config.get("r", 16),
579
- lora_alpha=lora_config.get("lora_alpha", 32),
580
- lora_dropout=lora_config.get("lora_dropout", 0.05),
581
- bias=lora_config.get("bias", "none"),
582
- target_modules=lora_config.get("target_modules", ["q_proj", "k_proj", "v_proj", "o_proj"])
583
- )
584
-
585
- # Force eager attention implementation
586
- use_flash_attention = False # Override to force eager implementation
587
-
588
- # Initialize ds_config_path to None before checking
589
- ds_config_path = None
590
-
591
- # Optimize batch size for L40S GPU
592
- gpu_info = torch.cuda.get_device_properties(0)
593
- logger.info(f"GPU Model: {gpu_info.name}, VRAM: {gpu_info.total_memory / 1e9:.2f} GB")
594
-
595
- # For L40S GPU, we can use a larger batch size and shard model across the single GPU
596
- if "L40S" in gpu_info.name or gpu_info.total_memory > 40e9: # Check if it's L40S (>40GB VRAM)
597
- logger.info("Detected L40S GPU - optimizing for high-memory GPU")
598
- per_device_train_batch_size = training_config.get("per_device_train_batch_size", 4)
599
- logger.info(f"Using optimized batch size for L40S: {per_device_train_batch_size}")
600
- else:
601
- # Default to a smaller batch size for other GPUs
602
- per_device_train_batch_size = 2
603
- logger.info(f"Using conservative batch size for non-L40S GPU: {per_device_train_batch_size}")
604
-
605
- # Check if DeepSpeed config is available and if DeepSpeed is available
606
- # Note: DeepSpeed is now disabled by default for HF Spaces
607
- deepspeed_config = None
608
- logger.info("DeepSpeed is disabled for Hugging Face Spaces to avoid compatibility issues")
609
- ds_config_path = None
610
- using_deepspeed = False
611
-
612
- # Initialize model with our safe loading function
613
- logger.info("Loading pre-quantized model with eager attention")
614
- dtype = torch.float16 if hardware_config.get("fp16", True) else None
615
- model, tokenizer = load_model_safely(model_name, max_seq_length, dtype, use_flash_attention, use_deepspeed=using_deepspeed)
616
-
617
- # Disable generation capabilities for research training
618
- logger.info("Disabling generation capabilities - Research training only")
619
- model.config.is_decoder = False
620
- model.config.task_specific_params = None
621
-
622
- # Apply LoRA to model
623
- logger.info("Applying LoRA to model")
624
- from peft import get_peft_model
625
- model = get_peft_model(model, lora_config_obj)
626
- logger.info("Successfully applied LoRA with standard PEFT")
627
-
628
- # Explicitly set attention implementation in model config again after PEFT
629
- model.config.attn_implementation = "eager"
630
-
631
- # No need to format the dataset - it's already pre-tokenized
632
- logger.info("Using dataset with flexible tokenization handling")
633
- logger.info("Will use pre-tokenized data if available, or tokenize strings as fallback")
634
- training_dataset = dataset
635
-
636
- # Configure reporting backends with fallbacks
637
- reports = []
638
- if TENSORBOARD_AVAILABLE:
639
- reports.append("tensorboard")
640
- logger.info("Tensorboard available and enabled for reporting")
641
- else:
642
- logger.warning("Tensorboard not available - metrics won't be logged to tensorboard")
643
-
644
- if os.getenv("WANDB_API_KEY"):
645
- reports.append("wandb")
646
- logger.info("Wandb API key found, enabling wandb reporting")
647
-
648
- # Default to "none" if no reporting backends are available
649
- if not reports:
650
- reports = ["none"]
651
- logger.warning("No reporting backends available - training metrics won't be logged")
652
-
653
- training_args_dict = {
654
- "output_dir": output_dir,
655
- "num_train_epochs": training_config.get("num_train_epochs", 3),
656
- "per_device_train_batch_size": per_device_train_batch_size,
657
- "gradient_accumulation_steps": training_config.get("gradient_accumulation_steps", 4),
658
- "learning_rate": training_config.get("learning_rate", 2e-5),
659
- "lr_scheduler_type": training_config.get("lr_scheduler_type", "cosine"),
660
- "warmup_ratio": training_config.get("warmup_ratio", 0.03),
661
- "weight_decay": training_config.get("weight_decay", 0.01),
662
- "optim": training_config.get("optim", "adamw_torch"),
663
- "logging_steps": training_config.get("logging_steps", 10),
664
- "save_steps": training_config.get("save_steps", 200),
665
- "save_total_limit": training_config.get("save_total_limit", 3),
666
- "fp16": hardware_config.get("fp16", True),
667
- "bf16": hardware_config.get("bf16", False),
668
- "max_grad_norm": training_config.get("max_grad_norm", 0.3),
669
- "report_to": reports,
670
- "logging_first_step": training_config.get("logging_first_step", True),
671
- "disable_tqdm": training_config.get("disable_tqdm", False),
672
- "remove_unused_columns": False,
673
- "seed": 42,
674
- "dataloader_num_workers": 4, # Use multiple workers for data loading
675
- }
676
-
677
- # Add DeepSpeed config path if available and enabled
678
- # DeepSpeed is disabled for Hugging Face Spaces
679
- logger.info("DeepSpeed is disabled - using standard training")
680
-
681
- # Create TrainingArguments with validated parameters
682
- try:
683
- training_args = TrainingArguments(**training_args_dict)
684
- except Exception as e:
685
- logger.error(f"Failed to create training arguments: {e}")
686
- if "deepspeed" in training_args_dict:
687
- logger.warning("Removing any DeepSpeed configuration")
688
- del training_args_dict["deepspeed"]
689
- training_args = TrainingArguments(**training_args_dict)
690
-
691
- # Create trainer with pre-tokenized collator
692
- trainer = Trainer(
693
- model=model,
694
- args=training_args,
695
- train_dataset=training_dataset,
696
- data_collator=PreTokenizedCollator(pad_token_id=tokenizer.pad_token_id, tokenizer=tokenizer),
697
- )
698
-
699
- # Start training
700
- logger.info("Starting training - RESEARCH PHASE ONLY")
701
- trainer.train()
702
-
703
- # Save the model
704
- logger.info(f"Saving model to {output_dir}")
705
- trainer.save_model(output_dir)
706
-
707
- # Save LoRA adapter separately for easier deployment
708
- lora_output_dir = os.path.join(output_dir, "lora_adapter")
709
- model.save_pretrained(lora_output_dir)
710
- logger.info(f"Saved LoRA adapter to {lora_output_dir}")
711
-
712
- # Save tokenizer for completeness
713
- tokenizer_output_dir = os.path.join(output_dir, "tokenizer")
714
- tokenizer.save_pretrained(tokenizer_output_dir)
715
- logger.info(f"Saved tokenizer to {tokenizer_output_dir}")
716
-
717
- # Copy config file for reference
718
- with open(os.path.join(output_dir, "training_config.json"), "w") as f:
719
- json.dump(config, f, indent=2)
720
-
721
- logger.info("Training complete - RESEARCH PHASE ONLY")
722
- return output_dir
723
-
724
- finally:
725
- # Always remove the training marker when done
726
- remove_training_marker()
727
-
728
- if __name__ == "__main__":
729
- parser = argparse.ArgumentParser(description="Fine-tune Unsloth/DeepSeek-R1-Distill-Qwen-14B-unsloth-bnb-4bit model (RESEARCH ONLY)")
730
- parser.add_argument("--config", type=str, default="transformers_config.json",
731
- help="Path to the transformers config JSON file")
732
- parser.add_argument("--dataset", type=str, default="phi4-cognitive-dataset",
733
- help="Dataset name or path")
734
- parser.add_argument("--output_dir", type=str, default=None,
735
- help="Output directory for the fine-tuned model")
736
- parser.add_argument("--use_flash_attention", action="store_true",
737
- help="Use Flash Attention if available (NOT RECOMMENDED)")
738
-
739
- args = parser.parse_args()
740
-
741
- # Override flash attention setting to force eager implementation
742
- args.use_flash_attention = False
743
-
744
- # Run training - Research phase only
745
- try:
746
- output_path = train(args.config, args.dataset, args.output_dir)
747
- print(f"Research training completed. Model saved to: {output_path}")
748
- except Exception as e:
749
- logger.error(f"Training failed: {str(e)}")
750
- remove_training_marker() # Clean up marker if training fails
751
- raise
 
1
+ #!/usr/bin/env python
2
+
3
+ """
4
+ Simplified fine-tuning script for DeepSeek-R1-Distill-Qwen-14B-unsloth-bnb-4bit
5
+ - Optimized for L40S GPU
6
+ - Works with pre-tokenized datasets
7
+ - Research training only (no inference)
8
+ """
9
+
10
+ import os
11
+ import logging
12
+ import json
13
+ import torch
14
+ import argparse
15
+ from datasets import load_dataset
16
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, AutoConfig, BitsAndBytesConfig
17
+ from transformers.data.data_collator import DataCollatorMixin
18
+ from peft import LoraConfig, get_peft_model
19
+ from dotenv import load_dotenv
20
+
21
+ # Basic environment setup for L40S
22
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True,max_split_size_mb:256"
23
+ os.environ["TRANSFORMERS_NO_FLASH_ATTENTION"] = "1"
24
+
25
+ # Set up logging
26
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
27
+ logger = logging.getLogger(__name__)
28
+
29
+ # Create a marker file to indicate training is active
30
+ def create_training_marker(output_dir):
31
+ os.makedirs(output_dir, exist_ok=True)
32
+ with open("TRAINING_ACTIVE", "w") as f:
33
+ f.write(f"Training active in {output_dir}")
34
+
35
+ with open(os.path.join(output_dir, "RESEARCH_TRAINING_ONLY"), "w") as f:
36
+ f.write("This model is for research training only. No interactive outputs.")
37
+
38
+ # Remove the training marker file
39
+ def remove_training_marker():
40
+ if os.path.exists("TRAINING_ACTIVE"):
41
+ os.remove("TRAINING_ACTIVE")
42
+ logger.info("Removed training active marker")
43
+
44
+ # Custom data collator for pre-tokenized data
45
+ class PreTokenizedCollator(DataCollatorMixin):
46
+ def __init__(self, pad_token_id=0, tokenizer=None):
47
+ self.pad_token_id = pad_token_id
48
+ self.tokenizer = tokenizer # Keep reference to tokenizer for fallback
49
+
50
+ def __call__(self, features):
51
+ # Extract features properly from the batch
52
+ processed_features = []
53
+ for feature in features:
54
+ # If input_ids is directly available, use it
55
+ if 'input_ids' in feature and isinstance(feature['input_ids'], list):
56
+ processed_features.append(feature)
57
+ continue
58
+
59
+ # If input_ids is not available, try to extract from conversations
60
+ if 'input_ids' not in feature and 'conversations' in feature:
61
+ conversations = feature['conversations']
62
+
63
+ if isinstance(conversations, list) and len(conversations) > 0:
64
+ # Case 1: If conversations has 'input_ids' field (pre-tokenized)
65
+ if isinstance(conversations[0], dict) and 'input_ids' in conversations[0]:
66
+ feature['input_ids'] = conversations[0]['input_ids']
67
+
68
+ # Case 2: If conversations itself contains input_ids
69
+ elif all(isinstance(x, int) for x in conversations):
70
+ feature['input_ids'] = conversations
71
+
72
+ # Case 3: If conversations has 'content' field
73
+ elif isinstance(conversations[0], dict) and 'content' in conversations[0]:
74
+ content = conversations[0]['content']
75
+
76
+ # If content is already tokens, use directly
77
+ if isinstance(content, list) and all(isinstance(x, int) for x in content):
78
+ feature['input_ids'] = content
79
+ # If content is a string and we have tokenizer, tokenize as fallback
80
+ elif isinstance(content, str) and self.tokenizer:
81
+ logger.warning("Tokenizing string content as fallback")
82
+ feature['input_ids'] = self.tokenizer.encode(content, add_special_tokens=False)
83
+
84
+ # Ensure input_ids is present and is a list of integers
85
+ if 'input_ids' in feature:
86
+ if isinstance(feature['input_ids'], str) and self.tokenizer:
87
+ feature['input_ids'] = self.tokenizer.encode(feature['input_ids'], add_special_tokens=False)
88
+ elif not isinstance(feature['input_ids'], list):
89
+ try:
90
+ feature['input_ids'] = list(feature['input_ids'])
91
+ except Exception as e:
92
+ logger.error(f"Could not convert input_ids to list: {e}")
93
+ continue
94
+
95
+ processed_features.append(feature)
96
+
97
+ if len(processed_features) == 0:
98
+ raise ValueError("No valid examples found. Check dataset structure.")
99
+
100
+ # Determine max length in this batch
101
+ batch_max_len = max(len(x["input_ids"]) for x in processed_features)
102
+
103
+ # Initialize batch tensors
104
+ batch = {
105
+ "input_ids": torch.ones((len(processed_features), batch_max_len), dtype=torch.long) * self.pad_token_id,
106
+ "attention_mask": torch.zeros((len(processed_features), batch_max_len), dtype=torch.long),
107
+ "labels": torch.ones((len(processed_features), batch_max_len), dtype=torch.long) * -100 # -100 is ignored in loss
108
+ }
109
+
110
+ # Fill batch tensors
111
+ for i, feature in enumerate(processed_features):
112
+ input_ids = feature["input_ids"]
113
+ seq_len = len(input_ids)
114
+
115
+ # Convert to tensor if it's a list
116
+ if isinstance(input_ids, list):
117
+ input_ids = torch.tensor(input_ids, dtype=torch.long)
118
+
119
+ # Copy data to batch tensors
120
+ batch["input_ids"][i, :seq_len] = input_ids
121
+ batch["attention_mask"][i, :seq_len] = 1
122
+
123
+ # If there are labels, use them, otherwise use input_ids
124
+ if "labels" in feature:
125
+ labels = feature["labels"]
126
+ if isinstance(labels, list):
127
+ labels = torch.tensor(labels, dtype=torch.long)
128
+ batch["labels"][i, :len(labels)] = labels
129
+ else:
130
+ batch["labels"][i, :seq_len] = input_ids
131
+
132
+ return batch
133
+
134
+ # Load and prepare dataset with proper sorting
135
+ def load_and_prepare_dataset(dataset_name, config):
136
+ """Load and prepare the dataset for fine-tuning with proper sorting"""
137
+ logger.info(f"Loading dataset: {dataset_name}")
138
+
139
+ try:
140
+ # Load dataset
141
+ dataset = load_dataset(dataset_name)
142
+
143
+ # Extract the split we want to use (usually 'train')
144
+ if 'train' in dataset:
145
+ dataset = dataset['train']
146
+
147
+ # Get the dataset config
148
+ dataset_config = config.get("dataset_config", {})
149
+ sort_field = dataset_config.get("sort_by_field", "prompt_number")
150
+
151
+ # Sort in ascending order by specified field
152
+ logger.info(f"Sorting dataset by {sort_field} in ascending order")
153
+ dataset = dataset.sort(sort_field)
154
+
155
+ # Print dataset info
156
+ logger.info(f"Dataset loaded with {len(dataset)} entries")
157
+ logger.info(f"Dataset columns: {dataset.column_names}")
158
+
159
+ # Print sample for debugging
160
+ if len(dataset) > 0:
161
+ logger.info(f"Sample entry structure: {list(dataset[0].keys())}")
162
+
163
+ return dataset
164
+
165
+ except Exception as e:
166
+ logger.error(f"Error loading dataset: {str(e)}")
167
+ raise
168
+
169
+ # Main training function
170
+ def train(config_path, dataset_name, output_dir):
171
+ # Load environment variables
172
+ load_dotenv()
173
+
174
+ # Load config
175
+ with open(config_path, 'r') as f:
176
+ config = json.load(f)
177
+
178
+ # Create training marker
179
+ create_training_marker(output_dir)
180
+
181
+ try:
182
+ # Extract configs
183
+ model_config = config.get("model_config", {})
184
+ training_config = config.get("training_config", {})
185
+ hardware_config = config.get("hardware_config", {})
186
+ lora_config = config.get("lora_config", {})
187
+ dataset_config = config.get("dataset_config", {})
188
+
189
+ # Load and prepare dataset with proper sorting
190
+ dataset = load_and_prepare_dataset(dataset_name, config)
191
+
192
+ # Load model settings
193
+ model_name = model_config.get("model_name_or_path")
194
+ logger.info(f"Using model: {model_name}")
195
+
196
+ # Initialize tokenizer
197
+ logger.info("Loading tokenizer")
198
+ tokenizer = AutoTokenizer.from_pretrained(
199
+ model_name,
200
+ trust_remote_code=True
201
+ )
202
+ tokenizer.pad_token = tokenizer.eos_token
203
+
204
+ # Create quantization config
205
+ quant_config = config.get("quantization_config", {})
206
+ bnb_config = BitsAndBytesConfig(
207
+ load_in_4bit=quant_config.get("load_in_4bit", True),
208
+ bnb_4bit_compute_dtype=torch.float16,
209
+ bnb_4bit_quant_type=quant_config.get("bnb_4bit_quant_type", "nf4"),
210
+ bnb_4bit_use_double_quant=quant_config.get("bnb_4bit_use_double_quant", True)
211
+ )
212
+
213
+ # Create model with proper configuration
214
+ logger.info("Loading pre-quantized model")
215
+ model = AutoModelForCausalLM.from_pretrained(
216
+ model_name,
217
+ quantization_config=bnb_config,
218
+ device_map="auto",
219
+ torch_dtype=torch.float16,
220
+ trust_remote_code=True,
221
+ use_cache=model_config.get("use_cache", False),
222
+ attn_implementation=hardware_config.get("attn_implementation", "eager")
223
+ )
224
+
225
+ # Apply rope scaling if configured
226
+ if "rope_scaling" in model_config:
227
+ logger.info(f"Applying rope scaling: {model_config['rope_scaling']}")
228
+ if hasattr(model.config, "rope_scaling"):
229
+ model.config.rope_scaling = model_config["rope_scaling"]
230
+
231
+ # Create LoRA config
232
+ logger.info("Creating LoRA configuration")
233
+ lora_config_obj = LoraConfig(
234
+ r=lora_config.get("r", 16),
235
+ lora_alpha=lora_config.get("lora_alpha", 32),
236
+ lora_dropout=lora_config.get("lora_dropout", 0.05),
237
+ bias=lora_config.get("bias", "none"),
238
+ target_modules=lora_config.get("target_modules", ["q_proj", "k_proj", "v_proj", "o_proj"])
239
+ )
240
+
241
+ # Apply LoRA to model
242
+ logger.info("Applying LoRA to model")
243
+ model = get_peft_model(model, lora_config_obj)
244
+ logger.info("Successfully applied LoRA")
245
+
246
+ # Check for L40S GPU and optimize batch size
247
+ if torch.cuda.is_available():
248
+ gpu_info = torch.cuda.get_device_properties(0)
249
+ logger.info(f"GPU: {gpu_info.name}, VRAM: {gpu_info.total_memory / 1e9:.2f} GB")
250
+
251
+ # Check if it's an L40S or high-memory GPU
252
+ if "L40S" in gpu_info.name or gpu_info.total_memory > 40e9:
253
+ logger.info("Detected L40S GPU - optimizing for high-memory GPU")
254
+ per_device_train_batch_size = training_config.get("per_device_train_batch_size", 4)
255
+ else:
256
+ # Use a smaller batch size for other GPUs
257
+ per_device_train_batch_size = 2
258
+ logger.info(f"Using conservative batch size for non-L40S GPU: {per_device_train_batch_size}")
259
+ else:
260
+ per_device_train_batch_size = 1
261
+ logger.warning("No GPU detected - using minimal batch size")
262
+
263
+ # Configure reporting backends
264
+ reports = training_config.get("report_to", ["tensorboard"])
265
+
266
+ # Create training arguments
267
+ logger.info("Creating training arguments")
268
+ training_args = TrainingArguments(
269
+ output_dir=output_dir,
270
+ num_train_epochs=training_config.get("num_train_epochs", 3),
271
+ per_device_train_batch_size=per_device_train_batch_size,
272
+ gradient_accumulation_steps=training_config.get("gradient_accumulation_steps", 4),
273
+ learning_rate=training_config.get("learning_rate", 2e-5),
274
+ lr_scheduler_type=training_config.get("lr_scheduler_type", "cosine"),
275
+ warmup_ratio=training_config.get("warmup_ratio", 0.03),
276
+ weight_decay=training_config.get("weight_decay", 0.01),
277
+ optim=training_config.get("optim", "adamw_torch"),
278
+ fp16=hardware_config.get("fp16", True),
279
+ bf16=hardware_config.get("bf16", False),
280
+ max_grad_norm=training_config.get("max_grad_norm", 0.3),
281
+ logging_steps=training_config.get("logging_steps", 10),
282
+ save_steps=training_config.get("save_steps", 200),
283
+ save_total_limit=training_config.get("save_total_limit", 3),
284
+ evaluation_strategy=training_config.get("evaluation_strategy", "steps"),
285
+ eval_steps=training_config.get("eval_steps", 200),
286
+ load_best_model_at_end=training_config.get("load_best_model_at_end", True),
287
+ report_to=reports,
288
+ logging_first_step=training_config.get("logging_first_step", True),
289
+ disable_tqdm=training_config.get("disable_tqdm", False),
290
+ remove_unused_columns=False,
291
+ gradient_checkpointing=hardware_config.get("gradient_checkpointing", True),
292
+ dataloader_num_workers=training_config.get("dataloader_num_workers", 4)
293
+ )
294
+
295
+ # Create trainer with pre-tokenized collator
296
+ logger.info("Creating trainer with pre-tokenized collator")
297
+ trainer = Trainer(
298
+ model=model,
299
+ args=training_args,
300
+ train_dataset=dataset,
301
+ data_collator=PreTokenizedCollator(
302
+ pad_token_id=tokenizer.pad_token_id,
303
+ tokenizer=tokenizer
304
+ ),
305
+ )
306
+
307
+ # Start training
308
+ logger.info("Starting training - RESEARCH PHASE ONLY")
309
+ trainer.train()
310
+
311
+ # Save the model
312
+ logger.info(f"Saving model to {output_dir}")
313
+ trainer.save_model(output_dir)
314
+
315
+ # Save LoRA adapter separately
316
+ lora_output_dir = os.path.join(output_dir, "lora_adapter")
317
+ model.save_pretrained(lora_output_dir)
318
+ logger.info(f"Saved LoRA adapter to {lora_output_dir}")
319
+
320
+ # Save tokenizer
321
+ tokenizer_output_dir = os.path.join(output_dir, "tokenizer")
322
+ tokenizer.save_pretrained(tokenizer_output_dir)
323
+ logger.info(f"Saved tokenizer to {tokenizer_output_dir}")
324
+
325
+ # Save config for reference
326
+ with open(os.path.join(output_dir, "training_config.json"), "w") as f:
327
+ json.dump(config, f, indent=2)
328
+
329
+ logger.info("Training complete - RESEARCH PHASE ONLY")
330
+ return output_dir
331
+
332
+ finally:
333
+ # Always remove the training marker when done
334
+ remove_training_marker()
335
+
336
+ if __name__ == "__main__":
337
+ parser = argparse.ArgumentParser(description="Fine-tune DeepSeek model (Research Only)")
338
+ parser.add_argument("--config", type=str, default="transformers_config.json",
339
+ help="Path to the configuration file")
340
+ parser.add_argument("--dataset", type=str, default="phi4-cognitive-dataset",
341
+ help="Dataset name or path")
342
+ parser.add_argument("--output_dir", type=str, default="fine_tuned_model",
343
+ help="Output directory for the fine-tuned model")
344
+
345
+ args = parser.parse_args()
346
+
347
+ try:
348
+ output_path = train(args.config, args.dataset, args.output_dir)
349
+ print(f"Research training completed. Model saved to: {output_path}")
350
+ except Exception as e:
351
+ logging.error(f"Training failed: {str(e)}")
352
+ remove_training_marker() # Clean up marker if training fails
353
+ raise