File size: 2,636 Bytes
7831eba
 
9d49e57
7831eba
 
 
a7d91d4
 
408d3e1
a7d91d4
93ac29b
8baca64
408d3e1
7831eba
 
 
8baca64
7831eba
 
 
 
 
b5fab19
8baca64
7831eba
0cd27a0
 
 
7831eba
 
 
 
 
 
 
 
 
 
 
8baca64
 
408d3e1
7831eba
 
408d3e1
 
 
8baca64
408d3e1
 
 
 
 
7831eba
 
 
 
 
 
 
 
 
8baca64
7831eba
 
 
8baca64
0cd27a0
7831eba
0f82a6a
7831eba
 
 
 
 
 
 
8baca64
7831eba
 
 
 
 
408d3e1
d8d19ad
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import gradio as gr
from huggingface_hub import InferenceClient
import os
"""
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
"""
import requests

from openai import OpenAI

clients = {'3B': [OpenAI(api_key='123', base_url=os.getenv('MODEL_NAME_OR_PATH_3B')), 'RefalMachine/ruadapt_qwen2.5_3B_ext_u48_instruct'],
        '7B (work in progress)': [OpenAI(api_key='123', base_url=os.getenv('MODEL_NAME_OR_PATH_7B')), 'RefalMachine/ruadapt_qwen2.5_7B_ext_u48_instruct']}
#client = InferenceClient(os.getenv('MODEL_NAME_OR_PATH'))


def respond(
    model_name,
    message,
    history: list[tuple[str, str]],
    system_message,
    max_tokens,
    temperature,
    top_p,
    repetition_penalty
):
    messages = []
    if len(system_message.strip()) > 0:
        messages = [{"role": "system", "content": system_message}]

    for val in history:
        if val[0]:
            messages.append({"role": "user", "content": val[0]})
        if val[1]:
            messages.append({"role": "assistant", "content": val[1]})

    messages.append({"role": "user", "content": message})

    response = ""

    res = clients[model_name][0].chat.completions.create(
        model=clients[model_name][1],
        messages=messages,
        temperature=temperature,
        top_p=top_p,
        max_tokens=max_tokens,
        stream=True,
        extra_body={
            "repetition_penalty": repetition_penalty,
            "add_generation_prompt": True,
        }
    )

    for message in res:
        token = message.choices[0].delta.content

        response += token
        yield response


"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
options = ["3B", "7B (work in progress)"]
demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Radio(choices=options, label="Model:", value=options[0])
        gr.Textbox(value="", label="System message"),
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.3, step=0.1, label="Temperature"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (nucleus sampling)",
        ),
        gr.Slider(minimum=0.9, maximum=1.2, value=1.0, step=0.05, label="repetition_penalty"),
    ],
)


if __name__ == "__main__":
    #print(requests.get(os.getenv('MODEL_NAME_OR_PATH')[:-3] + '/docs'))
    demo.launch(share=True)