Spaces:
Sleeping
Sleeping
File size: 6,457 Bytes
7252f98 9aaa660 7252f98 9aaa660 183a6ee 42ed840 7252f98 b5abc9b 0e3f268 617b180 9aaa660 b5abc9b 9aaa660 7252f98 0e1a415 7aaa1c3 7252f98 0e1a415 b5abc9b 0e1a415 7252f98 0e1a415 7252f98 84a6c46 7252f98 7aaa1c3 7252f98 84a6c46 7252f98 0e1a415 7252f98 84a6c46 0e1a415 7252f98 0e1a415 55b43fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
import gradio as gr
import torch
import numpy as np
import json
import time
from transformers import AutoTokenizer
from llama_diffusion_model import disable_dropout
import os
import importlib
from huggingface_hub import hf_hub_download
from llama_diffusion_model import CustomTransformerModel, CustomTransformerConfig, BidirectionalLlamaAttention, disable_dropout
import spaces
hf_token = os.getenv("HF_TOKEN")
# --- Load tokenizer ---
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B", use_fast=True, token=hf_token)
vocab_size = len(tokenizer)
pad_token = tokenizer.pad_token_id or tokenizer.eos_token_id
eot_token_id = tokenizer.eos_token_id
assistant_marker_ids = tokenizer.encode("Assistant:", add_special_tokens=False)
# --- Load token probabilities ---
with open("token_probabilities.json") as f:
token_probs_dict = json.load(f)
token_probabilities = np.array([token_probs_dict[str(i)] for i in range(len(token_probs_dict))], dtype=np.float32)
def load_model():
ckpt_path = hf_hub_download(
repo_id="ruurd/tini_model",
filename="diffusion-model.pth",
token=os.getenv("HF_TOKEN")
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torch.load(ckpt_path, map_location=device)
model = disable_dropout(model)
model.to(device)
model.eval()
return model
rng = np.random.default_rng()
# --- Utility Functions ---
def decode_tokens_safe(token_ids):
return tokenizer.decode(token_ids, skip_special_tokens=True).replace("\n", " ")
def find_answer_start(input_ids, marker_ids):
for i in range(len(input_ids) - len(marker_ids) + 1):
if input_ids[i:i + len(marker_ids)] == marker_ids:
return i + len(marker_ids)
return None
def get_noising_schedule(i, max_it, sharpness=5.0):
x = i / max_it
return (np.exp(-sharpness * x) - np.exp(-sharpness)) / (1 - np.exp(-sharpness))
def noisify_answer(input_ids, answer_start, threshold=1.0, eot_weight=1.0):
noised = input_ids.copy()
answer_len = len(input_ids) - answer_start
num_to_noise = int(threshold * answer_len)
if num_to_noise > 0:
indices = rng.choice(np.arange(answer_start, len(input_ids)), size=num_to_noise, replace=False)
mixed_probs = token_probabilities.copy()
mixed_probs[eot_token_id] *= eot_weight
mixed_probs /= mixed_probs.sum()
noise = rng.choice(np.arange(vocab_size), size=num_to_noise, p=mixed_probs)
for idx, val in zip(indices, noise):
noised[idx] = val
return noised
print("Loading model...")
model = load_model()
print("✅ Model loaded.")
def generate_diffusion_text(input_ids, answer_start):
with torch.no_grad():
input_tensor = torch.tensor([input_ids], dtype=torch.long).to(model.device)
logits = model(input_ids=input_tensor)["logits"]
probs = torch.nn.functional.softmax(logits, dim=-1).squeeze()
probs = torch.clamp(probs, min=1e-8, max=1.0)
sampled = torch.multinomial(probs, num_samples=1).squeeze().tolist()
return input_ids[:answer_start] + sampled[answer_start:]
# --- Diffusion Chat Function ---
@spaces.GPU
def diffusion_chat(message, system_prompt, eot_weight, max_it, sharpness):
prompt = f"{system_prompt}\nUser: {message}\nAssistant:"
input_ids = tokenizer.encode(prompt, add_special_tokens=False)
answer_start = find_answer_start(input_ids, assistant_marker_ids)
if answer_start is None:
yield "<span style='color:red'><b>Error:</b> Could not find Assistant marker in input.</span>"
return
input_ids = (input_ids + [pad_token] * (256 - len(input_ids)))[:256]
current_tokens = noisify_answer(input_ids, answer_start, threshold=1.0, eot_weight=eot_weight)
prev_decoded_tokens = []
last_tokens = []
for i in range(max_it):
generated_tokens = generate_diffusion_text(current_tokens, answer_start)
current_tokens = generated_tokens
decoded_ids = current_tokens[answer_start:]
decoded_tokens = tokenizer.convert_ids_to_tokens(decoded_ids)
filtered_tokens = [tok for tok in decoded_tokens if tokenizer.convert_tokens_to_ids(tok) != eot_token_id]
filtered_prev_tokens = [tok for tok in prev_decoded_tokens if tokenizer.convert_tokens_to_ids(tok) != eot_token_id] if prev_decoded_tokens else []
highlighted = []
for tok_new, tok_old in zip(filtered_tokens, filtered_prev_tokens):
text = tokenizer.convert_tokens_to_string([tok_new])
if tok_new != tok_old:
highlighted.append(f"<span style='color:green'>{text}</span>")
else:
highlighted.append(text)
prev_decoded_tokens = decoded_tokens
yield ("<div style='padding:0.5em'><b>Iteration {}</b><br>"
"<div style='background:#f5f5f5;padding:0.5em;border-radius:0.5em'>{}</div></div>").format(i+1, ''.join(highlighted))
last_tokens.append(generated_tokens)
if len(last_tokens) == 3 and all(t == last_tokens[0] for t in last_tokens):
yield f"<div style='color:gray'><i>Stopped early after {i+1} iterations (converged).</i></div>"
break
threshold = get_noising_schedule(i, max_it, sharpness=sharpness)
current_tokens = noisify_answer(generated_tokens, answer_start, threshold=threshold, eot_weight=eot_weight)
time.sleep(0.01)
final_tokens = tokenizer.convert_ids_to_tokens(current_tokens[answer_start:])
final_tokens = [tok for tok in final_tokens if tokenizer.convert_tokens_to_ids(tok) != eot_token_id]
final_output = tokenizer.convert_tokens_to_string(final_tokens)
yield f"<div style='padding:0.5em'><b>Final Output:</b><br><div style='background:#e0ffe0;padding:0.5em;border-radius:0.5em'>{final_output}</div></div>"
# --- Chat Interface ---
demo = gr.ChatInterface(
diffusion_chat,
additional_inputs=[
gr.Textbox(value="You are a helpful assistant.", label="System message"),
gr.Slider(0, 1, value=0.4, step=0.05, label="EOT token weight (lower = longer output)"),
gr.Slider(1, 512, value=64, step=1, label="Max Iterations"),
gr.Slider(1.0, 20.0, value=5.0, step=0.5, label="Noising sharpness (lower = more noise)")
],
title="Diffusion Language Model Chat",
description="Iterative denoising chat interface using a fine-tuned LLaMA model."
)
if __name__ == "__main__":
demo.launch() |