import gradio as gr import torch import numpy as np import json import time from transformers import AutoTokenizer from llama_diffusion_model import disable_dropout import os import importlib from huggingface_hub import hf_hub_download from llama_diffusion_model import CustomTransformerModel, CustomTransformerConfig, BidirectionalLlamaAttention, disable_dropout import spaces hf_token = os.getenv("HF_TOKEN") # --- Load tokenizer --- tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B", use_fast=True, token=hf_token) vocab_size = len(tokenizer) pad_token = tokenizer.pad_token_id or tokenizer.eos_token_id eot_token_id = tokenizer.eos_token_id assistant_marker_ids = tokenizer.encode("Assistant:", add_special_tokens=False) # --- Load token probabilities --- with open("token_probabilities.json") as f: token_probs_dict = json.load(f) token_probabilities = np.array([token_probs_dict[str(i)] for i in range(len(token_probs_dict))], dtype=np.float32) def load_model(): ckpt_path = hf_hub_download( repo_id="ruurd/tini_model", filename="diffusion-model.pth", token=os.getenv("HF_TOKEN") ) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = torch.load(ckpt_path, map_location=device) model = disable_dropout(model) model.to(device) model.eval() return model rng = np.random.default_rng() # --- Utility Functions --- def decode_tokens_safe(token_ids): return tokenizer.decode(token_ids, skip_special_tokens=True).replace("\n", " ") def find_answer_start(input_ids, marker_ids): for i in range(len(input_ids) - len(marker_ids) + 1): if input_ids[i:i + len(marker_ids)] == marker_ids: return i + len(marker_ids) return None def get_noising_schedule(i, max_it, sharpness=5.0): x = i / max_it return (np.exp(-sharpness * x) - np.exp(-sharpness)) / (1 - np.exp(-sharpness)) def noisify_answer(input_ids, answer_start, threshold=1.0, eot_weight=1.0): noised = input_ids.copy() answer_len = len(input_ids) - answer_start num_to_noise = int(threshold * answer_len) if num_to_noise > 0: indices = rng.choice(np.arange(answer_start, len(input_ids)), size=num_to_noise, replace=False) mixed_probs = token_probabilities.copy() mixed_probs[eot_token_id] *= eot_weight mixed_probs /= mixed_probs.sum() noise = rng.choice(np.arange(vocab_size), size=num_to_noise, p=mixed_probs) for idx, val in zip(indices, noise): noised[idx] = val return noised def generate_diffusion_text(input_ids, answer_start): with torch.no_grad(): input_tensor = torch.tensor([input_ids], dtype=torch.long).to(model.device) logits = model(input_ids=input_tensor)["logits"] probs = torch.nn.functional.softmax(logits, dim=-1).squeeze() probs = torch.clamp(probs, min=1e-8, max=1.0) sampled = torch.multinomial(probs, num_samples=1).squeeze().tolist() return input_ids[:answer_start] + sampled[answer_start:] # --- Inference Wrapper --- @spaces.GPU def diffusion_chat(question, eot_weight, max_it, sharpness): placeholder = "What do you know about the city of New York?" if question.strip() == "": question = placeholder prompt = f"User: {question}\nAssistant:" input_ids = tokenizer.encode(prompt, add_special_tokens=False) answer_start = find_answer_start(input_ids, assistant_marker_ids) if answer_start is None: yield "Error: Could not find Assistant marker in input." return if len(input_ids) < 256: input_ids += [pad_token] * (256 - len(input_ids)) else: input_ids = input_ids[:256] ori_input_tokens = input_ids current_tokens = noisify_answer(ori_input_tokens, answer_start, threshold=1.0, eot_weight=eot_weight) prev_decoded_tokens = [] last_tokens = [] for i in range(max_it): generated_tokens = generate_diffusion_text(current_tokens, answer_start) current_tokens = generated_tokens decoded_ids = current_tokens[answer_start:] decoded_tokens = tokenizer.convert_ids_to_tokens(decoded_ids) filtered_tokens = [tok for tok in decoded_tokens if tokenizer.convert_tokens_to_ids(tok) != eot_token_id] filtered_prev_tokens = [tok for tok in prev_decoded_tokens if tokenizer.convert_tokens_to_ids(tok) != eot_token_id] if prev_decoded_tokens else [] if filtered_prev_tokens: highlighted = [] for tok_new, tok_old in zip(filtered_tokens, filtered_prev_tokens): if tok_new != tok_old: highlighted.append(f'{tokenizer.convert_tokens_to_string([tok_new])}') else: highlighted.append(tokenizer.convert_tokens_to_string([tok_new])) else: highlighted = [tokenizer.convert_tokens_to_string([tok]) for tok in filtered_tokens] prev_decoded_tokens = decoded_tokens yield f"Iteration {i+1}/{max_it} (running):
" + "".join(highlighted) last_tokens.append(generated_tokens) if len(last_tokens) > 3: last_tokens.pop(0) if len(last_tokens) == 3 and last_tokens[0] == last_tokens[1] == last_tokens[2]: yield f"Stopped early after {i+1} iterations." break threshold = get_noising_schedule(i, max_it, sharpness=sharpness) current_tokens = noisify_answer(generated_tokens, answer_start, threshold=threshold, eot_weight=eot_weight) time.sleep(0.01) final_tokens = tokenizer.convert_ids_to_tokens(current_tokens[answer_start:]) final_tokens = [tok for tok in final_tokens if tokenizer.convert_tokens_to_ids(tok) != eot_token_id] final_output = tokenizer.convert_tokens_to_string(final_tokens) yield f"Final Output (after {i+1} iterations):
" + final_output # --- Gradio Interface --- print("Loading model...") model = load_model() print("✅ Model loaded.") demo = gr.Interface( fn=diffusion_chat, inputs=[ gr.Textbox(label="User Question", lines=2, placeholder="What do you know about the city of New York?"), gr.Slider(0, 1, value=0.4, step=0.05, label="↓ = longer answers (EOT weight)"), gr.Slider(1, 512, value=64, step=1, label="↑ = more iterations"), gr.Slider(1.0, 20.0, value=5.0, step=0.5, label="↓ = more noising (sharpness)") ], outputs=gr.HTML(label="Diffusion Output"), title="Diffusion Language Model Chat", description="This interface runs a diffusion-based language model to generate answers progressively." ) demo.launch()