George-API commited on
Commit
e278512
·
verified ·
1 Parent(s): dba9417

Fix: Remove unsupported attn_implementation parameter

Browse files
Files changed (1) hide show
  1. run_cloud_training.py +503 -0
run_cloud_training.py ADDED
@@ -0,0 +1,503 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ """
5
+ Fine-tuning script for DeepSeek-R1-Distill-Qwen-14B-bnb-4bit using unsloth
6
+ RESEARCH TRAINING PHASE ONLY - No output generation
7
+ WORKS WITH PRE-TOKENIZED DATASET - No re-tokenization
8
+ """
9
+
10
+ import os
11
+ import json
12
+ import logging
13
+ import argparse
14
+ import numpy as np
15
+ from dotenv import load_dotenv
16
+ import torch
17
+ from datasets import load_dataset
18
+ import transformers
19
+ from transformers import AutoTokenizer, TrainingArguments, Trainer, AutoModelForCausalLM, AutoConfig
20
+ from transformers.data.data_collator import DataCollatorMixin
21
+ from peft import LoraConfig
22
+ from unsloth import FastLanguageModel
23
+
24
+ # Disable flash attention globally
25
+ os.environ["TRANSFORMERS_NO_FLASH_ATTENTION"] = "1"
26
+
27
+ # Check if tensorboard is available
28
+ try:
29
+ import tensorboard
30
+ TENSORBOARD_AVAILABLE = True
31
+ except ImportError:
32
+ TENSORBOARD_AVAILABLE = False
33
+ print("Tensorboard not available. Will skip tensorboard logging.")
34
+
35
+ # Configure logging
36
+ logging.basicConfig(
37
+ level=logging.INFO,
38
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
39
+ handlers=[
40
+ logging.StreamHandler(),
41
+ logging.FileHandler("training.log")
42
+ ]
43
+ )
44
+ logger = logging.getLogger(__name__)
45
+
46
+ # Default dataset path - use the correct path with username
47
+ DEFAULT_DATASET = "George-API/phi4-cognitive-dataset"
48
+
49
+ def load_config(config_path):
50
+ """Load the transformers config from JSON file"""
51
+ logger.info(f"Loading config from {config_path}")
52
+ with open(config_path, 'r') as f:
53
+ config = json.load(f)
54
+ return config
55
+
56
+ def load_and_prepare_dataset(dataset_name, config):
57
+ """
58
+ Load and prepare the dataset for fine-tuning.
59
+ Sort entries by prompt_number as required.
60
+ NO TOKENIZATION - DATASET IS ALREADY TOKENIZED
61
+ """
62
+ # Use the default dataset path if no specific path is provided
63
+ if dataset_name == "phi4-cognitive-dataset":
64
+ dataset_name = DEFAULT_DATASET
65
+
66
+ logger.info(f"Loading dataset: {dataset_name}")
67
+
68
+ try:
69
+ # Load dataset
70
+ dataset = load_dataset(dataset_name)
71
+
72
+ # Extract the split we want to use (usually 'train')
73
+ if 'train' in dataset:
74
+ dataset = dataset['train']
75
+
76
+ # Get the dataset config
77
+ dataset_config = config.get("dataset_config", {})
78
+ sort_field = dataset_config.get("sort_by_field", "prompt_number")
79
+ sort_direction = dataset_config.get("sort_direction", "ascending")
80
+
81
+ # Sort the dataset by prompt_number
82
+ logger.info(f"Sorting dataset by {sort_field} in {sort_direction} order")
83
+ if sort_direction == "ascending":
84
+ dataset = dataset.sort(sort_field)
85
+ else:
86
+ dataset = dataset.sort(sort_field, reverse=True)
87
+
88
+ # Add shuffle with fixed seed if specified
89
+ if "shuffle_seed" in dataset_config:
90
+ shuffle_seed = dataset_config.get("shuffle_seed")
91
+ logger.info(f"Shuffling dataset with seed {shuffle_seed}")
92
+ dataset = dataset.shuffle(seed=shuffle_seed)
93
+
94
+ # Print dataset structure for debugging
95
+ logger.info(f"Dataset loaded with {len(dataset)} entries")
96
+ logger.info(f"Dataset columns: {dataset.column_names}")
97
+
98
+ # Print a sample entry to understand structure
99
+ if len(dataset) > 0:
100
+ sample = dataset[0]
101
+ logger.info(f"Sample entry structure: {list(sample.keys())}")
102
+ if 'conversations' in sample:
103
+ logger.info(f"Sample conversations structure: {sample['conversations'][:1]}")
104
+
105
+ return dataset
106
+
107
+ except Exception as e:
108
+ logger.error(f"Error loading dataset: {str(e)}")
109
+ logger.info("Available datasets in the Hub:")
110
+ # Print a more helpful error message
111
+ print(f"Failed to load dataset: {dataset_name}")
112
+ print(f"Make sure the dataset exists and is accessible.")
113
+ print(f"If it's a private dataset, ensure your HF_TOKEN has access to it.")
114
+ raise
115
+
116
+ def tokenize_string(text, tokenizer):
117
+ """Tokenize a string using the provided tokenizer"""
118
+ if not text:
119
+ return []
120
+
121
+ # Tokenize the text
122
+ tokens = tokenizer.encode(text, add_special_tokens=False)
123
+ return tokens
124
+
125
+ # Data collator for pre-tokenized dataset
126
+ class PreTokenizedCollator(DataCollatorMixin):
127
+ """
128
+ Data collator for pre-tokenized datasets.
129
+ Expects input_ids and labels already tokenized.
130
+ """
131
+ def __init__(self, pad_token_id=0, tokenizer=None):
132
+ self.pad_token_id = pad_token_id
133
+ self.tokenizer = tokenizer # Keep a reference to the tokenizer for string conversion
134
+
135
+ def __call__(self, features):
136
+ # Print a sample feature to understand structure
137
+ if len(features) > 0:
138
+ logger.info(f"Sample feature keys: {list(features[0].keys())}")
139
+
140
+ # Extract input_ids from conversations if needed
141
+ processed_features = []
142
+ for feature in features:
143
+ # If input_ids is not directly available, try to extract from conversations
144
+ if 'input_ids' not in feature and 'conversations' in feature:
145
+ # Extract from conversations based on your dataset structure
146
+ conversations = feature['conversations']
147
+
148
+ # Debug the conversations structure
149
+ logger.info(f"Conversations type: {type(conversations)}")
150
+ if isinstance(conversations, list) and len(conversations) > 0:
151
+ logger.info(f"First conversation type: {type(conversations[0])}")
152
+ logger.info(f"First conversation: {conversations[0]}")
153
+
154
+ # Try different approaches to extract input_ids
155
+ if isinstance(conversations, list) and len(conversations) > 0:
156
+ # Case 1: If conversations is a list of dicts with 'content' field
157
+ if isinstance(conversations[0], dict) and 'content' in conversations[0]:
158
+ content = conversations[0]['content']
159
+ logger.info(f"Found content field: {type(content)}")
160
+
161
+ # If content is a string, tokenize it
162
+ if isinstance(content, str) and self.tokenizer:
163
+ logger.info(f"Tokenizing string content: {content[:50]}...")
164
+ feature['input_ids'] = self.tokenizer.encode(content, add_special_tokens=False)
165
+ # If content is already a list of integers, use it directly
166
+ elif isinstance(content, list) and all(isinstance(x, int) for x in content):
167
+ feature['input_ids'] = content
168
+ # If content is already tokenized in some other format
169
+ else:
170
+ logger.warning(f"Unexpected content format: {type(content)}")
171
+
172
+ # Case 2: If conversations is a list of dicts with 'input_ids' field
173
+ elif isinstance(conversations[0], dict) and 'input_ids' in conversations[0]:
174
+ feature['input_ids'] = conversations[0]['input_ids']
175
+
176
+ # Case 3: If conversations itself contains the input_ids
177
+ elif all(isinstance(x, int) for x in conversations):
178
+ feature['input_ids'] = conversations
179
+
180
+ # Case 4: If conversations is a list of strings
181
+ elif all(isinstance(x, str) for x in conversations) and self.tokenizer:
182
+ # Join all strings and tokenize
183
+ full_text = " ".join(conversations)
184
+ feature['input_ids'] = self.tokenizer.encode(full_text, add_special_tokens=False)
185
+
186
+ # Ensure input_ids is a list of integers
187
+ if 'input_ids' in feature:
188
+ # If input_ids is a string, tokenize it
189
+ if isinstance(feature['input_ids'], str) and self.tokenizer:
190
+ logger.info(f"Converting string input_ids to tokens: {feature['input_ids'][:50]}...")
191
+ feature['input_ids'] = self.tokenizer.encode(feature['input_ids'], add_special_tokens=False)
192
+ # If input_ids is not a list, convert it
193
+ elif not isinstance(feature['input_ids'], list):
194
+ try:
195
+ feature['input_ids'] = list(feature['input_ids'])
196
+ except:
197
+ logger.error(f"Could not convert input_ids to list: {type(feature['input_ids'])}")
198
+
199
+ processed_features.append(feature)
200
+
201
+ # If we still don't have input_ids, log an error
202
+ if len(processed_features) > 0 and 'input_ids' not in processed_features[0]:
203
+ logger.error(f"Could not find input_ids in features. Available keys: {list(processed_features[0].keys())}")
204
+ if 'conversations' in processed_features[0]:
205
+ logger.error(f"Conversations structure: {processed_features[0]['conversations'][:1]}")
206
+ raise ValueError("Could not find input_ids in dataset. Please check dataset structure.")
207
+
208
+ # Determine max length in this batch
209
+ batch_max_len = max(len(x["input_ids"]) for x in processed_features)
210
+
211
+ # Initialize batch tensors
212
+ batch = {
213
+ "input_ids": torch.ones((len(processed_features), batch_max_len), dtype=torch.long) * self.pad_token_id,
214
+ "attention_mask": torch.zeros((len(processed_features), batch_max_len), dtype=torch.long),
215
+ "labels": torch.ones((len(processed_features), batch_max_len), dtype=torch.long) * -100 # -100 is ignored in loss
216
+ }
217
+
218
+ # Fill batch tensors
219
+ for i, feature in enumerate(processed_features):
220
+ input_ids = feature["input_ids"]
221
+ seq_len = len(input_ids)
222
+
223
+ # Convert to tensor if it's a list
224
+ if isinstance(input_ids, list):
225
+ input_ids = torch.tensor(input_ids, dtype=torch.long)
226
+
227
+ # Copy data to batch tensors
228
+ batch["input_ids"][i, :seq_len] = input_ids
229
+ batch["attention_mask"][i, :seq_len] = 1
230
+
231
+ # If there are labels, use them, otherwise use input_ids
232
+ if "labels" in feature:
233
+ labels = feature["labels"]
234
+ if isinstance(labels, list):
235
+ labels = torch.tensor(labels, dtype=torch.long)
236
+ batch["labels"][i, :len(labels)] = labels
237
+ else:
238
+ batch["labels"][i, :seq_len] = input_ids
239
+
240
+ return batch
241
+
242
+ def create_training_marker(output_dir):
243
+ """Create a marker file to indicate training is active"""
244
+ # Create in current directory for app.py to find
245
+ with open("TRAINING_ACTIVE", "w") as f:
246
+ f.write(f"Training active in {output_dir}")
247
+
248
+ # Also create in output directory
249
+ os.makedirs(output_dir, exist_ok=True)
250
+ with open(os.path.join(output_dir, "RESEARCH_TRAINING_ONLY"), "w") as f:
251
+ f.write("This model is for research training only. No interactive outputs.")
252
+
253
+ def remove_training_marker():
254
+ """Remove the training marker file"""
255
+ if os.path.exists("TRAINING_ACTIVE"):
256
+ os.remove("TRAINING_ACTIVE")
257
+ logger.info("Removed training active marker")
258
+
259
+ def load_model_safely(model_name, max_seq_length, dtype=None):
260
+ """
261
+ Load the model in a safe way that works with Qwen models
262
+ by trying different loading strategies.
263
+ """
264
+ try:
265
+ logger.info(f"Attempting to load model with unsloth optimizations: {model_name}")
266
+ # First try the standard unsloth loading
267
+ try:
268
+ # Try loading with unsloth but without the problematic parameter
269
+ logger.info("Loading model with flash attention DISABLED")
270
+ model, tokenizer = FastLanguageModel.from_pretrained(
271
+ model_name=model_name,
272
+ max_seq_length=max_seq_length,
273
+ dtype=dtype,
274
+ load_in_4bit=True, # This should work for already quantized models
275
+ use_flash_attention=False # Explicitly disable flash attention
276
+ )
277
+ logger.info("Model loaded successfully with unsloth with 4-bit quantization and flash attention disabled")
278
+ return model, tokenizer
279
+
280
+ except TypeError as e:
281
+ # If we get a TypeError about unexpected keyword arguments
282
+ if "unexpected keyword argument" in str(e):
283
+ logger.warning(f"Unsloth loading error with 4-bit: {e}")
284
+ logger.info("Trying alternative loading method for Qwen model...")
285
+
286
+ # Try loading with different parameters for Qwen model
287
+ model, tokenizer = FastLanguageModel.from_pretrained(
288
+ model_name=model_name,
289
+ max_seq_length=max_seq_length,
290
+ dtype=dtype,
291
+ use_flash_attention=False, # Explicitly disable flash attention
292
+ )
293
+ logger.info("Model loaded successfully with unsloth using alternative method")
294
+ return model, tokenizer
295
+ else:
296
+ # Re-raise if it's a different type error
297
+ raise
298
+
299
+ except Exception as e:
300
+ # Fallback to standard loading if unsloth methods fail
301
+ logger.warning(f"Unsloth loading failed: {e}")
302
+ logger.info("Falling back to standard Hugging Face loading...")
303
+
304
+ # Disable flash attention in transformers config
305
+ config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
306
+ if hasattr(config, "use_flash_attention"):
307
+ config.use_flash_attention = False
308
+ logger.info("Disabled flash attention in model config")
309
+
310
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
311
+ model = AutoModelForCausalLM.from_pretrained(
312
+ model_name,
313
+ config=config,
314
+ device_map="auto",
315
+ torch_dtype=dtype or torch.float16,
316
+ load_in_4bit=True
317
+ )
318
+ logger.info("Model loaded successfully with standard HF loading and flash attention disabled")
319
+ return model, tokenizer
320
+
321
+ def train(config_path, dataset_name, output_dir):
322
+ """Main training function - RESEARCH TRAINING PHASE ONLY"""
323
+ # Load environment variables
324
+ load_dotenv()
325
+ config = load_config(config_path)
326
+
327
+ # Extract configs
328
+ model_config = config.get("model_config", {})
329
+ training_config = config.get("training_config", {})
330
+ hardware_config = config.get("hardware_config", {})
331
+ lora_config = config.get("lora_config", {})
332
+ dataset_config = config.get("dataset_config", {})
333
+
334
+ # Override flash attention setting to disable it
335
+ hardware_config["use_flash_attention"] = False
336
+ logger.info("Flash attention has been DISABLED due to GPU compatibility issues")
337
+
338
+ # Verify this is training phase only
339
+ training_phase_only = dataset_config.get("training_phase_only", True)
340
+ if not training_phase_only:
341
+ logger.warning("This script is meant for research training phase only")
342
+ logger.warning("Setting training_phase_only=True")
343
+
344
+ # Verify dataset is pre-tokenized
345
+ logger.info("IMPORTANT: Using pre-tokenized dataset - No tokenization will be performed")
346
+
347
+ # Set the output directory
348
+ output_dir = output_dir or training_config.get("output_dir", "fine_tuned_model")
349
+ os.makedirs(output_dir, exist_ok=True)
350
+
351
+ # Create training marker
352
+ create_training_marker(output_dir)
353
+
354
+ try:
355
+ # Print configuration summary
356
+ logger.info("RESEARCH TRAINING PHASE ACTIVE - No output generation")
357
+ logger.info("Configuration Summary:")
358
+ model_name = model_config.get("model_name_or_path")
359
+ logger.info(f"Model: {model_name}")
360
+ logger.info(f"Dataset: {dataset_name if dataset_name != 'phi4-cognitive-dataset' else DEFAULT_DATASET}")
361
+ logger.info(f"Output directory: {output_dir}")
362
+ logger.info("IMPORTANT: Using already 4-bit quantized model - not re-quantizing")
363
+
364
+ # Load and prepare the dataset
365
+ dataset = load_and_prepare_dataset(dataset_name, config)
366
+
367
+ # Initialize tokenizer (just for model initialization, not for tokenizing data)
368
+ logger.info("Loading tokenizer (for model initialization only, not for tokenizing data)")
369
+ tokenizer = AutoTokenizer.from_pretrained(
370
+ model_name,
371
+ trust_remote_code=True
372
+ )
373
+ tokenizer.pad_token = tokenizer.eos_token
374
+
375
+ # Initialize model with unsloth
376
+ logger.info("Initializing model with unsloth (preserving 4-bit quantization)")
377
+ max_seq_length = training_config.get("max_seq_length", 2048)
378
+
379
+ # Create LoRA config directly
380
+ logger.info("Creating LoRA configuration")
381
+ lora_config_obj = LoraConfig(
382
+ r=lora_config.get("r", 16),
383
+ lora_alpha=lora_config.get("lora_alpha", 32),
384
+ lora_dropout=lora_config.get("lora_dropout", 0.05),
385
+ bias=lora_config.get("bias", "none"),
386
+ target_modules=lora_config.get("target_modules", ["q_proj", "k_proj", "v_proj", "o_proj"])
387
+ )
388
+
389
+ # Initialize model with our safe loading function
390
+ logger.info("Loading pre-quantized model safely")
391
+ dtype = torch.float16 if hardware_config.get("fp16", True) else None
392
+ model, tokenizer = load_model_safely(model_name, max_seq_length, dtype)
393
+
394
+ # Try different approaches to apply LoRA
395
+ logger.info("Applying LoRA to model")
396
+
397
+ # Skip unsloth's method and go directly to PEFT
398
+ logger.info("Using standard PEFT method to apply LoRA")
399
+ from peft import get_peft_model
400
+ model = get_peft_model(model, lora_config_obj)
401
+ logger.info("Successfully applied LoRA with standard PEFT")
402
+
403
+ # No need to format the dataset - it's already pre-tokenized
404
+ logger.info("Using pre-tokenized dataset - skipping tokenization step")
405
+ training_dataset = dataset
406
+
407
+ # Configure reporting backends with fallbacks
408
+ reports = []
409
+ if TENSORBOARD_AVAILABLE:
410
+ reports.append("tensorboard")
411
+ logger.info("Tensorboard available and enabled for reporting")
412
+ else:
413
+ logger.warning("Tensorboard not available - metrics won't be logged to tensorboard")
414
+
415
+ if os.getenv("WANDB_API_KEY"):
416
+ reports.append("wandb")
417
+ logger.info("Wandb API key found, enabling wandb reporting")
418
+
419
+ # Default to "none" if no reporting backends are available
420
+ if not reports:
421
+ reports = ["none"]
422
+ logger.warning("No reporting backends available - training metrics won't be logged")
423
+
424
+ # Set up training arguments with flash attention disabled
425
+ training_args = TrainingArguments(
426
+ output_dir=output_dir,
427
+ num_train_epochs=training_config.get("num_train_epochs", 3),
428
+ per_device_train_batch_size=training_config.get("per_device_train_batch_size", 2),
429
+ gradient_accumulation_steps=training_config.get("gradient_accumulation_steps", 4),
430
+ learning_rate=training_config.get("learning_rate", 2e-5),
431
+ lr_scheduler_type=training_config.get("lr_scheduler_type", "cosine"),
432
+ warmup_ratio=training_config.get("warmup_ratio", 0.03),
433
+ weight_decay=training_config.get("weight_decay", 0.01),
434
+ optim=training_config.get("optim", "adamw_torch"),
435
+ logging_steps=training_config.get("logging_steps", 10),
436
+ save_steps=training_config.get("save_steps", 200),
437
+ save_total_limit=training_config.get("save_total_limit", 3),
438
+ fp16=hardware_config.get("fp16", True),
439
+ bf16=hardware_config.get("bf16", False),
440
+ max_grad_norm=training_config.get("max_grad_norm", 0.3),
441
+ report_to=reports,
442
+ logging_first_step=training_config.get("logging_first_step", True),
443
+ disable_tqdm=training_config.get("disable_tqdm", False),
444
+ # Important: Don't remove columns that don't match model's forward method
445
+ remove_unused_columns=False
446
+ )
447
+
448
+ # Create trainer with pre-tokenized collator
449
+ trainer = Trainer(
450
+ model=model,
451
+ args=training_args,
452
+ train_dataset=training_dataset,
453
+ data_collator=PreTokenizedCollator(pad_token_id=tokenizer.pad_token_id, tokenizer=tokenizer),
454
+ )
455
+
456
+ # Start training
457
+ logger.info("Starting training - RESEARCH PHASE ONLY")
458
+ trainer.train()
459
+
460
+ # Save the model
461
+ logger.info(f"Saving model to {output_dir}")
462
+ trainer.save_model(output_dir)
463
+
464
+ # Save LoRA adapter separately for easier deployment
465
+ lora_output_dir = os.path.join(output_dir, "lora_adapter")
466
+ model.save_pretrained(lora_output_dir)
467
+ logger.info(f"Saved LoRA adapter to {lora_output_dir}")
468
+
469
+ # Save tokenizer for completeness
470
+ tokenizer_output_dir = os.path.join(output_dir, "tokenizer")
471
+ tokenizer.save_pretrained(tokenizer_output_dir)
472
+ logger.info(f"Saved tokenizer to {tokenizer_output_dir}")
473
+
474
+ # Copy config file for reference
475
+ with open(os.path.join(output_dir, "training_config.json"), "w") as f:
476
+ json.dump(config, f, indent=2)
477
+
478
+ logger.info("Training complete - RESEARCH PHASE ONLY")
479
+ return output_dir
480
+
481
+ finally:
482
+ # Always remove the training marker when done
483
+ remove_training_marker()
484
+
485
+ if __name__ == "__main__":
486
+ parser = argparse.ArgumentParser(description="Fine-tune Unsloth/DeepSeek-R1-Distill-Qwen-14B-4bit model (RESEARCH ONLY)")
487
+ parser.add_argument("--config", type=str, default="transformers_config.json",
488
+ help="Path to the transformers config JSON file")
489
+ parser.add_argument("--dataset", type=str, default="phi4-cognitive-dataset",
490
+ help="Dataset name or path")
491
+ parser.add_argument("--output_dir", type=str, default=None,
492
+ help="Output directory for the fine-tuned model")
493
+
494
+ args = parser.parse_args()
495
+
496
+ # Run training - Research phase only
497
+ try:
498
+ output_path = train(args.config, args.dataset, args.output_dir)
499
+ print(f"Research training completed. Model saved to: {output_path}")
500
+ except Exception as e:
501
+ logger.error(f"Training failed: {str(e)}")
502
+ remove_training_marker() # Clean up marker if training fails
503
+ raise