Spaces:
Build error
Build error
import gradio as gr | |
import subprocess | |
import os | |
import json | |
from huggingface_hub import hf_hub_download | |
import logging | |
# Setup logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Paths | |
BITNET_BINARY = "/home/user/app/bitnet.cpp/build/bin/main" | |
MODEL_DIR = "/home/user/app/models" | |
MODEL_PATH = os.path.join(MODEL_DIR, "bitnet-b1.58-2B-4T.gguf") | |
MODEL_REPO = "microsoft/bitnet-b1.58-2B-4T-gguf" | |
MODEL_FILE = "bitnet-b1.58-2B-4T.gguf" | |
# Download model weights if not present | |
def download_model(): | |
if not os.path.exists(MODEL_PATH): | |
logger.info("Downloading model weights...") | |
os.makedirs(MODEL_DIR, exist_ok=True) | |
hf_hub_download( | |
repo_id=MODEL_REPO, | |
filename=MODEL_FILE, | |
local_dir=MODEL_DIR, | |
local_dir_use_symlinks=False | |
) | |
logger.info("Model weights downloaded successfully.") | |
else: | |
logger.info("Model weights already exist.") | |
# Run model download on startup | |
download_model() | |
def run_bitnet_inference(prompt, max_tokens=50, temperature=0.7, top_p=0.9, top_k=50): | |
# Prepare the command to call bitnet.cpp binary | |
cmd = [ | |
BITNET_BINARY, | |
"-m", MODEL_PATH, | |
"-p", prompt, | |
"--max-tokens", str(max_tokens), | |
"--temperature", str(temperature), | |
"--top-p", str(top_p), | |
"--top-k", str(top_k) | |
] | |
try: | |
# Run the command and capture output | |
result = subprocess.run(cmd, capture_output=True, text=True, check=True) | |
output = result.stdout.strip() | |
return output | |
except subprocess.CalledProcessError as e: | |
logger.error(f"Inference error: {e.stderr}") | |
return f"Error during inference: {e.stderr}" | |
def manage_history(history): | |
# Limit to 3 turns (user + assistant = 2 messages per turn) | |
max_messages = 6 | |
if len(history) > max_messages: | |
history = history[-max_messages:] | |
# Limit total character count to 300 | |
total_chars = sum(len(msg["content"]) for msg in history) | |
while total_chars > 300 and history: | |
history.pop(0) | |
total_chars = sum(len(msg["content"]) for msg in dobrohistory) | |
return history | |
def generate_response(user_input, system_prompt, max_new_tokens, temperature, top_p, top_k, history): | |
# Format the prompt for bitnet.cpp | |
full_prompt = f"{system_prompt}\n\nUser: {user_input}\nAssistant: " | |
# Run inference | |
response = run_bitnet_inference( | |
full_prompt, | |
max_tokens=max_new_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
top_k=top_k | |
) | |
# Update history | |
history.append({"role": "user", "content": user_input}) | |
history.append({"role": "assistant", "content": response}) | |
# Manage history limits | |
history = manage_history(history) | |
return history, history | |
# Gradio interface | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown("# BitNet b1.58 2B4T Demo") | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown(""" | |
## About BitNet b1.58 2B4T | |
BitNet b1.58 2B4T is the first open-source, native 1-bit Large Language Model with 2 billion parameters, developed by Microsoft Research. Trained on 4 trillion tokens, it matches the performance of full-precision models while offering significant efficiency gains in memory, energy, and latency. Features include: | |
- Transformer-based architecture with BitLinear layers | |
- Native 1.58-bit weights and 8-bit activations | |
- Maximum context length of 4096 tokens | |
- Optimized for efficient inference with bitnet.cpp | |
""") | |
with gr.Column(): | |
gr.Markdown(""" | |
## About Tonic AI | |
Tonic AI is a vibrant community of AI enthusiasts and developers always building cool demos and pushing the boundaries of what's possible with AI. We're passionate about creating innovative, accessible, and engaging AI experiences for everyone. Join us in exploring the future of AI! | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
user_input = gr.Textbox(label="Your Message", placeholder="Type your message here...") | |
system_prompt = gr.Textbox( | |
label="System Prompt", | |
value="You are a helpful AI assistant.", | |
placeholder="Enter system prompt..." | |
) | |
with gr.Accordion("Advanced Options", open=False): | |
max_new_tokens = gr.Slider( | |
minimum=10, | |
maximum=500, | |
value=50, | |
step=10, | |
label="Max New Tokens" | |
) | |
temperature = gr.Slider( | |
minimum=0.1, | |
maximum=2.0, | |
value=0.7, | |
step=0.1, | |
label="Temperature" | |
) | |
top_p = gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.9, | |
step=0.05, | |
label="Top P" | |
) | |
top_k = gr.Slider( | |
minimum=1, | |
maximum=100, | |
value=50, | |
step=1, | |
label="Top K" | |
) | |
submit_btn = gr.Button("Send") | |
with gr.Column(): | |
chatbot = gr.Chatbot(label="Conversation", type="messages") | |
chat_history = gr.State([]) | |
submit_btn.click( | |
fn=generate_response, | |
inputs=[ | |
user_input, | |
system_prompt, | |
max_new_tokens, | |
temperature, | |
top_p, | |
top_k, | |
chat_history | |
], | |
outputs=[chatbot, chat_history] | |
) | |
if __name__ == "__main__": | |
demo.launch(server_name="0.0.0.0", server_port=7860, ssr_mode=False, share=True) |