Mihaiii commited on
Commit
d1e1697
·
verified ·
1 Parent(s): 3e181c7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -18
app.py CHANGED
@@ -10,18 +10,17 @@ description = """## Compare Creative Writing: Standard Sampler vs. Backtrack Sam
10
  This is a demo of the [Backtrack Sampler](https://github.com/Mihaiii/backtrack_sampler) framework using "Creative Writing Strategy".
11
  <br />On the left you have the output of the standard sampling and on the write the output privided by Backtrack Sampler.
12
  """
13
- # Load tokenizer
14
  model_name = "unsloth/Llama-3.2-1B-Instruct"
 
15
  tokenizer = AutoTokenizer.from_pretrained(model_name)
16
 
17
- # Load two instances of the model on CUDA for parallel inference
18
  model = AutoModelForCausalLM.from_pretrained(model_name).to("cuda")
19
 
20
  provider = TransformersProvider(model, tokenizer, device)
21
  strategy = CreativeWritingStrategy(provider)
22
  creative_sampler = BacktrackSampler(strategy, provider)
23
 
24
- # Helper function to create message array for the chat template
25
  def create_chat_template_messages(history, prompt):
26
  messages = [{"role": "user", "content": prompt}]
27
 
@@ -33,14 +32,11 @@ def create_chat_template_messages(history, prompt):
33
 
34
  @spaces.GPU
35
  def generate_responses(prompt, history):
36
- # Create messages array for chat history and apply template
37
  messages = create_chat_template_messages(history, prompt)
38
  wrapped_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_special_tokens=True, add_generation_prompt=True)
39
 
40
- #already has special tokens
41
  inputs = tokenizer.encode(wrapped_prompt, add_special_tokens=False, return_tensors="pt").to("cuda")
42
 
43
- # Custom sampler task: loop over generator and collect outputs in a list
44
  async def custom_sampler_task():
45
  generated_list = []
46
  generator = creative_sampler.generate(wrapped_prompt, max_length=2048, temperature=1)
@@ -50,51 +46,39 @@ def generate_responses(prompt, history):
50
 
51
  custom_output = asyncio.run(custom_sampler_task())
52
  standard_output = model.generate(inputs, max_length=2048, temperature=1)
53
- # Decode standard output and remove the prompt from the generated response
54
  standard_response = tokenizer.decode(standard_output[0][len(inputs[0]):], skip_special_tokens=True)
55
 
56
  return standard_response.strip(), custom_output.strip()
57
 
58
- # Create the Gradio interface with the Citrus theme
59
  with gr.Blocks(theme=gr.themes.Citrus()) as demo:
60
  gr.Markdown(description)
61
 
62
- # Chatbot components
63
  with gr.Row():
64
  standard_chat = gr.Chatbot(label="Standard Sampler")
65
  custom_chat = gr.Chatbot(label="Creative Writing Strategy")
66
 
67
- # Input components
68
  with gr.Row():
69
  prompt_input = gr.Textbox(label="Enter your prompt", placeholder="Type your message here...", lines=1)
70
 
71
- # Example prompts
72
  examples = [
73
  "Write me a short story about a talking dog who wants to be a detective.",
74
  "Tell me a short tale of a dragon who is afraid of heights.",
75
  "Create a short story where aliens land on Earth, but they just want to throw a party."
76
  ]
77
 
78
- # Add example buttons
79
  gr.Examples(examples=examples, inputs=prompt_input)
80
 
81
- # Button to submit the prompt
82
  submit_button = gr.Button("Submit")
83
 
84
- # Function to handle chat updates
85
  def update_chat(prompt, standard_history, custom_history):
86
  standard_response, custom_response = generate_responses(prompt, standard_history)
87
 
88
- # Append new responses to chat histories
89
  standard_history = standard_history + [(prompt, standard_response)]
90
  custom_history = custom_history + [(prompt, custom_response)]
91
 
92
- # Clear the input field after submission
93
  return standard_history, custom_history, ""
