Spaces:
Running on Zero

tini / app.py
Ruurd's picture
Load model differently
9aaa660
raw
history blame
7.37 kB
import gradio as gr
import torch
import numpy as np
import json
import time
from transformers import AutoTokenizer
from llama_diffusion_model import disable_dropout
import os
import importlib
from huggingface_hub import hf_hub_download
hf_token = os.getenv("HF_TOKEN")
# --- Load tokenizer ---
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B", use_fast=True, token=hf_token)
vocab_size = len(tokenizer)
pad_token = tokenizer.pad_token_id or tokenizer.eos_token_id
eot_token_id = tokenizer.eos_token_id
assistant_marker_ids = tokenizer.encode("Assistant:", add_special_tokens=False)
# --- Load token probabilities ---
with open("token_probabilities.json") as f:
token_probs_dict = json.load(f)
token_probabilities = np.array([token_probs_dict[str(i)] for i in range(len(token_probs_dict))], dtype=np.float32)
def load_model():
# 1. Download the checkpoint
checkpoint_path = hf_hub_download(
repo_id="ruurd/diffusion-llama",
filename="diffusion-model.pth",
token=os.getenv("HF_TOKEN")
)
# 2. Prepare dynamic class loading like you did before
torch.serialization.clear_safe_globals()
unsafe_globals = torch.serialization.get_unsafe_globals_in_checkpoint(checkpoint_path)
missing_class_names = [name.split(".")[-1] for name in unsafe_globals]
safe_classes = [cls for name, cls in globals().items() if name in missing_class_names]
for class_path in unsafe_globals:
try:
module_name, class_name = class_path.rsplit(".", 1)
module = importlib.import_module(module_name)
cls = getattr(module, class_name)
safe_classes.append(cls)
except (ImportError, AttributeError) as e:
print(f"⚠️ Warning: Could not import {class_path} - {e}")
torch.serialization.add_safe_globals(safe_classes)
# 3. Actually load the full model
model = torch.load(checkpoint_path, weights_only=True)
# 4. Final setup
model = disable_dropout(model)
model.to("cuda")
model.eval()
return model
rng = np.random.default_rng()
# --- Utility Functions ---
def decode_tokens_safe(token_ids):
return tokenizer.decode(token_ids, skip_special_tokens=True).replace("\n", " ")
def find_answer_start(input_ids, marker_ids):
for i in range(len(input_ids) - len(marker_ids) + 1):
if input_ids[i:i + len(marker_ids)] == marker_ids:
return i + len(marker_ids)
return None
def get_noising_schedule(i, max_it, sharpness=5.0):
x = i / max_it
return (np.exp(-sharpness * x) - np.exp(-sharpness)) / (1 - np.exp(-sharpness))
def noisify_answer(input_ids, answer_start, threshold=1.0, eot_weight=1.0):
noised = input_ids.copy()
answer_len = len(input_ids) - answer_start
num_to_noise = int(threshold * answer_len)
if num_to_noise > 0:
indices = rng.choice(np.arange(answer_start, len(input_ids)), size=num_to_noise, replace=False)
mixed_probs = token_probabilities.copy()
mixed_probs[eot_token_id] *= eot_weight
mixed_probs /= mixed_probs.sum()
noise = rng.choice(np.arange(vocab_size), size=num_to_noise, p=mixed_probs)
for idx, val in zip(indices, noise):
noised[idx] = val
return noised
def generate_diffusion_text(model, input_ids, answer_start):
with torch.no_grad():
input_tensor = torch.tensor([input_ids], dtype=torch.long).to(model.device)
logits = model(input_ids=input_tensor)["logits"]
probs = torch.nn.functional.softmax(logits, dim=-1).squeeze()
probs = torch.clamp(probs, min=1e-8, max=1.0)
sampled = torch.multinomial(probs, num_samples=1).squeeze().tolist()
return input_ids[:answer_start] + sampled[answer_start:]
# --- Inference Wrapper ---
def diffusion_chat(question, eot_weight, max_it, sharpness, model):
placeholder = "What do you know about the city of New York?"
if question.strip() == "":
question = placeholder
prompt = f"User: {question}\nAssistant:"
input_ids = tokenizer.encode(prompt, add_special_tokens=False)
answer_start = find_answer_start(input_ids, assistant_marker_ids)
if answer_start is None:
yield "Error: Could not find Assistant marker in input."
return
if len(input_ids) < 256:
input_ids += [pad_token] * (256 - len(input_ids))
else:
input_ids = input_ids[:256]
ori_input_tokens = input_ids
current_tokens = noisify_answer(ori_input_tokens, answer_start, threshold=1.0, eot_weight=eot_weight)
prev_decoded_tokens = []
last_tokens = []
for i in range(max_it):
generated_tokens = generate_diffusion_text(model, current_tokens, answer_start)
current_tokens = generated_tokens
decoded_ids = current_tokens[answer_start:]
decoded_tokens = tokenizer.convert_ids_to_tokens(decoded_ids)
filtered_tokens = [tok for tok in decoded_tokens if tokenizer.convert_tokens_to_ids(tok) != eot_token_id]
filtered_prev_tokens = [tok for tok in prev_decoded_tokens if tokenizer.convert_tokens_to_ids(tok) != eot_token_id] if prev_decoded_tokens else []
if filtered_prev_tokens:
highlighted = []
for tok_new, tok_old in zip(filtered_tokens, filtered_prev_tokens):
if tok_new != tok_old:
highlighted.append(f'<span style="color:green">{tokenizer.convert_tokens_to_string([tok_new])}</span>')
else:
highlighted.append(tokenizer.convert_tokens_to_string([tok_new]))
else:
highlighted = [tokenizer.convert_tokens_to_string([tok]) for tok in filtered_tokens]
prev_decoded_tokens = decoded_tokens
yield f"<b>Iteration {i+1}/{max_it} (running):</b><br>" + "".join(highlighted)
last_tokens.append(generated_tokens)
if len(last_tokens) > 3:
last_tokens.pop(0)
if len(last_tokens) == 3 and last_tokens[0] == last_tokens[1] == last_tokens[2]:
yield f"<b>Stopped early after {i+1} iterations.</b>"
break
threshold = get_noising_schedule(i, max_it, sharpness=sharpness)
current_tokens = noisify_answer(generated_tokens, answer_start, threshold=threshold, eot_weight=eot_weight)
time.sleep(0.01)
final_tokens = tokenizer.convert_ids_to_tokens(current_tokens[answer_start:])
final_tokens = [tok for tok in final_tokens if tokenizer.convert_tokens_to_ids(tok) != eot_token_id]
final_output = tokenizer.convert_tokens_to_string(final_tokens)
yield f"<b>Final Output (after {i+1} iterations):</b><br>" + final_output
# --- Gradio Interface ---
model_state = gr.State(load_model())
demo = gr.Interface(
fn=diffusion_chat,
inputs=[
gr.Textbox(label="User Question", lines=2, placeholder="What do you know about the city of New York?"),
gr.Slider(0, 1, value=0.4, step=0.05, label="↓ = longer answers (EOT weight)"),
gr.Slider(1, 512, value=64, step=1, label="↑ = more iterations"),
gr.Slider(1.0, 20.0, value=5.0, step=0.5, label="↓ = more noising (sharpness)"),
model_state
],
outputs=gr.HTML(label="Diffusion Output"),
title="Diffusion Language Model Chat",
description="This interface runs a diffusion-based language model to generate answers progressively."
)
demo.launch()