WillHeld commited on
Commit
ab30e5f
·
verified ·
1 Parent(s): a60fda2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -3
app.py CHANGED
@@ -8,15 +8,28 @@ tokenizer = AutoTokenizer.from_pretrained(checkpoint)
8
  model = AutoModelForCausalLM.from_pretrained(checkpoint).to(device)
9
 
10
  @spaces.GPU(duration=120)
11
- def predict(message, history):
12
  history.append({"role": "user", "content": message})
13
  input_text = tokenizer.apply_chat_template(history, tokenize=False, add_generation_prompt=True)
14
  inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
15
- outputs = model.generate(inputs, max_new_tokens=1024, temperature=0.7, top_p=0.9, do_sample=True)
 
 
 
 
 
 
16
  decoded = tokenizer.decode(outputs[0])
17
  response = decoded.split("<|assistant|>")[-1]
18
  return response
19
 
20
- demo = gr.ChatInterface(predict, type="messages")
 
 
 
 
 
 
 
21
 
22
  demo.launch()
 
8
  model = AutoModelForCausalLM.from_pretrained(checkpoint).to(device)
9
 
10
  @spaces.GPU(duration=120)
11
+ def predict(message, history, temperature, top_p):
12
  history.append({"role": "user", "content": message})
13
  input_text = tokenizer.apply_chat_template(history, tokenize=False, add_generation_prompt=True)
14
  inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
15
+ outputs = model.generate(
16
+ inputs,
17
+ max_new_tokens=1024,
18
+ temperature=temperature,
19
+ top_p=top_p,
20
+ do_sample=True
21
+ )
22
  decoded = tokenizer.decode(outputs[0])
23
  response = decoded.split("<|assistant|>")[-1]
24
  return response
25
 
26
+ with gr.Blocks() as demo:
27
+ chatbot = gr.ChatInterface(
28
+ predict,
29
+ additional_inputs=[
30
+ gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature"),
31
+ gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-P")
32
+ ]
33
+ )
34
 
35
  demo.launch()