File size: 8,969 Bytes
7c5aa99
f94c5ea
c259678
0430da2
 
f94c5ea
4a4435c
5d41434
1e3138f
 
c259678
72577d1
 
 
f94c5ea
5d41434
 
 
c259678
 
72577d1
3d37920
72577d1
 
 
 
9dd78b5
c259678
9dd78b5
0430da2
5c35386
c259678
5c35386
 
 
 
 
 
c259678
5c35386
c259678
 
5c35386
a9ae246
c259678
a9ae246
 
 
c259678
 
 
 
 
72577d1
c259678
 
 
 
 
 
 
 
 
a9ae246
72577d1
c259678
 
 
 
 
72577d1
c259678
 
 
 
 
 
72577d1
9dd78b5
4410500
d6a5933
72577d1
f94c5ea
c259678
 
 
 
 
 
 
 
 
 
 
 
72577d1
c259678
 
 
 
9dd78b5
c259678
 
9dd78b5
c259678
5c35386
d6a5933
72577d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c259678
 
72577d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5c35386
08046e4
72577d1
08046e4
 
 
 
 
 
4a4435c
3d37920
72577d1
5d41434
c259678
 
 
 
 
72577d1
 
5d41434
72577d1
c259678
 
 
 
 
 
 
 
 
 
 
 
 
 
72577d1
 
c259678
 
72577d1
c259678
cb19003
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
# -*- coding: utf-8 -*-
import os
import gc
import gradio as gr
from datasets import load_dataset
from train_tokenizer import train_tokenizer
from tokenizers import Tokenizer
from langdetect import detect, DetectorFactory
from PIL import Image
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor
import matplotlib.pyplot as plt
from io import BytesIO
import traceback

# Για επαναληψιμότητα στο langdetect
DetectorFactory.seed = 0

# Ρυθμίσεις
CHECKPOINT_FILE = "checkpoint.txt"
TOKENIZER_DIR = "./tokenizer_model"
TOKENIZER_FILE = os.path.join(TOKENIZER_DIR, "tokenizer.json")
MAX_SAMPLES = 5000000
DEFAULT_CHUNK_SIZE = 200000
BATCH_SIZE = 1000
NUM_WORKERS = 4

# Παγκόσμια μεταβλητή ελέγχου
STOP_COLLECTION = False

def load_checkpoint():
    """Φόρτωση δεδομένων από το checkpoint."""
    if os.path.exists(CHECKPOINT_FILE):
        with open(CHECKPOINT_FILE, "r", encoding="utf-8") as f:
            return f.read().splitlines()
    return []

def append_to_checkpoint(texts):
    """Αποθήκευση δεδομένων με ομαδοποίηση."""
    with open(CHECKPOINT_FILE, "a", encoding="utf-8") as f:
        batch = "\n".join(texts) + "\n"
        f.write(batch)

def create_iterator(dataset_name, configs, split):
    """Βελτιωμένο iterator με batch φόρτωση και caching."""
    configs_list = [c.strip() for c in configs.split(",") if c.strip()]
    for config in configs_list:
        try:
            dataset = load_dataset(
                dataset_name,
                name=config,
                split=split,
                streaming=True,
                cache_dir="./dataset_cache"
            )
            while True:
                batch = list(dataset.take(BATCH_SIZE))
                if not batch:
                    break
                dataset = dataset.skip(BATCH_SIZE)
                with ThreadPoolExecutor(max_workers=NUM_WORKERS) as executor:
                    processed_texts = list(executor.map(process_example, batch))
                yield from filter(None, processed_texts)
        except Exception as e:
            print(f"⚠️ Σφάλμα φόρτωσης {config}: {e}")

def process_example(example):
    """Επεξεργασία ενός παραδείγματος με έλεγχο γλώσσας."""
    try:
        text = example.get('text', '').strip()
        if text and detect(text) in ['el', 'en']:
            return text
        return None
    except:
        return None

def collect_samples(dataset_name, configs, split, chunk_size, max_samples):
    """Συλλογή δεδομένων με streaming και checkpoints."""
    global STOP_COLLECTION
    STOP_COLLECTION = False
    total_processed = len(load_checkpoint())
    progress_messages = [f"🚀 Εκκίνηση συλλογής... Πρόοδος: {total_processed}/{max_samples}"]
    dataset_iterator = create_iterator(dataset_name, configs, split)
    chunk = []
    while not STOP_COLLECTION and total_processed < max_samples:
        try:
            while len(chunk) < chunk_size:
                text = next(dataset_iterator)
                if text:
                    chunk.append(text)
                    total_processed += 1
                    if total_processed >= max_samples:
                        break
            if chunk:
                append_to_checkpoint(chunk)
                progress_messages.append(f"✅ Αποθηκεύτηκαν {len(chunk)} δείγματα (Σύνολο: {total_processed})")
                chunk = []
                gc.collect()
        except StopIteration:
            progress_messages.append("🏁 Ολοκληρώθηκε η επεξεργασία όλων των δεδομένων!")
            break
        except Exception as e:
            progress_messages.append(f"⛔ Σφάλμα: {str(e)}")
            break
    return "\n".join(progress_messages)

