bitnet / app.py
kimhyunwoo's picture
Update app.py
0c5d476 verified
raw
history blame
13.2 kB
import os
import threading
import torch
import torch._dynamo
torch._dynamo.config.suppress_errors = True
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TextIteratorStreamer,
)
import gradio as gr
import spaces
# ํ•„์š”ํ•œ ๊ฒฝ์šฐ Bitnet ์ง€์›์„ ์œ„ํ•œ transformers ์„ค์น˜
# Hugging Face Spaces์—์„œ๋Š” Dockerfile ๋“ฑ์„ ํ†ตํ•ด ๋ฏธ๋ฆฌ ์„ค์น˜ํ•˜๋Š” ๊ฒƒ์ด ๋” ์ผ๋ฐ˜์ ์ž…๋‹ˆ๋‹ค.
# ๋กœ์ปฌ์—์„œ ํ…Œ์ŠคํŠธ ์‹œ์—๋Š” ํ•„์š”ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
# print("Installing required transformers branch...")
# try:
# os.system("pip install git+https://github.com/shumingma/transformers.git -q")
# print("transformers branch installed.")
# except Exception as e:
# print(f"Error installing transformers branch: {e}")
# print("Proceeding with potentially default transformers version.")
# os.system("pip install accelerate bitsandbytes -q") # bitsandbytes, accelerate๋„ ํ•„์š”ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
model_id = "microsoft/bitnet-b1.58-2B-4T"
# ๋ชจ๋ธ ๋ฐ ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ
print(f"Loading model: {model_id}")
try:
tokenizer = AutoTokenizer.from_pretrained(model_id)
# device_map="auto"๋Š” ์—ฌ๋Ÿฌ GPU ๋˜๋Š” CPU๋กœ ๋ชจ๋ธ์„ ์ž๋™์œผ๋กœ ๋ถ„์‚ฐ ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค.
# bfloat16์€ ๋ชจ๋ธ ๊ฐ€์ค‘์น˜์— ์‚ฌ์šฉ๋˜๋Š” ๋ฐ์ดํ„ฐ ํƒ€์ž…์ž…๋‹ˆ๋‹ค.
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto",
# load_in_8bit=True # Bitnet์€ 1.58bit์ด๋ฏ€๋กœ 8bit ๋กœ๋”ฉ์ด ์˜๋ฏธ ์—†์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
)
print(f"Model loaded successfully on device: {model.device}")
except Exception as e:
print(f"Error loading model: {e}")
# ๋ชจ๋ธ ๋กœ๋”ฉ ์‹คํŒจ ์‹œ ๋”๋ฏธ ๋ชจ๋ธ ์‚ฌ์šฉ ๋˜๋Š” ์˜ค๋ฅ˜ ์ฒ˜๋ฆฌ
class DummyModel:
def generate(self, **kwargs):
# ๋”๋ฏธ ์‘๋‹ต ์ƒ์„ฑ
input_ids = kwargs.get('input_ids')
streamer = kwargs.get('streamer')
if streamer:
# ๊ฐ„๋‹จํ•œ ๋”๋ฏธ ์‘๋‹ต ์ŠคํŠธ๋ฆฌ๋ฐ
dummy_response = "๋ชจ๋ธ ๋กœ๋”ฉ์— ์‹คํŒจํ•˜์—ฌ ๋”๋ฏธ ์‘๋‹ต์„ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค. ์„ค์ •/๊ฒฝ๋กœ๋ฅผ ํ™•์ธํ•˜์„ธ์š”."
for char in dummy_response:
streamer.put(char)
streamer.end()
model = DummyModel()
tokenizer = AutoTokenizer.from_pretrained("gpt2") # ๋”๋ฏธ ํ† ํฌ๋‚˜์ด์ €
print("Using dummy model due to loading failure.")
@spaces.GPU # Hugging Face Spaces์—์„œ GPU ์‚ฌ์šฉ์„ ์ง€์ •ํ•ฉ๋‹ˆ๋‹ค.
def respond(
message: str,
history: list[tuple[str, str]],
system_message: str,
max_tokens: int,
temperature: float,
top_p: float,
):
"""
Generate a chat response using streaming with TextIteratorStreamer.
Args:
message: User's current message.
history: List of (user, assistant) tuples from previous turns.
system_message: Initial system prompt guiding the assistant.
max_tokens: Maximum number of tokens to generate.
temperature: Sampling temperature.
top_p: Nucleus sampling probability.
Yields:
The growing response text as new tokens are generated.
"""
# ๋”๋ฏธ ๋ชจ๋ธ ์‚ฌ์šฉ ์‹œ ์ŠคํŠธ๋ฆฌ๋ฐ ์˜ค๋ฅ˜ ๋ฐฉ์ง€
if isinstance(model, DummyModel):
yield "๋ชจ๋ธ ๋กœ๋”ฉ์— ์‹คํŒจํ•˜์—ฌ ์‘๋‹ต์„ ์ƒ์„ฑํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค."
return
messages = [{"role": "system", "content": system_message}]
for user_msg, bot_msg in history:
if user_msg:
messages.append({"role": "user", "content": user_msg})
if bot_msg:
messages.append({"role": "assistant", "content": bot_msg})
messages.append({"role": "user", "content": message})
try:
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
streamer = TextIteratorStreamer(
tokenizer, skip_prompt=True, skip_special_tokens=True
)
generate_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
# Bitnet ๋ชจ๋ธ์— ํ•„์š”ํ•œ ์ถ”๊ฐ€ ์ธ์ž ์„ค์ • (๋ชจ๋ธ ๋ฌธ์„œ ํ™•์ธ ํ•„์š”)
# ์˜ˆ๋ฅผ ๋“ค์–ด, quantize_config ๋“ฑ
)
# ์“ฐ๋ ˆ๋“œ์—์„œ ๋ชจ๋ธ ์ƒ์„ฑ ์‹คํ–‰
thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
# ์ŠคํŠธ๋ฆฌ๋จธ๋กœ๋ถ€ํ„ฐ ํ…์ŠคํŠธ๋ฅผ ์ฝ์–ด์™€ yield
response = ""
for new_text in streamer:
# yield ํ•˜๊ธฐ ์ „์— ๋ถˆํ•„์š”ํ•œ ๊ณต๋ฐฑ/ํ† ํฐ ์ œ๊ฑฐ ๋˜๋Š” ์ฒ˜๋ฆฌ ๊ฐ€๋Šฅ
response += new_text
yield response
except Exception as e:
print(f"Error during response generation: {e}")
yield f"์‘๋‹ต ์ƒ์„ฑ ์ค‘ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค: {e}"
# --- ๋””์ž์ธ ๊ฐœ์„ ์„ ์œ„ํ•œ CSS ์ฝ”๋“œ ---
css_styles = """
/* ์ „์ฒด ํŽ˜์ด์ง€ ๋ฐฐ๊ฒฝ ๋ฐ ๊ธฐ๋ณธ ํฐํŠธ ์„ค์ • */
body {
font-family: 'Segoe UI', 'Roboto', 'Arial', sans-serif;
line-height: 1.6;
margin: 0;
padding: 20px; /* ์•ฑ ์ฃผ๋ณ€ ์—ฌ๋ฐฑ ์ถ”๊ฐ€ */
background-color: #f4f7f6; /* ๋ถ€๋“œ๋Ÿฌ์šด ๋ฐฐ๊ฒฝ์ƒ‰ */
}
/* ๋ฉ”์ธ ์•ฑ ์ปจํ…Œ์ด๋„ˆ ์Šคํƒ€์ผ */
.gradio-container {
max-width: 900px; /* ์ค‘์•™ ์ •๋ ฌ ๋ฐ ์ตœ๋Œ€ ๋„ˆ๋น„ ์ œํ•œ */
margin: 20px auto;
border-radius: 12px; /* ๋‘ฅ๊ทผ ๋ชจ์„œ๋ฆฌ */
overflow: hidden; /* ์ž์‹ ์š”์†Œ๋“ค์ด ๋ชจ์„œ๋ฆฌ๋ฅผ ๋„˜์ง€ ์•Š๋„๋ก */
box-shadow: 0 8px 16px rgba(0, 0, 0, 0.1); /* ๊ทธ๋ฆผ์ž ํšจ๊ณผ */
background-color: #ffffff; /* ์•ฑ ๋‚ด์šฉ ์˜์—ญ ๋ฐฐ๊ฒฝ์ƒ‰ */
}
/* ํƒ€์ดํ‹€ ๋ฐ ์„ค๋ช… ์˜์—ญ (ChatInterface์˜ ๊ธฐ๋ณธ ํƒ€์ดํ‹€/์„ค๋ช…) */
/* ์ด ์˜์—ญ์€ ChatInterface ๊ตฌ์กฐ์— ๋”ฐ๋ผ ์ •ํ™•ํ•œ ํด๋ž˜์Šค ์ด๋ฆ„์ด ๋‹ค๋ฅผ ์ˆ˜ ์žˆ์œผ๋‚˜,
.gradio-container ๋‚ด๋ถ€์˜ ์ฒซ ๋ธ”๋ก์ด๋‚˜ H1/P ํƒœ๊ทธ๋ฅผ ํƒ€๊ฒŸํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
ํ…Œ๋งˆ์™€ ํ•จ๊ป˜ ์‚ฌ์šฉํ•˜๋ฉด ๋Œ€๋ถ€๋ถ„ ์ž˜ ์ฒ˜๋ฆฌ๋ฉ๋‹ˆ๋‹ค. ์—ฌ๊ธฐ์„œ๋Š” ์ถ”๊ฐ€์ ์ธ ํŒจ๋”ฉ ๋“ฑ๋งŒ ๊ณ ๋ ค */
.gradio-container > .gradio-block:first-child {
padding: 20px 20px 10px 20px; /* ์ƒ๋‹จ ํŒจ๋”ฉ ์กฐ์ • */
}
/* ์ฑ„ํŒ… ๋ฐ•์Šค ์˜์—ญ ์Šคํƒ€์ผ */
.gradio-chatbox {
/* ํ…Œ๋งˆ์— ์˜ํ•ด ์Šคํƒ€์ผ๋ง๋˜์ง€๋งŒ, ์ถ”๊ฐ€์ ์ธ ๋‚ด๋ถ€ ํŒจ๋”ฉ ๋“ฑ ์กฐ์ • ๊ฐ€๋Šฅ */
padding: 15px;
background-color: #fefefe; /* ์ฑ„ํŒ… ์˜์—ญ ๋ฐฐ๊ฒฝ์ƒ‰ */
border-radius: 8px; /* ์ฑ„ํŒ… ์˜์—ญ ๋‚ด๋ถ€ ๋ชจ์„œ๋ฆฌ */
border: 1px solid #e0e0e0; /* ๊ฒฝ๊ณ„์„  */
}
/* ์ฑ„ํŒ… ๋ฉ”์‹œ์ง€ ์Šคํƒ€์ผ */
.gradio-chatmessage {
margin-bottom: 12px;
padding: 10px 15px;
border-radius: 20px; /* ๋‘ฅ๊ทผ ๋ฉ”์‹œ์ง€ ๋ชจ์„œ๋ฆฌ */
max-width: 75%; /* ๋ฉ”์‹œ์ง€ ๋„ˆ๋น„ ์ œํ•œ */
word-wrap: break-word; /* ๊ธด ๋‹จ์–ด ์ค„๋ฐ”๊ฟˆ */
white-space: pre-wrap; /* ๊ณต๋ฐฑ ๋ฐ ์ค„๋ฐ”๊ฟˆ ์œ ์ง€ */
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.05); /* ๋ฉ”์‹œ์ง€์— ์•ฝ๊ฐ„์˜ ๊ทธ๋ฆผ์ž */
}
/* ์‚ฌ์šฉ์ž ๋ฉ”์‹œ์ง€ ์Šคํƒ€์ผ */
.gradio-chatmessage.user {
background-color: #007bff; /* ํŒŒ๋ž€์ƒ‰ ๊ณ„์—ด */
color: white;
margin-left: auto; /* ์˜ค๋ฅธ์ชฝ ์ •๋ ฌ */
border-bottom-right-radius: 2px; /* ์˜ค๋ฅธ์ชฝ ์•„๋ž˜ ๋ชจ์„œ๋ฆฌ ๊ฐ์ง€๊ฒŒ */
}
/* ๋ด‡ ๋ฉ”์‹œ์ง€ ์Šคํƒ€์ผ */
.gradio-chatmessage.bot {
background-color: #e9ecef; /* ๋ฐ์€ ํšŒ์ƒ‰ */
color: #333; /* ์–ด๋‘์šด ํ…์ŠคํŠธ */
margin-right: auto; /* ์™ผ์ชฝ ์ •๋ ฌ */
border-bottom-left-radius: 2px; /* ์™ผ์ชฝ ์•„๋ž˜ ๋ชจ์„œ๋ฆฌ ๊ฐ์ง€๊ฒŒ */
}
/* ์ž…๋ ฅ์ฐฝ ๋ฐ ๋ฒ„ํŠผ ์˜์—ญ ์Šคํƒ€์ผ */
.gradio-input-box {
padding: 15px;
border-top: 1px solid #eee; /* ์œ„์ชฝ ๊ฒฝ๊ณ„์„  */
background-color: #f8f9fa; /* ์ž…๋ ฅ ์˜์—ญ ๋ฐฐ๊ฒฝ์ƒ‰ */
}
/* ์ž…๋ ฅ ํ…์ŠคํŠธ ์—์–ด๋ฆฌ์–ด ์Šคํƒ€์ผ */
.gradio-input-box textarea {
border-radius: 8px;
padding: 10px;
border: 1px solid #ccc;
resize: none !important; /* ์ž…๋ ฅ์ฐฝ ํฌ๊ธฐ ์กฐ์ ˆ ๋น„ํ™œ์„ฑํ™” (์„ ํƒ ์‚ฌํ•ญ) */
min-height: 50px; /* ์ตœ์†Œ ๋†’์ด */
max-height: 150px; /* ์ตœ๋Œ€ ๋†’์ด */
overflow-y: auto; /* ๋‚ด์šฉ ๋„˜์น  ๊ฒฝ์šฐ ์Šคํฌ๋กค */
}
/* ์Šคํฌ๋กค๋ฐ” ์Šคํƒ€์ผ (์„ ํƒ ์‚ฌํ•ญ) */
.gradio-input-box textarea::-webkit-scrollbar {
width: 8px;
}
.gradio-input-box textarea::-webkit-scrollbar-thumb {
background-color: #ccc;
border-radius: 4px;
}
.gradio-input-box textarea::-webkit-scrollbar-track {
background-color: #f1f1f1;
}
/* ๋ฒ„ํŠผ ์Šคํƒ€์ผ */
.gradio-button {
border-radius: 8px;
padding: 10px 20px;
font-weight: bold;
transition: background-color 0.2s ease, opacity 0.2s ease; /* ํ˜ธ๋ฒ„ ์• ๋‹ˆ๋ฉ”์ด์…˜ */
border: none; /* ๊ธฐ๋ณธ ํ…Œ๋‘๋ฆฌ ์ œ๊ฑฐ */
cursor: pointer;
}
.gradio-button:not(.clear-button) { /* Send ๋ฒ„ํŠผ */
background-color: #28a745; /* ์ดˆ๋ก์ƒ‰ */
color: white;
}
.gradio-button:not(.clear-button):hover {
background-color: #218838;
}
.gradio-button:disabled { /* ๋น„ํ™œ์„ฑํ™”๋œ ๋ฒ„ํŠผ */
opacity: 0.6;
cursor: not-allowed;
}
.gradio-button.clear-button { /* Clear ๋ฒ„ํŠผ */
background-color: #dc3545; /* ๋นจ๊ฐ„์ƒ‰ */
color: white;
}
.gradio-button.clear-button:hover {
background-color: #c82333;
}
/* Additional inputs (์ถ”๊ฐ€ ์„ค์ •) ์˜์—ญ ์Šคํƒ€์ผ */
/* ์ด ์˜์—ญ์€ ๋ณดํ†ต ์•„์ฝ”๋””์–ธ ํ˜•ํƒœ๋กœ ๋˜์–ด ์žˆ์œผ๋ฉฐ, .gradio-accordion ํด๋ž˜์Šค๋ฅผ ๊ฐ€์ง‘๋‹ˆ๋‹ค. */
.gradio-accordion {
border-radius: 12px; /* ์™ธ๋ถ€ ์ปจํ…Œ์ด๋„ˆ์™€ ๋™์ผํ•œ ๋ชจ์„œ๋ฆฌ */
margin-top: 15px; /* ์ฑ„ํŒ… ์˜์—ญ๊ณผ์˜ ๊ฐ„๊ฒฉ */
border: 1px solid #ddd; /* ๊ฒฝ๊ณ„์„  */
box-shadow: none; /* ๋‚ด๋ถ€ ๊ทธ๋ฆผ์ž ์ œ๊ฑฐ */
}
/* ์•„์ฝ”๋””์–ธ ํ—ค๋” (๋ผ๋ฒจ) ์Šคํƒ€์ผ */
.gradio-accordion .label {
font-weight: bold;
color: #007bff; /* ํŒŒ๋ž€์ƒ‰ ๊ณ„์—ด */
padding: 15px; /* ํ—ค๋” ํŒจ๋”ฉ */
background-color: #e9ecef; /* ํ—ค๋” ๋ฐฐ๊ฒฝ์ƒ‰ */
border-bottom: 1px solid #ddd; /* ํ—ค๋” ์•„๋ž˜ ๊ฒฝ๊ณ„์„  */
border-top-left-radius: 11px; /* ์ƒ๋‹จ ๋ชจ์„œ๋ฆฌ */
border-top-right-radius: 11px;
}
/* ์•„์ฝ”๋””์–ธ ๋‚ด์šฉ ์˜์—ญ ์Šคํƒ€์ผ */
.gradio-accordion .wrap {
padding: 15px; /* ๋‚ด์šฉ ํŒจ๋”ฉ */
background-color: #fefefe; /* ๋‚ด์šฉ ๋ฐฐ๊ฒฝ์ƒ‰ */
border-bottom-left-radius: 11px; /* ํ•˜๋‹จ ๋ชจ์„œ๋ฆฌ */
border-bottom-right-radius: 11px;
}
/* ์ถ”๊ฐ€ ์„ค์ • ๋‚ด ๊ฐœ๋ณ„ ์ž…๋ ฅ ์ปดํฌ๋„ŒํŠธ ์Šคํƒ€์ผ (์Šฌ๋ผ์ด๋”, ํ…์ŠคํŠธ๋ฐ•์Šค ๋“ฑ) */
.gradio-slider, .gradio-textbox, .gradio-number {
margin-bottom: 10px; /* ๊ฐ ์ž…๋ ฅ ์š”์†Œ ์•„๋ž˜ ๊ฐ„๊ฒฉ */
padding: 8px; /* ๋‚ด๋ถ€ ํŒจ๋”ฉ */
border: 1px solid #e0e0e0; /* ๊ฒฝ๊ณ„์„  */
border-radius: 8px; /* ๋‘ฅ๊ทผ ๋ชจ์„œ๋ฆฌ */
background-color: #fff; /* ๋ฐฐ๊ฒฝ์ƒ‰ */
}
/* ์ž…๋ ฅ ํ•„๋“œ ๋ผ๋ฒจ ์Šคํƒ€์ผ */
.gradio-label {
font-weight: normal; /* ๋ผ๋ฒจ ํฐํŠธ ๊ตต๊ธฐ */
margin-bottom: 5px; /* ๋ผ๋ฒจ๊ณผ ์ž…๋ ฅ ํ•„๋“œ ๊ฐ„ ๊ฐ„๊ฒฉ */
color: #555; /* ๋ผ๋ฒจ ์ƒ‰์ƒ */
display: block; /* ๋ผ๋ฒจ์„ ๋ธ”๋ก ์š”์†Œ๋กœ ๋งŒ๋“ค์–ด ์œ„๋กœ ์˜ฌ๋ฆผ */
}
/* ์Šฌ๋ผ์ด๋” ํŠธ๋ž™ ๋ฐ ํ•ธ๋“ค ์Šคํƒ€์ผ (๋” ์„ธ๋ฐ€ํ•œ ์กฐ์ • ๊ฐ€๋Šฅ) */
/* ์˜ˆ: .gradio-slider input[type="range"]::-webkit-slider-thumb {} */
/* ๋งˆํฌ๋‹ค์šด/HTML ์ปดํฌ๋„ŒํŠธ ๋‚ด ์Šคํƒ€์ผ */
.gradio-markdown, .gradio-html {
padding: 10px 0; /* ์ƒํ•˜ ํŒจ๋”ฉ */
}
"""
# --- ๋””์ž์ธ ๊ฐœ์„ ์„ ์œ„ํ•œ CSS ์ฝ”๋“œ ๋ ---
# Gradio ์ธํ„ฐํŽ˜์ด์Šค ์„ค์ •
demo = gr.ChatInterface(
fn=respond,
# ํƒ€์ดํ‹€ ๋ฐ ์„ค๋ช…์— HTML ํƒœ๊ทธ ์‚ฌ์šฉ ์˜ˆ์‹œ (<br> ํƒœ๊ทธ ์‚ฌ์šฉ)
title="<h1 style='text-align: center; color: #007bff;'>Bitnet-b1.58-2B-4T Chatbot</h1>",
description="<p style='text-align: center; color: #555;'>This chat application is powered by Microsoft's SOTA Bitnet-b1.58-2B-4T and designed for natural and fast conversations.</p>",
examples=[
[
"Hello! How are you?",
"You are a helpful AI assistant for everyday tasks.",
512,
0.7,
0.95,
],
[
"Can you code a snake game in Python?",
"You are a helpful AI assistant for coding.",
2048,
0.7,
0.95,
],
],
additional_inputs=[
gr.Textbox(
value="You are a helpful AI assistant.",
label="System message",
lines=3 # ์‹œ์Šคํ…œ ๋ฉ”์‹œ์ง€ ์ž…๋ ฅ์ฐฝ ๋†’์ด ์กฐ์ ˆ
),
gr.Slider(
minimum=1,
maximum=8192,
value=2048,
step=1,
label="Max new tokens"
),
gr.Slider(
minimum=0.1,
maximum=4.0,
value=0.7,
step=0.1,
label="Temperature"
),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)"
),
],
# ํ…Œ๋งˆ ์ ์šฉ (์—ฌ๋Ÿฌ ํ…Œ๋งˆ ์ค‘ ์„ ํƒ ๊ฐ€๋Šฅ: gr.themes.Soft(), gr.themes.Glass(), gr.themes.Default(), etc.)
theme=gr.themes.Soft(),
# ์ปค์Šคํ…€ CSS ์ ์šฉ
css=css_styles,
)
# ์• ํ”Œ๋ฆฌ์ผ€์ด์…˜ ์‹คํ–‰
if __name__ == "__main__":
# launch(share=True)๋Š” ํผ๋ธ”๋ฆญ URL ์ƒ์„ฑ (๋””๋ฒ„๊น…/๊ณต์œ  ๋ชฉ์ , ์ฃผ์˜ ํ•„์š”)
demo.launch()
# demo.launch(debug=True) # ๋””๋ฒ„๊น… ๋ชจ๋“œ ํ™œ์„ฑํ™”