File size: 2,389 Bytes
8b04d55
dfce08c
 
8b04d55
dfce08c
da7102c
8b04d55
ca73284
e233fc5
dfce08c
e233fc5
dfce08c
 
 
 
ca73284
8b04d55
 
 
 
dfce08c
 
 
da7102c
8b04d55
 
 
 
 
 
 
 
 
 
dfce08c
 
8b04d55
dfce08c
 
 
 
 
 
 
 
 
 
da7102c
dfce08c
 
 
 
8b04d55
dfce08c
 
8b04d55
 
 
 
 
 
 
dfce08c
da7102c
 
8b04d55
 
 
 
 
 
 
da7102c
8b04d55
 
 
 
 
e233fc5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import spaces

# Load model and tokenizer
model_name = "yuchenlin/Rex-v0.1-0.5B"

device = "cuda" # the device to load the model onto
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, rex_size=3)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype="auto"
)
model.to(device)

@spaces.GPU(enable_queue=True)
def respond(
    message,
    history: list[tuple[str, str]],
    system_message,
    max_tokens=512,
    temperature=0.5,
    top_p=1.0,
    repetition_penalty=1.1,
):
    messages = [{"role": "system", "content": system_message}]

    for val in history:
        if val[0]:
            messages.append({"role": "user", "content": val[0]})
        if val[1]:
            messages.append({"role": "assistant", "content": val[1]})

    messages.append({"role": "user", "content": message})
 
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )
    model_inputs = tokenizer([text], return_tensors="pt").to(device)

    generated_ids = model.generate(
        model_inputs.input_ids,
        max_new_tokens = max_tokens,
        temperature = temperature,
        top_p = top_p,
        repetition_penalty=repetition_penalty,
    )
    generated_ids = [
        output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
    ]

    response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    return response

"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(value="You are a helpful AI assistant and your name is RexLM.", label="System message"),
        gr.Slider(minimum=1, maximum=4096, value=1024, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.5, step=0.1, label="Temperature"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (nucleus sampling)",
        ),
        gr.Slider(minimum=0.5, maximum=1.5, value=1.1, step=0.1, label="Repetation Penalty"),
    ],
)


if __name__ == "__main__":
    demo.launch(share=True)