File size: 2,929 Bytes
7997f6a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import gradio as gr
from openai import OpenAI
import uuid
import json
import os
import tempfile

BASE_URL = "http://localhost:8080/v1"
MODEL_NAME = "bn"

cli = OpenAI(api_key="sk-nokey", base_url=BASE_URL)

def openai_call(message, history, system_prompt, max_new_tokens):
    #print(history) # DEBUG
    history.insert(0, {
        "role": "system",
        "content": system_prompt
    })
    history.append({
        "role": "user",
        "content": message
    })
    response = cli.chat.completions.create(
        model=MODEL_NAME,
        messages=history,
        max_tokens=max_new_tokens,
        stop=["<|im_end|>", "</s>"],
        stream=True
    )
    reply = ""
    for chunk in response:
        delta = chunk.choices[0].delta.content
        if delta is not None:
            reply = reply + delta
            yield reply, None
    history.append({ "role": "assistant", "content": reply })
    yield reply, gr.State(history)

def gen_file(conv_state):
    #print(conv_state) # DEBUG
    fname = f"{str(uuid.uuid4())}.json"
    #with tempfile.NamedTemporaryFile(prefix=str(uuid.uuid4()), suffix=".json", mode="w", encoding="utf-8", delete_on_close=False) as f:
    with open(fname, mode="w", encoding="utf-8") as f:
        json.dump(conv_state.value, f, indent=4, ensure_ascii=False)
    return gr.File(fname), gr.State(fname)

def rm_file_wrap(path : str):
    # Try to delete the file.
    try:
        os.remove(path)
    except OSError as e:
        # If it fails, inform the user.
        print("Error: %s - %s." % (e.filename, e.strerror))

def on_download(download_data: gr.DownloadData):
    print(f"deleting {download_data.file.path}")
    rm_file_wrap(download_data.file.path)

def clean_file(orig_path):
    print(f"Deleting {orig_path.value}")
    rm_file_wrap(orig_path.value)

with gr.Blocks() as demo:
    #download=gr.DownloadButton(label="Download Conversation", value=None)
    conv_state = gr.State()
    orig_path = gr.State()
    chat = gr.ChatInterface(
        openai_call,
        type="messages",
        additional_inputs=[
            gr.Textbox("You are a helpful AI assistant.", label="System Prompt"),
            gr.Slider(30, 2048, label="Max new tokens"),
        ],
        additional_outputs=[conv_state],
        title="Chat with bitnet using ik_llama",
        description="Warning: Do not input sensitive info - assume everything is public! Also note this is experimental and ik_llama server doesn't seems to support arbitrary chat template, we're using vicuna as approximate match - so there might be intelligence degradation."
    )
    download_file = gr.File()
    download_btn = gr.Button("Export Conversation for Download") \
        .click(fn=gen_file, inputs=[conv_state], outputs=[download_file, orig_path]) \
        .success(fn=clean_file, inputs=[orig_path])
    download_file.download(on_download, None, None)

demo.queue(max_size=10, api_open=True).launch()