94
 
95
- # Bind the submit button to the update function and allow pressing Enter to submit
96
  prompt_input.submit(fn=update_chat, inputs=[prompt_input, standard_chat, custom_chat], outputs=[standard_chat, custom_chat, prompt_input])
97
  submit_button.click(fn=update_chat, inputs=[prompt_input, standard_chat, custom_chat], outputs=[standard_chat, custom_chat, prompt_input])
98
 
99
- # Launch the app with queueing and sharing enabled
100
  demo.queue().launch(debug=True)
 
10
  This is a demo of the [Backtrack Sampler](https://github.com/Mihaiii/backtrack_sampler) framework using "Creative Writing Strategy".
11
  <br />On the left you have the output of the standard sampling and on the write the output privided by Backtrack Sampler.
12
  """
13
+
14
  model_name = "unsloth/Llama-3.2-1B-Instruct"
15
+ device = torch.device('cuda')
16
  tokenizer = AutoTokenizer.from_pretrained(model_name)
17
 
 
18
  model = AutoModelForCausalLM.from_pretrained(model_name).to("cuda")
19
 
20
  provider = TransformersProvider(model, tokenizer, device)
21
  strategy = CreativeWritingStrategy(provider)
22
  creative_sampler = BacktrackSampler(strategy, provider)
23
 
 
24
  def create_chat_template_messages(history, prompt):
25
  messages = [{"role": "user", "content": prompt}]
26
 
 
32
 
33
  @spaces.GPU
34
  def generate_responses(prompt, history):
 
35
  messages = create_chat_template_messages(history, prompt)
36
  wrapped_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_special_tokens=True, add_generation_prompt=True)
37
 
 
38
  inputs = tokenizer.encode(wrapped_prompt, add_special_tokens=False, return_tensors="pt").to("cuda")
39
 
 
40
  async def custom_sampler_task():
41
  generated_list = []
42
  generator = creative_sampler.generate(wrapped_prompt, max_length=2048, temperature=1)
 
46
 
47
  custom_output = asyncio.run(custom_sampler_task())
48
  standard_output = model.generate(inputs, max_length=2048, temperature=1)
 
49
  standard_response = tokenizer.decode(standard_output[0][len(inputs[0]):], skip_special_tokens=True)
50
 
51
  return standard_response.strip(), custom_output.strip()
52
 
 
53
  with gr.Blocks(theme=gr.themes.Citrus()) as demo:
54
  gr.Markdown(description)
55
 
 
56
  with gr.Row():
57
  standard_chat = gr.Chatbot(label="Standard Sampler")
58
  custom_chat = gr.Chatbot(label="Creative Writing Strategy")
59
 
 
60
  with gr.Row():
61
  prompt_input = gr.Textbox(label="Enter your prompt", placeholder="Type your message here...", lines=1)
62
 
 
63
  examples = [
64
  "Write me a short story about a talking dog who wants to be a detective.",
65
  "Tell me a short tale of a dragon who is afraid of heights.",
66
  "Create a short story where aliens land on Earth, but they just want to throw a party."
67
  ]
68
 
 
69
  gr.Examples(examples=examples, inputs=prompt_input)
70
 
 
71
  submit_button = gr.Button("Submit")
72
 
 
73
  def update_chat(prompt, standard_history, custom_history):
74
  standard_response, custom_response = generate_responses(prompt, standard_history)
75
 
 
76
  standard_history = standard_history + [(prompt, standard_response)]
77
  custom_history = custom_history + [(prompt, custom_response)]
78
 
 
79
  return standard_history, custom_history, ""
80
 
 
81
  prompt_input.submit(fn=update_chat, inputs=[prompt_input, standard_chat, custom_chat], outputs=[standard_chat, custom_chat, prompt_input])
82
  submit_button.click(fn=update_chat, inputs=[prompt_input, standard_chat, custom_chat], outputs=[standard_chat, custom_chat, prompt_input])
83
 
 
84
  demo.queue().launch(debug=True)