bitnet-space / app.py
aaurelions's picture
Update app.py
9b68748 verified
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread
# --- Configuration ---
MODEL_ID = "microsoft/bitnet-b1.58-2B-4T"
# Try 'cuda' if you have a GPU space, 'cpu' otherwise (will be slow)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")
# --- Load Model and Tokenizer ---
# Note: Loading might require specific trust_remote_code=True or other flags
# depending on the model implementation. Check the model card on Hugging Face.
# You might also need specific quantization configs if not handled automatically.
try:
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
# Adjust loading parameters as needed (e.g., torch_dtype, device_map)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.bfloat16, # Or float16, adjust based on hardware/model reqs
device_map="auto", # Automatically distribute across available devices (GPU/CPU)
trust_remote_code=True # May be required for some custom model code
)
# model.to(DEVICE) # Usually handled by device_map="auto"
print("Model and tokenizer loaded successfully.")
except Exception as e:
print(f"Error loading model or tokenizer: {e}")
# Fallback or exit if loading fails
raise SystemExit("Failed to load model/tokenizer.")
# --- Chat Processing Function ---
def predict(message, history):
"""
Generates a response to the user's message using the chat history.
"""
history_transformer_format = []
for human, assistant in history:
# Basic alternating format - adjust if the model expects something different
history_transformer_format.append({"role": "user", "content": human})
history_transformer_format.append({"role": "assistant", "content": assistant})
# Add the current user message
history_transformer_format.append({"role": "user", "content": message})
# Use the tokenizer's chat template if available, otherwise manual formatting.
# Base models might not have a specific chat template.
try:
prompt = tokenizer.apply_chat_template(
history_transformer_format,
tokenize=False,
add_generation_prompt=True # Important for generation
)
except Exception:
# Manual fallback prompt formatting (Example - adjust as needed!)
print("Warning: Using basic manual prompt formatting.")
prompt_parts = ["Chat History:"]
for turn in history_transformer_format:
prompt_parts.append(f"{turn['role'].capitalize()}: {turn['content']}")
prompt = "\n".join(prompt_parts) + "\nAssistant:" # Ensure it ends ready for generation
print(f"\n--- Prompt Sent to Model ---\n{prompt}\n---------------------------\n")
# Use a streamer for interactive generation
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
generation_kwargs = dict(
inputs,
streamer=streamer,
max_new_tokens=512,
do_sample=True,
top_p=0.9,
temperature=0.7,
# Add other generation parameters as needed
# eos_token_id=tokenizer.eos_token_id # Important if model needs it
pad_token_id=tokenizer.eos_token_id # Often set for open-end generation
)
# Run generation in a separate thread for streaming
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
# Yield tokens as they become available
partial_message = ""
for new_token in streamer:
partial_message += new_token
yield partial_message
# --- Gradio Interface ---
# Use gr.ChatInterface - it handles history management automatically
chatbot_interface = gr.ChatInterface(
fn=predict,
chatbot=gr.Chatbot(height=500),
textbox=gr.Textbox(placeholder="Ask me anything...", container=False, scale=7),
title="Chat with microsoft/bitnet-b1.58-2B-4T",
description="A basic chat interface for the BitNet 1.58-bit 2B parameter model. Remember it's a base model, so prompting matters!",
theme="soft",
examples=[["Hello!"], ["Explain the concept of 1.58-bit quantization."]],
cache_examples=False, # Set to True to cache example results
retry_btn=None,
undo_btn="Delete Previous Turn",
clear_btn="Clear Chat",
)
# --- Launch the Interface ---
if __name__ == "__main__":
chatbot_interface.launch() # Use share=True for public link if running locally