import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer from threading import Thread # --- Configuration --- MODEL_ID = "microsoft/bitnet-b1.58-2B-4T" # Try 'cuda' if you have a GPU space, 'cpu' otherwise (will be slow) DEVICE = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {DEVICE}") # --- Load Model and Tokenizer --- # Note: Loading might require specific trust_remote_code=True or other flags # depending on the model implementation. Check the model card on Hugging Face. # You might also need specific quantization configs if not handled automatically. try: tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) # Adjust loading parameters as needed (e.g., torch_dtype, device_map) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype=torch.bfloat16, # Or float16, adjust based on hardware/model reqs device_map="auto", # Automatically distribute across available devices (GPU/CPU) trust_remote_code=True # May be required for some custom model code ) # model.to(DEVICE) # Usually handled by device_map="auto" print("Model and tokenizer loaded successfully.") except Exception as e: print(f"Error loading model or tokenizer: {e}") # Fallback or exit if loading fails raise SystemExit("Failed to load model/tokenizer.") # --- Chat Processing Function --- def predict(message, history): """ Generates a response to the user's message using the chat history. """ history_transformer_format = [] for human, assistant in history: # Basic alternating format - adjust if the model expects something different history_transformer_format.append({"role": "user", "content": human}) history_transformer_format.append({"role": "assistant", "content": assistant}) # Add the current user message history_transformer_format.append({"role": "user", "content": message}) # Use the tokenizer's chat template if available, otherwise manual formatting. # Base models might not have a specific chat template. try: prompt = tokenizer.apply_chat_template( history_transformer_format, tokenize=False, add_generation_prompt=True # Important for generation ) except Exception: # Manual fallback prompt formatting (Example - adjust as needed!) print("Warning: Using basic manual prompt formatting.") prompt_parts = ["Chat History:"] for turn in history_transformer_format: prompt_parts.append(f"{turn['role'].capitalize()}: {turn['content']}") prompt = "\n".join(prompt_parts) + "\nAssistant:" # Ensure it ends ready for generation print(f"\n--- Prompt Sent to Model ---\n{prompt}\n---------------------------\n") # Use a streamer for interactive generation streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) inputs = tokenizer(prompt, return_tensors="pt").to(model.device) generation_kwargs = dict( inputs, streamer=streamer, max_new_tokens=512, do_sample=True, top_p=0.9, temperature=0.7, # Add other generation parameters as needed # eos_token_id=tokenizer.eos_token_id # Important if model needs it pad_token_id=tokenizer.eos_token_id # Often set for open-end generation ) # Run generation in a separate thread for streaming thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() # Yield tokens as they become available partial_message = "" for new_token in streamer: partial_message += new_token yield partial_message # --- Gradio Interface --- # Use gr.ChatInterface - it handles history management automatically chatbot_interface = gr.ChatInterface( fn=predict, chatbot=gr.Chatbot(height=500), textbox=gr.Textbox(placeholder="Ask me anything...", container=False, scale=7), title="Chat with microsoft/bitnet-b1.58-2B-4T", description="A basic chat interface for the BitNet 1.58-bit 2B parameter model. Remember it's a base model, so prompting matters!", theme="soft", examples=[["Hello!"], ["Explain the concept of 1.58-bit quantization."]], cache_examples=False, # Set to True to cache example results retry_btn=None, undo_btn="Delete Previous Turn", clear_btn="Clear Chat", ) # --- Launch the Interface --- if __name__ == "__main__": chatbot_interface.launch() # Use share=True for public link if running locally