bitnet / app.py
kimhyunwoo's picture
Update app.py
4bf6d97 verified
raw
history blame
6.24 kB
# ν•„μš”ν•œ 라이브러리λ₯Ό μ„€μΉ˜ν•˜λŠ” λͺ…λ Ήμ–΄μž…λ‹ˆλ‹€.
# 이 뢀뢄은 슀크립트 μ‹€ν–‰ μ΄ˆλ°˜μ— ν•œ 번 μ‹€ν–‰λ©λ‹ˆλ‹€.
import os
print("Installing required transformers branch...")
os.system("pip install git+https://github.com/shumingma/transformers.git")
print("Installation complete.")
# ν•„μš”ν•œ λΌμ΄λΈŒλŸ¬λ¦¬λ“€μ„ import ν•©λ‹ˆλ‹€.
import threading
import torch
import torch._dynamo
import gradio as gr
import spaces # Hugging Face Spaces κ΄€λ ¨ μœ ν‹Έλ¦¬ν‹°
# torch._dynamo μ„€μ • (선택 사항, μ„±λŠ₯ ν–₯상 μ‹œλ„)
torch._dynamo.config.suppress_errors = True
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
TextIteratorStreamer,
)
# --- λͺ¨λΈ λ‘œλ“œ ---
# λͺ¨λΈ 경둜 μ„€μ • (Hugging Face λͺ¨λΈ ID)
model_id = "microsoft/bitnet-b1.58-2B-4T"
# λͺ¨λΈ λ‘œλ“œ μ‹œ κ²½κ³  λ©”μ‹œμ§€λ₯Ό μ΅œμ†Œν™”ν•˜κΈ° μœ„ν•΄ λ‘œκΉ… 레벨 μ„€μ •
os.environ["TRANSFORMERS_VERBOSITY"] = "error"
# AutoModelForCausalLMκ³Ό AutoTokenizerλ₯Ό λ‘œλ“œν•©λ‹ˆλ‹€.
# trust_remote_code=Trueκ°€ ν•„μš”ν•˜λ©°, device_map="auto"λ₯Ό μ‚¬μš©ν•˜μ—¬ μžλ™μœΌλ‘œ λ””λ°”μ΄μŠ€ μ„€μ •
try:
print(f"λͺ¨λΈ λ‘œλ”© 쀑: {model_id}...")
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16, # bf16 μ‚¬μš© (GPU ꢌμž₯)
device_map="auto", # μ‚¬μš© κ°€λŠ₯ν•œ λ””λ°”μ΄μŠ€μ— μžλ™μœΌλ‘œ λͺ¨λΈ 배치
trust_remote_code=True
)
print(f"λͺ¨λΈ λ””λ°”μ΄μŠ€: {model.device}")
print("λͺ¨λΈ λ‘œλ“œ μ™„λ£Œ.")
except Exception as e:
print(f"λͺ¨λΈ λ‘œλ“œ 쀑 였λ₯˜ λ°œμƒ: {e}")
tokenizer = None
model = None
print("λͺ¨λΈ λ‘œλ“œμ— μ‹€νŒ¨ν–ˆμŠ΅λ‹ˆλ‹€. μ• ν”Œλ¦¬μΌ€μ΄μ…˜μ΄ μ œλŒ€λ‘œ λ™μž‘ν•˜μ§€ μ•Šμ„ 수 μžˆμŠ΅λ‹ˆλ‹€.")
# --- ν…μŠ€νŠΈ 생성 ν•¨μˆ˜ (Gradio ChatInterface용) ---
@spaces.GPU # 이 ν•¨μˆ˜κ°€ GPU μžμ›μ„ μ‚¬μš©ν•˜λ„λ‘ λͺ…μ‹œ (Hugging Face Spaces)
def respond(
message: str,
history: list[tuple[str, str]],
system_message: str,
max_tokens: int,
temperature: float,
top_p: float,
):
if model is None or tokenizer is None:
yield "λͺ¨λΈ λ‘œλ“œμ— μ‹€νŒ¨ν•˜μ—¬ ν…μŠ€νŠΈ 생성을 ν•  수 μ—†μŠ΅λ‹ˆλ‹€."
return # 생성기 ν•¨μˆ˜μ΄λ―€λ‘œ return λŒ€μ‹  빈 yield λ˜λŠ” κ·Έλƒ₯ return
try:
# λ©”μ‹œμ§€ ν˜•μ‹μ„ λͺ¨λΈμ˜ chat template에 맞게 ꡬ성
messages = [{"role": "system", "content": system_message}]
for user_msg, bot_msg in history:
if user_msg:
messages.append({"role": "user", "content": user_msg})
if bot_msg:
messages.append({"role": "assistant", "content": bot_msg})
messages.append({"role": "user", "content": message})
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
# ν…μŠ€νŠΈ μŠ€νŠΈλ¦¬λ°μ„ μœ„ν•œ streamer μ„€μ •
streamer = TextIteratorStreamer(
tokenizer, skip_prompt=True, skip_special_tokens=True
)
generate_kwargs = dict(
**inputs,
streamer=streamer,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True,
pad_token_id=tokenizer.eos_token_id # νŒ¨λ”© 토큰 ID μ„€μ •
)
# λͺ¨λΈ 생성을 λ³„λ„μ˜ μŠ€λ ˆλ“œμ—μ„œ μ‹€ν–‰
thread = threading.Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
# μŠ€νŠΈλ¦¬λ¨Έμ—μ„œ μƒμ„±λœ ν…μŠ€νŠΈλ₯Ό 읽어와 yield
response = ""
for new_text in streamer:
response += new_text
yield response # μ‹€μ‹œκ°„μœΌλ‘œ 응닡을 Gradio μΈν„°νŽ˜μ΄μŠ€λ‘œ 전달
except Exception as e:
yield f"ν…μŠ€νŠΈ 생성 쀑 였λ₯˜ λ°œμƒ: {e}"
# 였λ₯˜ λ°œμƒ μ‹œ μŠ€λ ˆλ“œ 처리 둜직 μΆ”κ°€ κ³ λ € ν•„μš” (선택 사항)
# --- Gradio μΈν„°νŽ˜μ΄μŠ€ μ„€μ • ---
if model is not None and tokenizer is not None:
demo = gr.ChatInterface(
fn=respond,
title="Bitnet-b1.58-2B-4T Chatbot",
description="Microsoft Bitnet-b1.58-2B-4T λͺ¨λΈμ„ μ‚¬μš©ν•œ μ±„νŒ… 데λͺ¨μž…λ‹ˆλ‹€.",
examples=[
[
"μ•ˆλ…•ν•˜μ„Έμš”! μžκΈ°μ†Œκ°œ ν•΄μ£Όμ„Έμš”.",
"당신은 유λŠ₯ν•œ AI λΉ„μ„œμž…λ‹ˆλ‹€.", # System message μ˜ˆμ‹œ
512, # Max new tokens μ˜ˆμ‹œ
0.7, # Temperature μ˜ˆμ‹œ
0.95, # Top-p μ˜ˆμ‹œ
],
[
"파이썬으둜 κ°„λ‹¨ν•œ μ›Ή μ„œλ²„ λ§Œλ“œλŠ” μ½”λ“œ μ•Œλ €μ€˜",
"당신은 유λŠ₯ν•œ AI κ°œλ°œμžμž…λ‹ˆλ‹€.", # System message μ˜ˆμ‹œ
1024, # Max new tokens μ˜ˆμ‹œ
0.8, # Temperature μ˜ˆμ‹œ
0.9, # Top-p μ˜ˆμ‹œ
],
],
additional_inputs=[
gr.Textbox(
value="당신은 유λŠ₯ν•œ AI λΉ„μ„œμž…λ‹ˆλ‹€.", # κΈ°λ³Έ μ‹œμŠ€ν…œ λ©”μ‹œμ§€
label="System message",
lines=1
),
gr.Slider(
minimum=1,
maximum=4096, # λͺ¨λΈ μ΅œλŒ€ μ»¨ν…μŠ€νŠΈ 길이 κ³ λ € (λ˜λŠ” 더 길게 μ„€μ •)
value=512,
step=1,
label="Max new tokens"
),
gr.Slider(
minimum=0.1,
maximum=2.0, # Temperature λ²”μœ„ μ‘°μ • (ν•„μš”μ‹œ)
value=0.7,
step=0.1,
label="Temperature"
),
gr.Slider(
minimum=0.0, # Top-p λ²”μœ„ μ‘°μ • (ν•„μš”μ‹œ)
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)"
),
],
)
# Gradio μ•± μ‹€ν–‰
# Hugging Face Spacesμ—μ„œλŠ” share=Trueκ°€ μžλ™μœΌλ‘œ μ„€μ •λ©λ‹ˆλ‹€.
# debug=True둜 μ„€μ •ν•˜λ©΄ 상세 둜그λ₯Ό λ³Ό 수 μžˆμŠ΅λ‹ˆλ‹€.
demo.launch(debug=True)
else:
print("λͺ¨λΈ λ‘œλ“œ μ‹€νŒ¨λ‘œ 인해 Gradio μΈν„°νŽ˜μ΄μŠ€λ₯Ό μ‹€ν–‰ν•  수 μ—†μŠ΅λ‹ˆλ‹€.")