lemonteaa commited on
Commit
7997f6a
·
verified ·
1 Parent(s): 48a1839

Create chat_demo.py

Browse files
Files changed (1) hide show
  1. chat_demo.py +84 -0
chat_demo.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from openai import OpenAI
3
+ import uuid
4
+ import json
5
+ import os
6
+ import tempfile
7
+
8
+ BASE_URL = "http://localhost:8080/v1"
9
+ MODEL_NAME = "bn"
10
+
11
+ cli = OpenAI(api_key="sk-nokey", base_url=BASE_URL)
12
+
13
+ def openai_call(message, history, system_prompt, max_new_tokens):
14
+ #print(history) # DEBUG
15
+ history.insert(0, {
16
+ "role": "system",
17
+ "content": system_prompt
18
+ })
19
+ history.append({
20
+ "role": "user",
21
+ "content": message
22
+ })
23
+ response = cli.chat.completions.create(
24
+ model=MODEL_NAME,
25
+ messages=history,
26
+ max_tokens=max_new_tokens,
27
+ stop=["<|im_end|>", "</s>"],
28
+ stream=True
29
+ )
30
+ reply = ""
31
+ for chunk in response:
32
+ delta = chunk.choices[0].delta.content
33
+ if delta is not None:
34
+ reply = reply + delta
35
+ yield reply, None
36
+ history.append({ "role": "assistant", "content": reply })
37
+ yield reply, gr.State(history)
38
+
39
+ def gen_file(conv_state):
40
+ #print(conv_state) # DEBUG
41
+ fname = f"{str(uuid.uuid4())}.json"
42
+ #with tempfile.NamedTemporaryFile(prefix=str(uuid.uuid4()), suffix=".json", mode="w", encoding="utf-8", delete_on_close=False) as f:
43
+ with open(fname, mode="w", encoding="utf-8") as f:
44
+ json.dump(conv_state.value, f, indent=4, ensure_ascii=False)
45
+ return gr.File(fname), gr.State(fname)
46
+
47
+ def rm_file_wrap(path : str):
48
+ # Try to delete the file.
49
+ try:
50
+ os.remove(path)
51
+ except OSError as e:
52
+ # If it fails, inform the user.
53
+ print("Error: %s - %s." % (e.filename, e.strerror))
54
+
55
+ def on_download(download_data: gr.DownloadData):
56
+ print(f"deleting {download_data.file.path}")
57
+ rm_file_wrap(download_data.file.path)
58
+
59
+ def clean_file(orig_path):
60
+ print(f"Deleting {orig_path.value}")
61
+ rm_file_wrap(orig_path.value)
62
+
63
+ with gr.Blocks() as demo:
64
+ #download=gr.DownloadButton(label="Download Conversation", value=None)
65
+ conv_state = gr.State()
66
+ orig_path = gr.State()
67
+ chat = gr.ChatInterface(
68
+ openai_call,
69
+ type="messages",
70
+ additional_inputs=[
71
+ gr.Textbox("You are a helpful AI assistant.", label="System Prompt"),
72
+ gr.Slider(30, 2048, label="Max new tokens"),
73
+ ],
74
+ additional_outputs=[conv_state],
75
+ title="Chat with bitnet using ik_llama",
76
+ description="Warning: Do not input sensitive info - assume everything is public! Also note this is experimental and ik_llama server doesn't seems to support arbitrary chat template, we're using vicuna as approximate match - so there might be intelligence degradation."
77
+ )
78
+ download_file = gr.File()
79
+ download_btn = gr.Button("Export Conversation for Download") \
80
+ .click(fn=gen_file, inputs=[conv_state], outputs=[download_file, orig_path]) \
81
+ .success(fn=clean_file, inputs=[orig_path])
82
+ download_file.download(on_download, None, None)
83
+
84
+ demo.queue(max_size=10, api_open=True).launch()