Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import torch | |
import numpy as np | |
import json | |
import time | |
from transformers import AutoTokenizer | |
from llama_diffusion_model import disable_dropout | |
import os | |
import importlib | |
from huggingface_hub import hf_hub_download | |
from llama_diffusion_model import CustomTransformerModel, CustomTransformerConfig, BidirectionalLlamaAttention, disable_dropout | |
import spaces | |
hf_token = os.getenv("HF_TOKEN") | |
# --- Load tokenizer --- | |
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B", use_fast=True, token=hf_token) | |
vocab_size = len(tokenizer) | |
pad_token = tokenizer.pad_token_id or tokenizer.eos_token_id | |
eot_token_id = tokenizer.eos_token_id | |
assistant_marker_ids = tokenizer.encode("Assistant:", add_special_tokens=False) | |
# --- Load token probabilities --- | |
with open("token_probabilities.json") as f: | |
token_probs_dict = json.load(f) | |
token_probabilities = np.array([token_probs_dict[str(i)] for i in range(len(token_probs_dict))], dtype=np.float32) | |
def load_model(): | |
ckpt_path = hf_hub_download( | |
repo_id="ruurd/tini_model", | |
filename="diffusion-model.pth", | |
token=os.getenv("HF_TOKEN") | |
) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = torch.load(ckpt_path, map_location=device) | |
model = disable_dropout(model) | |
model.to(device) | |
model.eval() | |
return model | |
rng = np.random.default_rng() | |
# --- Utility Functions --- | |
def decode_tokens_safe(token_ids): | |
return tokenizer.decode(token_ids, skip_special_tokens=True).replace("\n", " ") | |
def find_answer_start(input_ids, marker_ids): | |
for i in range(len(input_ids) - len(marker_ids) + 1): | |
if input_ids[i:i + len(marker_ids)] == marker_ids: | |
return i + len(marker_ids) | |
return None | |
def get_noising_schedule(i, max_it, sharpness=5.0): | |
x = i / max_it | |
return (np.exp(-sharpness * x) - np.exp(-sharpness)) / (1 - np.exp(-sharpness)) | |
def noisify_answer(input_ids, answer_start, threshold=1.0, eot_weight=1.0): | |
noised = input_ids.copy() | |
answer_len = len(input_ids) - answer_start | |
num_to_noise = int(threshold * answer_len) | |
if num_to_noise > 0: | |
indices = rng.choice(np.arange(answer_start, len(input_ids)), size=num_to_noise, replace=False) | |
mixed_probs = token_probabilities.copy() | |
mixed_probs[eot_token_id] *= eot_weight | |
mixed_probs /= mixed_probs.sum() | |
noise = rng.choice(np.arange(vocab_size), size=num_to_noise, p=mixed_probs) | |
for idx, val in zip(indices, noise): | |
noised[idx] = val | |
return noised | |
print("Loading model...") | |
model = load_model() | |
print("✅ Model loaded.") | |
def generate_diffusion_text(input_ids, answer_start): | |
with torch.no_grad(): | |
input_tensor = torch.tensor([input_ids], dtype=torch.long).to(model.device) | |
logits = model(input_ids=input_tensor)["logits"] | |
probs = torch.nn.functional.softmax(logits, dim=-1).squeeze() | |
probs = torch.clamp(probs, min=1e-8, max=1.0) | |
sampled = torch.multinomial(probs, num_samples=1).squeeze().tolist() | |
return input_ids[:answer_start] + sampled[answer_start:] | |
# --- Diffusion Chat Function --- | |
def diffusion_chat(message, system_prompt, eot_weight, max_it, sharpness): | |
prompt = f"{system_prompt}\nUser: {message}\nAssistant:" | |
input_ids = tokenizer.encode(prompt, add_special_tokens=False) | |
answer_start = find_answer_start(input_ids, assistant_marker_ids) | |
if answer_start is None: | |
yield "<span style='color:red'><b>Error:</b> Could not find Assistant marker in input.</span>" | |
return | |
input_ids = (input_ids + [pad_token] * (256 - len(input_ids)))[:256] | |
current_tokens = noisify_answer(input_ids, answer_start, threshold=1.0, eot_weight=eot_weight) | |
prev_decoded_tokens = [] | |
last_tokens = [] | |
for i in range(max_it): | |
generated_tokens = generate_diffusion_text(current_tokens, answer_start) | |
current_tokens = generated_tokens | |
decoded_ids = current_tokens[answer_start:] | |
decoded_tokens = tokenizer.convert_ids_to_tokens(decoded_ids) | |
filtered_tokens = [tok for tok in decoded_tokens if tokenizer.convert_tokens_to_ids(tok) != eot_token_id] | |
filtered_prev_tokens = [tok for tok in prev_decoded_tokens if tokenizer.convert_tokens_to_ids(tok) != eot_token_id] if prev_decoded_tokens else [] | |
highlighted = [] | |
for tok_new, tok_old in zip(filtered_tokens, filtered_prev_tokens): | |
text = tokenizer.convert_tokens_to_string([tok_new]) | |
if tok_new != tok_old: | |
highlighted.append(f"<span style='color:green'>{text}</span>") | |
else: | |
highlighted.append(text) | |
prev_decoded_tokens = decoded_tokens | |
yield ("<div style='padding:0.5em'><b>Iteration {}</b><br>" | |
"<div style='background:#f5f5f5;padding:0.5em;border-radius:0.5em'>{}</div></div>").format(i+1, ''.join(highlighted)) | |
last_tokens.append(generated_tokens) | |
if len(last_tokens) == 3 and all(t == last_tokens[0] for t in last_tokens): | |
yield f"<div style='color:gray'><i>Stopped early after {i+1} iterations (converged).</i></div>" | |
break | |
threshold = get_noising_schedule(i, max_it, sharpness=sharpness) | |
current_tokens = noisify_answer(generated_tokens, answer_start, threshold=threshold, eot_weight=eot_weight) | |
time.sleep(0.01) | |
final_tokens = tokenizer.convert_ids_to_tokens(current_tokens[answer_start:]) | |
final_tokens = [tok for tok in final_tokens if tokenizer.convert_tokens_to_ids(tok) != eot_token_id] | |
final_output = tokenizer.convert_tokens_to_string(final_tokens) | |
yield f"<div style='padding:0.5em'><b>Final Output:</b><br><div style='background:#e0ffe0;padding:0.5em;border-radius:0.5em'>{final_output}</div></div>" | |
# --- Chat Interface --- | |
demo = gr.ChatInterface( | |
diffusion_chat, | |
additional_inputs=[ | |
gr.Textbox(value="You are a helpful assistant.", label="System message"), | |
gr.Slider(0, 1, value=0.4, step=0.05, label="EOT token weight (lower = longer output)"), | |
gr.Slider(1, 512, value=64, step=1, label="Max Iterations"), | |
gr.Slider(1.0, 20.0, value=5.0, step=0.5, label="Noising sharpness (lower = more noise)") | |
], | |
title="Diffusion Language Model Chat", | |
description="Iterative denoising chat interface using a fine-tuned LLaMA model." | |
) | |
if __name__ == "__main__": | |
demo.launch() |