File size: 9,140 Bytes
7252f98
 
 
 
 
 
 
9aaa660
 
183a6ee
42ed840
7252f98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b5abc9b
0e3f268
617b180
9aaa660
 
 
 
b5abc9b
 
 
 
 
 
 
9aaa660
7252f98
 
 
 
 
 
 
 
 
 
 
 
 
 
9d9e261
7252f98
df366f4
7252f98
13b1370
7252f98
13b1370
df366f4
 
13b1370
 
 
 
7252f98
13b1370
7252f98
 
 
 
 
13b1370
7252f98
 
13b1370
7252f98
 
13b1370
2ba8b3f
eca0863
2ba8b3f
cfffc32
 
 
 
 
 
13b1370
6034d83
13b1370
 
 
 
 
 
6034d83
 
 
ae08b25
cfffc32
 
 
 
 
 
 
 
2ba8b3f
 
 
 
cfffc32
 
 
2ba8b3f
 
 
cfffc32
6034d83
ae08b25
92e70ff
7aaa1c3
7252f98
 
 
 
 
 
2ba8b3f
 
 
13b1370
7252f98
3f5293d
df366f4
3f5293d
 
 
 
2ba8b3f
3f5293d
7252f98
 
 
3f5293d
0e1a415
7252f98
3f5293d
 
 
 
 
 
 
7252f98
 
 
 
3f5293d
2ba8b3f
7252f98
 
 
 
 
 
 
3f5293d
 
 
 
 
 
 
 
 
7252f98
 
7c2923c
0e1a415
7252f98
3f5293d
 
 
 
7252f98
 
 
2ba8b3f
eca0863
2ba8b3f
df366f4
2ba8b3f
7252f98
 
 
 
 
3f5293d
7c2923c
3f5293d
 
 
 
 
55b43fa
3f5293d
 
 
 
 
 
2ba8b3f
13b1370
 
df366f4
3f5293d
 
 
 
 
 
 
3f7f1a0
 
f7efac8
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
import gradio as gr
import torch
import numpy as np
import json
import time
from transformers import AutoTokenizer
import os
import importlib
from huggingface_hub import hf_hub_download
from llama_diffusion_model import CustomTransformerModel, CustomTransformerConfig, BidirectionalLlamaAttention, disable_dropout
import spaces

hf_token = os.getenv("HF_TOKEN")


# --- Load tokenizer ---
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B", use_fast=True, token=hf_token)
vocab_size = len(tokenizer)
pad_token = tokenizer.pad_token_id or tokenizer.eos_token_id
eot_token_id = tokenizer.eos_token_id
assistant_marker_ids = tokenizer.encode("Assistant:", add_special_tokens=False)

# --- Load token probabilities ---
with open("token_probabilities.json") as f:
    token_probs_dict = json.load(f)
token_probabilities = np.array([token_probs_dict[str(i)] for i in range(len(token_probs_dict))], dtype=np.float32)

def load_model():
    ckpt_path = hf_hub_download(
        repo_id="ruurd/tini_model",
        filename="diffusion-model.pth",
        token=os.getenv("HF_TOKEN")
    )

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = torch.load(ckpt_path, map_location=device)
    model = disable_dropout(model)
    model.to(device)
    model.eval()
    return model


rng = np.random.default_rng()

# --- Utility Functions ---
def decode_tokens_safe(token_ids):
    return tokenizer.decode(token_ids, skip_special_tokens=True).replace("\n", " ")

def find_answer_start(input_ids, marker_ids):
    for i in range(len(input_ids) - len(marker_ids) + 1):
        if input_ids[i:i + len(marker_ids)] == marker_ids:
            return i + len(marker_ids)
    return None

def get_noising_schedule(i, max_it, sharpness=5.0):
    x = i / max_it
    return (np.exp(-sharpness * x) - np.exp(-sharpness)) / (1 - np.exp(-sharpness))

def noisify_answer(input_ids, answer_start, threshold=1.0, eot_weight=1.0, num_inserts=0):
    noised = input_ids.copy()
    answer_len = len(noised) - answer_start
    num_to_noise = int(threshold * answer_len)

    for _ in range(num_inserts):
        insert_idx = rng.integers(answer_start + 1, len(noised))
        insert_token = rng.choice(np.arange(vocab_size), p=token_probabilities)
        noised = np.concatenate([noised[:insert_idx], [insert_token], noised[insert_idx:]])
        noised = noised[:len(input_ids)]

    if num_to_noise > 0:
        indices = rng.choice(np.arange(answer_start, len(noised)), size=min(num_to_noise, len(noised) - answer_start), replace=False)

        mixed_probs = token_probabilities.copy()
        mixed_probs[eot_token_id] *= eot_weight
        mixed_probs /= mixed_probs.sum()

        noise = rng.choice(np.arange(vocab_size), size=len(indices), p=mixed_probs)
        for idx, val in zip(indices, noise):
            noised[idx] = val

    return noised


# Add new noising function
def confidence_guided_noising(input_ids, answer_start, confidences, threshold, eot_weight, noise_clipping):
    noised = input_ids.copy()
    answer_len = len(input_ids) - answer_start
    num_to_noise = int(threshold * answer_len)

    if num_to_noise == 0:
        return noised

    
    raw_weights = 1.0 - np.array(confidences[answer_start:])

    # Avoid zero-probability weights for selection
    # If noise clipping == 1, all tokens have equal chance to be noised. 
    # If noise_clipping == 0.00001, all tokens are noised according to the confidence of the past prediction
    raw_weights = np.clip(raw_weights, a_min = noise_clipping, a_max = None)

    weights = raw_weights / raw_weights.sum()

    if num_to_noise > len(weights):
        num_to_noise = len(weights)  # prevent oversampling

    indices = rng.choice(
        np.arange(answer_start, len(input_ids)),
        size=num_to_noise,
        replace=False,
        p=weights
    )

    mixed_probs = token_probabilities.copy()
    mixed_probs[eot_token_id] *= eot_weight
    mixed_probs /= mixed_probs.sum()

    noise = rng.choice(np.arange(vocab_size), size=num_to_noise, p=mixed_probs)
    for idx, val in zip(indices, noise):
        noised[idx] = val

    return noised




