# -*- coding: utf-8 -*- import os import gc import gradio as gr from datasets import load_dataset from train_tokenizer import train_tokenizer from tokenizers import Tokenizer from langdetect import detect, DetectorFactory from PIL import Image from datetime import datetime from concurrent.futures import ThreadPoolExecutor import matplotlib.pyplot as plt from io import BytesIO import traceback # Για επαναληψιμότητα στο langdetect DetectorFactory.seed = 0 # Ρυθμίσεις CHECKPOINT_FILE = "checkpoint.txt" TOKENIZER_DIR = "./tokenizer_model" TOKENIZER_FILE = os.path.join(TOKENIZER_DIR, "tokenizer.json") MAX_SAMPLES = 5000000 DEFAULT_CHUNK_SIZE = 200000 BATCH_SIZE = 1000 NUM_WORKERS = 4 # Παγκόσμια μεταβλητή ελέγχου STOP_COLLECTION = False def load_checkpoint(): """Φόρτωση δεδομένων από το checkpoint.""" if os.path.exists(CHECKPOINT_FILE): with open(CHECKPOINT_FILE, "r", encoding="utf-8") as f: return f.read().splitlines() return [] def append_to_checkpoint(texts): """Αποθήκευση δεδομένων με ομαδοποίηση.""" with open(CHECKPOINT_FILE, "a", encoding="utf-8") as f: batch = "\n".join(texts) + "\n" f.write(batch) def create_iterator(dataset_name, configs, split): """Βελτιωμένο iterator με batch φόρτωση και caching.""" configs_list = [c.strip() for c in configs.split(",") if c.strip()] for config in configs_list: try: dataset = load_dataset( dataset_name, name=config, split=split, streaming=True, cache_dir="./dataset_cache" ) while True: batch = list(dataset.take(BATCH_SIZE)) if not batch: break dataset = dataset.skip(BATCH_SIZE) with ThreadPoolExecutor(max_workers=NUM_WORKERS) as executor: processed_texts = list(executor.map(process_example, batch)) yield from filter(None, processed_texts) except Exception as e: print(f"⚠️ Σφάλμα φόρτωσης {config}: {e}") def process_example(example): """Επεξεργασία ενός παραδείγματος με έλεγχο γλώσσας.""" try: text = example.get('text', '').strip() if text and detect(text) in ['el', 'en']: return text return None except: return None def collect_samples(dataset_name, configs, split, chunk_size, max_samples): """Συλλογή δεδομένων με streaming και checkpoints.""" global STOP_COLLECTION STOP_COLLECTION = False total_processed = len(load_checkpoint()) progress_messages = [f"🚀 Εκκίνηση συλλογής... Πρόοδος: {total_processed}/{max_samples}"] dataset_iterator = create_iterator(dataset_name, configs, split) chunk = [] while not STOP_COLLECTION and total_processed < max_samples: try: while len(chunk) < chunk_size: text = next(dataset_iterator) if text: chunk.append(text) total_processed += 1 if total_processed >= max_samples: break if chunk: append_to_checkpoint(chunk) progress_messages.append(f"✅ Αποθηκεύτηκαν {len(chunk)} δείγματα (Σύνολο: {total_processed})") chunk = [] gc.collect() except StopIteration: progress_messages.append("🏁 Ολοκληρώθηκε η επεξεργασία όλων των δεδομένων!") break except Exception as e: progress_messages.append(f"⛔ Σφάλμα: {str(e)}") break return "\n".join(progress_messages) def train_tokenizer_fn(dataset_name, configs, split, vocab_size, min_freq, test_text): """Εκπαίδευση του tokenizer και έλεγχος ποιότητας.""" messages = ["🚀 Εκκίνηση εκπαίδευσης..."] try: all_texts = load_checkpoint() messages.append("📚 Φόρτωση δεδομένων από checkpoint...") tokenizer = train_tokenizer(all_texts, vocab_size, min_freq, TOKENIZER_DIR, NUM_WORKERS) messages.append("✅ Εκπαίδευση ολοκληρώθηκε!") trained_tokenizer = Tokenizer.from_file(TOKENIZER_FILE) encoded = trained_tokenizer.encode(test_text) decoded = trained_tokenizer.decode(encoded.ids) fig, ax = plt.subplots() ax.hist([len(t) for t in encoded.tokens], bins=20) ax.set_xlabel('Μήκος Token') ax.set_ylabel('Συχνότητα') img_buffer = BytesIO() plt.savefig(img_buffer, format='png') plt.close() return ("\n".join(messages), decoded, Image.open(img_buffer)) except Exception as e: messages.append(f"❌ Σφάλμα: {str(e)}") return ("\n".join(messages), "", None) def analyze_checkpoint(): """Ανάλυση δεδομένων από το checkpoint.""" messages = ["🔍 Έναρξη ανάλυσης..."] try: texts = load_checkpoint() if not texts: return "Δεν βρέθηκαν δεδομένα για ανάλυση." total_chars = sum(len(t) for t in texts) avg_length = total_chars / len(texts) if texts else 0 languages = {} for t in texts[:1000]: if len(t) > 20: try: lang = detect(t) languages[lang] = languages.get(lang, 0) + 1 except Exception as e: print(f"⚠️ Σφάλμα ανίχνευσης γλώσσας: {e}") report = [ f"📊 Σύνολο δειγμάτων: {len(texts)}", f"📝 Μέσο μήκος: {avg_length:.1f} χαρακτήρες", "🌍 Γλώσσες (δείγμα 1000):", *[f"- {k}: {v} ({v/10:.1f}%)" for k, v in languages.items()] ] return "\n".join(messages + report) except Exception as e: messages.append(f"❌ Σφάλμα: {str(e)}") return "\n".join(messages) def restart_collection(): """Διαγραφή checkpoint και επανεκκίνηση.""" global STOP_COLLECTION STOP_COLLECTION = False if os.path.exists(CHECKPOINT_FILE): os.remove(CHECKPOINT_FILE) return "🔄 Το checkpoint διαγράφηκε. Έτοιμο για νέα συλλογή." # Gradio Interface with gr.Blocks() as demo: gr.Markdown("## Custom Tokenizer Trainer για GPT-2") with gr.Row(): with gr.Column(scale=2): dataset_name = gr.Textbox(value="wikimedia/wikipedia", label="Dataset") configs = gr.Textbox(value="20231101.el,20231101.en", label="Configurations") split = gr.Dropdown(["train"], value="train", label="Split") chunk_size = gr.Slider(10000, 500000, value=200000, step=10000, label="Chunk Size") vocab_size = gr.Slider(20000, 50000, value=30000, step=1000, label="Μέγεθος Λεξιλογίου") min_freq = gr.Slider(1, 10, value=3, label="Ελάχιστη Συχνότητα") test_text = gr.Textbox(value="Η Ακρόπολη είναι σύμβολο της αρχαίας Ελλάδας.", label="Test Text") max_samples = gr.Slider(10000, 10000000, value=5000000, step=100000, label="Μέγιστα Δείγματα") with gr.Row(): start_btn = gr.Button("Start", variant="primary") stop_btn = gr.Button("Stop", variant="stop") restart_btn = gr.Button("Restart") analyze_btn = gr.Button("Analyze Data") train_btn = gr.Button("Train Tokenizer", variant="primary") with gr.Column(scale=3): progress = gr.Textbox(label="Πρόοδος", lines=10, interactive=False) gr.Markdown("### Αποτελέσματα") decoded_text = gr.Textbox(label="Αποκωδικοποιημένο Κείμενο") token_distribution = gr.Image(label="Κατανομή Tokens") # Event handlers start_btn.click(collect_samples, [dataset_name, configs, split, chunk_size, max_samples], progress) stop_btn.click(lambda: globals().update(STOP_COLLECTION=True) or "⏹️ Διακοπή συλλογής...", None, progress, queue=False) restart_btn.click(restart_collection, None, progress) analyze_btn.click(analyze_checkpoint, None, progress) train_btn.click(train_tokenizer_fn, [dataset_name, configs, split, vocab_size, min_freq, test_text], [progress, decoded_text, token_distribution]) demo.queue().launch()