File size: 8,969 Bytes
7c5aa99 f94c5ea c259678 0430da2 f94c5ea 4a4435c 5d41434 1e3138f c259678 72577d1 f94c5ea 5d41434 c259678 72577d1 3d37920 72577d1 9dd78b5 c259678 9dd78b5 0430da2 5c35386 c259678 5c35386 c259678 5c35386 c259678 5c35386 a9ae246 c259678 a9ae246 c259678 72577d1 c259678 a9ae246 72577d1 c259678 72577d1 c259678 72577d1 9dd78b5 4410500 d6a5933 72577d1 f94c5ea c259678 72577d1 c259678 9dd78b5 c259678 9dd78b5 c259678 5c35386 d6a5933 72577d1 c259678 72577d1 5c35386 08046e4 72577d1 08046e4 4a4435c 3d37920 72577d1 5d41434 c259678 72577d1 5d41434 72577d1 c259678 72577d1 c259678 72577d1 c259678 cb19003 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 |
# -*- coding: utf-8 -*-
import os
import gc
import gradio as gr
from datasets import load_dataset
from train_tokenizer import train_tokenizer
from tokenizers import Tokenizer
from langdetect import detect, DetectorFactory
from PIL import Image
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor
import matplotlib.pyplot as plt
from io import BytesIO
import traceback
# Για επαναληψιμότητα στο langdetect
DetectorFactory.seed = 0
# Ρυθμίσεις
CHECKPOINT_FILE = "checkpoint.txt"
TOKENIZER_DIR = "./tokenizer_model"
TOKENIZER_FILE = os.path.join(TOKENIZER_DIR, "tokenizer.json")
MAX_SAMPLES = 5000000
DEFAULT_CHUNK_SIZE = 200000
BATCH_SIZE = 1000
NUM_WORKERS = 4
# Παγκόσμια μεταβλητή ελέγχου
STOP_COLLECTION = False
def load_checkpoint():
"""Φόρτωση δεδομένων από το checkpoint."""
if os.path.exists(CHECKPOINT_FILE):
with open(CHECKPOINT_FILE, "r", encoding="utf-8") as f:
return f.read().splitlines()
return []
def append_to_checkpoint(texts):
"""Αποθήκευση δεδομένων με ομαδοποίηση."""
with open(CHECKPOINT_FILE, "a", encoding="utf-8") as f:
batch = "\n".join(texts) + "\n"
f.write(batch)
def create_iterator(dataset_name, configs, split):
"""Βελτιωμένο iterator με batch φόρτωση και caching."""
configs_list = [c.strip() for c in configs.split(",") if c.strip()]
for config in configs_list:
try:
dataset = load_dataset(
dataset_name,
name=config,
split=split,
streaming=True,
cache_dir="./dataset_cache"
)
while True:
batch = list(dataset.take(BATCH_SIZE))
if not batch:
break
dataset = dataset.skip(BATCH_SIZE)
with ThreadPoolExecutor(max_workers=NUM_WORKERS) as executor:
processed_texts = list(executor.map(process_example, batch))
yield from filter(None, processed_texts)
except Exception as e:
print(f"⚠️ Σφάλμα φόρτωσης {config}: {e}")
def process_example(example):
"""Επεξεργασία ενός παραδείγματος με έλεγχο γλώσσας."""
try:
text = example.get('text', '').strip()
if text and detect(text) in ['el', 'en']:
return text
return None
except:
return None
def collect_samples(dataset_name, configs, split, chunk_size, max_samples):
"""Συλλογή δεδομένων με streaming και checkpoints."""
global STOP_COLLECTION
STOP_COLLECTION = False
total_processed = len(load_checkpoint())
progress_messages = [f"🚀 Εκκίνηση συλλογής... Πρόοδος: {total_processed}/{max_samples}"]
dataset_iterator = create_iterator(dataset_name, configs, split)
chunk = []
while not STOP_COLLECTION and total_processed < max_samples:
try:
while len(chunk) < chunk_size:
text = next(dataset_iterator)
if text:
chunk.append(text)
total_processed += 1
if total_processed >= max_samples:
break
if chunk:
append_to_checkpoint(chunk)
progress_messages.append(f"✅ Αποθηκεύτηκαν {len(chunk)} δείγματα (Σύνολο: {total_processed})")
chunk = []
gc.collect()
except StopIteration:
progress_messages.append("🏁 Ολοκληρώθηκε η επεξεργασία όλων των δεδομένων!")
break
except Exception as e:
progress_messages.append(f"⛔ Σφάλμα: {str(e)}")
break
return "\n".join(progress_messages)
def train_tokenizer_fn(dataset_name, configs, split, vocab_size, min_freq, test_text):
"""Εκπαίδευση του tokenizer και έλεγχος ποιότητας."""
messages = ["🚀 Εκκίνηση εκπαίδευσης..."]
try:
all_texts = load_checkpoint()
messages.append("📚 Φόρτωση δεδομένων από checkpoint...")
tokenizer = train_tokenizer(all_texts, vocab_size, min_freq, TOKENIZER_DIR, NUM_WORKERS)
messages.append("✅ Εκπαίδευση ολοκληρώθηκε!")
trained_tokenizer = Tokenizer.from_file(TOKENIZER_FILE)
encoded = trained_tokenizer.encode(test_text)
decoded = trained_tokenizer.decode(encoded.ids)
fig, ax = plt.subplots()
ax.hist([len(t) for t in encoded.tokens], bins=20)
ax.set_xlabel('Μήκος Token')
ax.set_ylabel('Συχνότητα')
img_buffer = BytesIO()
plt.savefig(img_buffer, format='png')
plt.close()
return ("\n".join(messages), decoded, Image.open(img_buffer))
except Exception as e:
messages.append(f"❌ Σφάλμα: {str(e)}")
return ("\n".join(messages), "", None)
def analyze_checkpoint():
"""Ανάλυση δεδομένων από το checkpoint."""
messages = ["🔍 Έναρξη ανάλυσης..."]
try:
texts = load_checkpoint()
if not texts:
return "Δεν βρέθηκαν δεδομένα για ανάλυση."
total_chars = sum(len(t) for t in texts)
avg_length = total_chars / len(texts) if texts else 0
languages = {}
for t in texts[:1000]:
if len(t) > 20:
try:
lang = detect(t)
languages[lang] = languages.get(lang, 0) + 1
except Exception as e:
print(f"⚠️ Σφάλμα ανίχνευσης γλώσσας: {e}")
report = [
f"📊 Σύνολο δειγμάτων: {len(texts)}",
f"📝 Μέσο μήκος: {avg_length:.1f} χαρακτήρες",
"🌍 Γλώσσες (δείγμα 1000):",
*[f"- {k}: {v} ({v/10:.1f}%)" for k, v in languages.items()]
]
return "\n".join(messages + report)
except Exception as e:
messages.append(f"❌ Σφάλμα: {str(e)}")
return "\n".join(messages)
def restart_collection():
"""Διαγραφή checkpoint και επανεκκίνηση."""
global STOP_COLLECTION
STOP_COLLECTION = False
if os.path.exists(CHECKPOINT_FILE):
os.remove(CHECKPOINT_FILE)
return "🔄 Το checkpoint διαγράφηκε. Έτοιμο για νέα συλλογή."
# Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("## Custom Tokenizer Trainer για GPT-2")
with gr.Row():
with gr.Column(scale=2):
dataset_name = gr.Textbox(value="wikimedia/wikipedia", label="Dataset")
configs = gr.Textbox(value="20231101.el,20231101.en", label="Configurations")
split = gr.Dropdown(["train"], value="train", label="Split")
chunk_size = gr.Slider(10000, 500000, value=200000, step=10000, label="Chunk Size")
vocab_size = gr.Slider(20000, 50000, value=30000, step=1000, label="Μέγεθος Λεξιλογίου")
min_freq = gr.Slider(1, 10, value=3, label="Ελάχιστη Συχνότητα")
test_text = gr.Textbox(value="Η Ακρόπολη είναι σύμβολο της αρχαίας Ελλάδας.", label="Test Text")
max_samples = gr.Slider(10000, 10000000, value=5000000, step=100000, label="Μέγιστα Δείγματα")
with gr.Row():
start_btn = gr.Button("Start", variant="primary")
stop_btn = gr.Button("Stop", variant="stop")
restart_btn = gr.Button("Restart")
analyze_btn = gr.Button("Analyze Data")
train_btn = gr.Button("Train Tokenizer", variant="primary")
with gr.Column(scale=3):
progress = gr.Textbox(label="Πρόοδος", lines=10, interactive=False)
gr.Markdown("### Αποτελέσματα")
decoded_text = gr.Textbox(label="Αποκωδικοποιημένο Κείμενο")
token_distribution = gr.Image(label="Κατανομή Tokens")
# Event handlers
start_btn.click(collect_samples, [dataset_name, configs, split, chunk_size, max_samples], progress)
stop_btn.click(lambda: globals().update(STOP_COLLECTION=True) or "⏹️ Διακοπή συλλογής...", None, progress, queue=False)
restart_btn.click(restart_collection, None, progress)
analyze_btn.click(analyze_checkpoint, None, progress)
train_btn.click(train_tokenizer_fn, [dataset_name, configs, split, vocab_size, min_freq, test_text],
[progress, decoded_text, token_distribution])
demo.queue().launch() |