Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
from threading import Thread | |
# --- Configuration --- | |
MODEL_ID = "microsoft/bitnet-b1.58-2B-4T" | |
# Try 'cuda' if you have a GPU space, 'cpu' otherwise (will be slow) | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Using device: {DEVICE}") | |
# --- Load Model and Tokenizer --- | |
# Note: Loading might require specific trust_remote_code=True or other flags | |
# depending on the model implementation. Check the model card on Hugging Face. | |
# You might also need specific quantization configs if not handled automatically. | |
try: | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
# Adjust loading parameters as needed (e.g., torch_dtype, device_map) | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_ID, | |
torch_dtype=torch.bfloat16, # Or float16, adjust based on hardware/model reqs | |
device_map="auto", # Automatically distribute across available devices (GPU/CPU) | |
trust_remote_code=True # May be required for some custom model code | |
) | |
# model.to(DEVICE) # Usually handled by device_map="auto" | |
print("Model and tokenizer loaded successfully.") | |
except Exception as e: | |
print(f"Error loading model or tokenizer: {e}") | |
# Fallback or exit if loading fails | |
raise SystemExit("Failed to load model/tokenizer.") | |
# --- Chat Processing Function --- | |
def predict(message, history): | |
""" | |
Generates a response to the user's message using the chat history. | |
""" | |
history_transformer_format = [] | |
for human, assistant in history: | |
# Basic alternating format - adjust if the model expects something different | |
history_transformer_format.append({"role": "user", "content": human}) | |
history_transformer_format.append({"role": "assistant", "content": assistant}) | |
# Add the current user message | |
history_transformer_format.append({"role": "user", "content": message}) | |
# Use the tokenizer's chat template if available, otherwise manual formatting. | |
# Base models might not have a specific chat template. | |
try: | |
prompt = tokenizer.apply_chat_template( | |
history_transformer_format, | |
tokenize=False, | |
add_generation_prompt=True # Important for generation | |
) | |
except Exception: | |
# Manual fallback prompt formatting (Example - adjust as needed!) | |
print("Warning: Using basic manual prompt formatting.") | |
prompt_parts = ["Chat History:"] | |
for turn in history_transformer_format: | |
prompt_parts.append(f"{turn['role'].capitalize()}: {turn['content']}") | |
prompt = "\n".join(prompt_parts) + "\nAssistant:" # Ensure it ends ready for generation | |
print(f"\n--- Prompt Sent to Model ---\n{prompt}\n---------------------------\n") | |
# Use a streamer for interactive generation | |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
generation_kwargs = dict( | |
inputs, | |
streamer=streamer, | |
max_new_tokens=512, | |
do_sample=True, | |
top_p=0.9, | |
temperature=0.7, | |
# Add other generation parameters as needed | |
# eos_token_id=tokenizer.eos_token_id # Important if model needs it | |
pad_token_id=tokenizer.eos_token_id # Often set for open-end generation | |
) | |
# Run generation in a separate thread for streaming | |
thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
thread.start() | |
# Yield tokens as they become available | |
partial_message = "" | |
for new_token in streamer: | |
partial_message += new_token | |
yield partial_message | |
# --- Gradio Interface --- | |
# Use gr.ChatInterface - it handles history management automatically | |
chatbot_interface = gr.ChatInterface( | |
fn=predict, | |
chatbot=gr.Chatbot(height=500), | |
textbox=gr.Textbox(placeholder="Ask me anything...", container=False, scale=7), | |
title="Chat with microsoft/bitnet-b1.58-2B-4T", | |
description="A basic chat interface for the BitNet 1.58-bit 2B parameter model. Remember it's a base model, so prompting matters!", | |
theme="soft", | |
examples=[["Hello!"], ["Explain the concept of 1.58-bit quantization."]], | |
cache_examples=False, # Set to True to cache example results | |
retry_btn=None, | |
undo_btn="Delete Previous Turn", | |
clear_btn="Clear Chat", | |
) | |
# --- Launch the Interface --- | |
if __name__ == "__main__": | |
chatbot_interface.launch() # Use share=True for public link if running locally |