File size: 6,618 Bytes
7252f98
 
 
 
 
 
9aaa660
7252f98
9aaa660
 
183a6ee
42ed840
7252f98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b5abc9b
0e3f268
617b180
9aaa660
 
 
 
b5abc9b
 
 
 
 
 
 
9aaa660
7252f98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7aaa1c3
7252f98
 
 
 
 
 
 
 
 
b5abc9b
84a6c46
 
 
 
 
 
 
 
 
b5abc9b
84a6c46
7252f98
 
 
 
 
 
 
 
84a6c46
7252f98
84a6c46
 
7252f98
 
84a6c46
7252f98
 
7aaa1c3
7252f98
 
 
 
 
 
 
84a6c46
 
 
 
 
 
 
7252f98
 
 
84a6c46
7252f98
 
 
 
 
 
 
 
 
84a6c46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b5abc9b
84a6c46
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import gradio as gr
import torch
import numpy as np
import json
import time
from transformers import AutoTokenizer
from llama_diffusion_model import disable_dropout
import os
import importlib
from huggingface_hub import hf_hub_download
from llama_diffusion_model import CustomTransformerModel, CustomTransformerConfig, BidirectionalLlamaAttention, disable_dropout
import spaces

hf_token = os.getenv("HF_TOKEN")


# --- Load tokenizer ---
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B", use_fast=True, token=hf_token)
vocab_size = len(tokenizer)
pad_token = tokenizer.pad_token_id or tokenizer.eos_token_id
eot_token_id = tokenizer.eos_token_id
assistant_marker_ids = tokenizer.encode("Assistant:", add_special_tokens=False)

# --- Load token probabilities ---
with open("token_probabilities.json") as f:
    token_probs_dict = json.load(f)
token_probabilities = np.array([token_probs_dict[str(i)] for i in range(len(token_probs_dict))], dtype=np.float32)

def load_model():
    ckpt_path = hf_hub_download(
        repo_id="ruurd/tini_model",
        filename="diffusion-model.pth",
        token=os.getenv("HF_TOKEN")
    )

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = torch.load(ckpt_path, map_location=device)
    model = disable_dropout(model)
    model.to(device)
    model.eval()
    return model


rng = np.random.default_rng()

# --- Utility Functions ---
def decode_tokens_safe(token_ids):
    return tokenizer.decode(token_ids, skip_special_tokens=True).replace("\n", " ")

def find_answer_start(input_ids, marker_ids):
    for i in range(len(input_ids) - len(marker_ids) + 1):
        if input_ids[i:i + len(marker_ids)] == marker_ids:
            return i + len(marker_ids)
    return None

def get_noising_schedule(i, max_it, sharpness=5.0):
    x = i / max_it
    return (np.exp(-sharpness * x) - np.exp(-sharpness)) / (1 - np.exp(-sharpness))

def noisify_answer(input_ids, answer_start, threshold=1.0, eot_weight=1.0):
    noised = input_ids.copy()
    answer_len = len(input_ids) - answer_start
    num_to_noise = int(threshold * answer_len)
    if num_to_noise > 0:
        indices = rng.choice(np.arange(answer_start, len(input_ids)), size=num_to_noise, replace=False)

        mixed_probs = token_probabilities.copy()
        mixed_probs[eot_token_id] *= eot_weight
        mixed_probs /= mixed_probs.sum()

        noise = rng.choice(np.arange(vocab_size), size=num_to_noise, p=mixed_probs)
        for idx, val in zip(indices, noise):
            noised[idx] = val
    return noised

def generate_diffusion_text(input_ids, answer_start):
    with torch.no_grad():
        input_tensor = torch.tensor([input_ids], dtype=torch.long).to(model.device)
        logits = model(input_ids=input_tensor)["logits"]
        probs = torch.nn.functional.softmax(logits, dim=-1).squeeze()
        probs = torch.clamp(probs, min=1e-8, max=1.0)
        sampled = torch.multinomial(probs, num_samples=1).squeeze().tolist()
    return input_ids[:answer_start] + sampled[answer_start:]

