File size: 3,707 Bytes
baba7b1
 
 
21a124f
c03785c
 
 
 
 
 
 
 
87775b4
c03785c
 
 
 
 
 
 
 
 
 
 
87775b4
c03785c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import os

os.system("pip install git+https://github.com/shumingma/transformers.git")

import threading
import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TextIteratorStreamer,
)
import gradio as gr
import spaces

# Load model and tokenizer
model_id = "microsoft/bitnet-b1.58-2B-4T"

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto"
)

@spaces.GPU
def respond(
    message: str,
    history: list[tuple[str, str]],
    system_message: str,
    max_tokens: int,
    temperature: float,
    top_p: float,
):
    """
    Generate a chat response using streaming with TextIteratorStreamer.

    Args:
        message: User's current message.
        history: List of (user, assistant) tuples from previous turns.
        system_message: Initial system prompt guiding the assistant.
        max_tokens: Maximum number of tokens to generate.
        temperature: Sampling temperature.
        top_p: Nucleus sampling probability.

    Yields:
        The growing response text as new tokens are generated.
    """
    # Assemble messages
    messages = [{"role": "system", "content": system_message}]
    for user_msg, bot_msg in history:
        if user_msg:
            messages.append({"role": "user", "content": user_msg})
        if bot_msg:
            messages.append({"role": "assistant", "content": bot_msg})
    messages.append({"role": "user", "content": message})

    # Prepare prompt and tokenize
    prompt = tokenizer.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    # Set up streamer for real-time output
    streamer = TextIteratorStreamer(
        tokenizer, skip_prompt=True, skip_special_tokens=True
    )
    generate_kwargs = dict(
        **inputs,
        streamer=streamer,
        max_new_tokens=max_tokens,
        temperature=temperature,
        top_p=top_p,
        do_sample=True,
    )
    # Start generation in a separate thread
    thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
    thread.start()

    # Stream tokens back to user
    response = ""
    for new_text in streamer:
        response += new_text
        yield response

# Initialize Gradio chat interface

demo = gr.ChatInterface(
    fn=respond,
    title="Bitnet-b1.58-2B-4T Chatbot",
    description="This chat application is powered by Microsoft BitNet-B1 and designed for natural conversations.",
    examples=[
        # Each example: [message, system_message, max_new_tokens, temperature, top_p]
        [
            "Hello! How are you?",
            "You are a helpful AI assistant.",
            512,
            0.7,
            0.95,
        ],
        [
            "Can you code a snake game in Python?",
            "You are a helpful AI assistant.",
            512,
            0.7,
            0.95,
        ],
    ],
    additional_inputs=[
        gr.Textbox(
            value="You are a helpful AI assistant.",
            label="System message"
        ),
        gr.Slider(
            minimum=1,
            maximum=2048,
            value=512,
            step=1,
            label="Max new tokens"
        ),
        gr.Slider(
            minimum=0.1,
            maximum=4.0,
            value=0.7,
            step=0.1,
            label="Temperature"
        ),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (nucleus sampling)"
        ),
    ],
)

if __name__ == "__main__":
    demo.launch()