File size: 5,571 Bytes
86d1d69
 
 
466bfc4
86d1d69
 
ada1a42
86d1d69
ada1a42
ddce4ee
 
ada1a42
604c915
c04f5fc
86d1d69
 
aad9da3
86d1d69
5ebc32e
86d1d69
 
 
5ebc32e
86d1d69
534dd9e
86d1d69
 
 
 
 
 
 
 
8757e5a
466bfc4
 
 
 
 
 
 
6f11ed6
86d1d69
5bb345f
86d1d69
bdc902b
51d0233
 
 
 
 
466bfc4
 
5046dd6
ada1a42
51d0233
ed15ecf
86d1d69
 
5ebc32e
86d1d69
5ebc32e
 
 
 
 
 
 
 
 
86d1d69
 
 
5ebc32e
86d1d69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
466bfc4
5ebc32e
86d1d69
5ebc32e
86d1d69
 
466bfc4
86d1d69
c2cde6a
5ebc32e
 
 
 
 
 
 
86d1d69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ebc32e
 
86d1d69
524688e
466bfc4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import torch
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, BitsAndBytesConfig
import os
from threading import Thread
from accelerate import init_empty_weights

max_memory = {
    0: "30GiB", 
    "cpu": "64GiB",  
}
MODEL_LIST = ["THUDM/GLM-4-Z1-32B-0414"]

HF_TOKEN = os.environ.get("HF_TOKEN", None)
MODEL_ID = MODEL_LIST[0]
MODEL_NAME = "GLM-4-Z1-32B-0414"

TITLE = "<h1>3ML-bot (Text Only)</h1>"

DESCRIPTION = f"""
<center>
<p>😊 A Multi-Lingual Analytical Chatbot. 
<br>
🚀 MODEL NOW: <a href="https://hf.co/nikravan/GLM4-Z-0414">{MODEL_NAME}</a>
</center>"""

CSS = """
h1 {
    text-align: center;
    display: block;
}
"""

# Configure BitsAndBytes for 4-bit quantization
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
)

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)

@spaces.GPU()
def stream_chat(message, history: list, temperature: float, max_length: int, top_p: float, top_k: int, penalty: float):
    
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        torch_dtype=torch.bfloat16,
        low_cpu_mem_usage=True,
        trust_remote_code=True,
        quantization_config=quantization_config,
        device_map="auto",
        max_memory=max_memory,
    )
        
    print(f'message is - {message}')
    print(f'history is - {history}')
    
    conversation = []
    if len(history) > 0:
        for prompt, answer in history:
            conversation.extend([
                {"role": "user", "content": prompt},
                {"role": "assistant", "content": answer}
            ])
    
    conversation.append({"role": "user", "content": message})
    
    print(f"Conversation is -\n{conversation}")

    input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True,
                                            return_tensors="pt", return_dict=True).to(model.device)
    streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)

    generate_kwargs = dict(
        max_length=max_length,
        streamer=streamer,
        do_sample=True,
        top_p=top_p,
        top_k=top_k,
        temperature=temperature,
        repetition_penalty=penalty,
        eos_token_id=[151329, 151336, 151338],
    )
    gen_kwargs = {**input_ids, **generate_kwargs}

    with torch.no_grad():
        thread = Thread(target=model.generate, kwargs=gen_kwargs)
        thread.start()
        buffer = ""
        for new_text in streamer:
            buffer += new_text
            yield buffer

chatbot = gr.Chatbot()
chat_input = gr.Textbox(
    interactive=True,
    placeholder="Enter your message here...",
    show_label=False,
)

EXAMPLES = [
    ["Analyze the geopolitical implications of recent technological advancements in AI ."],
    ["¿Cuáles son los desafíos éticos más importantes en el desarrollo de la inteligencia artificial general?"],
    ["从经济学和社会学角度分析,人工智能将如何改变未来的就业市场?"],
    ["ما هي التحديات الرئيسية التي تواجه تطوير الذكاء الاصطناعي في العالم العربي؟"],
    ["नैतिक कृत्रिम बुद्धिमत्ता विकास में सबसे बड़ी चुनौतियाँ क्या हैं? विस्तार से समझाइए।"],
    ["Кои са основните предизвикателства пред разработването на изкуствен интелект в България и Източна Европа?"],
    ["Explain the potential risks and benefits of quantum computing in national security contexts."],
    ["分析气候变化对全球经济不平等的影响,并提出可能的解决方案。"],
]

with gr.Blocks(css=CSS, theme="soft", fill_height=True) as demo:
    gr.HTML(TITLE)
    gr.HTML(DESCRIPTION)
    gr.ChatInterface(
        fn=stream_chat,
        textbox=chat_input,
        chatbot=chatbot,
        fill_height=True,
        additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
        additional_inputs=[
            gr.Slider(
                minimum=0,
                maximum=1,
                step=0.1,
                value=0.8,
                label="Temperature",
                render=False,
            ),
            gr.Slider(
                minimum=1024,
                maximum=8192,
                step=1,
                value=4096,
                label="Max Length",
                render=False,
            ),
            gr.Slider(
                minimum=0.0,
                maximum=1.0,
                step=0.1,
                value=1.0,
                label="top_p",
                render=False,
            ),
            gr.Slider(
                minimum=1,
                maximum=20,
                step=1,
                value=10,
                label="top_k",
                render=False,
            ),
            gr.Slider(
                minimum=0.0,
                maximum=2.0,
                step=0.1,
                value=1.0,
                label="Repetition penalty",
                render=False,
            ),
        ],
        examples=EXAMPLES,
    )

if __name__ == "__main__":
    demo.queue(api_open=False).launch(show_api=False, share=False)