def train_tokenizer_fn(dataset_name, configs, split, vocab_size, min_freq, test_text):
    """Εκπαίδευση του tokenizer και έλεγχος ποιότητας."""
    messages = ["🚀 Εκκίνηση εκπαίδευσης..."]
    try:
        all_texts = load_checkpoint()
        messages.append("📚 Φόρτωση δεδομένων από checkpoint...")
        tokenizer = train_tokenizer(all_texts, vocab_size, min_freq, TOKENIZER_DIR, NUM_WORKERS)
        messages.append("✅ Εκπαίδευση ολοκληρώθηκε!")
        trained_tokenizer = Tokenizer.from_file(TOKENIZER_FILE)
        encoded = trained_tokenizer.encode(test_text)
        decoded = trained_tokenizer.decode(encoded.ids)
        fig, ax = plt.subplots()
        ax.hist([len(t) for t in encoded.tokens], bins=20)
        ax.set_xlabel('Μήκος Token')
        ax.set_ylabel('Συχνότητα')
        img_buffer = BytesIO()
        plt.savefig(img_buffer, format='png')
        plt.close()
        return ("\n".join(messages), decoded, Image.open(img_buffer))
    except Exception as e:
        messages.append(f"❌ Σφάλμα: {str(e)}")
        return ("\n".join(messages), "", None)

def analyze_checkpoint():
    """Ανάλυση δεδομένων από το checkpoint."""
    messages = ["🔍 Έναρξη ανάλυσης..."]
    try:
        texts = load_checkpoint()
        if not texts:
            return "Δεν βρέθηκαν δεδομένα για ανάλυση."
        total_chars = sum(len(t) for t in texts)
        avg_length = total_chars / len(texts) if texts else 0
        languages = {}
        for t in texts[:1000]:
            if len(t) > 20:
                try:
                    lang = detect(t)
                    languages[lang] = languages.get(lang, 0) + 1
                except Exception as e:
                    print(f"⚠️ Σφάλμα ανίχνευσης γλώσσας: {e}")
        report = [
            f"📊 Σύνολο δειγμάτων: {len(texts)}",
            f"📝 Μέσο μήκος: {avg_length:.1f} χαρακτήρες",
            "🌍 Γλώσσες (δείγμα 1000):",
            *[f"- {k}: {v} ({v/10:.1f}%)" for k, v in languages.items()]
        ]
        return "\n".join(messages + report)
    except Exception as e:
        messages.append(f"❌ Σφάλμα: {str(e)}")
        return "\n".join(messages)

def restart_collection():
    """Διαγραφή checkpoint και επανεκκίνηση."""
    global STOP_COLLECTION
    STOP_COLLECTION = False
    if os.path.exists(CHECKPOINT_FILE):
        os.remove(CHECKPOINT_FILE)
    return "🔄 Το checkpoint διαγράφηκε. Έτοιμο για νέα συλλογή."

# Gradio Interface
with gr.Blocks() as demo:
    gr.Markdown("## Custom Tokenizer Trainer για GPT-2")
    with gr.Row():
        with gr.Column(scale=2):
            dataset_name = gr.Textbox(value="wikimedia/wikipedia", label="Dataset")
            configs = gr.Textbox(value="20231101.el,20231101.en", label="Configurations")
            split = gr.Dropdown(["train"], value="train", label="Split")
            chunk_size = gr.Slider(10000, 500000, value=200000, step=10000, label="Chunk Size")
            vocab_size = gr.Slider(20000, 50000, value=30000, step=1000, label="Μέγεθος Λεξιλογίου")
            min_freq = gr.Slider(1, 10, value=3, label="Ελάχιστη Συχνότητα")
            test_text = gr.Textbox(value="Η Ακρόπολη είναι σύμβολο της αρχαίας Ελλάδας.", label="Test Text")
            max_samples = gr.Slider(10000, 10000000, value=5000000, step=100000, label="Μέγιστα Δείγματα")
            with gr.Row():
                start_btn = gr.Button("Start", variant="primary")
                stop_btn = gr.Button("Stop", variant="stop")
                restart_btn = gr.Button("Restart")
            analyze_btn = gr.Button("Analyze Data")
            train_btn = gr.Button("Train Tokenizer", variant="primary")
        with gr.Column(scale=3):
            progress = gr.Textbox(label="Πρόοδος", lines=10, interactive=False)
            gr.Markdown("### Αποτελέσματα")
            decoded_text = gr.Textbox(label="Αποκωδικοποιημένο Κείμενο")
            token_distribution = gr.Image(label="Κατανομή Tokens")

    # Event handlers
    start_btn.click(collect_samples, [dataset_name, configs, split, chunk_size, max_samples], progress)
    stop_btn.click(lambda: globals().update(STOP_COLLECTION=True) or "⏹️ Διακοπή συλλογής...", None, progress, queue=False)
    restart_btn.click(restart_collection, None, progress)
    analyze_btn.click(analyze_checkpoint, None, progress)
    train_btn.click(train_tokenizer_fn, [dataset_name, configs, split, vocab_size, min_freq, test_text],
                    [progress, decoded_text, token_distribution])

demo.queue().launch()