Spaces:
Running
Running
import gradio as gr | |
import json | |
import os | |
import sys | |
import numpy as np | |
# Add the current directory to Python path | |
current_dir = os.path.dirname(os.path.abspath(__file__)) | |
sys.path.append(current_dir) | |
from tokenizers.basic import BasicTokenizer | |
def load_tokenizer(model_path, vocab_path): | |
"""Load the trained tokenizer""" | |
tokenizer = BasicTokenizer() | |
try: | |
# Check if paths exist | |
if not os.path.exists(model_path): | |
raise FileNotFoundError(f"Model file not found at: {model_path}") | |
if not os.path.exists(vocab_path): | |
raise FileNotFoundError(f"Vocabulary file not found at: {vocab_path}") | |
# Load the trained model | |
tokenizer.load(model_path) | |
# Load vocabulary | |
with open(vocab_path, 'r', encoding='utf-8') as f: | |
vocab_data = json.load(f) | |
tokenizer.token_to_id = {k: int(v) for k, v in vocab_data['token_to_id'].items()} | |
tokenizer.id_to_token = {int(k): v for k, v in vocab_data['id_to_token'].items()} | |
tokenizer.merges = {tuple(map(int, k.split(','))): int(v) | |
for k, v in vocab_data['merges'].items()} | |
return tokenizer | |
except Exception as e: | |
raise Exception(f"Error loading tokenizer: {str(e)}") | |
def encode_text(text, tokenizer): | |
"""Encode text and return statistics""" | |
if not text.strip(): | |
return ("Please enter some Telugu text", | |
"No statistics available", | |
[]) # Empty list for visualization | |
try: | |
# Encode the text | |
encoded = tokenizer.encode(text) | |
# Calculate compression ratio | |
original_size = len(text.encode('utf-8')) | |
encoded_size = len(encoded) * 2 | |
compression_ratio = original_size / encoded_size | |
# Prepare statistics | |
stats = f""" | |
π Encoding Statistics: | |
β’ Original text length: {len(text)} characters | |
β’ Encoded length: {len(encoded)} tokens | |
β’ Compression ratio: {compression_ratio:.2f}X | |
β’ Original size: {original_size} bytes | |
β’ Encoded size: {encoded_size} bytes | |
β’ Space saved: {(1 - encoded_size/original_size) * 100:.1f}% | |
""" | |
# Create visualization data | |
tokens = [] | |
# Generate colors based on token frequencies | |
unique_tokens = set(encoded) | |
# Create color map with string hex colors instead of RGB lists | |
color_map = {token: f"#{hash(str(token)) % 0xFFFFFF:06x}" for token in unique_tokens} | |
# Create visualization list with proper format | |
visualization = [] | |
for token_id in encoded: | |
token_bytes = tokenizer.vocab[token_id] | |
token_text = token_bytes.decode('utf-8', errors='replace') | |
visualization.append((token_text, color_map[token_id])) | |
return ( | |
str(encoded), # encoded_ids for the first textbox | |
stats, # stats for the second textbox | |
visualization # for the HighlightedText component | |
) | |
except Exception as e: | |
return ( | |
f"Error: {str(e)}", | |
"Error occurred during encoding", | |
[] # Empty list for visualization on error | |
) | |
def decode_ids(encoded_ids_str, tokenizer): | |
"""Decode the encoded IDs back to text""" | |
if not encoded_ids_str.strip(): | |
return "Please enter encoded IDs" | |
try: | |
# Convert string representation of list to actual list of integers | |
encoded_ids = eval(encoded_ids_str) | |
if not isinstance(encoded_ids, list): | |
return "Invalid input: Please enter a list of integers" | |
# Decode the IDs | |
decoded_text = tokenizer.decode(encoded_ids) | |
return decoded_text | |
except Exception as e: | |
return f"Error during decoding: {str(e)}" | |
def visualize_encoding(text, encoded_ids, tokenizer): | |
"""Create a visual representation of the encoding""" | |
tokens = [] | |
colors = [] | |
# Generate colors based on token frequencies | |
unique_tokens = set(encoded_ids) | |
color_map = {token: np.random.rand(3).tolist() for token in unique_tokens} | |
for token_id in encoded_ids: | |
token_bytes = tokenizer.vocab[token_id] | |
token_text = token_bytes.decode('utf-8', errors='replace') | |
tokens.append(token_text) | |
colors.append(color_map[token_id]) | |
return { | |
"tokens": tokens, | |
"colors": colors | |
} | |
# Load the tokenizer with proper path handling | |
try: | |
model_path = os.path.join(current_dir, "models", "version_2", "checkpoints", "telugu_basic.model") | |
vocab_path = os.path.join(current_dir, "models", "version_2", "vocabulary", "vocabulary.json") | |
print(f"Loading model from: {model_path}") | |
print(f"Loading vocabulary from: {vocab_path}") | |
tokenizer = load_tokenizer(model_path, vocab_path) | |
print("Tokenizer loaded successfully") | |
except Exception as e: | |
print(f"Error loading tokenizer: {str(e)}") | |
raise | |
# Create the Gradio interface | |
with gr.Blocks(title="Telugu Text Tokenizer", theme=gr.themes.Soft()) as demo: | |
gr.Markdown(""" | |
# π€ Telugu Text Tokenizer | |
This tool helps you encode Telugu text into tokens and decode them back. | |
It uses a trained BPE (Byte Pair Encoding) tokenizer optimized for Telugu language. | |
## Features: | |
- π Encode Telugu text to token IDs | |
- π View compression statistics | |
- π¨ Visualize token segmentation | |
- β‘ Fast and efficient encoding/decoding | |
""") | |
with gr.Tab("Encoder"): | |
with gr.Row(): | |
with gr.Column(): | |
input_text = gr.Textbox( | |
label="Enter Telugu Text", | |
placeholder="Type or paste Telugu text here...", | |
lines=5 | |
) | |
encode_btn = gr.Button("π Encode", variant="primary") | |
with gr.Column(): | |
encoded_output = gr.Textbox( | |
label="Encoded Token IDs", | |
lines=5, | |
interactive=False | |
) | |
stats_output = gr.Textbox( | |
label="Statistics", | |
lines=8, | |
interactive=False | |
) | |
with gr.Row(): | |
gr.Markdown("### Token Visualization") | |
token_viz = gr.HighlightedText( | |
label="Token Segmentation", | |
show_legend=True, | |
combine_adjacent=True, | |
color_map={} # Let Gradio handle the color mapping | |
) | |
with gr.Tab("Decoder"): | |
with gr.Row(): | |
with gr.Column(): | |
encoded_input = gr.Textbox( | |
label="Enter Encoded Token IDs", | |
placeholder="Paste the encoded token IDs here...", | |
lines=5 | |
) | |
decode_btn = gr.Button("π Decode", variant="primary") | |
with gr.Column(): | |
decoded_output = gr.Textbox( | |
label="Decoded Telugu Text", | |
lines=5, | |
interactive=False | |
) | |
# Set up event handlers | |
encode_btn.click( | |
fn=encode_text, # Now using the function directly | |
inputs=[input_text, gr.State(tokenizer)], # Pass tokenizer as state | |
outputs=[encoded_output, stats_output, token_viz] | |
) | |
decode_btn.click( | |
fn=lambda ids: decode_ids(ids, tokenizer), | |
inputs=encoded_input, | |
outputs=decoded_output | |
) | |
gr.Markdown(""" | |
### π Instructions: | |
1. **Encoding**: Enter Telugu text in the encoder tab and click "Encode" | |
2. **Decoding**: Copy the encoded IDs and paste them in the decoder tab | |
3. **Visualization**: View token segmentation with color coding | |
### βΉοΈ Notes: | |
- The tokenizer uses BPE (Byte Pair Encoding) algorithm | |
- Compression ratio shows how efficiently the text is encoded | |
- Different colors in visualization represent different tokens | |
""") | |
# Launch the app with additional configurations | |
if __name__ == "__main__": | |
demo.launch( | |
share=True, | |
debug=True, | |
server_name="0.0.0.0", | |
server_port=7860, | |
show_error=True | |
) |