File size: 4,602 Bytes
fad4129
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b68748
fad4129
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread

# --- Configuration ---
MODEL_ID = "microsoft/bitnet-b1.58-2B-4T"
# Try 'cuda' if you have a GPU space, 'cpu' otherwise (will be slow)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")

# --- Load Model and Tokenizer ---
# Note: Loading might require specific trust_remote_code=True or other flags
# depending on the model implementation. Check the model card on Hugging Face.
# You might also need specific quantization configs if not handled automatically.
try:
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
    # Adjust loading parameters as needed (e.g., torch_dtype, device_map)
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        torch_dtype=torch.bfloat16, # Or float16, adjust based on hardware/model reqs
        device_map="auto",         # Automatically distribute across available devices (GPU/CPU)
        trust_remote_code=True # May be required for some custom model code
    )
    # model.to(DEVICE) # Usually handled by device_map="auto"
    print("Model and tokenizer loaded successfully.")
except Exception as e:
    print(f"Error loading model or tokenizer: {e}")
    # Fallback or exit if loading fails
    raise SystemExit("Failed to load model/tokenizer.")

# --- Chat Processing Function ---
def predict(message, history):
    """
    Generates a response to the user's message using the chat history.
    """
    history_transformer_format = []
    for human, assistant in history:
        # Basic alternating format - adjust if the model expects something different
        history_transformer_format.append({"role": "user", "content": human})
        history_transformer_format.append({"role": "assistant", "content": assistant})

    # Add the current user message
    history_transformer_format.append({"role": "user", "content": message})

    # Use the tokenizer's chat template if available, otherwise manual formatting.
    # Base models might not have a specific chat template.
    try:
        prompt = tokenizer.apply_chat_template(
            history_transformer_format,
            tokenize=False,
            add_generation_prompt=True # Important for generation
        )
    except Exception:
        # Manual fallback prompt formatting (Example - adjust as needed!)
        print("Warning: Using basic manual prompt formatting.")
        prompt_parts = ["Chat History:"]
        for turn in history_transformer_format:
            prompt_parts.append(f"{turn['role'].capitalize()}: {turn['content']}")
        prompt = "\n".join(prompt_parts) + "\nAssistant:" # Ensure it ends ready for generation

    print(f"\n--- Prompt Sent to Model ---\n{prompt}\n---------------------------\n")

    # Use a streamer for interactive generation
    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    generation_kwargs = dict(
        inputs,
        streamer=streamer,
        max_new_tokens=512,
        do_sample=True,
        top_p=0.9,
        temperature=0.7,
        # Add other generation parameters as needed
        # eos_token_id=tokenizer.eos_token_id # Important if model needs it
        pad_token_id=tokenizer.eos_token_id # Often set for open-end generation
    )

    # Run generation in a separate thread for streaming
    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()

    # Yield tokens as they become available
    partial_message = ""
    for new_token in streamer:
        partial_message += new_token
        yield partial_message

# --- Gradio Interface ---
# Use gr.ChatInterface - it handles history management automatically
chatbot_interface = gr.ChatInterface(
    fn=predict,
    chatbot=gr.Chatbot(height=500),
    textbox=gr.Textbox(placeholder="Ask me anything...", container=False, scale=7),
    title="Chat with microsoft/bitnet-b1.58-2B-4T",
    description="A basic chat interface for the BitNet 1.58-bit 2B parameter model. Remember it's a base model, so prompting matters!",
    theme="soft",
    examples=[["Hello!"], ["Explain the concept of 1.58-bit quantization."]],
    cache_examples=False, # Set to True to cache example results
    retry_btn=None,
    undo_btn="Delete Previous Turn",
    clear_btn="Clear Chat",
)

# --- Launch the Interface ---
if __name__ == "__main__":
    chatbot_interface.launch() # Use share=True for public link if running locally