@spaces.GPU
def generate_diffusion_text(input_ids, answer_start):
    with torch.no_grad():
        input_tensor = torch.tensor([input_ids], dtype=torch.long).to(model.device)
        logits = model(input_ids=input_tensor)["logits"]
        probs = torch.nn.functional.softmax(logits, dim=-1).squeeze()
        probs = torch.clamp(probs, min=1e-8, max=1.0)
        sampled = torch.multinomial(probs, num_samples=1).squeeze().tolist()

        # Extract confidence of selected tokens
        conf = probs[range(len(sampled)), sampled].cpu().numpy()
    return sampled, conf 

# --- Inference Wrapper ---
def diffusion_chat(question, eot_weight, max_it, sharpness, noise_clipping, use_confidence_noising, num_inserts):
    placeholder = "What do you know about the city of New York?"
    if question.strip() == "":
        question = placeholder

    print('started generation')
    prompt = f"User: {question}\nAssistant:"
    input_ids = tokenizer.encode(prompt, add_special_tokens=False)
    answer_start = find_answer_start(input_ids, assistant_marker_ids)
    if answer_start is None:
        yield "Error: Could not find Assistant marker in input."
        return

    if len(input_ids) < 256:
        input_ids += [pad_token] * (256 - len(input_ids))
    else:
        input_ids = input_ids[:256]

    ori_input_tokens = input_ids
    current_tokens = noisify_answer(ori_input_tokens, answer_start, threshold=1.0, eot_weight=eot_weight)
    prev_decoded_tokens = []
    last_tokens = []

    for i in range(max_it):
        print('Generating output')
        generated_tokens, confidences = generate_diffusion_text(current_tokens, answer_start)
        current_tokens = generated_tokens

        decoded_ids = current_tokens[answer_start:]
        decoded_tokens = tokenizer.convert_ids_to_tokens(decoded_ids)
        filtered_tokens = [tok for tok in decoded_tokens if tokenizer.convert_tokens_to_ids(tok) != eot_token_id]
        filtered_prev_tokens = [tok for tok in prev_decoded_tokens if tokenizer.convert_tokens_to_ids(tok) != eot_token_id] if prev_decoded_tokens else []

        if filtered_prev_tokens:
            highlighted = []
            for tok_new, tok_old in zip(filtered_tokens, filtered_prev_tokens):
                if tok_new != tok_old:
                    highlighted.append(f'<span style="color:green">{tokenizer.convert_tokens_to_string([tok_new])}</span>')
                else:
                    highlighted.append(tokenizer.convert_tokens_to_string([tok_new]))
        else:
            highlighted = [tokenizer.convert_tokens_to_string([tok]) for tok in filtered_tokens]

        prev_decoded_tokens = decoded_tokens
        yield f"<b>Iteration {i+1}/{max_it} (running):</b><br>" + "".join(highlighted).replace('\n', '<br>')

        last_tokens.append(generated_tokens)
        if len(last_tokens) > 3:
            last_tokens.pop(0)
        if len(last_tokens) == 3 and last_tokens[0] == last_tokens[1] == last_tokens[2]:
            yield f"<b>Stopped early after {i+1} iterations.</b>"
            break

        threshold = get_noising_schedule(i, max_it, sharpness=sharpness)
        if use_confidence_noising:
            current_tokens = confidence_guided_noising(generated_tokens, answer_start, confidences, threshold, eot_weight, noise_clipping)
        else:
            current_tokens = noisify_answer(generated_tokens, answer_start, threshold=threshold, eot_weight=eot_weight, num_inserts=num_inserts)

        time.sleep(0.01)

    final_tokens = tokenizer.convert_ids_to_tokens(current_tokens[answer_start:])
    final_tokens = [tok for tok in final_tokens if tokenizer.convert_tokens_to_ids(tok) != eot_token_id]
    final_output = tokenizer.convert_tokens_to_string(final_tokens)
    print(final_output)
    yield f"<b>Final Output (after {i+1} iterations):</b><br>" + final_output.replace('\n', '<br>')

# --- Gradio Interface ---
print("Loading model...")
model = load_model()
print("✅ Model loaded.")

demo = gr.Interface(
    fn=diffusion_chat,
    inputs=[
        gr.Textbox(label="User Question", lines=2, placeholder="What do you know about the city of New York?"),
        gr.Slider(0, 1, value=0.4, step=0.05, label="↓ = longer answers (EOT weight)"),
        gr.Slider(1, 512, value=64, step=1, label="↑ = more iterations"),
        gr.Slider(1.0, 20.0, value=5.0, step=0.5, label="↓ = more noising (sharpness)"),
        gr.Slider(0.01, 1.0, value=0.05, step=0.01, label="↓ = more confidence guidance (noise clipping)"),
        gr.Checkbox(value=False, label="Use confidence-guided noising"),
        gr.Slider(0, 100, value=0, step=1, label="Number of tokens to insert randomly (↓ = less structural change)")
    ],
    outputs=[gr.HTML(label="Diffusion Output")],
    title="Diffusion Language Model Chat",
    theme="default",
    description="This interface runs a diffusion-based language model to generate answers progressively."
)

demo.launch(share=True, allowed_paths=["."], ssr_mode=False)