File size: 5,711 Bytes
078013c
 
 
c84e5ec
 
4ce621a
 
078013c
304493f
078013c
 
c84e5ec
 
078013c
 
4ce621a
078013c
 
 
4ce621a
c5892bf
304493f
4ce621a
078013c
 
 
c84e5ec
66a9100
 
 
 
 
 
 
 
 
 
 
 
 
 
078013c
 
c84e5ec
 
 
 
 
 
 
 
 
 
 
078013c
c84e5ec
 
 
 
 
 
 
 
 
078013c
c84e5ec
 
078013c
c84e5ec
 
66a9100
078013c
66a9100
 
c84e5ec
 
078013c
 
c84e5ec
 
 
 
 
 
 
078013c
 
 
c84e5ec
 
 
 
 
 
 
 
 
078013c
 
 
c84e5ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
078013c
 
c84e5ec
 
4ce621a
b1e1add
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
# πŸ€–βš‘ β–„β–€ [ I M P O R T S ]

import accelerate
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig

# πŸ§ πŸ”§ β–„β–€ [ M O D E L ]

microsoft_model = None
microsoft_tokenizer = None

def load_model():
    global microsoft_model, microsoft_tokenizer
    if microsoft_model is None or microsoft_tokenizer is None:
        model_id = "microsoft/bitnet-b1.58-2B-4T"
        microsoft_tokenizer = AutoTokenizer.from_pretrained(model_id)
        config = AutoConfig.from_pretrained(model_id)
        microsoft_model = AutoModelForCausalLM.from_pretrained(
            model_id,
            config=config,
            torch_dtype=torch.bfloat16
        )
    return microsoft_model, microsoft_tokenizer
    
# πŸ—‚οΈπŸ•°οΈ β–„β–€ [ C O N V E R S A T I O N - H I S T O R Y ]

def manage_history(history):
    # Limit to 3 turns (each turn is user + assistant = 2 messages)
    max_messages = 6  # 3 turns * 2 messages per turn
    if len(history) > max_messages:
        history = history[-max_messages:]
    
    # Limit total character count to 300
    total_chars = sum(len(msg["content"]) for msg in history)
    while total_chars > 300 and history:
        history.pop(0)  # Remove oldest message
        total_chars = sum(len(msg["content"]) for msg in history)
    
    return history

# πŸ’¬βœ¨ β–„β–€ [ G E N E R A T E - R E S P O N S E ]

def generate_response(user_input, system_prompt, max_new_tokens, temperature, top_p, top_k, history):
    model, tokenizer = load_model()
    
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_input},
    ]
    
    prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    chat_input = tokenizer(prompt, return_tensors="pt").to(model.device)
    
    # Generate Response
    chat_outputs = model.generate(
        **chat_input,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        top_p=top_p,
        top_k=top_k,
        do_sample=True
    )
    
    # Decode Response
    response = tokenizer.decode(chat_outputs[0][chat_input['input_ids'].shape[-1]:], skip_special_tokens=True)
    
    # Update History
    history.append({"role": "user", "content": user_input})
    history.append({"role": "assistant", "content": response})
    
    # Manage History Limits
    history = manage_history(history)
    
    return history, history

# πŸŽ›οΈπŸ–₯️ β–„β–€ [ G R A D I O - I N T E R F A C E ]

with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("# BitNet b1.58 2B4T Demo")
    
    with gr.Row():
        with gr.Column():
            gr.Markdown("""
            ## About BitNet b1.58 2B4T
            BitNet b1.58 2B4T is the first open-source, native 1-bit Large Language Model with 2 billion parameters, 
            developed by Microsoft Research. Trained on 4 trillion tokens, it matches the performance of full-precision
            models while offering significant efficiency gains in memory, energy, and latency. Features include:
            - Transformer-based architecture with BitLinear layers
            - Native 1.58-bit weights and 8-bit activations
            - Maximum context length of 4096 tokens
            - Optimized for efficient inference with bitnet.cpp
            """)
        
        with gr.Column():
            gr.Markdown("""
            ## About Tonic AI
            Tonic AI is a vibrant community of AI enthusiasts and developers always building cool demos and pushing
            the boundaries of what's possible with AI. We're passionate about creating innovative, accessible, and 
            engaging AI experiences for everyone. Join us in exploring the future of AI!
            """)
    
    with gr.Row():
        with gr.Column():
            user_input = gr.Textbox(label="Your Message", placeholder="Type your message here...")
            system_prompt = gr.Textbox(
                label="System Prompt",
                value="You are a helpful AI assistant.",
                placeholder="Enter system prompt..."
            )
            
            with gr.Accordion("Advanced Options", open=False):
                max_new_tokens = gr.Slider(
                    minimum=10,
                    maximum=500,
                    value=50,
                    step=10,
                    label="Max New Tokens"
                )
                temperature = gr.Slider(
                    minimum=0.1,
                    maximum=2.0,
                    value=0.7,
                    step=0.1,
                    label="Temperature"
                )
                top_p = gr.Slider(
                    minimum=0.1,
                    maximum=1.0,
                    value=0.9,
                    step=0.05,
                    label="Top P"
                )
                top_k = gr.Slider(
                    minimum=1,
                    maximum=100,
                    value=50,
                    step=1,
                    label="Top K"
                )
            
            submit_btn = gr.Button("Send")
        
        with gr.Column():
            chatbot = gr.Chatbot(label="Conversation", type="messages")
    
    chat_history = gr.State([])
    
    submit_btn.click(
        fn=generate_response,
        inputs=[
            user_input,
            system_prompt,
            max_new_tokens,
            temperature,
            top_p,
            top_k,
            chat_history
        ],
        outputs=[chatbot, chat_history]
    )
    
# πŸš€πŸ”₯ β–„β–€ [ M A I N ]

if __name__ == "__main__":
    load_model()
    demo.launch(ssr_mode=False, share=False)