Mihai commited on
Commit
9591850
·
1 Parent(s): ed0697d

Initial commit

Browse files
Files changed (2) hide show
  1. app.py +107 -0
  2. requirements.txt +1 -0
app.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ from backtrack_sampler import BacktrackSampler, CreativeWritingStrategy
4
+ from backtrack_sampler.provider.transformers_provider import TransformersProvider
5
+ import torch
6
+ import asyncio
7
+
8
+ description = """## Compare Creative Writing: Custom Sampler vs. Backtrack Sampler with Creative Writing Strategy
9
+ This is a demo of [Backtrack Sampler](https://github.com/Mihaiii/backtrack_sampler) using one of its algorithms named "Creative Writing Strategy".
10
+ <br />On the left you have the output of the standard sampling and on the write the output privided by Backtrack Sampler.
11
+ """
12
+ # Load tokenizer
13
+ model_name = "unsloth/Llama-3.2-1B-Instruct"
14
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
15
+
16
+ # Load two instances of the model on CUDA for parallel inference
17
+ model1 = AutoModelForCausalLM.from_pretrained(model_name).to("cuda")
18
+
19
+ model2 = AutoModelForCausalLM.from_pretrained(model_name)
20
+ device = torch.device('cuda')
21
+
22
+ strategy = CreativeWritingStrategy(top_p_flat = 0.8, top_k_threshold_flat = 2, min_prob_second_highest = 0.2)
23
+ provider = TransformersProvider(model2, tokenizer, device)
24
+ creative_sampler = BacktrackSampler(strategy, provider)
25
+
26
+ # Helper function to create message array for the chat template
27
+ def create_chat_template_messages(history, prompt):
28
+ messages = [{"role": "user", "content": prompt}]
29
+
30
+ for i, (input_text, response_text) in enumerate(history):
31
+ messages.append({"role": "user" if i % 2 == 0 else "assistant", "content": input_text})
32
+ messages.append({"role": "assistant", "content": response_text})
33
+
34
+ return messages
35
+
36
+ # Async function for generating responses using two models
37
+ @spaces.GPU(duration=60)
38
+ async def generate_responses(prompt, history):
39
+ # Create messages array for chat history and apply template
40
+ messages = create_chat_template_messages(history, prompt)
41
+ wrapped_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_special_tokens=True, add_generation_prompt=True)
42
+
43
+ #already has special tokens
44
+ inputs = tokenizer.encode(wrapped_prompt, add_special_tokens=False, return_tensors="pt").to("cuda")
45
+ # Standard sampler task
46
+ standard_task = asyncio.to_thread(
47
+ model1.generate, inputs, max_length=2048, temperature=1
48
+ )
49
+
50
+ # Custom sampler task: loop over generator and collect outputs in a list
51
+ async def custom_sampler_task():
52
+ generated_list = []
53
+ generator = creative_sampler.generate(wrapped_prompt, max_length=2048, temperature=1)
54
+ for token in generator:
55
+ generated_list.append(token)
56
+ return tokenizer.decode(generated_list, skip_special_tokens=True)
57
+
58
+ # Wait for both responses
59
+ standard_output, custom_output = await asyncio.gather(standard_task, custom_sampler_task())
60
+ # Decode standard output and remove the prompt from the generated response
61
+ standard_response = tokenizer.decode(standard_output[0][len(inputs[0]):], skip_special_tokens=True)
62
+
63
+ return standard_response.strip(), custom_output.strip()
64
+
65
+ # Create the Gradio interface with the Citrus theme
66
+ with gr.Blocks(theme=gr.themes.Citrus()) as demo:
67
+ gr.Markdown(description)
68
+
69
+ # Chatbot components
70
+ with gr.Row():
71
+ standard_chat = gr.Chatbot(label="Standard Sampler")
72
+ custom_chat = gr.Chatbot(label="Creative Writing Strategy")
73
+
74
+ # Input components
75
+ with gr.Row():
76
+ prompt_input = gr.Textbox(label="Enter your prompt", placeholder="Type your message here...", lines=1)
77
+
78
+ # Example prompts
79
+ examples = [
80
+ "Write me a short story about a talking dog who wants to be a detective.",
81
+ "Tell me a short tale of a dragon who is afraid of heights.",
82
+ "Create a short story where aliens land on Earth, but they just want to throw a party."
83
+ ]
84
+
85
+ # Add example buttons
86
+ gr.Examples(examples=examples, inputs=prompt_input)
87
+
88
+ # Button to submit the prompt
89
+ submit_button = gr.Button("Submit")
90
+
91
+ # Function to handle chat updates
92
+ async def update_chat(prompt, standard_history, custom_history):
93
+ standard_response, custom_response = await generate_responses(prompt, standard_history)
94
+
95
+ # Append new responses to chat histories
96
+ standard_history = standard_history + [(prompt, standard_response)]
97
+ custom_history = custom_history + [(prompt, custom_response)]
98
+
99
+ # Clear the input field after submission
100
+ return standard_history, custom_history, ""
101
+
102
+ # Bind the submit button to the update function and allow pressing Enter to submit
103
+ prompt_input.submit(fn=update_chat, inputs=[prompt_input, standard_chat, custom_chat], outputs=[standard_chat, custom_chat, prompt_input])
104
+ submit_button.click(fn=update_chat, inputs=[prompt_input, standard_chat, custom_chat], outputs=[standard_chat, custom_chat, prompt_input])
105
+
106
+ # Launch the app with queueing and sharing enabled
107
+ demo.queue().launch(share=True, debug=True)
requirements.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ backtrack_sampler