Spaces:
Running
Running
import os | |
import threading | |
import torch | |
import torch._dynamo | |
torch._dynamo.config.suppress_errors = True | |
from transformers import ( | |
AutoModelForCausalLM, | |
AutoTokenizer, | |
TextIteratorStreamer, | |
) | |
import gradio as gr | |
import spaces | |
# ํ์ํ ๊ฒฝ์ฐ Bitnet ์ง์์ ์ํ transformers ์ค์น | |
# Hugging Face Spaces์์๋ Dockerfile ๋ฑ์ ํตํด ๋ฏธ๋ฆฌ ์ค์นํ๋ ๊ฒ์ด ๋ ์ผ๋ฐ์ ์ ๋๋ค. | |
# ๋ก์ปฌ์์ ํ ์คํธ ์์๋ ํ์ํ ์ ์์ต๋๋ค. | |
# print("Installing required transformers branch...") | |
# try: | |
# os.system("pip install git+https://github.com/shumingma/transformers.git -q") | |
# print("transformers branch installed.") | |
# except Exception as e: | |
# print(f"Error installing transformers branch: {e}") | |
# print("Proceeding with potentially default transformers version.") | |
# os.system("pip install accelerate bitsandbytes -q") # bitsandbytes, accelerate๋ ํ์ํ ์ ์์ต๋๋ค. | |
model_id = "microsoft/bitnet-b1.58-2B-4T" | |
# ๋ชจ๋ธ ๋ฐ ํ ํฌ๋์ด์ ๋ก๋ | |
print(f"Loading model: {model_id}") | |
try: | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
# device_map="auto"๋ ์ฌ๋ฌ GPU ๋๋ CPU๋ก ๋ชจ๋ธ์ ์๋์ผ๋ก ๋ถ์ฐ ๋ก๋ํฉ๋๋ค. | |
# bfloat16์ ๋ชจ๋ธ ๊ฐ์ค์น์ ์ฌ์ฉ๋๋ ๋ฐ์ดํฐ ํ์ ์ ๋๋ค. | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
torch_dtype=torch.bfloat16, | |
device_map="auto", | |
# load_in_8bit=True # Bitnet์ 1.58bit์ด๋ฏ๋ก 8bit ๋ก๋ฉ์ด ์๋ฏธ ์์ ์ ์์ต๋๋ค. | |
) | |
print(f"Model loaded successfully on device: {model.device}") | |
except Exception as e: | |
print(f"Error loading model: {e}") | |
# ๋ชจ๋ธ ๋ก๋ฉ ์คํจ ์ ๋๋ฏธ ๋ชจ๋ธ ์ฌ์ฉ ๋๋ ์ค๋ฅ ์ฒ๋ฆฌ | |
class DummyModel: | |
def generate(self, **kwargs): | |
# ๋๋ฏธ ์๋ต ์์ฑ | |
input_ids = kwargs.get('input_ids') | |
streamer = kwargs.get('streamer') | |
if streamer: | |
# ๊ฐ๋จํ ๋๋ฏธ ์๋ต ์คํธ๋ฆฌ๋ฐ | |
dummy_response = "๋ชจ๋ธ ๋ก๋ฉ์ ์คํจํ์ฌ ๋๋ฏธ ์๋ต์ ์ ๊ณตํฉ๋๋ค. ์ค์ /๊ฒฝ๋ก๋ฅผ ํ์ธํ์ธ์." | |
for char in dummy_response: | |
streamer.put(char) | |
streamer.end() | |
model = DummyModel() | |
tokenizer = AutoTokenizer.from_pretrained("gpt2") # ๋๋ฏธ ํ ํฌ๋์ด์ | |
print("Using dummy model due to loading failure.") | |
# Hugging Face Spaces์์ GPU ์ฌ์ฉ์ ์ง์ ํฉ๋๋ค. | |
def respond( | |
message: str, | |
history: list[tuple[str, str]], | |
system_message: str, | |
max_tokens: int, | |
temperature: float, | |
top_p: float, | |
): | |
""" | |
Generate a chat response using streaming with TextIteratorStreamer. | |
Args: | |
message: User's current message. | |
history: List of (user, assistant) tuples from previous turns. | |
system_message: Initial system prompt guiding the assistant. | |
max_tokens: Maximum number of tokens to generate. | |
temperature: Sampling temperature. | |
top_p: Nucleus sampling probability. | |
Yields: | |
The growing response text as new tokens are generated. | |
""" | |
# ๋๋ฏธ ๋ชจ๋ธ ์ฌ์ฉ ์ ์คํธ๋ฆฌ๋ฐ ์ค๋ฅ ๋ฐฉ์ง | |
if isinstance(model, DummyModel): | |
yield "๋ชจ๋ธ ๋ก๋ฉ์ ์คํจํ์ฌ ์๋ต์ ์์ฑํ ์ ์์ต๋๋ค." | |
return | |
messages = [{"role": "system", "content": system_message}] | |
for user_msg, bot_msg in history: | |
if user_msg: | |
messages.append({"role": "user", "content": user_msg}) | |
if bot_msg: | |
messages.append({"role": "assistant", "content": bot_msg}) | |
messages.append({"role": "user", "content": message}) | |
try: | |
prompt = tokenizer.apply_chat_template( | |
messages, tokenize=False, add_generation_prompt=True | |
) | |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
streamer = TextIteratorStreamer( | |
tokenizer, skip_prompt=True, skip_special_tokens=True | |
) | |
generate_kwargs = dict( | |
**inputs, | |
streamer=streamer, | |
max_new_tokens=max_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
do_sample=True, | |
# Bitnet ๋ชจ๋ธ์ ํ์ํ ์ถ๊ฐ ์ธ์ ์ค์ (๋ชจ๋ธ ๋ฌธ์ ํ์ธ ํ์) | |
# ์๋ฅผ ๋ค์ด, quantize_config ๋ฑ | |
) | |
# ์ฐ๋ ๋์์ ๋ชจ๋ธ ์์ฑ ์คํ | |
thread = threading.Thread(target=model.generate, kwargs=generate_kwargs) | |
thread.start() | |
# ์คํธ๋ฆฌ๋จธ๋ก๋ถํฐ ํ ์คํธ๋ฅผ ์ฝ์ด์ yield | |
response = "" | |
for new_text in streamer: | |
# yield ํ๊ธฐ ์ ์ ๋ถํ์ํ ๊ณต๋ฐฑ/ํ ํฐ ์ ๊ฑฐ ๋๋ ์ฒ๋ฆฌ ๊ฐ๋ฅ | |
response += new_text | |
yield response | |
except Exception as e: | |
print(f"Error during response generation: {e}") | |
yield f"์๋ต ์์ฑ ์ค ์ค๋ฅ๊ฐ ๋ฐ์ํ์ต๋๋ค: {e}" | |
# --- ๋์์ธ ๊ฐ์ ์ ์ํ CSS ์ฝ๋ --- | |
css_styles = """ | |
/* ์ ์ฒด ํ์ด์ง ๋ฐฐ๊ฒฝ ๋ฐ ๊ธฐ๋ณธ ํฐํธ ์ค์ */ | |
body { | |
font-family: 'Segoe UI', 'Roboto', 'Arial', sans-serif; | |
line-height: 1.6; | |
margin: 0; | |
padding: 20px; /* ์ฑ ์ฃผ๋ณ ์ฌ๋ฐฑ ์ถ๊ฐ */ | |
background-color: #f4f7f6; /* ๋ถ๋๋ฌ์ด ๋ฐฐ๊ฒฝ์ */ | |
} | |
/* ๋ฉ์ธ ์ฑ ์ปจํ ์ด๋ ์คํ์ผ */ | |
.gradio-container { | |
max-width: 900px; /* ์ค์ ์ ๋ ฌ ๋ฐ ์ต๋ ๋๋น ์ ํ */ | |
margin: 20px auto; | |
border-radius: 12px; /* ๋ฅ๊ทผ ๋ชจ์๋ฆฌ */ | |
overflow: hidden; /* ์์ ์์๋ค์ด ๋ชจ์๋ฆฌ๋ฅผ ๋์ง ์๋๋ก */ | |
box-shadow: 0 8px 16px rgba(0, 0, 0, 0.1); /* ๊ทธ๋ฆผ์ ํจ๊ณผ */ | |
background-color: #ffffff; /* ์ฑ ๋ด์ฉ ์์ญ ๋ฐฐ๊ฒฝ์ */ | |
} | |
/* ํ์ดํ ๋ฐ ์ค๋ช ์์ญ (ChatInterface์ ๊ธฐ๋ณธ ํ์ดํ/์ค๋ช ) */ | |
/* ์ด ์์ญ์ ChatInterface ๊ตฌ์กฐ์ ๋ฐ๋ผ ์ ํํ ํด๋์ค ์ด๋ฆ์ด ๋ค๋ฅผ ์ ์์ผ๋, | |
.gradio-container ๋ด๋ถ์ ์ฒซ ๋ธ๋ก์ด๋ H1/P ํ๊ทธ๋ฅผ ํ๊ฒํ ์ ์์ต๋๋ค. | |
ํ ๋ง์ ํจ๊ป ์ฌ์ฉํ๋ฉด ๋๋ถ๋ถ ์ ์ฒ๋ฆฌ๋ฉ๋๋ค. ์ฌ๊ธฐ์๋ ์ถ๊ฐ์ ์ธ ํจ๋ฉ ๋ฑ๋ง ๊ณ ๋ ค */ | |
.gradio-container > .gradio-block:first-child { | |
padding: 20px 20px 10px 20px; /* ์๋จ ํจ๋ฉ ์กฐ์ */ | |
} | |
/* ์ฑํ ๋ฐ์ค ์์ญ ์คํ์ผ */ | |
.gradio-chatbox { | |
/* ํ ๋ง์ ์ํด ์คํ์ผ๋ง๋์ง๋ง, ์ถ๊ฐ์ ์ธ ๋ด๋ถ ํจ๋ฉ ๋ฑ ์กฐ์ ๊ฐ๋ฅ */ | |
padding: 15px; | |
background-color: #fefefe; /* ์ฑํ ์์ญ ๋ฐฐ๊ฒฝ์ */ | |
border-radius: 8px; /* ์ฑํ ์์ญ ๋ด๋ถ ๋ชจ์๋ฆฌ */ | |
border: 1px solid #e0e0e0; /* ๊ฒฝ๊ณ์ */ | |
} | |
/* ์ฑํ ๋ฉ์์ง ์คํ์ผ */ | |
.gradio-chatmessage { | |
margin-bottom: 12px; | |
padding: 10px 15px; | |
border-radius: 20px; /* ๋ฅ๊ทผ ๋ฉ์์ง ๋ชจ์๋ฆฌ */ | |
max-width: 75%; /* ๋ฉ์์ง ๋๋น ์ ํ */ | |
word-wrap: break-word; /* ๊ธด ๋จ์ด ์ค๋ฐ๊ฟ */ | |
white-space: pre-wrap; /* ๊ณต๋ฐฑ ๋ฐ ์ค๋ฐ๊ฟ ์ ์ง */ | |
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.05); /* ๋ฉ์์ง์ ์ฝ๊ฐ์ ๊ทธ๋ฆผ์ */ | |
} | |
/* ์ฌ์ฉ์ ๋ฉ์์ง ์คํ์ผ */ | |
.gradio-chatmessage.user { | |
background-color: #007bff; /* ํ๋์ ๊ณ์ด */ | |
color: white; | |
margin-left: auto; /* ์ค๋ฅธ์ชฝ ์ ๋ ฌ */ | |
border-bottom-right-radius: 2px; /* ์ค๋ฅธ์ชฝ ์๋ ๋ชจ์๋ฆฌ ๊ฐ์ง๊ฒ */ | |
} | |
/* ๋ด ๋ฉ์์ง ์คํ์ผ */ | |
.gradio-chatmessage.bot { | |
background-color: #e9ecef; /* ๋ฐ์ ํ์ */ | |
color: #333; /* ์ด๋์ด ํ ์คํธ */ | |
margin-right: auto; /* ์ผ์ชฝ ์ ๋ ฌ */ | |
border-bottom-left-radius: 2px; /* ์ผ์ชฝ ์๋ ๋ชจ์๋ฆฌ ๊ฐ์ง๊ฒ */ | |
} | |
/* ์ ๋ ฅ์ฐฝ ๋ฐ ๋ฒํผ ์์ญ ์คํ์ผ */ | |
.gradio-input-box { | |
padding: 15px; | |
border-top: 1px solid #eee; /* ์์ชฝ ๊ฒฝ๊ณ์ */ | |
background-color: #f8f9fa; /* ์ ๋ ฅ ์์ญ ๋ฐฐ๊ฒฝ์ */ | |
} | |
/* ์ ๋ ฅ ํ ์คํธ ์์ด๋ฆฌ์ด ์คํ์ผ */ | |
.gradio-input-box textarea { | |
border-radius: 8px; | |
padding: 10px; | |
border: 1px solid #ccc; | |
resize: none !important; /* ์ ๋ ฅ์ฐฝ ํฌ๊ธฐ ์กฐ์ ๋นํ์ฑํ (์ ํ ์ฌํญ) */ | |
min-height: 50px; /* ์ต์ ๋์ด */ | |
max-height: 150px; /* ์ต๋ ๋์ด */ | |
overflow-y: auto; /* ๋ด์ฉ ๋์น ๊ฒฝ์ฐ ์คํฌ๋กค */ | |
} | |
/* ์คํฌ๋กค๋ฐ ์คํ์ผ (์ ํ ์ฌํญ) */ | |
.gradio-input-box textarea::-webkit-scrollbar { | |
width: 8px; | |
} | |
.gradio-input-box textarea::-webkit-scrollbar-thumb { | |
background-color: #ccc; | |
border-radius: 4px; | |
} | |
.gradio-input-box textarea::-webkit-scrollbar-track { | |
background-color: #f1f1f1; | |
} | |
/* ๋ฒํผ ์คํ์ผ */ | |
.gradio-button { | |
border-radius: 8px; | |
padding: 10px 20px; | |
font-weight: bold; | |
transition: background-color 0.2s ease, opacity 0.2s ease; /* ํธ๋ฒ ์ ๋๋ฉ์ด์ */ | |
border: none; /* ๊ธฐ๋ณธ ํ ๋๋ฆฌ ์ ๊ฑฐ */ | |
cursor: pointer; | |
} | |
.gradio-button:not(.clear-button) { /* Send ๋ฒํผ */ | |
background-color: #28a745; /* ์ด๋ก์ */ | |
color: white; | |
} | |
.gradio-button:not(.clear-button):hover { | |
background-color: #218838; | |
} | |
.gradio-button:disabled { /* ๋นํ์ฑํ๋ ๋ฒํผ */ | |
opacity: 0.6; | |
cursor: not-allowed; | |
} | |
.gradio-button.clear-button { /* Clear ๋ฒํผ */ | |
background-color: #dc3545; /* ๋นจ๊ฐ์ */ | |
color: white; | |
} | |
.gradio-button.clear-button:hover { | |
background-color: #c82333; | |
} | |
/* Additional inputs (์ถ๊ฐ ์ค์ ) ์์ญ ์คํ์ผ */ | |
/* ์ด ์์ญ์ ๋ณดํต ์์ฝ๋์ธ ํํ๋ก ๋์ด ์์ผ๋ฉฐ, .gradio-accordion ํด๋์ค๋ฅผ ๊ฐ์ง๋๋ค. */ | |
.gradio-accordion { | |
border-radius: 12px; /* ์ธ๋ถ ์ปจํ ์ด๋์ ๋์ผํ ๋ชจ์๋ฆฌ */ | |
margin-top: 15px; /* ์ฑํ ์์ญ๊ณผ์ ๊ฐ๊ฒฉ */ | |
border: 1px solid #ddd; /* ๊ฒฝ๊ณ์ */ | |
box-shadow: none; /* ๋ด๋ถ ๊ทธ๋ฆผ์ ์ ๊ฑฐ */ | |
} | |
/* ์์ฝ๋์ธ ํค๋ (๋ผ๋ฒจ) ์คํ์ผ */ | |
.gradio-accordion .label { | |
font-weight: bold; | |
color: #007bff; /* ํ๋์ ๊ณ์ด */ | |
padding: 15px; /* ํค๋ ํจ๋ฉ */ | |
background-color: #e9ecef; /* ํค๋ ๋ฐฐ๊ฒฝ์ */ | |
border-bottom: 1px solid #ddd; /* ํค๋ ์๋ ๊ฒฝ๊ณ์ */ | |
border-top-left-radius: 11px; /* ์๋จ ๋ชจ์๋ฆฌ */ | |
border-top-right-radius: 11px; | |
} | |
/* ์์ฝ๋์ธ ๋ด์ฉ ์์ญ ์คํ์ผ */ | |
.gradio-accordion .wrap { | |
padding: 15px; /* ๋ด์ฉ ํจ๋ฉ */ | |
background-color: #fefefe; /* ๋ด์ฉ ๋ฐฐ๊ฒฝ์ */ | |
border-bottom-left-radius: 11px; /* ํ๋จ ๋ชจ์๋ฆฌ */ | |
border-bottom-right-radius: 11px; | |
} | |
/* ์ถ๊ฐ ์ค์ ๋ด ๊ฐ๋ณ ์ ๋ ฅ ์ปดํฌ๋ํธ ์คํ์ผ (์ฌ๋ผ์ด๋, ํ ์คํธ๋ฐ์ค ๋ฑ) */ | |
.gradio-slider, .gradio-textbox, .gradio-number { | |
margin-bottom: 10px; /* ๊ฐ ์ ๋ ฅ ์์ ์๋ ๊ฐ๊ฒฉ */ | |
padding: 8px; /* ๋ด๋ถ ํจ๋ฉ */ | |
border: 1px solid #e0e0e0; /* ๊ฒฝ๊ณ์ */ | |
border-radius: 8px; /* ๋ฅ๊ทผ ๋ชจ์๋ฆฌ */ | |
background-color: #fff; /* ๋ฐฐ๊ฒฝ์ */ | |
} | |
/* ์ ๋ ฅ ํ๋ ๋ผ๋ฒจ ์คํ์ผ */ | |
.gradio-label { | |
font-weight: normal; /* ๋ผ๋ฒจ ํฐํธ ๊ตต๊ธฐ */ | |
margin-bottom: 5px; /* ๋ผ๋ฒจ๊ณผ ์ ๋ ฅ ํ๋ ๊ฐ ๊ฐ๊ฒฉ */ | |
color: #555; /* ๋ผ๋ฒจ ์์ */ | |
display: block; /* ๋ผ๋ฒจ์ ๋ธ๋ก ์์๋ก ๋ง๋ค์ด ์๋ก ์ฌ๋ฆผ */ | |
} | |
/* ์ฌ๋ผ์ด๋ ํธ๋ ๋ฐ ํธ๋ค ์คํ์ผ (๋ ์ธ๋ฐํ ์กฐ์ ๊ฐ๋ฅ) */ | |
/* ์: .gradio-slider input[type="range"]::-webkit-slider-thumb {} */ | |
/* ๋งํฌ๋ค์ด/HTML ์ปดํฌ๋ํธ ๋ด ์คํ์ผ */ | |
.gradio-markdown, .gradio-html { | |
padding: 10px 0; /* ์ํ ํจ๋ฉ */ | |
} | |
""" | |
# --- ๋์์ธ ๊ฐ์ ์ ์ํ CSS ์ฝ๋ ๋ --- | |
# Gradio ์ธํฐํ์ด์ค ์ค์ | |
demo = gr.ChatInterface( | |
fn=respond, | |
# ํ์ดํ ๋ฐ ์ค๋ช ์ HTML ํ๊ทธ ์ฌ์ฉ ์์ (<br> ํ๊ทธ ์ฌ์ฉ) | |
title="<h1 style='text-align: center; color: #007bff;'>Bitnet-b1.58-2B-4T Chatbot</h1>", | |
description="<p style='text-align: center; color: #555;'>This chat application is powered by Microsoft's SOTA Bitnet-b1.58-2B-4T and designed for natural and fast conversations.</p>", | |
examples=[ | |
[ | |
"Hello! How are you?", | |
"You are a helpful AI assistant for everyday tasks.", | |
512, | |
0.7, | |
0.95, | |
], | |
[ | |
"Can you code a snake game in Python?", | |
"You are a helpful AI assistant for coding.", | |
2048, | |
0.7, | |
0.95, | |
], | |
], | |
additional_inputs=[ | |
gr.Textbox( | |
value="You are a helpful AI assistant.", | |
label="System message", | |
lines=3 # ์์คํ ๋ฉ์์ง ์ ๋ ฅ์ฐฝ ๋์ด ์กฐ์ | |
), | |
gr.Slider( | |
minimum=1, | |
maximum=8192, | |
value=2048, | |
step=1, | |
label="Max new tokens" | |
), | |
gr.Slider( | |
minimum=0.1, | |
maximum=4.0, | |
value=0.7, | |
step=0.1, | |
label="Temperature" | |
), | |
gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.95, | |
step=0.05, | |
label="Top-p (nucleus sampling)" | |
), | |
], | |
# ํ ๋ง ์ ์ฉ (์ฌ๋ฌ ํ ๋ง ์ค ์ ํ ๊ฐ๋ฅ: gr.themes.Soft(), gr.themes.Glass(), gr.themes.Default(), etc.) | |
theme=gr.themes.Soft(), | |
# ์ปค์คํ CSS ์ ์ฉ | |
css=css_styles, | |
) | |
# ์ ํ๋ฆฌ์ผ์ด์ ์คํ | |
if __name__ == "__main__": | |
# launch(share=True)๋ ํผ๋ธ๋ฆญ URL ์์ฑ (๋๋ฒ๊น /๊ณต์ ๋ชฉ์ , ์ฃผ์ ํ์) | |
demo.launch() | |
# demo.launch(debug=True) # ๋๋ฒ๊น ๋ชจ๋ ํ์ฑํ |