Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,438 Bytes
570eaa9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
import os
import gradio as gr
import torch
# Removed matplotlib and plot_entropies imports
# Assuming bytelatent library and its dependencies are installed
from bytelatent.data.file_util import get_fs
# from bytelatent.distributed import DistributedArgs, setup_torch_distributed # Not needed
from bytelatent.generate_patcher import patcher_nocache
from bytelatent.tokenizers.blt_tokenizer import BltTokenizer
# Removed: from bytelatent.plotting.entropy_figure_via_matplot_lib import plot_entropies
from bytelatent.args import TrainArgs
from download_blt_weights import main as ensure_present
# --- Global Setup (Consider loading models outside if necessary) ---
# Kept inside the function for simplicity as before.
def process_text(prompt: str, model_name: str = "blt-1b"):
"""
Processes the input prompt using the ByteLatent model and returns decoded characters.
Args:
prompt: The input text string from the Gradio interface.
model_name: The name of the model to use.
Returns:
A string containing the decoded characters after processing, or an error message.
"""
try:
# --- Model and Tokenizer Loading ---
consolidated_path = os.path.join("hf-weights", model_name)
train_args_path = os.path.join(consolidated_path, "params.json")
if not os.path.exists(train_args_path):
raise FileNotFoundError(f"Training args not found at {train_args_path}. "
f"Ensure model '{model_name}' is downloaded/available.")
fs = get_fs(train_args_path)
train_args = TrainArgs.model_validate_json(fs.read_text(train_args_path))
tokenizer = train_args.data.tokenizer_args.build()
assert isinstance(tokenizer, BltTokenizer)
patcher_args = train_args.data.patcher_args.model_copy(deep=True)
patcher_args.realtime_patching = True
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
patcher_args.patching_device = device
patcher_args.device = device
print("Loading entropy model and patcher...")
entropy_model_dir = os.path.join(consolidated_path, "entropy_model")
if not os.path.exists(entropy_model_dir):
raise FileNotFoundError(f"Entropy model directory not found at {entropy_model_dir}.")
patcher_args.entropy_model_checkpoint_dir = entropy_model_dir
patcher = patcher_args.build()
# --- End Loading ---
# --- Processing ---
prompts = [prompt]
print(f"Processing prompt: '{prompt}'")
results = patcher_nocache(
prompts, tokenizer=tokenizer, patcher=patcher
)
if not results:
print("Processing returned no results.")
return "Processing completed, but no results were generated." # Return info message
batch_patch_lengths, batch_scores, batch_tokens = results
# Decode the first (and only) result in the batch
decoded_chars_list = [tokenizer.decode(row_tokens.tolist()) for row_tokens in batch_tokens]
decoded_output = decoded_chars_list[0] if decoded_chars_list else "No characters decoded."
print("Processing and decoding complete.")
# --- End Processing ---
# Return the decoded text string
return decoded_output
except FileNotFoundError as e:
print(f"Error: {e}")
# raise gr.Error(str(e)) # Display specific error in Gradio UI
return f"Error: {str(e)}" # Return error as text output
except Exception as e:
print(f"An unexpected error occurred: {e}")
import traceback
traceback.print_exc()
# raise gr.Error(f"An error occurred during processing: {e}")
return f"An unexpected error occurred: {e}" # Return error as text output
# --- Gradio Interface Definition ---
iface = gr.Interface(
fn=process_text,
inputs=gr.Textbox(
label="Input Prompt",
placeholder="Enter your text here..."
),
# Changed output to display the decoded text
outputs=gr.Text(label="Decoded Output"),
title="ByteLatent Text Processor",
description="Enter text to process it with the ByteLatent model ('blt-1b' by default). The decoded output will be shown.",
allow_flagging="never",
)
# --- Launch the Gradio App ---
if __name__ == "__main__":
ensure_present(["blt-1b"])
iface.launch()
|