Spaces:
Running
Running
# νμν λΌμ΄λΈλ¬λ¦¬λ₯Ό μ€μΉνλ λͺ λ Ήμ΄μ λλ€. | |
# μ΄ λΆλΆμ μ€ν¬λ¦½νΈ μ€ν μ΄λ°μ ν λ² μ€νλ©λλ€. | |
import os | |
print("Installing required transformers branch...") | |
os.system("pip install git+https://github.com/shumingma/transformers.git") | |
print("Installation complete.") | |
# νμν λΌμ΄λΈλ¬λ¦¬λ€μ import ν©λλ€. | |
import threading | |
import torch | |
import torch._dynamo | |
import gradio as gr | |
import spaces # Hugging Face Spaces κ΄λ ¨ μ νΈλ¦¬ν° | |
# torch._dynamo μ€μ (μ ν μ¬ν, μ±λ₯ ν₯μ μλ) | |
torch._dynamo.config.suppress_errors = True | |
from transformers import ( | |
AutoModelForCausalLM, | |
AutoTokenizer, | |
TextIteratorStreamer, | |
) | |
# --- λͺ¨λΈ λ‘λ --- | |
# λͺ¨λΈ κ²½λ‘ μ€μ (Hugging Face λͺ¨λΈ ID) | |
model_id = "microsoft/bitnet-b1.58-2B-4T" | |
# λͺ¨λΈ λ‘λ μ κ²½κ³ λ©μμ§λ₯Ό μ΅μννκΈ° μν΄ λ‘κΉ λ 벨 μ€μ | |
os.environ["TRANSFORMERS_VERBOSITY"] = "error" | |
# AutoModelForCausalLMκ³Ό AutoTokenizerλ₯Ό λ‘λν©λλ€. | |
# trust_remote_code=Trueκ° νμνλ©°, device_map="auto"λ₯Ό μ¬μ©νμ¬ μλμΌλ‘ λλ°μ΄μ€ μ€μ | |
try: | |
print(f"λͺ¨λΈ λ‘λ© μ€: {model_id}...") | |
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
torch_dtype=torch.bfloat16, # bf16 μ¬μ© (GPU κΆμ₯) | |
device_map="auto", # μ¬μ© κ°λ₯ν λλ°μ΄μ€μ μλμΌλ‘ λͺ¨λΈ λ°°μΉ | |
trust_remote_code=True | |
) | |
print(f"λͺ¨λΈ λλ°μ΄μ€: {model.device}") | |
print("λͺ¨λΈ λ‘λ μλ£.") | |
except Exception as e: | |
print(f"λͺ¨λΈ λ‘λ μ€ μ€λ₯ λ°μ: {e}") | |
tokenizer = None | |
model = None | |
print("λͺ¨λΈ λ‘λμ μ€ν¨νμ΅λλ€. μ ν리μΌμ΄μ μ΄ μ λλ‘ λμνμ§ μμ μ μμ΅λλ€.") | |
# --- ν μ€νΈ μμ± ν¨μ (Gradio ChatInterfaceμ©) --- | |
# μ΄ ν¨μκ° GPU μμμ μ¬μ©νλλ‘ λͺ μ (Hugging Face Spaces) | |
def respond( | |
message: str, | |
history: list[tuple[str, str]], | |
system_message: str, | |
max_tokens: int, | |
temperature: float, | |
top_p: float, | |
): | |
if model is None or tokenizer is None: | |
yield "λͺ¨λΈ λ‘λμ μ€ν¨νμ¬ ν μ€νΈ μμ±μ ν μ μμ΅λλ€." | |
return # μμ±κΈ° ν¨μμ΄λ―λ‘ return λμ λΉ yield λλ κ·Έλ₯ return | |
try: | |
# λ©μμ§ νμμ λͺ¨λΈμ chat templateμ λ§κ² κ΅¬μ± | |
messages = [{"role": "system", "content": system_message}] | |
for user_msg, bot_msg in history: | |
if user_msg: | |
messages.append({"role": "user", "content": user_msg}) | |
if bot_msg: | |
messages.append({"role": "assistant", "content": bot_msg}) | |
messages.append({"role": "user", "content": message}) | |
prompt = tokenizer.apply_chat_template( | |
messages, tokenize=False, add_generation_prompt=True | |
) | |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) | |
# ν μ€νΈ μ€νΈλ¦¬λ°μ μν streamer μ€μ | |
streamer = TextIteratorStreamer( | |
tokenizer, skip_prompt=True, skip_special_tokens=True | |
) | |
generate_kwargs = dict( | |
**inputs, | |
streamer=streamer, | |
max_new_tokens=max_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
do_sample=True, | |
pad_token_id=tokenizer.eos_token_id # ν¨λ© ν ν° ID μ€μ | |
) | |
# λͺ¨λΈ μμ±μ λ³λμ μ€λ λμμ μ€ν | |
thread = threading.Thread(target=model.generate, kwargs=generate_kwargs) | |
thread.start() | |
# μ€νΈλ¦¬λ¨Έμμ μμ±λ ν μ€νΈλ₯Ό μ½μ΄μ yield | |
response = "" | |
for new_text in streamer: | |
response += new_text | |
yield response # μ€μκ°μΌλ‘ μλ΅μ Gradio μΈν°νμ΄μ€λ‘ μ λ¬ | |
except Exception as e: | |
yield f"ν μ€νΈ μμ± μ€ μ€λ₯ λ°μ: {e}" | |
# μ€λ₯ λ°μ μ μ€λ λ μ²λ¦¬ λ‘μ§ μΆκ° κ³ λ € νμ (μ ν μ¬ν) | |
# --- Gradio μΈν°νμ΄μ€ μ€μ --- | |
if model is not None and tokenizer is not None: | |
demo = gr.ChatInterface( | |
fn=respond, | |
title="Bitnet-b1.58-2B-4T Chatbot", | |
description="Microsoft Bitnet-b1.58-2B-4T λͺ¨λΈμ μ¬μ©ν μ±ν λ°λͺ¨μ λλ€.", | |
examples=[ | |
[ | |
"μλ νμΈμ! μκΈ°μκ° ν΄μ£ΌμΈμ.", | |
"λΉμ μ μ λ₯ν AI λΉμμ λλ€.", # System message μμ | |
512, # Max new tokens μμ | |
0.7, # Temperature μμ | |
0.95, # Top-p μμ | |
], | |
[ | |
"νμ΄μ¬μΌλ‘ κ°λ¨ν μΉ μλ² λ§λλ μ½λ μλ €μ€", | |
"λΉμ μ μ λ₯ν AI κ°λ°μμ λλ€.", # System message μμ | |
1024, # Max new tokens μμ | |
0.8, # Temperature μμ | |
0.9, # Top-p μμ | |
], | |
], | |
additional_inputs=[ | |
gr.Textbox( | |
value="λΉμ μ μ λ₯ν AI λΉμμ λλ€.", # κΈ°λ³Έ μμ€ν λ©μμ§ | |
label="System message", | |
lines=1 | |
), | |
gr.Slider( | |
minimum=1, | |
maximum=4096, # λͺ¨λΈ μ΅λ 컨ν μ€νΈ κΈΈμ΄ κ³ λ € (λλ λ κΈΈκ² μ€μ ) | |
value=512, | |
step=1, | |
label="Max new tokens" | |
), | |
gr.Slider( | |
minimum=0.1, | |
maximum=2.0, # Temperature λ²μ μ‘°μ (νμμ) | |
value=0.7, | |
step=0.1, | |
label="Temperature" | |
), | |
gr.Slider( | |
minimum=0.0, # Top-p λ²μ μ‘°μ (νμμ) | |
maximum=1.0, | |
value=0.95, | |
step=0.05, | |
label="Top-p (nucleus sampling)" | |
), | |
], | |
) | |
# Gradio μ± μ€ν | |
# Hugging Face Spacesμμλ share=Trueκ° μλμΌλ‘ μ€μ λ©λλ€. | |
# debug=Trueλ‘ μ€μ νλ©΄ μμΈ λ‘κ·Έλ₯Ό λ³Ό μ μμ΅λλ€. | |
demo.launch(debug=True) | |
else: | |
print("λͺ¨λΈ λ‘λ μ€ν¨λ‘ μΈν΄ Gradio μΈν°νμ΄μ€λ₯Ό μ€νν μ μμ΅λλ€.") |