George-API commited on
Commit
22cec44
·
verified ·
1 Parent(s): 2c3731c

Upload folder using huggingface_hub

Browse files
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ .env
2
+ *.pyc
3
+ __pycache__
DEPLOY_CHECKLIST.md ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Phi-4 Training Critical Deployment Checklist
2
+
3
+ ## Essential Configuration Requirements
4
+
5
+ ### 1. Model Configuration
6
+ - [ ] Model name: `unsloth/phi-4-unsloth-bnb-4bit`
7
+ - [ ] BF16 precision enabled, FP16 disabled
8
+ - [ ] Appropriate sequence length (2048)
9
+ - [ ] LoRA parameters correctly configured (r: 32, alpha: 16)
10
+
11
+ ### 2. Hardware & Resource Management
12
+ - [ ] Per-device batch size ≤ 16
13
+ - [ ] Gradient accumulation steps ≥ 3
14
+ - [ ] Gradient checkpointing enabled
15
+ - [ ] Memory usage limits properly set (85% of GPU capacity)
16
+
17
+ ### 3. Critical Dataset Handling Rules
18
+ - [ ] **NO REORDERING of dataset entries** - original order must be preserved
19
+ - [ ] **NO COMBINING of separate entries** - each entry must remain distinct
20
+ - [ ] **SEQUENTIAL PROCESSING required** - entries must be processed one after another
21
+ - [ ] `sort_by_id` and `maintain_paper_order` flags properly set to preserve data sequence
22
+ - [ ] Sequential sampler used with no shuffling (`"shuffle": false`)
23
+ - [ ] Dataset sequential integrity verified with validation samples
24
+ - [ ] Conversation structure preserved (original format maintained)
25
+
26
+ ### 4. Essential Error Handling
27
+ - [ ] Clear error catching for dataset loading issues
28
+ - [ ] Memory tracking at key training points
29
+ - [ ] Low-verbosity logging for HF Space compatibility
30
+
31
+ ### 5. Training Core Requirements
32
+ - [ ] Appropriate learning rate (2e-5)
33
+ - [ ] Proper checkpointing frequency
34
+ - [ ] Hub settings correctly configured for model saving
35
+
36
+ ---
37
+
38
+ ## Pre-Deployment Verification
39
+
40
+ | Requirement | Status | Notes |
41
+ |-------------|--------|-------|
42
+ | Data sequential integrity | | Confirm entries processed in order |
43
+ | GPU memory within limits | | Check peak memory doesn't exceed 20GB per GPU |
44
+ | Training batch verification | | Verify first few batches maintain proper order |
45
+
46
+ ---
47
+
48
+ **Current Hardware**: 4× NVIDIA L4 GPUs (24GB VRAM each)
49
+
50
+ **CRITICAL REMINDER**: Data sequence preservation is the highest priority - any shuffling, reordering, or combining of entries will compromise model quality.
51
+
52
+ *Last Updated: 2025-03-09*
README.md CHANGED
@@ -1,12 +1,283 @@
1
- ---
2
- title: Mindmodel Phi4 Unsupervised
3
- emoji: 📈
4
- colorFrom: blue
5
- colorTo: purple
6
- sdk: gradio
7
- sdk_version: 5.20.1
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Phi-4 Unsloth Training
3
+ emoji: 🧠
4
+ colorFrom: blue
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 5.17.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ ---
12
+
13
+ # Phi-4 Unsloth Optimized Training
14
+
15
+ This space is dedicated to training Microsoft's Phi-4 model using Unsloth optimizations for enhanced performance and efficiency. The training process utilizes 4-bit quantization and advanced memory optimizations.
16
+
17
+ ## Installation
18
+
19
+ This Hugging Face Space automatically installs dependencies from requirements.txt. The following packages are included:
20
+
21
+ ### Installation Process
22
+
23
+ For clearer dependency management, the installation is split into multiple files:
24
+
25
+ 1. **Base Dependencies (requirements-base.txt)**:
26
+ - Core packages like torch, transformers, accelerate, etc.
27
+ - Install with: `pip install -r requirements-base.txt`
28
+
29
+ 2. **Standard Dependencies (requirements.txt)**:
30
+ - References base requirements and adds additional packages
31
+ - Install with: `pip install -r requirements.txt`
32
+
33
+ 3. **Flash Attention (requirements-flash.txt)** (Optional):
34
+ - For faster attention computation
35
+ - Install with: `pip install -r requirements-flash.txt --no-build-isolation`
36
+
37
+ Using this staged approach helps prevent dependency conflicts and installation issues.
38
+
39
+ ### Essential Dependencies
40
+
41
+ - **unsloth** (>=2024.3): Required for optimized 4-bit training
42
+ - **peft** (>=0.9.0): Required for parameter-efficient fine-tuning
43
+ - **transformers** (>=4.36.0): For model architecture and tokenization
44
+ - **einops**: Required by Unsloth for tensor manipulation
45
+ - **sentencepiece**: Required for tokenization
46
+
47
+ ### Optional Dependencies
48
+
49
+ - **flash-attn**: Optional for faster attention computation (not included by default as it can cause build issues)
50
+
51
+ ## Features
52
+
53
+ - 4-bit quantization using Unsloth
54
+ - Optimized training pipeline
55
+ - Cognitive dataset integration
56
+ - Advanced memory management
57
+ - Gradient checkpointing
58
+ - Sequential data processing
59
+
60
+ ## Configuration Files
61
+
62
+ - `transformers_config.json`: Model and training parameters
63
+ - `hardware_config.json`: Hardware-specific optimizations
64
+ - `dataset_config.json`: Dataset processing settings
65
+ - `requirements.txt`: Required dependencies
66
+
67
+ ## Training Process
68
+
69
+ The training utilizes the following optimizations:
70
+ - Unsloth's 4-bit quantization
71
+ - Custom chat templates for Phi-4
72
+ - Paper-order preservation
73
+ - Efficient memory usage
74
+ - Gradient accumulation
75
+
76
+ ## Dataset
77
+
78
+ Training uses the cognitive dataset with:
79
+ - Maintained paper order
80
+ - Proper metadata handling
81
+ - Optimized sequence length
82
+ - Efficient batching
83
+
84
+ ## Hardware Requirements
85
+
86
+ - GPU: A10G or better
87
+ - VRAM: 24GB minimum
88
+ - RAM: 32GB recommended
89
+
90
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
91
+
92
+ # Phase 1: Domain Adaptation (Unsupervised)
93
+
94
+ This directory contains the code and configuration for domain adaptation of the phi-4-unsloth-bnb-4bit model to the cognitive science domain. This phase produces our domain-adapted model: [George-API/phi-4-research-assistant](https://huggingface.co/George-API/phi-4-research-assistant).
95
+
96
+ ## Overview
97
+
98
+ Domain adaptation is the first phase of our training process, where we expose the model to a large corpus of cognitive science texts to help it learn domain-specific vocabulary, concepts, and patterns. This phase prepares the model for the more focused supervised fine-tuning in Phase 2.
99
+
100
+ ## Files
101
+
102
+ ### Core Training Files
103
+ - `run_transformers_training.py`: Main script for domain adaptation
104
+ - `transformers_config.json`: Model and training parameters
105
+ - `hardware_config.json`: Hardware-specific optimizations
106
+ - `dataset_config.json`: Dataset loading and processing settings
107
+ - `requirements.txt`: Required Python packages
108
+
109
+ ### Analysis & Utilities
110
+ - `check_tokenization.py`: Script to analyze token distributions
111
+ - `update_space.py`: Hugging Face Space update utility
112
+ - `.env`: Environment variables (API tokens, etc.)
113
+
114
+ ## Setup
115
+
116
+ 1. **Environment Setup**:
117
+ ```bash
118
+ python -m venv venv
119
+ source venv/bin/activate # or `venv\Scripts\activate` on Windows
120
+ pip install -r requirements.txt
121
+ ```
122
+
123
+ 2. **Environment Variables**:
124
+ Create `.env` file with:
125
+ ```
126
+ HUGGINGFACE_TOKEN=your_token_here
127
+ ```
128
+
129
+ 3. **Verify Setup**:
130
+ ```bash
131
+ python check_tokenization.py # Ensures tokenizer works
132
+ ```
133
+
134
+ ## How It Works
135
+
136
+ 1. **Data Loading**: Loads pre-tokenized data from the Hugging Face dataset
137
+ 2. **Sequential Processing**: Processes data in order, maintaining the integrity of research papers
138
+ 3. **Efficient Training**: Uses pre-quantized Unsloth 4-bit model for memory-efficient and faster training
139
+ 4. **Checkpointing**: Saves regular checkpoints and pushes to Hub
140
+ 5. **Monitoring**: Logs detailed metrics and statistics during training
141
+ 6. **Model Publishing**: Pushes the trained model to Hugging Face Hub
142
+
143
+ ## Key Features
144
+
145
+ ### Memory-Efficient Training
146
+
147
+ The training setup is optimized for A10G GPUs:
148
+ - Uses pre-quantized 4-bit model (no additional quantization needed)
149
+ - Gradient checkpointing for memory efficiency
150
+ - Flash attention for faster training
151
+ - bfloat16 mixed precision training
152
+ - Optimized batch sizes for maximum throughput
153
+
154
+ ### Sequential Processing
155
+
156
+ The training script ensures that chunks from the same research paper are processed together by:
157
+ - Sorting the dataset by ID
158
+ - Using a SequentialSampler to maintain order
159
+ - Processing chunks sequentially (average 1,673 tokens per chunk)
160
+
161
+ ### Data Collator
162
+
163
+ The `SimpleDataCollator` class:
164
+ - Preserves pre-tokenized data format
165
+ - Processes each entry independently
166
+ - Provides detailed logging of processing statistics
167
+ - Handles errors gracefully
168
+
169
+ ### Checkpointing
170
+
171
+ The training process saves checkpoints:
172
+ - Every 200 steps
173
+ - Pushes to Hub on every save
174
+ - Maintains up to 5 recent checkpoints
175
+ - Automatically resumes from the latest checkpoint if interrupted
176
+
177
+ ## Hardware Requirements
178
+
179
+ This training setup is optimized for:
180
+ - 2x NVIDIA A10G GPUs (24GB VRAM each)
181
+ - 92GB System RAM
182
+ - CUDA 11.8 or higher
183
+
184
+ Memory breakdown per GPU:
185
+ - Model (4-bit): ~3.5GB
186
+ - Optimizer states: ~1GB
187
+ - Batch memory: ~2GB
188
+ - Peak usage: 18-20GB
189
+ - Safe headroom: 4-6GB
190
+
191
+ ## Configuration
192
+
193
+ Key parameters in `transformers_config.json`:
194
+
195
+ - `model_name`: unsloth/phi-4-unsloth-bnb-4bit
196
+ - `learning_rate`: 2e-5
197
+ - `num_train_epochs`: 3
198
+ - `per_device_train_batch_size`: 16
199
+ - `gradient_accumulation_steps`: 4
200
+ - `effective_batch_size`: 128 (16 * 4 * 2 GPUs)
201
+ - `max_seq_length`: 2048
202
+ - `lr_scheduler_type`: "cosine"
203
+ - `warmup_ratio`: 0.03
204
+ - `neftune_noise_alpha`: 5
205
+
206
+ The configuration is optimized for:
207
+ - Maximum memory efficiency with pre-quantized model
208
+ - Stable training with cosine learning rate schedule
209
+ - Effective gradient updates with accumulation
210
+ - Regular checkpointing and Hub updates
211
+
212
+ ## Running Domain Adaptation
213
+
214
+ To start domain adaptation:
215
+
216
+ ```bash
217
+ python run_transformers_training.py
218
+ ```
219
+
220
+ The script will:
221
+ 1. Load the pre-quantized model and dataset
222
+ 2. Apply optimized training parameters
223
+ 3. Process the data sequentially
224
+ 4. Train the model for 3 epochs
225
+ 5. Save and push checkpoints to Hub regularly
226
+
227
+ ## Using the Model
228
+
229
+ After training, you can use the domain-adapted model:
230
+
231
+ ```python
232
+ from transformers import AutoModelForCausalLM, AutoTokenizer
233
+
234
+ # Load the domain-adapted model
235
+ model_name = "George-API/phi-4-research-assistant"
236
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
237
+ model = AutoModelForCausalLM.from_pretrained(model_name,
238
+ device_map="auto",
239
+ torch_dtype="bfloat16")
240
+
241
+ # Generate text
242
+ input_text = "The hippocampus is involved in"
243
+ inputs = tokenizer(input_text, return_tensors="pt")
244
+ outputs = model.generate(**inputs, max_length=100)
245
+ print(tokenizer.decode(outputs[0], skip_special_tokens=True))
246
+ ```
247
+
248
+ ## Chat Format Example
249
+
250
+ Phi-4 works best with its native chat template:
251
+
252
+ ```python
253
+ from transformers import pipeline
254
+
255
+ pipeline = pipeline(
256
+ "text-generation",
257
+ model="George-API/phi-4-research-assistant",
258
+ model_kwargs={"torch_dtype": "bfloat16"},
259
+ device_map="auto",
260
+ )
261
+
262
+ messages = [
263
+ {"role": "system", "content": "You are an expert in cognitive science."},
264
+ {"role": "user", "content": "Explain the role of the hippocampus in memory formation."},
265
+ ]
266
+
267
+ outputs = pipeline(messages, max_new_tokens=256)
268
+ print(outputs[0]["generated_text"])
269
+ ```
270
+
271
+ ## Expected Outcomes
272
+
273
+ After domain adaptation, the model should:
274
+ - Have a better understanding of cognitive science terminology
275
+ - Show improved performance on domain-specific tasks
276
+ - Be ready for supervised fine-tuning in Phase 2
277
+
278
+ ## Next Steps
279
+
280
+ After completing domain adaptation:
281
+ 1. Evaluate the model's performance on cognitive science texts
282
+ 2. Proceed to Phase 2 (Supervised Fine-Tuning)
283
+ 3. Use TensorBoard to analyze training metrics
app.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+
4
+ import os
5
+ import sys
6
+ import json
7
+ import logging
8
+ import subprocess
9
+ import time
10
+ from datetime import datetime
11
+
12
+ # Configure logging to match HF Space logs
13
+ logging.basicConfig(
14
+ level=logging.INFO,
15
+ format="%(asctime)s - %(levelname)s - %(message)s",
16
+ handlers=[logging.StreamHandler(sys.stdout)]
17
+ )
18
+ logger = logging.getLogger(__name__)
19
+
20
+ # Set other loggers to WARNING to reduce noise and ensure our logs are visible
21
+ logging.getLogger("transformers").setLevel(logging.WARNING)
22
+ logging.getLogger("datasets").setLevel(logging.WARNING)
23
+ logging.getLogger("accelerate").setLevel(logging.WARNING)
24
+ logging.getLogger("torch").setLevel(logging.WARNING)
25
+ logging.getLogger("bitsandbytes").setLevel(logging.WARNING)
26
+
27
+ # Define a clean logging function for HF Space compatibility
28
+ def log_info(message):
29
+ """Log information in a format compatible with Hugging Face Spaces"""
30
+ logger.info(message)
31
+ # Ensure output is flushed immediately for streaming
32
+ sys.stdout.flush()
33
+
34
+ # Configuration paths
35
+ CONFIG_DIR = "."
36
+ TRANSFORMERS_CONFIG = os.path.join(CONFIG_DIR, "transformers_config.json")
37
+
38
+ def load_config(config_path):
39
+ """Load configuration from JSON file."""
40
+ try:
41
+ if os.path.exists(config_path):
42
+ with open(config_path, 'r') as f:
43
+ return json.load(f)
44
+ else:
45
+ log_info(f"Config file not found: {config_path}")
46
+ return None
47
+ except Exception as e:
48
+ log_info(f"Error loading config: {str(e)}")
49
+ return None
50
+
51
+ def display_config():
52
+ """Display current training configuration."""
53
+ config = load_config(TRANSFORMERS_CONFIG)
54
+
55
+ if not config:
56
+ return "Error loading configuration file."
57
+
58
+ # Extract sub-configurations
59
+ transformers_config = config
60
+ hardware_config = config.get("hardware", {})
61
+ dataset_config = config.get("dataset", {})
62
+
63
+ model_name = transformers_config.get("model", {}).get("name") or transformers_config.get("model_name_or_path", "")
64
+
65
+ # Training parameters
66
+ training_config = transformers_config.get("training", {})
67
+ batch_size = training_config.get("per_device_train_batch_size", 16)
68
+ grad_accum = training_config.get("gradient_accumulation_steps", 3)
69
+ epochs = training_config.get("num_train_epochs", 3)
70
+ learning_rate = training_config.get("learning_rate", 2e-5)
71
+
72
+ # Hardware settings
73
+ gpu_count = hardware_config.get("specs", {}).get("gpu_count", 4)
74
+ gpu_type = hardware_config.get("specs", {}).get("gpu_type", "L4")
75
+ vram = hardware_config.get("specs", {}).get("vram_per_gpu", 24)
76
+
77
+ # Dataset info
78
+ dataset_name = dataset_config.get("dataset", {}).get("name", "")
79
+
80
+ # Format response as HTML for better display
81
+ html = f"""
82
+ <h2>Training Configuration</h2>
83
+ <h3>Model</h3>
84
+ <ul>
85
+ <li><b>Model:</b> {model_name}</li>
86
+ <li><b>Learning Rate:</b> {training_config.get('learning_rate', '2e-5')}</li>
87
+ <li><b>Per-Device Batch Size:</b> {batch_size}</li>
88
+ <li><b>Gradient Accumulation:</b> {grad_accum}</li>
89
+ <li><b>Total Effective Batch Size:</b> {batch_size} × {gpu_count} × {grad_accum} = {batch_size * gpu_count * grad_accum}</li>
90
+ <li><b>Epochs:</b> {epochs}</li>
91
+ <li><b>Precision:</b> {'BF16' if transformers_config.get('bf16', True) else 'FP16' if transformers_config.get('fp16', False) else 'FP32'}</li>
92
+ <li><b>Max Sequence Length:</b> {transformers_config.get('tokenizer', {}).get('max_seq_length', 2048)}</li>
93
+ </ul>
94
+
95
+ <h3>Hardware</h3>
96
+ <ul>
97
+ <li><b>GPU:</b> {gpu_count}× {gpu_type} ({vram} GB VRAM per GPU, total: {vram * gpu_count} GB)</li>
98
+ <li><b>Multi-GPU Strategy:</b> {hardware_config.get('training_optimizations', {}).get('multi_gpu_strategy', 'data_parallel')}</li>
99
+ <li><b>Memory Optimizations:</b> {'Gradient Checkpointing' if hardware_config.get('training_optimizations', {}).get('memory_optimizations', {}).get('use_gradient_checkpointing', True) else 'None'}</li>
100
+ </ul>
101
+
102
+ <h3>Dataset</h3>
103
+ <ul>
104
+ <li><b>Dataset:</b> {dataset_name}</li>
105
+ <li><b>Dataset Split:</b> {dataset_config.get('dataset', {}).get('split', 'train')}</li>
106
+ </ul>
107
+ """
108
+
109
+ return html
110
+
111
+ def start_training():
112
+ """Start the training process."""
113
+ try:
114
+ # Log configuration check
115
+ log_info("Preparing to start training process...")
116
+ log_info("Using consolidated configuration from transformers_config.json")
117
+
118
+ # Start training
119
+ log_info("Starting training process...")
120
+
121
+ # Run in a background process for HF Space
122
+ cmd = "python run_transformers_training.py"
123
+
124
+ # In HF Spaces, we don't need to handle process management ourselves
125
+ subprocess.Popen(cmd, shell=True, stdout=sys.stdout, stderr=sys.stderr)
126
+
127
+ log_info("Training process has been started. You can monitor progress in the logs.")
128
+
129
+ return "Training started successfully. Monitor progress in the Hugging Face Space logs."
130
+
131
+ except Exception as e:
132
+ error_msg = f"Error starting training: {str(e)}"
133
+ log_info(error_msg)
134
+ return error_msg
135
+
136
+ # Interface setup for gradio
137
+ def create_interface():
138
+ import gradio as gr
139
+
140
+ with gr.Blocks(title="Phi-4 Training Center") as demo:
141
+ gr.Markdown("# Phi-4 Research Assistant Training")
142
+
143
+ with gr.Row():
144
+ with gr.Column():
145
+ gr.Markdown("## Control Panel")
146
+
147
+ # Display current config
148
+ config_html = gr.HTML(display_config())
149
+ refresh_btn = gr.Button("Refresh Configuration")
150
+
151
+ # Training controls
152
+ train_btn = gr.Button("Start Training", variant="primary")
153
+ train_output = gr.Textbox(label="Status", interactive=False)
154
+
155
+ with gr.Column():
156
+ gr.Markdown("## Training Information")
157
+ gr.Markdown("""
158
+ ### Hardware:
159
+ - 4× NVIDIA L4 GPUs (24GB VRAM per GPU, 96GB total)
160
+ - Training with BF16 precision
161
+ - Using Data Parallel for multi-GPU
162
+ - Effective batch size: 16 (per device) × 4 (GPUs) × 3 (gradient accumulation) = 192
163
+
164
+ ### Notes:
165
+ - Training may take several hours depending on dataset size
166
+ - Check the Space logs for real-time progress
167
+ - Model checkpoints will be saved to ./results directory
168
+ """)
169
+
170
+ # Connect buttons to functions
171
+ refresh_btn.click(lambda: gr.update(value=display_config()), outputs=config_html)
172
+ train_btn.click(start_training, outputs=train_output)
173
+
174
+ return demo
175
+
176
+ if __name__ == "__main__":
177
+ # If run directly, create and launch the Gradio interface
178
+ demo = create_interface()
179
+ demo.queue()
180
+ demo.launch()
install_requirements.bat ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @echo off
2
+ echo Installing Phi-4 Training Requirements
3
+ echo =====================================
4
+ echo.
5
+
6
+ REM Check if Python is available
7
+ where python >nul 2>&1
8
+ if %ERRORLEVEL% neq 0 (
9
+ echo Python not found! Please make sure Python is installed and in your PATH.
10
+ exit /b 1
11
+ )
12
+
13
+ echo Step 1: Installing base requirements...
14
+ python -m pip install -r requirements-base.txt
15
+ if %ERRORLEVEL% neq 0 (
16
+ echo Failed to install base requirements.
17
+ exit /b 1
18
+ )
19
+ echo Base requirements installed successfully.
20
+ echo.
21
+
22
+ echo Step 2: Installing additional requirements...
23
+ python -m pip install -r requirements.txt
24
+ if %ERRORLEVEL% neq 0 (
25
+ echo Failed to install additional requirements.
26
+ exit /b 1
27
+ )
28
+ echo Additional requirements installed successfully.
29
+ echo.
30
+
31
+ echo All required packages installed successfully!
32
+ echo To install optional flash-attention, run: python -m pip install -r requirements-flash.txt --no-build-isolation
33
+ echo.
34
+
35
+ pause
install_requirements.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+
4
+ """
5
+ Script to install requirements in the correct order for the Phi-4 training project.
6
+ This ensures base requirements are installed first, followed by additional requirements.
7
+ """
8
+
9
+ import os
10
+ import sys
11
+ import subprocess
12
+ import argparse
13
+ import logging
14
+ from pathlib import Path
15
+
16
+ # Configure logging
17
+ logging.basicConfig(
18
+ level=logging.INFO,
19
+ format="%(asctime)s - %(levelname)s - %(message)s",
20
+ handlers=[logging.StreamHandler(sys.stdout)]
21
+ )
22
+ logger = logging.getLogger(__name__)
23
+
24
+ def install_requirements(include_flash=False):
25
+ """Install requirements in the correct order."""
26
+ current_dir = Path(__file__).parent
27
+ base_req_path = current_dir / "requirements-base.txt"
28
+ main_req_path = current_dir / "requirements.txt"
29
+ flash_req_path = current_dir / "requirements-flash.txt"
30
+
31
+ if not base_req_path.exists():
32
+ logger.error(f"Base requirements file not found: {base_req_path}")
33
+ return False
34
+
35
+ if not main_req_path.exists():
36
+ logger.error(f"Main requirements file not found: {main_req_path}")
37
+ return False
38
+
39
+ logger.info("Installing dependencies in sequential order...")
40
+
41
+ try:
42
+ # Step 1: Install base requirements
43
+ logger.info(f"Step 1: Installing base requirements from {base_req_path}")
44
+ subprocess.run([sys.executable, "-m", "pip", "install", "-r", str(base_req_path)],
45
+ check=True)
46
+ logger.info("Base requirements installed successfully")
47
+
48
+ # Step 2: Install main requirements
49
+ logger.info(f"Step 2: Installing additional requirements from {main_req_path}")
50
+ subprocess.run([sys.executable, "-m", "pip", "install", "-r", str(main_req_path)],
51
+ check=True)
52
+ logger.info("Additional requirements installed successfully")
53
+
54
+ # Step 3: Optionally install flash-attention
55
+ if include_flash and flash_req_path.exists():
56
+ logger.info(f"Step 3: Installing flash-attention from {flash_req_path}")
57
+ subprocess.run([sys.executable, "-m", "pip", "install", "-r", str(flash_req_path), "--no-build-isolation"],
58
+ check=True)
59
+ logger.info("Flash-attention installed successfully")
60
+ elif include_flash:
61
+ logger.warning(f"Flash requirements file not found: {flash_req_path}")
62
+
63
+ logger.info("All required packages installed successfully!")
64
+ return True
65
+
66
+ except subprocess.CalledProcessError as e:
67
+ logger.error(f"Error installing dependencies: {str(e)}")
68
+ return False
69
+
70
+ def main():
71
+ parser = argparse.ArgumentParser(description="Install requirements for Phi-4 training")
72
+ parser.add_argument("--flash", action="store_true", help="Also install flash-attention (optional)")
73
+ args = parser.parse_args()
74
+
75
+ success = install_requirements(include_flash=args.flash)
76
+ if success:
77
+ logger.info("Installation completed successfully!")
78
+ else:
79
+ logger.error("Installation failed. Please check the logs for details.")
80
+ sys.exit(1)
81
+
82
+ if __name__ == "__main__":
83
+ main()
requirements-base.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ accelerate>=0.27.0
3
+ bitsandbytes>=0.41.0
4
+ datasets>=2.15.0
5
+ gradio>=5.17.0
6
+ huggingface-hub>=0.19.0
7
+ tensorboard>=2.15.0
8
+ transformers>=4.36.0
requirements-flash.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ -r requirements-base.txt
2
+ flash-attn==2.5.2
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -r requirements-base.txt
2
+ einops>=0.7.0
3
+ filelock>=3.13.1
4
+ matplotlib>=3.7.0
5
+ numpy>=1.24.0
6
+ packaging>=23.0
7
+ peft>=0.9.0
8
+ psutil>=5.9.0
9
+ python-dotenv>=1.0.0
10
+ pyyaml>=6.0.1
11
+ regex>=2023.0.0
12
+ requests>=2.31.0
13
+ safetensors>=0.4.1
14
+ sentencepiece>=0.1.99
15
+ tqdm>=4.65.0
16
+ typing-extensions>=4.8.0
17
+ unsloth>=2024.3
run_transformers_training.py ADDED
@@ -0,0 +1,964 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+
4
+ # Basic Python imports
5
+ import os
6
+ import sys
7
+ import json
8
+ import argparse
9
+ import logging
10
+ from datetime import datetime
11
+ import time
12
+ import warnings
13
+ from importlib.util import find_spec
14
+
15
+ # Check hardware capabilities first
16
+ import torch
17
+ CUDA_AVAILABLE = torch.cuda.is_available()
18
+ NUM_GPUS = torch.cuda.device_count() if CUDA_AVAILABLE else 0
19
+ DEVICE_TYPE = "cuda" if CUDA_AVAILABLE else "cpu"
20
+
21
+ # Configure logging early
22
+ logging.basicConfig(
23
+ level=logging.INFO,
24
+ format="%(asctime)s - %(levelname)s - %(message)s",
25
+ handlers=[logging.StreamHandler(sys.stdout)]
26
+ )
27
+ logger = logging.getLogger(__name__)
28
+
29
+ # Set other loggers to WARNING to reduce noise and ensure our logs are visible
30
+ logging.getLogger("transformers").setLevel(logging.WARNING)
31
+ logging.getLogger("datasets").setLevel(logging.WARNING)
32
+ logging.getLogger("accelerate").setLevel(logging.WARNING)
33
+ logging.getLogger("torch").setLevel(logging.WARNING)
34
+ logging.getLogger("bitsandbytes").setLevel(logging.WARNING)
35
+
36
+ # Import Unsloth first, before other ML imports
37
+ try:
38
+ from unsloth import FastLanguageModel
39
+ from unsloth.chat_templates import get_chat_template
40
+ unsloth_available = True
41
+ logger.info("Unsloth successfully imported")
42
+ except ImportError:
43
+ unsloth_available = False
44
+ logger.warning("Unsloth not available. Please install with: pip install unsloth")
45
+
46
+ # Now import other ML libraries
47
+ try:
48
+ import transformers
49
+ from transformers import (
50
+ AutoModelForCausalLM,
51
+ AutoTokenizer,
52
+ TrainingArguments,
53
+ Trainer,
54
+ TrainerCallback,
55
+ set_seed,
56
+ BitsAndBytesConfig
57
+ )
58
+ logger.info(f"Transformers version: {transformers.__version__}")
59
+ except ImportError:
60
+ logger.error("Transformers not available. This is a critical dependency.")
61
+
62
+ # Check availability of libraries
63
+ peft_available = find_spec("peft") is not None
64
+ if peft_available:
65
+ import peft
66
+ logger.info(f"PEFT version: {peft.__version__}")
67
+ else:
68
+ logger.warning("PEFT not available. Parameter-efficient fine-tuning will not be used.")
69
+
70
+ # Import datasets library after the main ML libraries
71
+ try:
72
+ from datasets import load_dataset
73
+ logger.info("Datasets library successfully imported")
74
+ except ImportError:
75
+ logger.error("Datasets library not available. This is required for loading training data.")
76
+
77
+ # Define a clean logging function for HF Space compatibility
78
+ def log_info(message):
79
+ """Log information in a format compatible with Hugging Face Spaces"""
80
+ # Just use the logger, but ensure consistent formatting
81
+ logger.info(message)
82
+ # Also ensure output is flushed immediately for streaming
83
+ sys.stdout.flush()
84
+
85
+ # Check for BitsAndBytes
86
+ try:
87
+ from transformers import BitsAndBytesConfig
88
+ bitsandbytes_available = True
89
+ except ImportError:
90
+ bitsandbytes_available = False
91
+ logger.warning("BitsAndBytes not available. 4-bit quantization will not be used.")
92
+
93
+ # Check for PEFT
94
+ try:
95
+ from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
96
+ peft_available = True
97
+ except ImportError:
98
+ peft_available = False
99
+ logger.warning("PEFT not available. Parameter-efficient fine-tuning will not be used.")
100
+
101
+ def load_env_variables():
102
+ """Load environment variables from system, .env file, or Hugging Face Space variables."""
103
+ # Check if we're running in a Hugging Face Space
104
+ if os.environ.get("SPACE_ID"):
105
+ logging.info("Running in Hugging Face Space")
106
+
107
+ # Log the presence of variables (without revealing values)
108
+ logging.info(f"HF_TOKEN available: {bool(os.environ.get('HF_TOKEN'))}")
109
+ logging.info(f"HF_USERNAME available: {bool(os.environ.get('HF_USERNAME'))}")
110
+
111
+ # If username is not set, try to extract from SPACE_ID
112
+ if not os.environ.get("HF_USERNAME") and "/" in os.environ.get("SPACE_ID", ""):
113
+ username = os.environ.get("SPACE_ID").split("/")[0]
114
+ os.environ["HF_USERNAME"] = username
115
+ logging.info(f"Set HF_USERNAME from SPACE_ID: {username}")
116
+ else:
117
+ # Try to load from .env file if not in a Space
118
+ try:
119
+ from dotenv import load_dotenv
120
+ # Updated path to .env file in the new directory structure
121
+ env_path = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "shared", ".env")
122
+ if os.path.exists(env_path):
123
+ load_dotenv(env_path)
124
+ logging.info(f"Loaded environment variables from {env_path}")
125
+ logging.info(f"HF_TOKEN loaded from .env file: {bool(os.environ.get('HF_TOKEN'))}")
126
+ logging.info(f"HF_USERNAME loaded from .env file: {bool(os.environ.get('HF_USERNAME'))}")
127
+ logging.info(f"HF_SPACE_NAME loaded from .env file: {bool(os.environ.get('HF_SPACE_NAME'))}")
128
+ else:
129
+ logging.warning(f"No .env file found at {env_path}")
130
+ except ImportError:
131
+ logging.warning("python-dotenv not installed, not loading from .env file")
132
+
133
+ if not os.environ.get("HF_USERNAME"):
134
+ logger.warning("HF_USERNAME is not set. Using default username.")
135
+
136
+ if not os.environ.get("HF_SPACE_NAME"):
137
+ logger.warning("HF_SPACE_NAME is not set. Using default space name.")
138
+
139
+ # Set HF_TOKEN for huggingface_hub
140
+ if os.environ.get("HF_TOKEN"):
141
+ os.environ["HUGGING_FACE_HUB_TOKEN"] = os.environ.get("HF_TOKEN")
142
+
143
+ def load_configs(base_path):
144
+ """Load configuration from transformers_config.json file."""
145
+ # Using a single consolidated config file
146
+ config_file = base_path
147
+
148
+ try:
149
+ with open(config_file, "r") as f:
150
+ config = json.load(f)
151
+ logger.info(f"Loaded configuration from {config_file}")
152
+ return config
153
+ except Exception as e:
154
+ logger.error(f"Error loading {config_file}: {e}")
155
+ raise
156
+
157
+ def parse_args():
158
+ parser = argparse.ArgumentParser(description="Fine-tune a language model on a text dataset")
159
+ parser.add_argument("--config", type=str, default="transformers_config.json", help="Path to configuration file")
160
+ return parser.parse_args()
161
+
162
+ def load_model_and_tokenizer(config):
163
+ """Load model and tokenizer with proper error handling and optimizations."""
164
+ try:
165
+ if not unsloth_available:
166
+ logger.error("Unsloth is required for training with pre-quantized model")
167
+ logger.error("Please ensure unsloth is in requirements.txt")
168
+ raise ImportError("Unsloth is required for this training setup")
169
+
170
+ # Get model name correctly from config
171
+ model_name = config.get("model_name") or config.get("model", {}).get("name")
172
+ logger.info(f"Loading model: {model_name}")
173
+
174
+ if not model_name:
175
+ raise ValueError("Model name not found in configuration. Please check your transformers_config.json file.")
176
+
177
+ logger.info("Using Unsloth optimizations with pre-quantized model")
178
+
179
+ # First detect if we have a GPU
180
+ if torch.cuda.is_available():
181
+ gpu_count = torch.cuda.device_count()
182
+ logger.info(f"Found {gpu_count} CUDA devices")
183
+ else:
184
+ logger.warning("No CUDA devices detected. Training will be slow on CPU!")
185
+ gpu_count = 0
186
+
187
+ # Set default dtype for better numerics
188
+ if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8:
189
+ # Use bfloat16 for Ampere or newer
190
+ dtype = torch.bfloat16
191
+ logger.info("Using bfloat16 precision (Ampere+ GPU)")
192
+ elif torch.cuda.is_available():
193
+ # Use float16 for older GPUs
194
+ dtype = torch.float16
195
+ logger.info("Using float16 precision (pre-Ampere GPU)")
196
+ else:
197
+ # CPU, use default dtype
198
+ dtype = None
199
+ logger.info("Using default precision (CPU)")
200
+
201
+ # Check for flash attention as the last dependency check
202
+ use_flash_attention = config.get("use_flash_attention", True)
203
+ if use_flash_attention and not find_spec("flash_attn"):
204
+ logger.warning("flash-attn not found. Will continue without flash attention.")
205
+ logger.warning("To use flash attention, install with: pip install flash-attn --no-build-isolation")
206
+ use_flash_attention = False
207
+
208
+ # Set device map based on config or default to "auto"
209
+ device_map = config.get("hardware", {}).get("hardware_setup", {}).get("device_map", "auto")
210
+
211
+ # Calculate max memory settings if multiple GPUs are available
212
+ max_memory = None
213
+ if gpu_count > 1:
214
+ memory_per_gpu = config.get("hardware", {}).get("specs", {}).get("vram_per_gpu", 24)
215
+ max_memory = {i: f"{int(memory_per_gpu * 0.85)}GiB" for i in range(gpu_count)}
216
+ max_memory["cpu"] = "64GiB" # Allow CPU offloading if needed
217
+
218
+ # Load model with proper error handling for out-of-memory
219
+ try:
220
+ # Improved memory settings for multi-GPU setup
221
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
222
+
223
+ model, tokenizer = FastLanguageModel.from_pretrained(
224
+ model_name=model_name,
225
+ max_seq_length=config.get("max_seq_length", 2048) or config.get("tokenizer", {}).get("max_seq_length", 2048),
226
+ dtype=dtype,
227
+ device_map=device_map,
228
+ max_memory=max_memory,
229
+ # Don't explicitly use flash attention config here, let Unsloth handle it
230
+ )
231
+ except RuntimeError as e:
232
+ if "CUDA out of memory" in str(e):
233
+ logger.error("Out of GPU memory. Consider using a smaller batch size or gradient accumulation steps.")
234
+ raise
235
+ else:
236
+ # Try again with CPU placement to see if it's a memory issue
237
+ logger.warning(f"Error loading model on default device: {str(e)}")
238
+ logger.warning("Attempting to load with device_map='cpu' and no specific dtype")
239
+ model, tokenizer = FastLanguageModel.from_pretrained(
240
+ model_name=model_name,
241
+ max_seq_length=config.get("max_seq_length", 2048) or config.get("tokenizer", {}).get("max_seq_length", 2048),
242
+ dtype=None,
243
+ device_map={"": "cpu"},
244
+ )
245
+ logger.warning("Model loaded on CPU. Training will be very slow.")
246
+
247
+ # Ensure model and optimizer init is on the same device
248
+ logger.info(f"Model device map: {model.hf_device_map if hasattr(model, 'hf_device_map') else 'Not available'}")
249
+
250
+ # Apply Unsloth's training optimizations with config parameters
251
+ unsloth_config = config.get("unsloth", {})
252
+ model = FastLanguageModel.get_peft_model(
253
+ model,
254
+ r=unsloth_config.get("r", 32),
255
+ target_modules=unsloth_config.get("target_modules",
256
+ ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]),
257
+ lora_alpha=unsloth_config.get("alpha", 16),
258
+ lora_dropout=unsloth_config.get("dropout", 0.05),
259
+ bias="none",
260
+ use_gradient_checkpointing=config.get("gradient_checkpointing", True) or config.get("training", {}).get("gradient_checkpointing", True),
261
+ random_state=config.get("seed", 42),
262
+ )
263
+ logger.info("Unsloth optimizations applied successfully")
264
+
265
+ # Set up tokenizer settings
266
+ chat_template = config.get("chat_template") or config.get("tokenizer", {}).get("chat_template")
267
+ if chat_template:
268
+ try:
269
+ template = get_chat_template("phi")
270
+ tokenizer.chat_template = template
271
+ logger.info("Set phi chat template")
272
+ except Exception as e:
273
+ logger.warning(f"Failed to set chat template: {str(e)}")
274
+
275
+ # Ensure proper token settings
276
+ if tokenizer.pad_token_id is None:
277
+ tokenizer.pad_token_id = tokenizer.eos_token_id
278
+ logger.info(f"Set pad_token_id to eos_token_id: {tokenizer.pad_token_id}")
279
+
280
+ return model, tokenizer
281
+
282
+ except Exception as e:
283
+ logger.error(f"Error in model/tokenizer loading: {str(e)}")
284
+ logger.error("If missing dependencies, check the requirements.txt file")
285
+ raise
286
+
287
+ def load_dataset_with_mapping(dataset_config):
288
+ """Load dataset and apply appropriate column mappings."""
289
+ try:
290
+ # Load dataset
291
+ dataset_name = dataset_config.get("dataset", {}).get("name", "")
292
+ dataset_split = dataset_config.get("dataset", {}).get("split", "train")
293
+
294
+ if not dataset_name:
295
+ raise ValueError("Dataset name not provided in configuration")
296
+
297
+ logger.info(f"Loading pre-processed dataset {dataset_name}, split {dataset_split}")
298
+ dataset = load_dataset(dataset_name, split=dataset_split)
299
+
300
+ # Apply minimal processing since the dataset has already been properly structured
301
+ # Just perform validation to ensure required fields exist
302
+
303
+ # Check for required fields
304
+ required_fields = ["prompt_number", "article_id", "conversations"]
305
+ missing_fields = [field for field in required_fields if field not in dataset.column_names]
306
+
307
+ if missing_fields:
308
+ logger.warning(f"Dataset is missing required fields: {missing_fields}")
309
+ logger.warning("This may cause issues with sequence integrity and metadata management")
310
+ else:
311
+ logger.info(f"Dataset has all required fields: {required_fields}")
312
+
313
+ # Verify that column order matches our expectation
314
+ expected_order = ["prompt_number", "article_id", "conversations"]
315
+ actual_order = dataset.column_names
316
+
317
+ if actual_order == expected_order:
318
+ logger.info("Dataset column order matches expected order (prompt_number, article_id, conversations)")
319
+ else:
320
+ logger.warning(f"Dataset column order ({', '.join(actual_order)}) differs from expected order ({', '.join(expected_order)})")
321
+ logger.warning("This should not affect processing but is noted for debugging purposes")
322
+
323
+ # Log a few samples for verification
324
+ if len(dataset) > 0:
325
+ sample_indices = range(min(5, len(dataset)))
326
+ sample_records = []
327
+
328
+ for i in sample_indices:
329
+ record = {}
330
+ record["prompt_number"] = dataset[i].get("prompt_number", "N/A")
331
+ record["article_id"] = dataset[i].get("article_id", "N/A")
332
+ if "conversations" in dataset[i]:
333
+ record["conversations_length"] = len(dataset[i]["conversations"])
334
+ sample_records.append(record)
335
+
336
+ logger.info(f"Sample records: {sample_records}")
337
+
338
+ # Verify sequential integrity
339
+ if "prompt_number" in dataset.column_names and len(dataset) > 1:
340
+ first_prompt_numbers = [dataset[i]["prompt_number"] for i in range(min(10, len(dataset)))]
341
+ is_sequential = all(first_prompt_numbers[i] == i + 1 for i in range(len(first_prompt_numbers)))
342
+
343
+ if is_sequential:
344
+ logger.info("Dataset prompt numbers are sequential (1-indexed) - sequence integrity preserved")
345
+ else:
346
+ logger.warning("Dataset prompt numbers are not sequential - sequence integrity may be compromised")
347
+ logger.info(f"First few prompt numbers: {first_prompt_numbers}")
348
+
349
+ logger.info(f"Dataset loaded successfully with {len(dataset)} examples")
350
+ logger.info(f"Dataset columns: {dataset.column_names}")
351
+
352
+ # Data loading configuration - ensure shuffle is disabled
353
+ data_loading_config = dataset_config.get("data_loading", {})
354
+ if data_loading_config.get("shuffle", False):
355
+ logger.error("CRITICAL: shuffle is enabled in the dataset config!")
356
+ logger.error("This will RANDOMIZE your dataset and break sequential order.")
357
+ logger.error("Setting shuffle to False to preserve order")
358
+ data_loading_config["shuffle"] = False
359
+
360
+ return dataset
361
+
362
+ except Exception as e:
363
+ logger.error(f"Error loading dataset: {str(e)}")
364
+ raise
365
+
366
+ def format_phi_chat(messages, dataset_config):
367
+ """Format messages according to phi-4's chat template and dataset config."""
368
+ formatted_chat = ""
369
+
370
+ # Get role templates from config
371
+ roles = dataset_config.get("data_formatting", {}).get("roles", {
372
+ "system": "System: {content}\n\n",
373
+ "human": "Human: {content}\n\n",
374
+ "user": "Human: {content}\n\n",
375
+ "assistant": "Assistant: {content}\n\n"
376
+ })
377
+
378
+ # Handle research introduction metadata first
379
+ metadata = next((msg for msg in messages if isinstance(msg, dict) and
380
+ "[RESEARCH INTRODUCTION]" in msg.get("content", "")), None)
381
+ if metadata:
382
+ system_template = roles.get("system", "System: {content}\n\n")
383
+ formatted_chat = system_template.format(content=metadata['content'])
384
+ messages = [msg for msg in messages if msg != metadata]
385
+
386
+ # Process remaining messages
387
+ for message in messages:
388
+ if not isinstance(message, dict) or "content" not in message:
389
+ logger.warning(f"Skipping invalid message format: {message}")
390
+ continue
391
+
392
+ role = message.get("role", "").lower()
393
+ content = message.get("content", "")
394
+
395
+ # Format based on role
396
+ if role == "human" or role == "user":
397
+ template = roles.get("user", roles.get("human", "Human: {content}\n\n"))
398
+ formatted_chat += template.format(content=content)
399
+ elif role == "assistant" or role == "bot":
400
+ template = roles.get("assistant", "Assistant: {content}\n\n")
401
+ formatted_chat += template.format(content=content)
402
+ elif role == "system":
403
+ # For system messages, prepend them
404
+ template = roles.get("system", "System: {content}\n\n")
405
+ formatted_chat = template.format(content=content) + formatted_chat
406
+ else:
407
+ # Default to system for unknown roles
408
+ logger.warning(f"Unknown role '{role}' - treating as system message")
409
+ template = roles.get("system", "System: {content}\n\n")
410
+ formatted_chat += template.format(content=content)
411
+
412
+ return formatted_chat.strip()
413
+
414
+ class SimpleDataCollator:
415
+ def __init__(self, tokenizer, dataset_config):
416
+ self.tokenizer = tokenizer
417
+ self.dataset_config = dataset_config
418
+ self.stats = {"processed": 0, "skipped": 0, "total_tokens": 0}
419
+ self.pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
420
+ self.max_seq_length = dataset_config.get("dataset", {}).get("processing", {}).get("max_seq_length", 2048)
421
+ logger.info(f"SimpleDataCollator initialized - using pre-audited dataset with max_seq_length={self.max_seq_length}")
422
+ logger.info("Using exact dataset structure without reformatting")
423
+
424
+ # Check if we're on GPU
425
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
426
+ logger.info(f"SimpleDataCollator using device: {self.device}")
427
+
428
+ def __call__(self, features):
429
+ """Process examples preserving exact JSONL structure"""
430
+ batch = {"input_ids": [], "attention_mask": [], "labels": []}
431
+
432
+ for example in features:
433
+ try:
434
+ # Get ID
435
+ paper_id = example.get("id", "")
436
+
437
+ # Get conversations - these should already contain role and content
438
+ conversations = example.get("conversations", [])
439
+ if not conversations:
440
+ self.stats["skipped"] += 1
441
+ continue
442
+
443
+ # Directly use the conversations array as input to the model's chat template
444
+ # This preserves the exact structure with roles and content as they are
445
+ try:
446
+ # Let tokenizer handle the content with the model's chat template
447
+ inputs = self.tokenizer.apply_chat_template(
448
+ conversations,
449
+ return_tensors=None,
450
+ add_generation_prompt=False
451
+ )
452
+ except Exception as chat_error:
453
+ # Fallback if apply_chat_template fails
454
+ logger.warning(f"Chat template application failed for example {paper_id}: {str(chat_error)[:100]}")
455
+
456
+ # Create a basic representation of the conversation
457
+ conversation_text = ""
458
+ for msg in conversations:
459
+ if isinstance(msg, dict) and 'content' in msg:
460
+ conversation_text += msg.get('content', '') + "\n\n"
461
+
462
+ # Basic tokenization
463
+ inputs = self.tokenizer(
464
+ conversation_text,
465
+ add_special_tokens=True,
466
+ return_tensors=None
467
+ )
468
+
469
+ # Apply length cap if needed (shouldn't be necessary for pre-audited data)
470
+ if self.max_seq_length > 0 and len(inputs) > self.max_seq_length:
471
+ logger.warning(f"Example {paper_id} exceeds max_seq_length ({len(inputs)} > {self.max_seq_length})")
472
+ inputs = inputs[:self.max_seq_length]
473
+
474
+ # Create attention mask (1 for all tokens)
475
+ attention_mask = [1] * len(inputs)
476
+
477
+ if len(inputs) > 0:
478
+ # For causal language modeling, labels are the same as inputs
479
+ labels = inputs.copy()
480
+
481
+ batch["input_ids"].append(inputs)
482
+ batch["attention_mask"].append(attention_mask)
483
+ batch["labels"].append(labels)
484
+
485
+ self.stats["processed"] += 1
486
+ self.stats["total_tokens"] += len(inputs)
487
+
488
+ # Debug logging for first few examples
489
+ log_samples = self.dataset_config.get("validation", {}).get("log_samples", 3)
490
+ if self.stats["processed"] <= log_samples:
491
+ logger.info(f"Example {self.stats['processed']}:")
492
+ logger.info(f"Paper ID: {paper_id}")
493
+ logger.info(f"Token count: {len(inputs)}")
494
+ logger.info(f"Conversation entries: {len(conversations)}")
495
+ else:
496
+ self.stats["skipped"] += 1
497
+ except Exception as e:
498
+ logger.warning(f"Error processing example: {str(e)[:100]}...")
499
+ logger.warning(f"Problematic example ID: {example.get('id', 'unknown')}")
500
+ self.stats["skipped"] += 1
501
+ continue
502
+
503
+ if not batch["input_ids"]:
504
+ logger.warning("Empty batch, returning dummy tensors")
505
+ return {
506
+ "input_ids": torch.zeros((1, 1), dtype=torch.long),
507
+ "attention_mask": torch.zeros((1, 1), dtype=torch.long),
508
+ "labels": torch.zeros((1, 1), dtype=torch.long)
509
+ }
510
+
511
+ # Pad the batch
512
+ max_length = max(len(ids) for ids in batch["input_ids"])
513
+
514
+ for i in range(len(batch["input_ids"])):
515
+ padding_length = max_length - len(batch["input_ids"][i])
516
+ if padding_length > 0:
517
+ batch["input_ids"][i].extend([self.pad_token_id] * padding_length)
518
+ batch["attention_mask"][i].extend([0] * padding_length)
519
+ batch["labels"][i].extend([-100] * padding_length)
520
+
521
+ # Convert to tensors
522
+ batch = {k: torch.tensor(v, dtype=torch.long) for k, v in batch.items()}
523
+
524
+ # Log stats periodically
525
+ log_interval = self.dataset_config.get("validation", {}).get("log_interval", 100)
526
+ if self.stats["processed"] % log_interval == 0 and self.stats["processed"] > 0:
527
+ logger.info(f"Data collator stats: processed={self.stats['processed']}, "
528
+ f"skipped={self.stats['skipped']}, "
529
+ f"avg_tokens={self.stats['total_tokens']/self.stats['processed']:.1f}")
530
+
531
+ return batch
532
+
533
+ class LoggingCallback(TrainerCallback):
534
+ def __init__(self):
535
+ super().__init__()
536
+ self.training_started = time.time()
537
+ self.last_log_time = time.time()
538
+ self.last_step = 0
539
+ self.verify_sequence = None
540
+ self.sequence_samples = None
541
+ self.sample_indices = None
542
+
543
+ def on_train_begin(self, args, state, control, **kwargs):
544
+ log_info(f"=== Training started at {time.strftime('%Y-%m-%d %H:%M:%S')} ===")
545
+ log_info(f"Model parameters: {sum(p.numel() for p in model.parameters())/1e6:.2f}M")
546
+
547
+ # Disable sequence verification
548
+ self.verify_sequence = False
549
+
550
+ log_info("=== Training is starting ===")
551
+
552
+ # Log important training parameters for visibility
553
+ total_batch_size = args.per_device_train_batch_size * args.gradient_accumulation_steps * NUM_GPUS
554
+ total_steps = int(len(dataset) / (args.per_device_train_batch_size * NUM_GPUS * args.gradient_accumulation_steps) * args.num_train_epochs)
555
+ log_info(f"Training plan: {len(dataset)} examples over {args.num_train_epochs} epochs ≈ {total_steps} steps")
556
+ log_info(f"Batch size: {args.per_device_train_batch_size} × {args.gradient_accumulation_steps} steps × {NUM_GPUS} GPUs = {total_batch_size} total")
557
+ log_info(f"Learning rate: {args.learning_rate}")
558
+ log_info(f"Epochs: {args.num_train_epochs}")
559
+
560
+ # Log memory information in compact format
561
+ if CUDA_AVAILABLE:
562
+ memory_info = []
563
+ for i in range(NUM_GPUS):
564
+ allocated = torch.cuda.memory_allocated(i) / 1024**2
565
+ max_mem = torch.cuda.max_memory_allocated(i) / 1024**2
566
+ memory_info.append(f"GPU {i}: {allocated:.1f}MB (max: {max_mem:.1f}MB)")
567
+
568
+ log_info(f"Initial memory usage - {', '.join(memory_info)}")
569
+
570
+ def on_step_end(self, args, state, control, **kwargs):
571
+ # Log every 50 steps or every 5 minutes, whichever comes first
572
+ current_time = time.time()
573
+
574
+ # Sequence verification removed
575
+
576
+ # Log progress at regular intervals
577
+ if (state.global_step % 50 == 0) or (current_time - self.last_log_time > 300):
578
+ if state.log_history:
579
+ loss = state.log_history[-1].get('loss', 'N/A')
580
+ # Use simple formatting for better Space log compatibility
581
+ log_info(f"Step {state.global_step}: Loss {loss}")
582
+ else:
583
+ log_info(f"Step {state.global_step}: No loss data available")
584
+ self.last_log_time = current_time
585
+
586
+ def on_train_end(self, args, state, control, **kwargs):
587
+ training_time = time.strftime("%H:%M:%S", time.gmtime(time.time() - self.training_started))
588
+ log_info(f"=== Training completed in {training_time} ===")
589
+
590
+ # Log final memory usage
591
+ if CUDA_AVAILABLE:
592
+ for i in range(NUM_GPUS):
593
+ max_mem = torch.cuda.max_memory_allocated(i) / 1024**3 # GB
594
+ log_info(f"GPU {i} max memory: {max_mem:.2f} GB")
595
+
596
+ # Clear GPU memory
597
+ torch.cuda.empty_cache()
598
+ log_info("GPU memory cleared")
599
+
600
+ log_info(f"Total steps: {state.global_step}")
601
+ log_info(f"Final loss: {state.log_history[-1].get('loss', 'N/A') if state.log_history else 'N/A'}")
602
+
603
+ def check_dependencies():
604
+ """Check if all required dependencies are installed and in the correct order."""
605
+ missing_packages = []
606
+ order_issues = []
607
+
608
+ # Check critical packages in the required order
609
+
610
+ # 1. First check for unsloth as it should be imported before transformers
611
+ if not unsloth_available:
612
+ missing_packages.append("unsloth>=2024.3")
613
+
614
+ # 2. Check transformers (imported at module level)
615
+ try:
616
+ import transformers
617
+ logger.info(f"Using transformers version {transformers.__version__}")
618
+ except ImportError:
619
+ missing_packages.append("transformers>=4.38.0")
620
+
621
+ # 3. Check for peft
622
+ if not peft_available:
623
+ missing_packages.append("peft>=0.9.0")
624
+
625
+ # 4. Check for accelerate
626
+ try:
627
+ import accelerate
628
+ logger.info(f"Using accelerate version {accelerate.__version__}")
629
+ except ImportError:
630
+ missing_packages.append("accelerate>=0.27.0")
631
+
632
+ # Check for order-specific issues
633
+ try:
634
+ import sys
635
+ modules = sys.modules.keys()
636
+
637
+ # Unsloth should be imported before transformers for optimal performance
638
+ if 'transformers' in modules and 'unsloth' in modules:
639
+ if modules.index('transformers') < modules.index('unsloth'):
640
+ order_issues.append("For optimal performance, unsloth should be imported before transformers")
641
+ except Exception:
642
+ # If we can't check order, just skip this check
643
+ pass
644
+
645
+ # If critical packages are missing, exit with instructions
646
+ if missing_packages:
647
+ logger.error("Critical dependencies missing:")
648
+ for pkg in missing_packages:
649
+ logger.error(f" - {pkg}")
650
+ logger.error("Please install the missing dependencies with:")
651
+ logger.error(f" pip install {' '.join(missing_packages)}")
652
+ return False
653
+
654
+ # Report order issues as warnings
655
+ for issue in order_issues:
656
+ logger.warning(issue)
657
+
658
+ # Optional packages - moved to the end
659
+ if find_spec("flash_attn"):
660
+ logger.info("flash-attn found. Flash attention will be used for faster training.")
661
+ else:
662
+ logger.warning("flash-attn not found. Training will work but may be slower.")
663
+ logger.warning("To use flash attention, install with: pip install flash-attn --no-build-isolation")
664
+
665
+ # Additional optional packages that improve performance
666
+ if find_spec("bitsandbytes"):
667
+ logger.info("bitsandbytes found. Quantization will be available.")
668
+ else:
669
+ logger.warning("bitsandbytes not found. Quantization may not be available.")
670
+ logger.warning("To use quantization, install with: pip install bitsandbytes")
671
+
672
+ return True
673
+
674
+ def main():
675
+ # Set up logging
676
+ logger.info("Starting training process")
677
+
678
+ # Check dependencies first, before any other operations
679
+ if not check_dependencies():
680
+ logger.error("Aborting due to missing critical dependencies")
681
+ return 1
682
+
683
+ # Parse arguments
684
+ args = parse_args()
685
+
686
+ # Load environment variables
687
+ load_env_variables()
688
+
689
+ # Load configuration
690
+ try:
691
+ transformers_config = load_configs(args.config)
692
+ hardware_config = transformers_config.get("hardware", {})
693
+ dataset_config = transformers_config.get("dataset", {})
694
+ logger.info("Configuration loaded successfully")
695
+ except Exception as e:
696
+ logger.error(f"Error loading configuration: {e}")
697
+ return 1
698
+
699
+ # Check if we're in distributed mode
700
+ is_distributed = "WORLD_SIZE" in os.environ and int(os.environ.get("WORLD_SIZE", "1")) > 1
701
+ if is_distributed:
702
+ local_rank = int(os.environ.get("LOCAL_RANK", "0"))
703
+ log_info(f"Running in distributed mode with {os.environ.get('WORLD_SIZE')} processes, local_rank: {local_rank}")
704
+ else:
705
+ log_info("Running in non-distributed mode (single process)")
706
+
707
+ # Set random seed for reproducibility
708
+ seed = transformers_config.get("seed", 42)
709
+ set_seed(seed)
710
+ logger.info(f"Set random seed to {seed}")
711
+
712
+ # Load model and tokenizer using the consolidated config
713
+ model, tokenizer = load_model_and_tokenizer(transformers_config)
714
+
715
+ # Empty CUDA cache to ensure clean state
716
+ if CUDA_AVAILABLE:
717
+ torch.cuda.empty_cache()
718
+ log_info("Cleared CUDA cache")
719
+
720
+ # Setup environment variable for CUDA memory allocation
721
+ if CUDA_AVAILABLE:
722
+ system_settings = hardware_config.get("system_settings", {})
723
+ cuda_memory_fraction = system_settings.get("cuda_memory_fraction", 0.85)
724
+
725
+ if cuda_memory_fraction < 1.0:
726
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = f"max_split_size_mb:128,expandable_segments:True"
727
+ log_info(f"Set CUDA memory allocation limit to expandable with max_split_size_mb:128")
728
+
729
+ try:
730
+ log_info("Loading dataset...")
731
+ dataset = load_dataset_with_mapping(dataset_config)
732
+ log_info(f"Dataset loaded with {len(dataset)} examples")
733
+
734
+ # Minimal validation before proceeding
735
+ if dataset is None or len(dataset) == 0:
736
+ logger.error("Dataset is empty or None! Cannot proceed with training.")
737
+ return 1
738
+
739
+ # Create data collator
740
+ data_collator = SimpleDataCollator(tokenizer, dataset_config)
741
+
742
+ # Verify precision settings - ensure only one of bf16/fp16 is set, with bf16 taking precedence
743
+ # First check hardware config, then transformers config
744
+ use_bf16 = False
745
+ use_fp16 = False
746
+
747
+ # Check hardware config first
748
+ hardware_precision = hardware_config.get("training_optimizations", {}).get("mixed_precision", "")
749
+ if hardware_precision.lower() == "bf16":
750
+ use_bf16 = True
751
+ log_info("Using BF16 precision from hardware config")
752
+ elif hardware_precision.lower() == "fp16":
753
+ use_fp16 = True
754
+ log_info("Using FP16 precision from hardware config")
755
+ else:
756
+ # Fall back to transformers config
757
+ use_bf16 = transformers_config.get("bf16", False) or transformers_config.get("torch_dtype", "") == "bfloat16"
758
+ use_fp16 = transformers_config.get("fp16", False) and not use_bf16 # Only use fp16 if bf16 is not set
759
+ log_info(f"Using precision: {'bf16' if use_bf16 else 'fp16' if use_fp16 else 'full precision'}")
760
+
761
+ # Get per device batch size - from transformers config, but possibly overridden by hardware config
762
+ per_device_batch_size = transformers_config.get("training", {}).get("per_device_train_batch_size", 16)
763
+ gradient_accumulation_steps = transformers_config.get("training", {}).get("gradient_accumulation_steps", 3)
764
+
765
+ # Get multi-GPU strategy from hardware config (default to data_parallel)
766
+ multi_gpu_strategy = hardware_config.get("training_optimizations", {}).get("multi_gpu_strategy", "data_parallel")
767
+ logger.info(f"Multi-GPU strategy: {multi_gpu_strategy}")
768
+
769
+ # For multi-GPU setup, adjust for better balance
770
+ if CUDA_AVAILABLE and NUM_GPUS > 1:
771
+ log_info(f"Multi-GPU setup: Adjusting for {NUM_GPUS} GPUs")
772
+
773
+ # Set up FSDP for multi-GPU training if specified and in distributed mode
774
+ fsdp_config = None
775
+ if multi_gpu_strategy == "fsdp" and is_distributed and NUM_GPUS > 1:
776
+ try:
777
+ from torch.distributed.fsdp import (
778
+ FullyShardedDataParallel as FSDP,
779
+ MixedPrecision,
780
+ BackwardPrefetch,
781
+ ShardingStrategy,
782
+ CPUOffload,
783
+ )
784
+ from torch.distributed.fsdp.wrap import (
785
+ transformer_auto_wrap_policy,
786
+ enable_wrap,
787
+ wrap,
788
+ )
789
+
790
+ log_info("Using FSDP for distributed training")
791
+
792
+ # Configure FSDP
793
+ fsdp_config = {
794
+ "fsdp_transformer_layer_cls_to_wrap": ["LlamaDecoderLayer"],
795
+ "fsdp_offload_params": False,
796
+ "fsdp_backward_prefetch": "BACKWARD_PRE",
797
+ "fsdp_min_num_params": 1e6,
798
+ "fsdp_sharding_strategy": 1, # FULL_SHARD
799
+ }
800
+
801
+ if use_bf16 or use_fp16:
802
+ precision_type = "bf16" if use_bf16 else "fp16"
803
+ fsdp_config["fsdp_state_dict_type"] = "FULL_STATE_DICT"
804
+ log_info(f"FSDP using mixed precision: {precision_type}")
805
+ except ImportError:
806
+ log_info("FSDP imports failed, falling back to standard DDP")
807
+ fsdp_config = None
808
+ elif multi_gpu_strategy == "fsdp" and not is_distributed:
809
+ log_info("FSDP disabled: requires distributed environment (use torchrun or accelerate)")
810
+ log_info("Using DataParallel for multi-GPU training instead")
811
+ else:
812
+ log_info(f"Using {multi_gpu_strategy} for multi-GPU training")
813
+
814
+ # Get system settings from hardware config
815
+ dataloader_workers = hardware_config.get("system_settings", {}).get("dataloader_num_workers", 2)
816
+ pin_memory = hardware_config.get("system_settings", {}).get("dataloader_pin_memory", True)
817
+
818
+ # Set up training arguments
819
+ log_info("Setting up training arguments")
820
+ training_args = TrainingArguments(
821
+ output_dir=transformers_config.get("output_dir", "./results") or transformers_config.get("checkpointing", {}).get("output_dir", "./results"),
822
+ num_train_epochs=transformers_config.get("training", {}).get("num_train_epochs", 3),
823
+ per_device_train_batch_size=per_device_batch_size,
824
+ gradient_accumulation_steps=gradient_accumulation_steps,
825
+ learning_rate=transformers_config.get("training", {}).get("learning_rate", 2e-5),
826
+ weight_decay=transformers_config.get("training", {}).get("weight_decay", 0.01),
827
+ warmup_ratio=transformers_config.get("training", {}).get("warmup_ratio", 0.05),
828
+ lr_scheduler_type=transformers_config.get("training", {}).get("lr_scheduler_type", "cosine"),
829
+ logging_steps=transformers_config.get("training", {}).get("logging_steps", 10),
830
+ save_strategy=transformers_config.get("checkpointing", {}).get("save_strategy", "steps"),
831
+ save_steps=transformers_config.get("checkpointing", {}).get("save_steps", 100),
832
+ save_total_limit=transformers_config.get("checkpointing", {}).get("save_total_limit", 3),
833
+ fp16=use_fp16,
834
+ bf16=use_bf16,
835
+ max_grad_norm=transformers_config.get("training", {}).get("max_grad_norm", 1.0),
836
+ push_to_hub=transformers_config.get("huggingface_hub", {}).get("push_to_hub", False),
837
+ hub_model_id=transformers_config.get("huggingface_hub", {}).get("hub_model_id", None),
838
+ hub_token=os.environ.get("HF_TOKEN", None),
839
+ report_to="tensorboard",
840
+ remove_unused_columns=False, # Keep all columns
841
+ gradient_checkpointing=transformers_config.get("training", {}).get("gradient_checkpointing", True),
842
+ dataloader_pin_memory=pin_memory,
843
+ optim=transformers_config.get("training", {}).get("optim", "adamw_torch"),
844
+ ddp_find_unused_parameters=False, # Improve distributed training efficiency
845
+ dataloader_drop_last=False, # Process all examples
846
+ dataloader_num_workers=dataloader_workers,
847
+ no_cuda=False if CUDA_AVAILABLE else True, # Use CUDA if available
848
+ # Only add FSDP if we're in distributed mode with FSDP strategy
849
+ fsdp=fsdp_config if is_distributed and multi_gpu_strategy == "fsdp" else None,
850
+ )
851
+
852
+ # Create sequential sampler to maintain original dataset order
853
+ sequential_sampler = torch.utils.data.SequentialSampler(dataset)
854
+
855
+ # Initialize trainer first
856
+ log_info("Initializing Trainer")
857
+ trainer = Trainer(
858
+ model=model,
859
+ args=training_args,
860
+ train_dataset=dataset, # We'll override this with our custom dataloader
861
+ data_collator=data_collator,
862
+ callbacks=[LoggingCallback()],
863
+ )
864
+
865
+ # Then override the get_train_dataloader method
866
+ def custom_get_train_dataloader():
867
+ """Custom dataloader that preserves original dataset order"""
868
+ log_info("Creating sequential dataloader to maintain original dataset order")
869
+
870
+ # Create a simple sequential sampler
871
+ sequential_sampler = torch.utils.data.SequentialSampler(dataset)
872
+
873
+ # Verification of sequence preservation flags - simplified
874
+ data_loading_config = dataset_config.get("data_loading", {})
875
+ shuffle_enabled = data_loading_config.get("shuffle", False)
876
+
877
+ if shuffle_enabled:
878
+ log_info("WARNING: Shuffle is enabled in configuration! This will be overridden to preserve order.")
879
+ # We enforce sequential processing regardless of config
880
+
881
+ # Log our approach clearly
882
+ log_info("Using SequentialSampler to guarantee dataset order is preserved based on prompt_number")
883
+
884
+ # Verify column order
885
+ expected_order = ["prompt_number", "article_id", "conversations"]
886
+ if hasattr(dataset, 'column_names'):
887
+ actual_order = dataset.column_names
888
+ if actual_order == expected_order:
889
+ log_info(f"Confirmed dataset columns are in expected order: {', '.join(expected_order)}")
890
+ else:
891
+ log_info(f"Note: Dataset columns ({', '.join(actual_order)}) are not in expected order ({', '.join(expected_order)})")
892
+ log_info("This is handled correctly by field-based access, but noting for clarity")
893
+
894
+ log_info("Dataset is pre-processed with prompt_number field indicating the correct sequence")
895
+
896
+ # Calculate batch size based on device availability
897
+ if getattr(training_args, "no_cuda", False):
898
+ batch_size = training_args.per_device_train_batch_size
899
+ else:
900
+ batch_size = max(training_args.per_device_train_batch_size * max(1, NUM_GPUS), 1)
901
+
902
+ log_info(f"Using sequential sampler with batch size {batch_size}")
903
+
904
+ # Return DataLoader with sequential sampler
905
+ return torch.utils.data.DataLoader(
906
+ dataset,
907
+ batch_size=batch_size,
908
+ sampler=sequential_sampler, # Always use sequential sampler
909
+ collate_fn=data_collator,
910
+ drop_last=training_args.dataloader_drop_last,
911
+ num_workers=training_args.dataloader_num_workers,
912
+ pin_memory=training_args.dataloader_pin_memory,
913
+ )
914
+
915
+ # Override the get_train_dataloader method
916
+ trainer.get_train_dataloader = custom_get_train_dataloader
917
+
918
+ # Start training
919
+ log_info("=== Starting Training ===")
920
+ try:
921
+ # Empty cache again right before training
922
+ if CUDA_AVAILABLE:
923
+ torch.cuda.empty_cache()
924
+ log_info("Cleared CUDA cache before training")
925
+
926
+ # Display compact training info
927
+ total_steps = int(len(dataset) / (per_device_batch_size * NUM_GPUS * gradient_accumulation_steps) * training_args.num_train_epochs)
928
+ log_info(f"Training plan: {len(dataset)} examples over {training_args.num_train_epochs} epochs ≈ {total_steps} steps")
929
+
930
+ trainer.train()
931
+ log_info("Training completed successfully!")
932
+
933
+ # Save the final model
934
+ log_info("Saving final model...")
935
+ trainer.save_model()
936
+ log_info(f"Model saved to {training_args.output_dir}")
937
+
938
+ # Push to hub if enabled
939
+ if transformers_config.get("huggingface_hub", {}).get("push_to_hub", False):
940
+ hub_id = transformers_config.get("huggingface_hub", {}).get("hub_model_id", "model")
941
+ log_info(f"Pushing model to Hugging Face Hub as {hub_id}...")
942
+ trainer.push_to_hub()
943
+ log_info("Model successfully pushed to Hub")
944
+
945
+ return 0
946
+ except Exception as e:
947
+ logger.error(f"Training failed with error: {str(e)}")
948
+ # Log CUDA memory info if available in compact format
949
+ if CUDA_AVAILABLE:
950
+ memory_info = []
951
+ for i in range(NUM_GPUS):
952
+ allocated = torch.cuda.memory_allocated(i) / 1024**2
953
+ reserved = torch.cuda.memory_reserved(i) / 1024**2
954
+ max_mem = torch.cuda.max_memory_allocated(i) / 1024**2
955
+ memory_info.append(f"GPU {i}: {allocated:.1f}MB/{reserved:.1f}MB (max: {max_mem:.1f}MB)")
956
+ logger.error(f"GPU memory at failure: {', '.join(memory_info)}")
957
+ raise
958
+
959
+ except Exception as e:
960
+ logger.error(f"Error in main training loop: {str(e)}")
961
+ return 1
962
+
963
+ if __name__ == "__main__":
964
+ sys.exit(main())
transformers_config.json ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model": {
3
+ "name": "unsloth/phi-4-unsloth-bnb-4bit",
4
+ "trust_remote_code": true,
5
+ "use_fast_tokenizer": true
6
+ },
7
+
8
+ "tokenizer": {
9
+ "chat_template": "phi",
10
+ "max_seq_length": 2048,
11
+ "padding_side": "right",
12
+ "add_eos_token": true
13
+ },
14
+
15
+ "training": {
16
+ "per_device_train_batch_size": 16,
17
+ "gradient_accumulation_steps": 3,
18
+ "learning_rate": 2e-5,
19
+ "num_train_epochs": 3,
20
+ "max_steps": -1,
21
+ "logging_steps": 10,
22
+ "save_steps": 200,
23
+ "save_total_limit": 5,
24
+ "push_to_hub": true,
25
+ "hub_strategy": "every_save",
26
+ "gradient_checkpointing": true,
27
+ "optim": "adamw_torch",
28
+ "lr_scheduler_type": "cosine",
29
+ "warmup_ratio": 0.05,
30
+ "weight_decay": 0.01,
31
+ "max_grad_norm": 1.0,
32
+ "neftune_noise_alpha": 5,
33
+ "fp16": false,
34
+ "bf16": true
35
+ },
36
+
37
+ "checkpointing": {
38
+ "output_dir": "./results",
39
+ "save_strategy": "steps",
40
+ "save_steps": 100,
41
+ "save_total_limit": 3,
42
+ "hub_strategy": "every_save"
43
+ },
44
+
45
+ "unsloth": {
46
+ "enabled": true,
47
+ "r": 32,
48
+ "alpha": 16,
49
+ "dropout": 0.05,
50
+ "target_modules": [
51
+ "q_proj",
52
+ "k_proj",
53
+ "v_proj",
54
+ "o_proj",
55
+ "gate_proj",
56
+ "up_proj",
57
+ "down_proj"
58
+ ]
59
+ },
60
+
61
+ "distributed_training": {
62
+ "fsdp_config": {
63
+ "enabled": false,
64
+ "sharding_strategy": "FULL_SHARD",
65
+ "mixed_precision": "BF16",
66
+ "activation_checkpointing": true,
67
+ "offload_params": false
68
+ },
69
+ "ddp_find_unused_parameters": false,
70
+ "dataloader_num_workers": 2
71
+ },
72
+
73
+ "logging": {
74
+ "logging_steps": 50,
75
+ "log_level": "info"
76
+ },
77
+
78
+ "huggingface_hub": {
79
+ "push_to_hub": true,
80
+ "hub_model_id": "phi-4-cognitive-assistant",
81
+ "hub_private_repo": true
82
+ },
83
+
84
+ "model_name_or_path": "unsloth/phi-4-unsloth-bnb-4bit",
85
+ "model_revision": "main",
86
+ "use_flash_attention": true,
87
+ "torch_dtype": "bfloat16",
88
+ "bf16": true,
89
+ "fp16": false,
90
+
91
+ "hardware": {
92
+ "hardware_name": "4xL4",
93
+ "specs": {
94
+ "gpu_count": 4,
95
+ "gpu_type": "L4",
96
+ "vram_per_gpu": 24,
97
+ "total_vram": 96,
98
+ "vcpu_count": 48,
99
+ "ram": 186
100
+ },
101
+ "hardware_setup": {
102
+ "use_cpu": false,
103
+ "num_gpus": 4,
104
+ "device_map": "auto"
105
+ },
106
+ "training_optimizations": {
107
+ "per_device_batch_size": 16,
108
+ "gradient_accumulation_steps": 3,
109
+ "mixed_precision": "bf16",
110
+ "torch_compile": false,
111
+ "memory_optimizations": {
112
+ "use_gradient_checkpointing": true,
113
+ "use_flash_attention": true
114
+ },
115
+ "multi_gpu_strategy": "data_parallel"
116
+ },
117
+ "system_settings": {
118
+ "cuda_memory_fraction": 0.85,
119
+ "dataloader_num_workers": 2,
120
+ "dataloader_pin_memory": true
121
+ },
122
+ "memory_breakdown": {
123
+ "model_size": "~3.5GB (pre-quantized 4-bit)",
124
+ "optimizer_states": "~1GB",
125
+ "batch_memory_per_gpu": "~3GB",
126
+ "peak_memory_estimate": "~18GB",
127
+ "safe_headroom": "~6GB"
128
+ },
129
+ "compute_environment": "L4_CLOUD"
130
+ },
131
+
132
+ "dataset": {
133
+ "dataset": {
134
+ "name": "George-API/phi4-cognitive-dataset",
135
+ "split": "train"
136
+ },
137
+ "data_formatting": {
138
+ "chat_template": "phi",
139
+ "roles": {
140
+ "system": "System: {content}\n\n",
141
+ "human": "Human: {content}\n\n",
142
+ "assistant": "Assistant: {content}\n\n",
143
+ "user": "Human: {content}\n\n"
144
+ }
145
+ },
146
+ "data_loading": {
147
+ "batch_size": 24,
148
+ "shuffle": false,
149
+ "sequential_processing": true,
150
+ "drop_last": false,
151
+ "num_workers": 4,
152
+ "pin_memory": true,
153
+ "prefetch_factor": 4
154
+ },
155
+ "validation": {
156
+ "log_samples": 3,
157
+ "log_interval": 50,
158
+ "verify_sequence_integrity": true,
159
+ "metrics": ["processed", "skipped", "avg_tokens", "unique_articles"]
160
+ }
161
+ }
162
+ }
update_space.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ """
4
+ Quick script to update your Hugging Face Space for phi-4-unsloth-bnb-4bit training.
5
+ This script handles the specific requirements for the 4-bit quantized Phi-4 model training,
6
+ including proper configuration and dependency management.
7
+ """
8
+
9
+ import os
10
+ import sys
11
+ import json
12
+ import subprocess
13
+ import argparse
14
+ import logging
15
+ from pathlib import Path
16
+ from huggingface_hub import HfApi, login
17
+ import getpass
18
+
19
+ # Configure logging
20
+ logging.basicConfig(
21
+ level=logging.INFO,
22
+ format="%(asctime)s - %(levelname)s - %(message)s",
23
+ handlers=[logging.StreamHandler(sys.stdout)]
24
+ )
25
+ logger = logging.getLogger(__name__)
26
+
27
+ def load_env_variables():
28
+ """Load environment variables from system or .env file."""
29
+ # First try to load from local .env file
30
+ try:
31
+ from dotenv import load_dotenv
32
+ env_path = Path(__file__).parent / ".env"
33
+ if env_path.exists():
34
+ # Load and explicitly set environment variables
35
+ with open(env_path) as f:
36
+ for line in f:
37
+ if line.strip() and not line.startswith('#'):
38
+ key, value = line.strip().split('=', 1)
39
+ os.environ[key] = value.strip()
40
+ logger.info(f"Loaded environment variables from {env_path}")
41
+ else:
42
+ logger.warning(f"No .env file found at {env_path}")
43
+ except ImportError:
44
+ logger.warning("python-dotenv not installed, skipping .env loading")
45
+
46
+ # Check if we're running in a Hugging Face Space
47
+ if os.environ.get("SPACE_ID"):
48
+ logger.info("Running in Hugging Face Space")
49
+ if "/" in os.environ.get("SPACE_ID", ""):
50
+ username = os.environ.get("SPACE_ID").split("/")[0]
51
+ os.environ["HF_USERNAME"] = username
52
+ logger.info(f"Set HF_USERNAME from SPACE_ID: {username}")
53
+
54
+ # Verify required variables
55
+ required_vars = {
56
+ "HF_TOKEN": os.environ.get("HF_TOKEN"),
57
+ "HF_USERNAME": os.environ.get("HF_USERNAME"),
58
+ "HF_SPACE_NAME": os.environ.get("HF_SPACE_NAME", "phi4training")
59
+ }
60
+
61
+ # Ensure the space name is set correctly
62
+ if "HF_SPACE_NAME" not in os.environ:
63
+ os.environ["HF_SPACE_NAME"] = "phi4training"
64
+
65
+ missing_vars = [k for k, v in required_vars.items() if not v]
66
+ if missing_vars:
67
+ raise ValueError(f"Missing required environment variables: {', '.join(missing_vars)}")
68
+
69
+ logger.info(f"Using environment variables: USERNAME={required_vars['HF_USERNAME']}, SPACE_NAME={required_vars['HF_SPACE_NAME']}")
70
+ return required_vars
71
+
72
+ def verify_configs():
73
+ """Verify that all necessary configuration files exist and are valid."""
74
+ current_dir = Path(__file__).parent
75
+ required_files = [
76
+ "transformers_config.json",
77
+ "requirements.txt",
78
+ "run_transformers_training.py"
79
+ ]
80
+
81
+ missing_files = []
82
+ for file in required_files:
83
+ if not (current_dir / file).exists():
84
+ missing_files.append(file)
85
+
86
+ if missing_files:
87
+ raise FileNotFoundError(f"Missing required files: {', '.join(missing_files)}")
88
+
89
+ # Verify JSON configs
90
+ json_files = [f for f in required_files if f.endswith('.json')]
91
+ for json_file in json_files:
92
+ try:
93
+ with open(current_dir / json_file) as f:
94
+ json.load(f)
95
+ logger.info(f"Verified {json_file} is valid JSON")
96
+ except json.JSONDecodeError as e:
97
+ raise ValueError(f"Invalid JSON in {json_file}: {e}")
98
+
99
+ def update_requirements():
100
+ """Update requirements.txt with necessary packages using a two-stage installation process."""
101
+ logger.info("Setting up requirements files for sequential installation...")
102
+ current_dir = Path(__file__).parent
103
+ base_req_path = current_dir / "requirements-base.txt"
104
+ main_req_path = current_dir / "requirements.txt"
105
+ flash_req_path = current_dir / "requirements-flash.txt"
106
+
107
+ # First ensure base requirements exist
108
+ required_base_packages = {
109
+ "torch>=2.0.0",
110
+ "transformers>=4.36.0",
111
+ "accelerate>=0.27.0",
112
+ "bitsandbytes>=0.41.0",
113
+ "tensorboard>=2.15.0",
114
+ "gradio>=5.17.0",
115
+ "huggingface-hub>=0.19.0",
116
+ "datasets>=2.15.0"
117
+ }
118
+
119
+ # Additional packages for main requirements
120
+ required_additional_packages = {
121
+ "einops>=0.7.0",
122
+ "filelock>=3.13.1",
123
+ "matplotlib>=3.7.0",
124
+ "numpy>=1.24.0",
125
+ "packaging>=23.0",
126
+ "peft>=0.9.0",
127
+ "psutil>=5.9.0",
128
+ "python-dotenv>=1.0.0",
129
+ "pyyaml>=6.0.1",
130
+ "regex>=2023.0.0",
131
+ "requests>=2.31.0",
132
+ "safetensors>=0.4.1",
133
+ "sentencepiece>=0.1.99",
134
+ "tqdm>=4.65.0",
135
+ "typing-extensions>=4.8.0",
136
+ "unsloth>=2024.3"
137
+ }
138
+
139
+ # Read existing base requirements
140
+ existing_requirements = set()
141
+ if base_req_path.exists():
142
+ with open(base_req_path) as f:
143
+ existing_requirements = {line.strip() for line in f if line.strip() and not line.startswith('-r')}
144
+
145
+ # Add new requirements
146
+ updated_requirements = existing_requirements.union(required_base_packages)
147
+
148
+ # 1. Write updated base requirements
149
+ with open(base_req_path, 'w') as f:
150
+ # Ensure torch is first
151
+ torch_req = next((req for req in updated_requirements if req.startswith("torch")), "torch>=2.0.0")
152
+ f.write(f"{torch_req}\n")
153
+
154
+ # Write all other requirements (excluding torch)
155
+ for req in sorted(r for r in updated_requirements if not r.startswith("torch")):
156
+ f.write(f"{req}\n")
157
+
158
+ # 2. Create main requirements file (references base)
159
+ with open(main_req_path, 'w') as f:
160
+ f.write("-r requirements-base.txt\n")
161
+ for req in sorted(required_additional_packages):
162
+ f.write(f"{req}\n")
163
+
164
+ # 3. Create or update flash-attn requirements
165
+ with open(flash_req_path, 'w') as f:
166
+ f.write("-r requirements-base.txt\n")
167
+ f.write("flash-attn==2.5.2\n")
168
+
169
+ logger.info("Updated requirements files for sequential installation:")
170
+ logger.info(f"1. Base requirements in {base_req_path}")
171
+ logger.info(f"2. Main requirements in {main_req_path}")
172
+ logger.info(f"3. Flash-attention requirements in {flash_req_path}")
173
+ logger.info("This ensures packages are installed in the correct order")
174
+
175
+ def create_space(username, space_name):
176
+ """Create or get a Hugging Face Space."""
177
+ try:
178
+ api = HfApi()
179
+ space_id = f"{username}/{space_name}"
180
+ logger.info(f"Checking Space {space_id}...")
181
+
182
+ # First try to get the space
183
+ try:
184
+ space_info = api.space_info(repo_id=space_id)
185
+ logger.info(f"Space {space_id} already exists")
186
+ return space_info
187
+ except Exception as e:
188
+ logger.info(f"Space {space_id} does not exist, creating new space...")
189
+
190
+ # Create new space
191
+ try:
192
+ api.create_repo(
193
+ repo_id=space_id,
194
+ private=False,
195
+ repo_type="space",
196
+ space_sdk="gradio"
197
+ )
198
+ logger.info(f"Created new space: {space_id}")
199
+ return api.space_info(repo_id=space_id)
200
+ except Exception as e:
201
+ logger.error(f"Failed to create space: {str(e)}")
202
+ raise
203
+ except Exception as e:
204
+ raise RuntimeError(f"Error with Space {space_id}: {str(e)}")
205
+
206
+ def main():
207
+ parser = argparse.ArgumentParser(description='Update Hugging Face Space for Phi-4 training')
208
+ parser.add_argument('--space_name', type=str, help='Space name (default: from env)')
209
+ parser.add_argument('--force', action='store_true', help='Skip confirmation')
210
+ args = parser.parse_args()
211
+
212
+ if not args.force:
213
+ print("\n" + "!"*80)
214
+ print("WARNING: Updating the Space will INTERRUPT any ongoing training!")
215
+ print("Make sure all checkpoints are saved before proceeding.")
216
+ print("!"*80 + "\n")
217
+
218
+ confirm = input("Type 'update' to confirm: ")
219
+ if confirm.lower() != 'update':
220
+ logger.info("Update cancelled")
221
+ return False
222
+
223
+ try:
224
+ # Load environment variables
225
+ env_vars = load_env_variables()
226
+ logger.info(f"Environment variables loaded: USERNAME={env_vars['HF_USERNAME']}, SPACE_NAME={env_vars['HF_SPACE_NAME']}")
227
+
228
+ # Verify configurations
229
+ verify_configs()
230
+ logger.info("All configuration files verified successfully")
231
+
232
+ # Update requirements
233
+ update_requirements()
234
+ logger.info("Requirements updated successfully")
235
+
236
+ # Get space name from args or env, prioritize args
237
+ space_name = args.space_name if args.space_name else env_vars["HF_SPACE_NAME"]
238
+ logger.info(f"Using space name: {space_name}")
239
+
240
+ # Login to Hugging Face
241
+ logger.info("Logging in to Hugging Face...")
242
+ login(token=env_vars["HF_TOKEN"])
243
+ logger.info("Successfully logged in to Hugging Face")
244
+
245
+ # Create/get space
246
+ space_info = create_space(env_vars["HF_USERNAME"], space_name)
247
+ logger.info(f"Space info: {space_info}")
248
+
249
+ # Upload files
250
+ current_dir = Path(__file__).parent
251
+ logger.info(f"Uploading files from {current_dir} to Space {env_vars['HF_USERNAME']}/{space_name}...")
252
+
253
+ # Create .gitignore
254
+ with open(current_dir / ".gitignore", "w") as f:
255
+ f.write(".env\n*.pyc\n__pycache__\n")
256
+ logger.info("Created .gitignore file")
257
+
258
+ api = HfApi()
259
+ api.upload_folder(
260
+ folder_path=str(current_dir),
261
+ repo_id=f"{env_vars['HF_USERNAME']}/{space_name}",
262
+ repo_type="space",
263
+ ignore_patterns=[".env", "*.pyc", "__pycache__", "TRAINING_IN_PROGRESS.lock"]
264
+ )
265
+
266
+ logger.info(f"Files uploaded successfully")
267
+ space_url = f"https://huggingface.co/spaces/{env_vars['HF_USERNAME']}/{space_name}"
268
+ logger.info(f"Space URL: {space_url}")
269
+ print(f"\nSpace created successfully! You can view it at:\n{space_url}")
270
+ return True
271
+
272
+ except Exception as e:
273
+ logger.error(f"Error updating Space: {str(e)}")
274
+ return False
275
+
276
+ if __name__ == "__main__":
277
+ success = main()
278
+ sys.exit(0 if success else 1)