import gradio as gr import torch import concurrent.futures from transformers import AutoTokenizer, AutoModelForCausalLM # Load the model and tokenizer (using GPT-2 as an example) model_name = "gpt2" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name) model.eval() torch.set_num_threads(2) def min_p_sampling(logits, pbase=0.1): """ Perform min-p sampling on the logits. As described in https://arxiv.org/abs/2407.01082 Args: logits (torch.Tensor): 1D tensor of logits for the next token. pbase (float): Base probability to scale pmax. Returns: int: The sampled token index. """ # Convert logits to probabilities. probs = torch.softmax(logits, dim=-1) # 1. Find maximum probability. pmax = probs.max() # 2. Compute the dynamic threshold. pscaled = pbase * pmax # 3. Create a mask of tokens with probability >= pscaled. mask = probs >= pscaled # In the unlikely event that no token meets the threshold, use the full distribution. if mask.sum() == 0: mask = torch.ones_like(probs, dtype=torch.bool) probs_filtered = probs * mask.float() # 4. Normalize and sample. probs_normalized = probs_filtered / probs_filtered.sum() sampled_index = torch.multinomial(probs_normalized, num_samples=1) return sampled_index.item() def generate_laconic_completion(prompt: str, n: int = 5, max_length: int = 100): # generate n completions greedily and return the shortest one with torch.no_grad(): # Encode the prompt and get the attention mask. encoded = tokenizer(prompt, return_tensors="pt") input_ids = encoded["input_ids"] attention_mask = encoded["attention_mask"] # Generate the output. outputs = model.generate( input_ids, attention_mask=attention_mask, max_length=max_length, num_return_sequences=n, do_sample=True, ) completions = [ tokenizer.decode(output, skip_special_tokens=True) for output in outputs ] return min(completions, key=len) def generate_with_confidence(input_ids, max_length): """ Generate a sequence using greedy decoding while returning the scores. """ outputs = model.generate( input_ids, max_length=max_length, do_sample=False, output_scores=True, return_dict_in_generate=True, ) return outputs def compute_answer_confidence(outputs): """ Compute the answer confidence over the generated tokens. For each generated token, compute the difference between the top-1 and top-2 logits. Returns the average difference. """ diffs = [] for score in outputs.scores: # Get top-2 logit values top2 = torch.topk(score[0], 2) diff = top2.values[0] - top2.values[1] diffs.append(diff.item()) return sum(diffs) / len(diffs) if diffs else 0.0 def cot_decoding(prompt, k=5, max_length=100): """ Perform Chain-of-Thought (CoT) decoding by exploring top-k alternative paths. """ input_ids = tokenizer.encode(prompt, return_tensors="pt") # Get logits for the next token with torch.no_grad(): outputs = model(input_ids) logits = outputs.logits[0, -1, :] # Get top-k candidate tokens topk = torch.topk(logits, k) candidate_tokens = topk.indices paths = [] for token in candidate_tokens: # Append the candidate token to the prompt new_input_ids = torch.cat([input_ids, token.view(1, 1)], dim=1) # Generate a full sequence with output scores gen_outputs = generate_with_confidence( new_input_ids, max_length=new_input_ids.shape[1] + max_length ) # Decode the generated sequence generated_text = tokenizer.decode( gen_outputs.sequences[0], skip_special_tokens=True ) # Compute answer confidence confidence = compute_answer_confidence(gen_outputs) paths.append({"text": generated_text, "confidence": confidence}) return max(paths, key=lambda x: x["confidence"])["text"] def generate_completion(prompt, strategy, params): """ Generate a complete answer using model.generate with specified parameters. """ with torch.no_grad(): # Encode the prompt and get the attention mask. encoded = tokenizer(prompt, return_tensors="pt") input_ids = encoded["input_ids"] attention_mask = encoded["attention_mask"] # Generate the output. output_ids = model.generate( input_ids, attention_mask=attention_mask, max_length=100, **params ) return tokenizer.decode(output_ids[0], skip_special_tokens=True) def generate_min_p_completion(prompt, pbase=0.1, max_length=100): input_ids = tokenizer.encode(prompt, return_tensors="pt") past = None with torch.no_grad(): for _ in range(max_length - input_ids.size(1)): # Only pass the last token if past is available outputs = ( model(input_ids[:, -1:], past_key_values=past) if past is not None else model(input_ids) ) past = outputs.past_key_values logits = outputs.logits[:, -1, :] next_token = min_p_sampling(logits, pbase=pbase) input_ids = torch.cat([input_ids, torch.tensor([[next_token]])], dim=-1) if next_token == tokenizer.eos_token_id: break return tokenizer.decode(input_ids[0], skip_special_tokens=True) def generate_all(prompt): """ Run multiple decoding strategies concurrently and yield updates as each completes. """ # Define each decoding strategy and its parameters. methods = { "Greedy": {"type": "default", "params": {"do_sample": False}}, "Top-k Sampling": { "type": "default", "params": {"do_sample": True, "top_k": 100}, }, "Top-p Sampling": { "type": "default", "params": {"do_sample": True, "top_p": 0.95}, }, "Beam Search": { "type": "default", "params": {"num_beams": 5, "early_stopping": True}, }, "Eta Sampling": { "type": "default", "params": {"do_sample": True, "eta_cutoff": 0.3}, }, "Epsilon Sampling": { "type": "default", "params": {"do_sample": True, "epsilon_cutoff": 0.2}, }, "Min-p Sampling": {"type": "min_p", "pbase": 0.1}, "laconic": { "type": "default", "params": {"do_sample": True, "num_return_sequences": 5}, }, "COT Decoding": { "type": "cot_decoding", "params": {"k": 5, "max_length": 100}, }, } # Define the order for display. method_order = [ "Greedy", "Top-k Sampling", "Top-p Sampling", "Beam Search", "Min-p Sampling", "Eta Sampling", "Epsilon Sampling", "laconic", "COT Decoding", ] results = {method: None for method in methods} # Yield an initial placeholder state. yield tuple("Processing..." for _ in method_order) # Use a thread pool to run each generation concurrently. with concurrent.futures.ThreadPoolExecutor() as executor: future_to_method = {} for method, info in methods.items(): if info["type"] == "default": future = executor.submit( generate_completion, prompt, method, info["params"] ) elif info["type"] == "min_p": future = executor.submit( generate_min_p_completion, prompt, info["pbase"] ) elif method == "laconic": future = executor.submit(generate_laconic_completion, prompt) elif method == "COT Decoding": future = executor.submit(cot_decoding, prompt, **info["params"]) future_to_method[future] = method # As each future completes, update its result and yield the current state. for future in concurrent.futures.as_completed(future_to_method): method = future_to_method[future] try: result = future.result() except Exception as exc: result = f"Error: {exc}" results[method] = result # Yield the results in the pre-defined order; pending methods show "Processing..." yield tuple( results[m] if results[m] is not None else "Processing..." for m in method_order ) # Create the Gradio interface. interface = gr.Interface( fn=generate_all, inputs=gr.Textbox(lines=3, placeholder="Enter your prompt here...", label="Prompt"), outputs=[ gr.Textbox(label="Greedy"), gr.Textbox(label="Top-k Sampling"), gr.Textbox(label="Top-p Sampling"), gr.Textbox(label="Beam Search"), gr.Textbox(label="Min-p Sampling (as in https://arxiv.org/abs/2407.01082)"), gr.Textbox(label="Eta Sampling"), gr.Textbox(label="Epsilon Sampling"), gr.Textbox( label="laconic decoding (by Alex Dimakis, 2025, search for twitter thread)" ), gr.Textbox( label="COT Decoding (Chain-of-Thought Reasoning without Prompting, Wang, Zhou, 2024)" ), ], title="Decoding Methods Comparison", description="Each decoding method's final answer is printed as soon as it is done. Model used: GPT-2.", ) if __name__ == "__main__": interface.launch()