tini / app.py
Ruurd's picture
Change version of loading model
098132b
raw
history blame
6.65 kB
import gradio as gr
import torch
import numpy as np
import json
import time
from transformers import AutoTokenizer
from llama_diffusion_model import disable_dropout
import os
import importlib
from huggingface_hub import hf_hub_download
from llama_diffusion_model import CustomTransformerModel, CustomTransformerConfig, BidirectionalLlamaAttention, disable_dropout
import spaces
hf_token = os.getenv("HF_TOKEN")
# --- Load tokenizer ---
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B", use_fast=True, token=hf_token)
vocab_size = len(tokenizer)
pad_token = tokenizer.pad_token_id or tokenizer.eos_token_id
eot_token_id = tokenizer.eos_token_id
assistant_marker_ids = tokenizer.encode("Assistant:", add_special_tokens=False)
# --- Load token probabilities ---
with open("token_probabilities.json") as f:
token_probs_dict = json.load(f)
token_probabilities = np.array([token_probs_dict[str(i)] for i in range(len(token_probs_dict))], dtype=np.float32)
@spaces.GPU
def load_weights():
# OK: download & load weights to CPU
ckpt_path = hf_hub_download(
repo_id="ruurd/tini_model",
filename="diffusion-model.pth",
token=os.getenv("HF_TOKEN")
)
return torch.load(ckpt_path, map_location="cpu") # ✅ returns only CPU tensors
model = CustomTransformerModel(...)
model.load_state_dict(load_weights())
model.to("cuda") # ✅ OK now, after @spaces.GPU is done
rng = np.random.default_rng()
# --- Utility Functions ---
def decode_tokens_safe(token_ids):
return tokenizer.decode(token_ids, skip_special_tokens=True).replace("\n", " ")
def find_answer_start(input_ids, marker_ids):
for i in range(len(input_ids) - len(marker_ids) + 1):
if input_ids[i:i + len(marker_ids)] == marker_ids:
return i + len(marker_ids)
return None
def get_noising_schedule(i, max_it, sharpness=5.0):
x = i / max_it
return (np.exp(-sharpness * x) - np.exp(-sharpness)) / (1 - np.exp(-sharpness))
def noisify_answer(input_ids, answer_start, threshold=1.0, eot_weight=1.0):
noised = input_ids.copy()
answer_len = len(input_ids) - answer_start
num_to_noise = int(threshold * answer_len)
if num_to_noise > 0:
indices = rng.choice(np.arange(answer_start, len(input_ids)), size=num_to_noise, replace=False)
mixed_probs = token_probabilities.copy()
mixed_probs[eot_token_id] *= eot_weight
mixed_probs /= mixed_probs.sum()
noise = rng.choice(np.arange(vocab_size), size=num_to_noise, p=mixed_probs)
for idx, val in zip(indices, noise):
noised[idx] = val
return noised
def generate_diffusion_text(input_ids, answer_start):
with torch.no_grad():
input_tensor = torch.tensor([input_ids], dtype=torch.long).to(model.device)
logits = model(input_ids=input_tensor)["logits"]
probs = torch.nn.functional.softmax(logits, dim=-1).squeeze()
probs = torch.clamp(probs, min=1e-8, max=1.0)
sampled = torch.multinomial(probs, num_samples=1).squeeze().tolist()
return input_ids[:answer_start] + sampled[answer_start:]
# --- Inference Wrapper ---
def diffusion_chat(question, eot_weight, max_it, sharpness):
placeholder = "What do you know about the city of New York?"
if question.strip() == "":
question = placeholder
prompt = f"User: {question}\nAssistant:"
input_ids = tokenizer.encode(prompt, add_special_tokens=False)
answer_start = find_answer_start(input_ids, assistant_marker_ids)
if answer_start is None:
yield "Error: Could not find Assistant marker in input."
return
if len(input_ids) < 256:
input_ids += [pad_token] * (256 - len(input_ids))
else:
input_ids = input_ids[:256]
ori_input_tokens = input_ids
current_tokens = noisify_answer(ori_input_tokens, answer_start, threshold=1.0, eot_weight=eot_weight)
prev_decoded_tokens = []
last_tokens = []
for i in range(max_it):
generated_tokens = generate_diffusion_text(current_tokens, answer_start)
current_tokens = generated_tokens
decoded_ids = current_tokens[answer_start:]
decoded_tokens = tokenizer.convert_ids_to_tokens(decoded_ids)
filtered_tokens = [tok for tok in decoded_tokens if tokenizer.convert_tokens_to_ids(tok) != eot_token_id]
filtered_prev_tokens = [tok for tok in prev_decoded_tokens if tokenizer.convert_tokens_to_ids(tok) != eot_token_id] if prev_decoded_tokens else []
if filtered_prev_tokens:
highlighted = []
for tok_new, tok_old in zip(filtered_tokens, filtered_prev_tokens):
if tok_new != tok_old:
highlighted.append(f'<span style="color:green">{tokenizer.convert_tokens_to_string([tok_new])}</span>')
else:
highlighted.append(tokenizer.convert_tokens_to_string([tok_new]))
else:
highlighted = [tokenizer.convert_tokens_to_string([tok]) for tok in filtered_tokens]
prev_decoded_tokens = decoded_tokens
yield f"<b>Iteration {i+1}/{max_it} (running):</b><br>" + "".join(highlighted)
last_tokens.append(generated_tokens)
if len(last_tokens) > 3:
last_tokens.pop(0)
if len(last_tokens) == 3 and last_tokens[0] == last_tokens[1] == last_tokens[2]:
yield f"<b>Stopped early after {i+1} iterations.</b>"
break
threshold = get_noising_schedule(i, max_it, sharpness=sharpness)
current_tokens = noisify_answer(generated_tokens, answer_start, threshold=threshold, eot_weight=eot_weight)
time.sleep(0.01)
final_tokens = tokenizer.convert_ids_to_tokens(current_tokens[answer_start:])
final_tokens = [tok for tok in final_tokens if tokenizer.convert_tokens_to_ids(tok) != eot_token_id]
final_output = tokenizer.convert_tokens_to_string(final_tokens)
yield f"<b>Final Output (after {i+1} iterations):</b><br>" + final_output
# --- Gradio Interface ---
demo = gr.Interface(
fn=diffusion_chat,
inputs=[
gr.Textbox(label="User Question", lines=2, placeholder="What do you know about the city of New York?"),
gr.Slider(0, 1, value=0.4, step=0.05, label="↓ = longer answers (EOT weight)"),
gr.Slider(1, 512, value=64, step=1, label="↑ = more iterations"),
gr.Slider(1.0, 20.0, value=5.0, step=0.5, label="↓ = more noising (sharpness)")
],
outputs=gr.HTML(label="Diffusion Output"),
title="Diffusion Language Model Chat",
description="This interface runs a diffusion-based language model to generate answers progressively."
)
demo.launch()