eBlessings commited on
Commit
a6899c7
·
verified ·
1 Parent(s): 8fe113f

Upload chkpt2.py

Browse files
Files changed (1) hide show
  1. chkpt2.py +625 -0
chkpt2.py ADDED
@@ -0,0 +1,625 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ app.py – Quranic Data Training Pipeline Endpoint for ZeroGPU Spaces
4
+ --------------------------------------------------------------------
5
+ This script integrates a full Quranic data processing and training pipeline
6
+ into a Gradio interface endpoint. It is optimized for CPU-based training on
7
+ Hugging Face ZeroGPU (using the Gradio SDK) and uses chunked incremental
8
+ training, memory management, and gradient checkpointing to efficiently update
9
+ Google's Gemma-2-2b model with Quranic data.
10
+
11
+ After every training chunk, the checkpoint folder is compressed into a ZIP file
12
+ and a download link is generated. Progress messages are streamed as the pipeline
13
+ runs.
14
+
15
+ Requirements:
16
+ - Transformers (>=4.42.0)
17
+ - Gradio (>=5.12.0)
18
+ - PyTorch (==2.2.2)
19
+ - psutil (==5.9.5)
20
+ - Accelerate (>=0.26.0)
21
+ - Hugging Face PRO subscription with ZeroGPU enabled (ensure your HF token is set as an environment variable HF_TOKEN)
22
+ - Ubuntu CPU/Linux with access to ZeroGPU hardware via Spaces
23
+ - Input data files placed in the project root.
24
+ - Sufficient storage in "working_directory"
25
+
26
+ Author: [M-Saddam Hussain]
27
+ Date: March 2025
28
+ Data References: [Tanzil.net, IslamSource, QuranicCorpus]
29
+ """
30
+
31
+ import json
32
+ import logging
33
+ import os
34
+ import traceback
35
+ import gc
36
+ import time
37
+ import psutil
38
+ import math
39
+ import zipfile
40
+ from datetime import datetime
41
+ from typing import Dict, List, Optional
42
+ from dataclasses import dataclass, asdict
43
+
44
+ import torch
45
+ # Limit PyTorch threads for CPU stability.
46
+ torch.set_num_threads(8)
47
+
48
+ from torch.utils.data import Dataset
49
+ from transformers import (
50
+ AutoTokenizer,
51
+ AutoModelForCausalLM,
52
+ TrainingArguments,
53
+ Trainer,
54
+ DataCollatorForLanguageModeling,
55
+ __version__ as transformers_version
56
+ )
57
+ from threading import Lock
58
+
59
+ import gradio as gr
60
+ import spaces
61
+
62
+ # Check for minimum required Transformers version for custom model support
63
+ MIN_TRANSFORMERS_VERSION = "4.42.0"
64
+ if tuple(map(int, transformers_version.split("."))) < tuple(map(int, MIN_TRANSFORMERS_VERSION.split("."))):
65
+ logging.warning(f"Transformers version {transformers_version} detected. Please upgrade to at least {MIN_TRANSFORMERS_VERSION} for proper support of the 'gemma2' architecture.")
66
+
67
+ # Configure logging
68
+ logging.basicConfig(
69
+ level=logging.INFO,
70
+ format='%(asctime)s - %(levelname)s - %(message)s',
71
+ handlers=[
72
+ logging.FileHandler('pipeline.log'),
73
+ logging.StreamHandler()
74
+ ]
75
+ )
76
+ logger = logging.getLogger(__name__)
77
+
78
+ def manage_memory(threshold_percent: int = 90, min_available_mb: int = 500, sleep_duration: int = 10):
79
+ """
80
+ Check memory usage; if usage is high or available memory is low,
81
+ force garbage collection and sleep briefly.
82
+ """
83
+ vm = psutil.virtual_memory()
84
+ used_percent = vm.percent
85
+ available_mb = vm.available / (1024 * 1024)
86
+ logger.info(f"Memory usage: {used_percent}% used, {available_mb:.2f} MB available")
87
+ if used_percent > threshold_percent or available_mb < min_available_mb:
88
+ logger.warning("High memory usage detected, forcing garbage collection and sleeping...")
89
+ gc.collect()
90
+ time.sleep(sleep_duration)
91
+
92
+ @dataclass
93
+ class WordAnalysis:
94
+ """Structured representation of word-level analysis"""
95
+ arabic: str
96
+ translation: str
97
+ position: str
98
+ morphology: Dict
99
+ features: List[str]
100
+ root: str
101
+ location: str
102
+ metadata: Dict
103
+
104
+ @dataclass
105
+ class VerseData:
106
+ """Structured representation of verse-level data"""
107
+ chapter: int
108
+ verse: int
109
+ arabic_text: str
110
+ translation: str
111
+ words: List[WordAnalysis]
112
+ metadata: Dict
113
+
114
+ class QuranicDataset(Dataset):
115
+ """Custom dataset for Quranic text training."""
116
+ def __init__(self, processed_data: List[Dict], tokenizer):
117
+ self.examples = []
118
+ self.tokenizer = tokenizer
119
+ for verse_data in processed_data:
120
+ self.examples.extend(self._create_training_examples(verse_data))
121
+
122
+ def _create_training_examples(self, verse_data: Dict) -> List[Dict]:
123
+ examples = []
124
+ text_block = (
125
+ f"[VERSE {verse_data['chapter']}:{verse_data['verse']}]\n"
126
+ f"Arabic: {verse_data['arabic_text']}\n"
127
+ f"Translation: {verse_data['translation']}\n"
128
+ "Morphological Analysis:\n"
129
+ )
130
+ for word in verse_data['words']:
131
+ text_block += (
132
+ f"[WORD] {word['arabic']}\n"
133
+ f"Root: {word['root']}\n"
134
+ f"Features: {', '.join(word['features'])}\n"
135
+ )
136
+ examples.append(self._format_example(text_block))
137
+ return examples
138
+
139
+ def _format_example(self, text: str) -> Dict:
140
+ encodings = self.tokenizer(
141
+ text,
142
+ truncation=True,
143
+ max_length=64,
144
+ padding="max_length",
145
+ return_tensors="pt"
146
+ )
147
+ return {
148
+ "input_ids": encodings["input_ids"][0],
149
+ "attention_mask": encodings["attention_mask"][0]
150
+ }
151
+
152
+ def __len__(self):
153
+ return len(self.examples)
154
+
155
+ def __getitem__(self, idx):
156
+ return self.examples[idx]
157
+
158
+ class QuranicDataProcessor:
159
+ """Processes Quranic data into structured training examples."""
160
+ def __init__(self, source_dir: str, output_dir: str):
161
+ self.source_dir = source_dir
162
+ self.output_dir = output_dir
163
+ self.morphological_data: Dict[str, Dict] = {}
164
+ self.word_by_word_data: Dict[str, List[str]] = {}
165
+ self.translation_data: Dict[str, str] = {}
166
+ self.processing_lock = Lock()
167
+ os.makedirs(output_dir, exist_ok=True)
168
+ os.makedirs(os.path.join(output_dir, 'json'), exist_ok=True)
169
+ os.makedirs(os.path.join(output_dir, 'txt'), exist_ok=True)
170
+ os.makedirs(os.path.join(output_dir, 'checkpoints'), exist_ok=True)
171
+ logger.info(f"Initialized processor with source dir: {source_dir}")
172
+
173
+ def load_source_files(self) -> bool:
174
+ """Loads morphological, translation, and word-by-word data from project root."""
175
+ try:
176
+ logger.info("Loading morphological data...")
177
+ morph_path = os.path.join(self.source_dir, 'quranic-corpus-morphology-0.4.txt')
178
+ with open(morph_path, 'r', encoding='utf-8') as f:
179
+ next(f)
180
+ for line in f:
181
+ if line.strip() and not line.startswith('#'):
182
+ parts = line.strip().split('\t')
183
+ if len(parts) >= 4:
184
+ location = parts[0].strip('()')
185
+ self.morphological_data[location] = {
186
+ 'form': parts[1],
187
+ 'tag': parts[2],
188
+ 'features': parts[3]
189
+ }
190
+ logger.info(f"Loaded {len(self.morphological_data)} morphological entries")
191
+ logger.info("Loading translation data...")
192
+ trans_path = os.path.join(self.source_dir, 'en.sample.quran-maududi.txt')
193
+ with open(trans_path, 'r', encoding='utf-8') as f:
194
+ next(f)
195
+ for line in f:
196
+ if line.strip():
197
+ parts = line.strip().split('|')
198
+ if len(parts) >= 3:
199
+ key = f"{parts[0]}:{parts[1]}"
200
+ self.translation_data[key] = parts[2].strip()
201
+ logger.info(f"Loaded {len(self.translation_data)} verse translations")
202
+ logger.info("Loading word-by-word data...")
203
+ word_path = os.path.join(self.source_dir, 'en.w4w.qurandev.txt')
204
+ with open(word_path, 'r', encoding='utf-8-sig') as f:
205
+ lines = [line.strip() for line in f if line.strip()]
206
+ sorted_keys = sorted(self.translation_data.keys(), key=lambda x: (int(x.split(':')[0]), int(x.split(':')[1])))
207
+ if len(lines) != len(sorted_keys):
208
+ logger.warning("Mismatch between word-by-word file and translation data")
209
+ for i, verse_key in enumerate(sorted_keys):
210
+ if i < len(lines):
211
+ words = [w.strip() for w in lines[i].split('|') if w.strip()]
212
+ self.word_by_word_data[verse_key] = words
213
+ logger.info(f"Loaded word-by-word data for {len(self.word_by_word_data)} verses")
214
+ return True
215
+ except Exception as e:
216
+ logger.error(f"Error loading source files: {str(e)}")
217
+ logger.error(traceback.format_exc())
218
+ return False
219
+
220
+ def process_verse(self, chapter: int, verse: int) -> Optional[VerseData]:
221
+ """Processes a single verse into structured format."""
222
+ try:
223
+ verse_ref = f"{chapter}:{verse}"
224
+ logger.info(f"Processing verse {verse_ref}")
225
+ translation = self.translation_data.get(verse_ref)
226
+ if not translation:
227
+ logger.warning(f"No translation for verse {verse_ref}")
228
+ return None
229
+ verse_word_list = self.word_by_word_data.get(verse_ref, [])
230
+ if not verse_word_list:
231
+ logger.warning(f"No word-by-word data for verse {verse_ref}")
232
+ return None
233
+ verse_words: List[WordAnalysis] = []
234
+ arabic_text = ""
235
+ for pos in range(1, len(verse_word_list) + 1):
236
+ pattern = f"{chapter}:{verse}:{pos}:"
237
+ matching_entries = [data for loc, data in self.morphological_data.items() if loc.startswith(pattern)]
238
+ if not matching_entries:
239
+ logger.debug(f"No morphological data for {pattern}")
240
+ continue
241
+ combined_form = " ".join(entry['form'] for entry in matching_entries)
242
+ combined_features = []
243
+ root = ""
244
+ for entry in matching_entries:
245
+ features = entry['features'].split('|')
246
+ combined_features.extend(features)
247
+ if not root:
248
+ for f in features:
249
+ if 'ROOT:' in f:
250
+ root = f.split('ROOT:')[1]
251
+ break
252
+ word_translation = verse_word_list[pos - 1]
253
+ word = WordAnalysis(
254
+ arabic=combined_form,
255
+ translation=word_translation,
256
+ position=str(pos),
257
+ morphology=matching_entries[0],
258
+ features=combined_features,
259
+ root=root,
260
+ location=f"{chapter}:{verse}:{pos}",
261
+ metadata={}
262
+ )
263
+ verse_words.append(word)
264
+ arabic_text += f" {combined_form}"
265
+ verse_data = VerseData(
266
+ chapter=chapter,
267
+ verse=verse,
268
+ arabic_text=arabic_text.strip(),
269
+ translation=translation,
270
+ words=verse_words,
271
+ metadata={
272
+ "processed_timestamp": datetime.now().isoformat(),
273
+ "word_count": len(verse_words)
274
+ }
275
+ )
276
+ self._save_verse_data(verse_data)
277
+ return verse_data
278
+ except Exception as e:
279
+ logger.error(f"Error processing verse {chapter}:{verse}: {str(e)}")
280
+ logger.error(traceback.format_exc())
281
+ return None
282
+
283
+ def _save_verse_data(self, verse_data: VerseData):
284
+ """Saves processed verse data as JSON and TXT."""
285
+ try:
286
+ verse_ref = f"{verse_data.chapter}:{verse_data.verse}"
287
+ json_path = os.path.join(self.output_dir, 'json', f'verse_{verse_ref.replace(":", "_")}.json')
288
+ with open(json_path, 'w', encoding='utf-8') as f:
289
+ json.dump(asdict(verse_data), f, ensure_ascii=False, indent=2)
290
+ txt_path = os.path.join(self.output_dir, 'txt', f'verse_{verse_ref.replace(":", "_")}.txt')
291
+ with open(txt_path, 'w', encoding='utf-8') as f:
292
+ f.write(f"=== Verse {verse_ref} ===\n\n")
293
+ f.write(f"Arabic Text:\n{verse_data.arabic_text}\n\n")
294
+ f.write(f"Translation:\n{verse_data.translation}\n\n")
295
+ f.write("Word Analysis:\n")
296
+ for i, word in enumerate(verse_data.words, 1):
297
+ f.write(f"\nWord {i}:\n")
298
+ f.write(f" Arabic: {word.arabic}\n")
299
+ f.write(f" Translation: {word.translation}\n")
300
+ f.write(f" Root: {word.root}\n")
301
+ f.write(" Features:\n")
302
+ for feature in word.features:
303
+ f.write(f" - {feature}\n")
304
+ f.write("\n")
305
+ logger.info(f"Saved verse data to {json_path} and {txt_path}")
306
+ except Exception as e:
307
+ logger.error(f"Error saving verse data: {str(e)}")
308
+ logger.error(traceback.format_exc())
309
+
310
+ def zip_directory(folder_path: str, zip_path: str):
311
+ """Compress the given folder into a zip file."""
312
+ with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zipf:
313
+ for root_dir, _, files in os.walk(folder_path):
314
+ for file in files:
315
+ file_path = os.path.join(root_dir, file)
316
+ arcname = os.path.relpath(file_path, folder_path)
317
+ zipf.write(file_path, arcname)
318
+
319
+ class QuranicModelTrainer:
320
+ """Trains the Gemma-2-2b model on Quranic data using chunked incremental updates."""
321
+ def __init__(self,
322
+ model_name: str = "google/gemma-2-2b",
323
+ processed_data_dir: str = "processed_data",
324
+ checkpoint_dir: str = "checkpoints"):
325
+ self.processed_data_dir = processed_data_dir
326
+ self.checkpoint_dir = checkpoint_dir
327
+ self.device = "cpu" # Training on CPU; ZeroGPU will handle GPU access.
328
+ logger.info("Loading tokenizer and model...")
329
+
330
+ # Load tokenizer with additional special tokens and HF token from environment
331
+ self.tokenizer = AutoTokenizer.from_pretrained(
332
+ model_name,
333
+ token=os.environ.get("HF_TOKEN"),
334
+ additional_special_tokens=["[VERSE]", "[WORD]", "[ROOT]", "[FEATURES]"],
335
+ trust_remote_code=True
336
+ )
337
+ if self.tokenizer.pad_token is None:
338
+ self.tokenizer.add_special_tokens({"pad_token": "[PAD]"})
339
+
340
+ # Load model using eager attention for Gemma2 and low_cpu_mem_usage.
341
+ try:
342
+ self.model = AutoModelForCausalLM.from_pretrained(
343
+ model_name,
344
+ token=os.environ.get("HF_TOKEN"),
345
+ torch_dtype=torch.float32,
346
+ low_cpu_mem_usage=True,
347
+ trust_remote_code=True,
348
+ attn_implementation="eager"
349
+ )
350
+ except Exception as e:
351
+ logger.error(f"Error loading model directly: {str(e)}")
352
+ logger.info("Attempting to load with fallback parameters...")
353
+ from transformers import AutoConfig
354
+ config = AutoConfig.from_pretrained(
355
+ model_name,
356
+ token=os.environ.get("HF_TOKEN"),
357
+ trust_remote_code=True
358
+ )
359
+ self.model = AutoModelForCausalLM.from_pretrained(
360
+ model_name,
361
+ token=os.environ.get("HF_TOKEN"),
362
+ config=config,
363
+ torch_dtype=torch.float32,
364
+ low_cpu_mem_usage=True,
365
+ trust_remote_code=True,
366
+ revision="main",
367
+ attn_implementation="eager"
368
+ )
369
+
370
+ # Resize token embeddings to match tokenizer vocabulary size
371
+ self.model.resize_token_embeddings(len(self.tokenizer))
372
+ self.model.train()
373
+ self.model.config.use_cache = False
374
+
375
+ if hasattr(self.model, "gradient_checkpointing_enable"):
376
+ self.model.gradient_checkpointing_enable()
377
+ else:
378
+ logger.warning("Gradient checkpointing not available for this model")
379
+
380
+ def prepare_training_data(self, chapter_data: List[Dict]) -> Dataset:
381
+ """Creates a QuranicDataset from processed chapter data."""
382
+ return QuranicDataset(chapter_data, self.tokenizer)
383
+
384
+ def train_chapter(self,
385
+ chapter_num: int,
386
+ processed_verses: List[Dict],
387
+ chunk_size: int = 5, # Reduced chunk size to help with memory
388
+ num_train_epochs: int = 5, # Fewer epochs for testing
389
+ per_device_train_batch_size: int = 1,
390
+ learning_rate: float = 3e-5,
391
+ weight_decay: float = 0.01,
392
+ gradient_accumulation_steps: int = 32):
393
+ """
394
+ Splits chapter data into chunks and trains incrementally.
395
+ Yields progress messages and download links for each checkpoint.
396
+ """
397
+ total_examples = len(processed_verses)
398
+ total_chunks = math.ceil(total_examples / chunk_size)
399
+ yield f"Chapter {chapter_num}: {total_examples} examples, {total_chunks} chunks."
400
+ for chunk_index in range(total_chunks):
401
+ chunk_data = processed_verses[chunk_index * chunk_size: (chunk_index + 1) * chunk_size]
402
+ dataset = self.prepare_training_data(chunk_data)
403
+ chunk_output_dir = os.path.join(self.checkpoint_dir, f"chapter_{chapter_num}", f"chunk_{chunk_index}")
404
+ os.makedirs(chunk_output_dir, exist_ok=True)
405
+ training_args = TrainingArguments(
406
+ output_dir=chunk_output_dir,
407
+ overwrite_output_dir=True,
408
+ num_train_epochs=num_train_epochs,
409
+ per_device_train_batch_size=per_device_train_batch_size,
410
+ learning_rate=learning_rate,
411
+ weight_decay=weight_decay,
412
+ gradient_accumulation_steps=gradient_accumulation_steps,
413
+ fp16=False,
414
+ remove_unused_columns=False,
415
+ logging_steps=50,
416
+ report_to="none",
417
+ evaluation_strategy="no",
418
+ use_cpu=True, # Using CPU flag (replaces no_cuda)
419
+ dataloader_num_workers=0,
420
+ dataloader_pin_memory=False
421
+ )
422
+ data_collator = DataCollatorForLanguageModeling(
423
+ tokenizer=self.tokenizer,
424
+ mlm=False
425
+ )
426
+ trainer = Trainer(
427
+ model=self.model,
428
+ args=training_args,
429
+ train_dataset=dataset,
430
+ tokenizer=self.tokenizer,
431
+ data_collator=data_collator
432
+ )
433
+ yield f"Training chunk {chunk_index+1}/{total_chunks} for Chapter {chapter_num}..."
434
+ trainer.train()
435
+ trainer.save_model(chunk_output_dir)
436
+ # Clean up trainer and dataset
437
+ del trainer, dataset
438
+ gc.collect()
439
+ manage_memory()
440
+ # Compress the checkpoint folder into a zip file
441
+ zip_path = f"{chunk_output_dir}.zip"
442
+ zip_directory(chunk_output_dir, zip_path)
443
+ # Generate a download link (assuming working_directory is served)
444
+ download_link = f"<a href='{zip_path}' target='_blank'>Download Chapter {chapter_num} Chunk {chunk_index+1} Checkpoint</a>"
445
+ yield f"Checkpoint saved for Chapter {chapter_num} Chunk {chunk_index+1}: {download_link}"
446
+ yield f"Completed training for Chapter {chapter_num}."
447
+
448
+ class QuranicPipeline:
449
+ """Integrates data processing and incremental model training for all chapters."""
450
+ def __init__(self,
451
+ source_dir: str = ".",
452
+ working_dir: str = "working_directory",
453
+ start_chapter: int = 1,
454
+ end_chapter: int = 114):
455
+ self.source_dir = source_dir
456
+ self.working_dir = working_dir
457
+ self.start_chapter = start_chapter
458
+ self.end_chapter = end_chapter
459
+ self.setup_directories()
460
+ global logger
461
+ logger = logging.getLogger(__name__)
462
+ self.state = {
463
+ "last_processed_chapter": 0,
464
+ "last_trained_chapter": 0,
465
+ "current_state": "initialized",
466
+ "errors": [],
467
+ "start_time": datetime.now().isoformat()
468
+ }
469
+ self.load_state()
470
+ try:
471
+ logger.info("Initializing Quranic Data Processor...")
472
+ self.processor = QuranicDataProcessor(
473
+ source_dir=self.source_dir,
474
+ output_dir=os.path.join(self.working_dir, "processed_data")
475
+ )
476
+ logger.info("Initializing Quranic Model Trainer...")
477
+ self.trainer = QuranicModelTrainer(
478
+ model_name="google/gemma-2-2b",
479
+ processed_data_dir=os.path.join(self.working_dir, "processed_data"),
480
+ checkpoint_dir=os.path.join(self.working_dir, "checkpoints")
481
+ )
482
+ self.state["current_state"] = "ready"
483
+ self.save_state()
484
+ except Exception as e:
485
+ self.handle_error("Initialization failed", e)
486
+ raise
487
+
488
+ def setup_directories(self):
489
+ dirs = [
490
+ self.working_dir,
491
+ os.path.join(self.working_dir, "processed_data"),
492
+ os.path.join(self.working_dir, "checkpoints"),
493
+ os.path.join(self.working_dir, "logs"),
494
+ os.path.join(self.working_dir, "state")
495
+ ]
496
+ for d in dirs:
497
+ os.makedirs(d, exist_ok=True)
498
+
499
+ def load_state(self):
500
+ state_file = os.path.join(self.working_dir, "state", "pipeline_state.json")
501
+ if os.path.exists(state_file):
502
+ try:
503
+ with open(state_file, 'r') as f:
504
+ saved_state = json.load(f)
505
+ self.state.update(saved_state)
506
+ logger.info(f"Loaded previous state: Last processed chapter {self.state.get('last_processed_chapter')}, "
507
+ f"last trained chapter {self.state.get('last_trained_chapter')}")
508
+ except Exception as e:
509
+ logger.warning(f"Could not load previous state: {str(e)}")
510
+
511
+ def save_state(self):
512
+ state_file = os.path.join(self.working_dir, "state", "pipeline_state.json")
513
+ with open(state_file, 'w') as f:
514
+ json.dump(self.state, f, indent=2)
515
+
516
+ def handle_error(self, context: str, error: Exception):
517
+ error_detail = {
518
+ "timestamp": datetime.now().isoformat(),
519
+ "context": context,
520
+ "error": str(error),
521
+ "traceback": traceback.format_exc()
522
+ }
523
+ self.state.setdefault("errors", []).append(error_detail)
524
+ logger.error(f"{context}: {str(error)}")
525
+ self.save_state()
526
+
527
+ def run_pipeline(self):
528
+ """
529
+ Runs processing and training for chapters sequentially.
530
+ Yields progress messages and download links as checkpoints become available.
531
+ Finally, saves the final model and yields its download link.
532
+ """
533
+ yield "Starting pipeline execution..."
534
+ if not self.processor.load_source_files():
535
+ yield "Failed to load source files. Stopping pipeline."
536
+ return
537
+ for chapter in range(self.start_chapter, self.end_chapter + 1):
538
+ yield f"=== Processing Chapter {chapter} ==="
539
+ processed_chapter_data = []
540
+ verse = 1
541
+ while True:
542
+ verse_data = self.processor.process_verse(chapter, verse)
543
+ if verse_data is None:
544
+ break
545
+ processed_chapter_data.append(asdict(verse_data))
546
+ verse += 1
547
+ if processed_chapter_data:
548
+ # Iterate over checkpoint messages from training this chapter
549
+ for msg in self.trainer.train_chapter(chapter, processed_chapter_data):
550
+ yield msg
551
+ self.state["last_trained_chapter"] = chapter
552
+ self.save_state()
553
+ else:
554
+ yield f"No processed data for Chapter {chapter}."
555
+ self.state["last_processed_chapter"] = chapter
556
+ self.save_state()
557
+ manage_memory()
558
+ yield "Pipeline execution completed."
559
+ # Save the final model and tokenizer
560
+ final_model_dir = os.path.join(self.working_dir, "final_model")
561
+ os.makedirs(final_model_dir, exist_ok=True)
562
+ self.trainer.model.save_pretrained(final_model_dir)
563
+ self.trainer.tokenizer.save_pretrained(final_model_dir)
564
+ # Compress the final model folder into a zip
565
+ final_zip = f"{final_model_dir}.zip"
566
+ zip_directory(final_model_dir, final_zip)
567
+ final_link = f"<a href='{final_zip}' target='_blank'>Download Final Model</a>"
568
+ yield f"Final model saved. {final_link}"
569
+
570
+ @spaces.GPU() # Request ZeroGPU hardware for the Space
571
+ def start_pipeline():
572
+ try:
573
+ yield "Starting Quranic Training Pipeline with Gemma-2-2b"
574
+ yield f"PyTorch version: {torch.__version__}"
575
+ yield f"CUDA available: {torch.cuda.is_available()}"
576
+ if torch.cuda.is_available():
577
+ yield f"CUDA device count: {torch.cuda.device_count()}"
578
+ yield f"CUDA device name: {torch.cuda.get_device_name(0)}"
579
+ if not os.environ.get("HF_TOKEN"):
580
+ yield "WARNING: HF_TOKEN environment variable not set. Model loading may fail."
581
+ required_files = [
582
+ 'quranic-corpus-morphology-0.4.txt',
583
+ 'en.sample.quran-maududi.txt',
584
+ 'en.w4w.qurandev.txt'
585
+ ]
586
+ missing_files = [f for f in required_files if not os.path.exists(f)]
587
+ if missing_files:
588
+ yield f"Missing required data files: {', '.join(missing_files)}"
589
+ return
590
+ pipeline = QuranicPipeline(
591
+ source_dir=".",
592
+ working_dir="working_directory",
593
+ start_chapter=1,
594
+ end_chapter=114
595
+ )
596
+ for msg in pipeline.run_pipeline():
597
+ yield msg
598
+ except Exception as e:
599
+ error_msg = f"Pipeline execution failed: {str(e)}\n{traceback.format_exc()}"
600
+ logger.error(error_msg)
601
+ yield error_msg
602
+
603
+ iface = gr.Interface(
604
+ fn=start_pipeline,
605
+ inputs=[],
606
+ outputs=gr.HTML(label="Pipeline Status"),
607
+ title="Quranic Training Pipeline for Gemma-2-2b",
608
+ description="""This pipeline fine-tunes Google's Gemma-2-2b model on Quranic data.
609
+
610
+ Click 'Submit' to trigger the Quranic data processing and training pipeline on ZeroGPU.
611
+
612
+ During execution, download links for each checkpoint (and the final model) will be provided as they become available.
613
+
614
+ Requirements:
615
+ - Transformers (>=4.42.0)
616
+ - Gradio (>=5.12.0)
617
+ - PyTorch (==2.2.2)
618
+ - psutil (==5.9.5)
619
+ - Accelerate (>=0.26.0)
620
+
621
+ The pipeline processes all 114 chapters of the Quran sequentially, with memory management optimizations for ZeroGPU environments."""
622
+ )
623
+
624
+ if __name__ == "__main__":
625
+ iface.launch()