|
import gradio as gr |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
from io import BytesIO |
|
import tempfile |
|
import os |
|
|
|
from train_tokenizer import train_tokenizer |
|
from tokenizers import Tokenizer |
|
from datasets import load_dataset |
|
|
|
def create_iterator(files=None, dataset_name=None, split="train", streaming=True): |
|
""" |
|
Δημιουργεί έναν iterator που διαβάζει κείμενο είτε από τοπικά αρχεία είτε από Hugging Face Dataset. |
|
""" |
|
if dataset_name: |
|
dataset = load_dataset(dataset_name, split=split, streaming=streaming) |
|
for example in dataset: |
|
|
|
yield example['text'] |
|
elif files: |
|
for file in files: |
|
with open(file.name, 'r', encoding='utf-8') as f: |
|
for line in f: |
|
if line.strip(): |
|
yield line.strip() |
|
|
|
def enhanced_validation(tokenizer, test_text): |
|
""" |
|
Εκτελεί επικύρωση του tokenizer με ένα roundtrip test και παρέχει στατιστικά. |
|
""" |
|
encoded = tokenizer.encode(test_text) |
|
decoded = tokenizer.decode(encoded.ids) |
|
|
|
|
|
unknown_tokens = sum(1 for t in encoded.tokens if t == "<unk>") |
|
unknown_percent = (unknown_tokens / len(encoded.tokens) * 100) if encoded.tokens else 0 |
|
|
|
|
|
token_lengths = [len(t) for t in encoded.tokens] |
|
avg_length = np.mean(token_lengths) if token_lengths else 0 |
|
|
|
|
|
code_symbols = ['{', '}', '(', ')', ';', '//', 'printf'] |
|
code_coverage = {sym: (sym in test_text and sym in encoded.tokens) for sym in code_symbols} |
|
|
|
|
|
fig = plt.figure() |
|
plt.hist(token_lengths, bins=20, color='skyblue', edgecolor='black') |
|
plt.xlabel('Μήκος Token') |
|
plt.ylabel('Συχνότητα') |
|
plt.title('Κατανομή Μήκους Tokens') |
|
img_buffer = BytesIO() |
|
plt.savefig(img_buffer, format='png') |
|
plt.close() |
|
img_buffer.seek(0) |
|
|
|
return { |
|
"roundtrip_success": test_text == decoded, |
|
"unknown_tokens": f"{unknown_tokens} ({unknown_percent:.2f}%)", |
|
"average_token_length": f"{avg_length:.2f}", |
|
"code_coverage": code_coverage, |
|
"token_length_distribution": img_buffer.getvalue() |
|
} |
|
|
|
|
|
|
|
def train_and_test(files, dataset_name, split, vocab_size, min_freq, test_text): |
|
if not files and not dataset_name: |
|
raise gr.Error("Πρέπει να παρέχετε αρχεία ή όνομα dataset!") |
|
|
|
try: |
|
|
|
iterator = create_iterator(files, dataset_name, split) |
|
|
|
|
|
with gr.Progress() as progress: |
|
progress(0.1, desc="Προεπεξεργασία δεδομένων...") |
|
tokenizer = train_tokenizer(iterator, vocab_size, min_freq) |
|
|
|
|
|
|
|
except Exception as e: |
|
raise gr.Error(f"Σφάλμα εκπαίδευσης: {str(e)}") |
|
|
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".json") as tmp: |
|
tokenizer.save(tmp.name) |
|
trained_tokenizer = Tokenizer.from_file(tmp.name) |
|
os.unlink(tmp.name) |
|
|
|
|
|
validation = enhanced_validation(trained_tokenizer, test_text) |
|
|
|
return { |
|
"validation_metrics": {k: v for k, v in validation.items() if k != "token_length_distribution"}, |
|
"histogram": validation["token_length_distribution"] |
|
} |
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
gr.Markdown("## Προχωρημένος BPE Tokenizer Trainer") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
with gr.Tab("Local Files"): |
|
file_input = gr.File(file_count="multiple", label="Ανέβασμα αρχείων") |
|
with gr.Tab("Hugging Face Dataset"): |
|
dataset_name = gr.Textbox(label="Όνομα Dataset (π.χ. 'wikitext', 'codeparrot/github-code')") |
|
split = gr.Textbox(value="train", label="Split") |
|
|
|
vocab_size = gr.Slider(1000, 100000, value=32000, label="Μέγεθος Λεξιλογίου") |
|
min_freq = gr.Slider(1, 100, value=2, label="Ελάχιστη Συχνότητα") |
|
test_text = gr.Textbox( |
|
value='function helloWorld() { console.log("Γειά σου Κόσμε!"); } // Ελληνικά + κώδικας', |
|
label="Test Text" |
|
) |
|
train_btn = gr.Button("Εκπαίδευση Tokenizer", variant="primary") |
|
|
|
with gr.Column(): |
|
results_json = gr.JSON(label="Μετρικές") |
|
results_plot = gr.Image(label="Κατανομή Μήκους Tokens") |
|
|
|
train_btn.click( |
|
fn=train_and_test, |
|
inputs=[file_input, dataset_name, split, vocab_size, min_freq, test_text], |
|
outputs=[results_json, results_plot] |
|
) |
|
if __name__ == "__main__": |
|
demo.launch() |