# --- Inference Wrapper ---



# --- Gradio Interface ---

print("Loading model...")
model = load_model()
print("✅ Model loaded.")

# --- Generation logic ---
@spaces.GPU
def run_diffusion_loop(question, eot_weight, max_it, sharpness):
    placeholder = "What do you know about the city of New York?"
    if question.strip() == "":
        question = placeholder

    prompt = f"User: {question}\nAssistant:"
    input_ids = tokenizer.encode(prompt, add_special_tokens=False)
    answer_start = find_answer_start(input_ids, assistant_marker_ids)
    if answer_start is None:
        return [], "Error: Could not find Assistant marker in input."

    input_ids = (input_ids + [pad_token] * (256 - len(input_ids)))[:256]
    current_tokens = noisify_answer(input_ids, answer_start, threshold=1.0, eot_weight=eot_weight)
    prev_decoded_tokens = []
    last_tokens = []
    history = ["**User:** " + question]

    for i in range(max_it):
        generated_tokens = generate_diffusion_text(current_tokens, answer_start)
        current_tokens = generated_tokens

        decoded_ids = current_tokens[answer_start:]
        decoded_tokens = tokenizer.convert_ids_to_tokens(decoded_ids)
        filtered_tokens = [tok for tok in decoded_tokens if tokenizer.convert_tokens_to_ids(tok) != eot_token_id]
        filtered_prev_tokens = [tok for tok in prev_decoded_tokens if tokenizer.convert_tokens_to_ids(tok) != eot_token_id] if prev_decoded_tokens else []

        highlighted = []
        for tok_new, tok_old in zip(filtered_tokens, filtered_prev_tokens):
            text = tokenizer.convert_tokens_to_string([tok_new])
            if tok_new != tok_old:
                highlighted.append(f"<span style='color:green'>{text}</span>")
            else:
                highlighted.append(text)

        prev_decoded_tokens = decoded_tokens
        last_tokens.append(generated_tokens)
        if len(last_tokens) == 3 and all(t == last_tokens[0] for t in last_tokens):
            break

        threshold = get_noising_schedule(i, max_it, sharpness=sharpness)
        current_tokens = noisify_answer(generated_tokens, answer_start, threshold=threshold, eot_weight=eot_weight)
        time.sleep(0.01)

    final_tokens = tokenizer.convert_ids_to_tokens(current_tokens[answer_start:])
    final_tokens = [tok for tok in final_tokens if tokenizer.convert_tokens_to_ids(tok) != eot_token_id]
    final_output = tokenizer.convert_tokens_to_string(final_tokens)
    history.append("**Assistant:** " + final_output)
    return history, final_output

# --- UI Layout ---
css = ".category-legend{display:none}"
with gr.Blocks(css=css) as demo:
    gr.Markdown("# Tini Diffusion LLM 🌀")
    with gr.Row():
        with gr.Column(scale=3):
            chatbox = gr.Chatbot(label="Conversation", value=[], height=400)
            question_input = gr.Textbox(label="Your Question", placeholder="What do you want to ask?", scale=8)
            send_btn = gr.Button("Generate")
        with gr.Column(scale=2):
            eot_weight = gr.Slider(0, 1, value=0.4, step=0.05, label="EOT weight")
            max_iters = gr.Slider(1, 512, value=64, step=1, label="Iterations")
            sharpness = gr.Slider(1.0, 20.0, value=5.0, step=0.5, label="Sharpness")

    def handle_submit(question, eot, max_it, sharp):
        history, _ = run_diffusion_loop(question, eot, max_it, sharp)
        return history

    send_btn.click(
        fn=handle_submit,
        inputs=[question_input, eot_weight, max_iters, sharpness],
        outputs=[chatbox]
    )

demo.queue().launch()