Spaces:
Running
on
Zero
Running
on
Zero
import re | |
import threading | |
import gc | |
import os | |
import torch | |
import gradio as gr | |
import spaces | |
import transformers | |
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer | |
from huggingface_hub import login | |
# ๋ชจ๋ธ ๋ฉ๋ชจ๋ฆฌ ๊ด๋ฆฌ ๋ฐ ์ต์ ํ๋ฅผ ์ํ ์ค์ | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32 | |
MAX_GPU_MEMORY = 80 * 1024 * 1024 * 1024 # 80GB A100 ๊ธฐ์ค (์ค์ ์ฌ์ฉ ๊ฐ๋ฅํ ๋ฉ๋ชจ๋ฆฌ๋ ์ด๋ณด๋ค ์ ์) | |
# ์ฌ์ฉ ๊ฐ๋ฅํ ๋ชจ๋ธ ๋ชฉ๋ก - A100์์ ํจ์จ์ ์ผ๋ก ์คํ ๊ฐ๋ฅํ ๋ชจ๋ธ๋ก ํํฐ๋ง | |
available_models = { | |
"meta-llama/Llama-3.2-3B-Instruct": "Llama 3.2 (3B)", | |
"Hermes-3-Llama-3.1-8B": "Hermes 3 Llama 3.1 (8B)", | |
"nvidia/Llama-3.1-Nemotron-Nano-8B-v1": "Nvidia Nemotron Nano (8B)", | |
"mistralai/Mistral-Small-3.1-24B-Instruct-2503": "Mistral Small 3.1 (24B)", | |
"google/gemma-3-27b-it": "Google Gemma 3 (27B)", | |
"Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen 2.5 Coder (32B)", | |
"open-r1/OlympicCoder-32B": "Olympic Coder (32B)" | |
} | |
# ๋ชจ๋ธ ๋ก๋์ ์ฌ์ฉ๋๋ ์ ์ญ ๋ณ์ | |
pipe = None | |
current_model_name = None | |
# Hugging Face ํ ํฐ์ผ๋ก ๋ก๊ทธ์ธ ์๋ | |
try: | |
hf_token = os.getenv("HF_TOKEN") | |
if hf_token: | |
login(token=hf_token) | |
print("Hugging Face์ ์ฑ๊ณต์ ์ผ๋ก ๋ก๊ทธ์ธํ์ต๋๋ค.") | |
else: | |
print("๊ฒฝ๊ณ : HF_TOKEN ํ๊ฒฝ ๋ณ์๊ฐ ์ค์ ๋์ง ์์์ต๋๋ค.") | |
except Exception as e: | |
print(f"Hugging Face ๋ก๊ทธ์ธ ์๋ฌ: {str(e)}") | |
# ์ต์ข ๋ต๋ณ์ ๊ฐ์งํ๊ธฐ ์ํ ๋ง์ปค | |
ANSWER_MARKER = "**๋ต๋ณ**" | |
# ๋จ๊ณ๋ณ ์ถ๋ก ์ ์์ํ๋ ๋ฌธ์ฅ๋ค | |
rethink_prepends = [ | |
"์, ์ด์ ๋ค์์ ํ์ ํด์ผ ํฉ๋๋ค ", | |
"์ ์๊ฐ์๋ ", | |
"์ ์๋ง์, ์ ์๊ฐ์๋ ", | |
"๋ค์ ์ฌํญ์ด ๋ง๋์ง ํ์ธํด ๋ณด๊ฒ ์ต๋๋ค ", | |
"๋ํ ๊ธฐ์ตํด์ผ ํ ๊ฒ์ ", | |
"๋ ๋ค๋ฅธ ์ฃผ๋ชฉํ ์ ์ ", | |
"๊ทธ๋ฆฌ๊ณ ์ ๋ ๋ค์๊ณผ ๊ฐ์ ์ฌ์ค๋ ๊ธฐ์ตํฉ๋๋ค ", | |
"์ด์ ์ถฉ๋ถํ ์ดํดํ๋ค๊ณ ์๊ฐํฉ๋๋ค ", | |
"์ง๊ธ๊น์ง์ ์ ๋ณด๋ฅผ ๋ฐํ์ผ๋ก, ์๋ ์ง๋ฌธ์ ์ฌ์ฉ๋ ์ธ์ด๋ก ๋ต๋ณํ๊ฒ ์ต๋๋ค:" | |
"\n{question}\n" | |
f"\n{ANSWER_MARKER}\n", | |
] | |
# ์์ ํ์ ๋ฌธ์ ํด๊ฒฐ์ ์ํ ์ค์ | |
latex_delimiters = [ | |
{"left": "$$", "right": "$$", "display": True}, | |
{"left": "$", "right": "$", "display": False}, | |
] | |
# ๋ชจ๋ธ ํฌ๊ธฐ ๊ธฐ๋ฐ ๊ตฌ์ฑ - ๋ชจ๋ธ ํฌ๊ธฐ์ ๋ฐ๋ฅธ ์ต์ ์ค์ ์ ์ | |
MODEL_CONFIG = { | |
"small": { # <10B | |
"max_memory": {0: "20GiB"}, | |
"offload": False, | |
"quantization": None | |
}, | |
"medium": { # 10B-30B | |
"max_memory": {0: "40GiB"}, | |
"offload": False, | |
"quantization": "4bit" | |
}, | |
"large": { # >30B | |
"max_memory": {0: "70GiB"}, | |
"offload": True, | |
"quantization": "4bit" | |
} | |
} | |
def get_model_size_category(model_name): | |
"""๋ชจ๋ธ ํฌ๊ธฐ ์นดํ ๊ณ ๋ฆฌ ๊ฒฐ์ """ | |
if "3B" in model_name or "8B" in model_name: | |
return "small" | |
elif "24B" in model_name or "27B" in model_name: | |
return "medium" | |
elif "32B" in model_name or "70B" in model_name: | |
return "large" | |
else: | |
# ๊ธฐ๋ณธ๊ฐ์ผ๋ก medium ๋ฐํ | |
return "medium" | |
def clear_gpu_memory(): | |
"""GPU ๋ฉ๋ชจ๋ฆฌ ์ ๋ฆฌ""" | |
global pipe | |
if pipe is not None: | |
del pipe | |
pipe = None | |
# CUDA ์บ์ ์ ๋ฆฌ | |
gc.collect() | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
torch.cuda.synchronize() | |
def reformat_math(text): | |
"""Gradio ๊ตฌ๋ฌธ(Katex)์ ์ฌ์ฉํ๋๋ก MathJax ๊ตฌ๋ถ ๊ธฐํธ ์์ .""" | |
text = re.sub(r"\\\[\s*(.*?)\s*\\\]", r"$$\1$$", text, flags=re.DOTALL) | |
text = re.sub(r"\\\(\s*(.*?)\s*\\\)", r"$\1$", text, flags=re.DOTALL) | |
return text | |
def user_input(message, history: list): | |
"""์ฌ์ฉ์ ์ ๋ ฅ์ ํ์คํ ๋ฆฌ์ ์ถ๊ฐํ๊ณ ์ ๋ ฅ ํ ์คํธ ์์ ๋น์ฐ๊ธฐ""" | |
return "", history + [ | |
gr.ChatMessage(role="user", content=message.replace(ANSWER_MARKER, "")) | |
] | |
def rebuild_messages(history: list): | |
"""์ค๊ฐ ์๊ฐ ๊ณผ์ ์์ด ๋ชจ๋ธ์ด ์ฌ์ฉํ ํ์คํ ๋ฆฌ์์ ๋ฉ์์ง ์ฌ๊ตฌ์ฑ""" | |
messages = [] | |
for h in history: | |
if isinstance(h, dict) and not h.get("metadata", {}).get("title", False): | |
messages.append(h) | |
elif ( | |
isinstance(h, gr.ChatMessage) | |
and h.metadata.get("title") | |
and isinstance(h.content, str) | |
): | |
messages.append({"role": h.role, "content": h.content}) | |
return messages | |
def load_model(model_names): | |
"""์ ํ๋ ๋ชจ๋ธ ์ด๋ฆ์ ๋ฐ๋ผ ๋ชจ๋ธ ๋ก๋ (A100์ ์ต์ ํ๋ ์ค์ ์ฌ์ฉ)""" | |
global pipe, current_model_name | |
# ๊ธฐ์กด ๋ชจ๋ธ ์ ๋ฆฌ | |
clear_gpu_memory() | |
# ๋ชจ๋ธ์ด ์ ํ๋์ง ์์์ ๊ฒฝ์ฐ ๊ธฐ๋ณธ๊ฐ ์ง์ | |
if not model_names: | |
model_name = "meta-llama/Llama-3.2-3B-Instruct" # ๋ ์์ ๋ชจ๋ธ์ ๊ธฐ๋ณธ๊ฐ์ผ๋ก ์ฌ์ฉ | |
else: | |
# ์ฒซ ๋ฒ์งธ ์ ํ๋ ๋ชจ๋ธ ์ฌ์ฉ | |
model_name = model_names[0] | |
# ๋ชจ๋ธ ํฌ๊ธฐ ์นดํ ๊ณ ๋ฆฌ ํ์ธ | |
size_category = get_model_size_category(model_name) | |
config = MODEL_CONFIG[size_category] | |
# ๋ชจ๋ธ ๋ก๋ (ํฌ๊ธฐ์ ๋ฐ๋ผ ์ต์ ํ๋ ์ค์ ์ ์ฉ) | |
try: | |
# HF_TOKEN ํ๊ฒฝ ๋ณ์ ํ์ธ | |
hf_token = os.getenv("HF_TOKEN") | |
# ๊ณตํต ๋งค๊ฐ๋ณ์ | |
common_params = { | |
"token": hf_token, # ์ ๊ทผ ์ ํ ๋ชจ๋ธ์ ์ํ ํ ํฐ | |
"trust_remote_code": True, | |
} | |
# BF16 ์ ๋ฐ๋ ์ฌ์ฉ (A100์ ์ต์ ํ) | |
if config["quantization"]: | |
# ์์ํ ์ ์ฉ | |
from transformers import BitsAndBytesConfig | |
quantization_config = BitsAndBytesConfig( | |
load_in_4bit=config["quantization"] == "4bit", | |
bnb_4bit_compute_dtype=DTYPE | |
) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
device_map="auto", | |
max_memory=config["max_memory"], | |
torch_dtype=DTYPE, | |
quantization_config=quantization_config if config["quantization"] else None, | |
offload_folder="offload" if config["offload"] else None, | |
**common_params | |
) | |
tokenizer = AutoTokenizer.from_pretrained(model_name, **common_params) | |
pipe = pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
torch_dtype=DTYPE, | |
device_map="auto" | |
) | |
else: | |
# ์์ํ ์์ด ๋ก๋ | |
pipe = pipeline( | |
"text-generation", | |
model=model_name, | |
device_map="auto", | |
torch_dtype=DTYPE, | |
**common_params | |
) | |
current_model_name = model_name | |
return f"๋ชจ๋ธ '{model_name}'์ด(๊ฐ) ์ฑ๊ณต์ ์ผ๋ก ๋ก๋๋์์ต๋๋ค. (์ต์ ํ: {size_category} ์นดํ ๊ณ ๋ฆฌ)" | |
except Exception as e: | |
return f"๋ชจ๋ธ ๋ก๋ ์คํจ: {str(e)}" | |
def bot( | |
history: list, | |
max_num_tokens: int, | |
final_num_tokens: int, | |
do_sample: bool, | |
temperature: float, | |
): | |
"""๋ชจ๋ธ์ด ์ง๋ฌธ์ ๋ต๋ณํ๋๋ก ํ๊ธฐ""" | |
global pipe | |
# ๋ชจ๋ธ์ด ๋ก๋๋์ง ์์๋ค๋ฉด ์ค๋ฅ ๋ฉ์์ง ํ์ | |
if pipe is None: | |
history.append( | |
gr.ChatMessage( | |
role="assistant", | |
content="๋ชจ๋ธ์ด ๋ก๋๋์ง ์์์ต๋๋ค. ํ๋ ์ด์์ ๋ชจ๋ธ์ ์ ํํด ์ฃผ์ธ์.", | |
) | |
) | |
yield history | |
return | |
# ํ ํฐ ๊ธธ์ด ์๋ ์กฐ์ (๋ชจ๋ธ ํฌ๊ธฐ์ ๋ฐ๋ผ) | |
size_category = get_model_size_category(current_model_name) | |
# ๋ํ ๋ชจ๋ธ์ ํ ํฐ ์๋ฅผ ์ค์ฌ ๋ฉ๋ชจ๋ฆฌ ํจ์จ์ฑ ํฅ์ | |
if size_category == "large": | |
max_num_tokens = min(max_num_tokens, 1000) | |
final_num_tokens = min(final_num_tokens, 1500) | |
# ๋์ค์ ์ค๋ ๋์์ ํ ํฐ์ ์คํธ๋ฆผ์ผ๋ก ๊ฐ์ ธ์ค๊ธฐ ์ํจ | |
streamer = transformers.TextIteratorStreamer( | |
pipe.tokenizer, | |
skip_special_tokens=True, | |
skip_prompt=True, | |
) | |
# ํ์ํ ๊ฒฝ์ฐ ์ถ๋ก ์ ์ง๋ฌธ์ ๋ค์ ์ฝ์ ํ๊ธฐ ์ํจ | |
question = history[-1]["content"] | |
# ๋ณด์กฐ์ ๋ฉ์์ง ์ค๋น | |
history.append( | |
gr.ChatMessage( | |
role="assistant", | |
content=str(""), | |
metadata={"title": "๐ง ์๊ฐ ์ค...", "status": "pending"}, | |
) | |
) | |
# ํ์ฌ ์ฑํ ์ ํ์๋ ์ถ๋ก ๊ณผ์ | |
messages = rebuild_messages(history) | |
try: | |
for i, prepend in enumerate(rethink_prepends): | |
if i > 0: | |
messages[-1]["content"] += "\n\n" | |
messages[-1]["content"] += prepend.format(question=question) | |
num_tokens = int( | |
max_num_tokens if ANSWER_MARKER not in prepend else final_num_tokens | |
) | |
# ์ค๋ ๋์์ ๋ชจ๋ธ ์คํ | |
t = threading.Thread( | |
target=pipe, | |
args=(messages,), | |
kwargs=dict( | |
max_new_tokens=num_tokens, | |
streamer=streamer, | |
do_sample=do_sample, | |
temperature=temperature, | |
# ๋ฉ๋ชจ๋ฆฌ ํจ์จ์ฑ์ ์ํ ์ถ๊ฐ ํ๋ผ๋ฏธํฐ | |
repetition_penalty=1.2, # ๋ฐ๋ณต ๋ฐฉ์ง | |
use_cache=True, # KV ์บ์ ์ฌ์ฉ | |
), | |
) | |
t.start() | |
# ์ ๋ด์ฉ์ผ๋ก ํ์คํ ๋ฆฌ ์ฌ๊ตฌ์ฑ | |
history[-1].content += prepend.format(question=question) | |
if ANSWER_MARKER in prepend: | |
history[-1].metadata = {"title": "๐ญ ์ฌ๊ณ ๊ณผ์ ", "status": "done"} | |
# ์๊ฐ ์ข ๋ฃ, ์ด์ ๋ต๋ณ์ ๋๋ค (์ค๊ฐ ๋จ๊ณ์ ๋ํ ๋ฉํ๋ฐ์ดํฐ ์์) | |
history.append(gr.ChatMessage(role="assistant", content="")) | |
# ํ ํฐ ์คํธ๋ฆฌ๋ฐ | |
for token in streamer: | |
history[-1].content += token | |
history[-1].content = reformat_math(history[-1].content) | |
yield history | |
t.join() | |
# ๋ํ ๋ชจ๋ธ์ธ ๊ฒฝ์ฐ ๊ฐ ๋จ๊ณ ํ ๋ถ๋ถ์ ๋ฉ๋ชจ๋ฆฌ ์ ๋ฆฌ | |
if size_category == "large" and torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
except Exception as e: | |
# ์ค๋ฅ ๋ฐ์์ ์ฌ์ฉ์์๊ฒ ์๋ฆผ | |
if len(history) > 0 and history[-1].role == "assistant": | |
history[-1].content += f"\n\nโ ๏ธ ์ฒ๋ฆฌ ์ค ์ค๋ฅ๊ฐ ๋ฐ์ํ์ต๋๋ค: {str(e)}" | |
yield history | |
yield history | |
# ์ฌ์ฉ ๊ฐ๋ฅํ GPU ์ ๋ณด ํ์ ํจ์ | |
def get_gpu_info(): | |
if not torch.cuda.is_available(): | |
return "GPU๋ฅผ ์ฌ์ฉํ ์ ์์ต๋๋ค." | |
gpu_info = [] | |
for i in range(torch.cuda.device_count()): | |
gpu_name = torch.cuda.get_device_name(i) | |
total_memory = torch.cuda.get_device_properties(i).total_memory / 1024**3 | |
gpu_info.append(f"GPU {i}: {gpu_name} ({total_memory:.1f} GB)") | |
return "\n".join(gpu_info) | |
# Gradio ์ธํฐํ์ด์ค | |
with gr.Blocks(fill_height=True, title="ThinkFlow - Step-by-step Reasoning Service") as demo: | |
# ์๋จ์ ํ์ดํ๊ณผ ์ค๋ช ์ถ๊ฐ | |
gr.Markdown(""" | |
# ThinkFlow | |
## A thought amplification service that implants step-by-step reasoning abilities into LLMs without model modification | |
""") | |
with gr.Row(scale=1): | |
with gr.Column(scale=5): | |
# ์ฑํ ์ธํฐํ์ด์ค | |
chatbot = gr.Chatbot( | |
scale=1, | |
type="messages", | |
latex_delimiters=latex_delimiters, | |
height=600, | |
) | |
msg = gr.Textbox( | |
submit_btn=True, | |
label="", | |
show_label=False, | |
placeholder="์ฌ๊ธฐ์ ์ง๋ฌธ์ ์ ๋ ฅํ์ธ์.", | |
autofocus=True, | |
) | |
with gr.Column(scale=1): | |
# ํ๋์จ์ด ์ ๋ณด ํ์ | |
gpu_info = gr.Markdown(f"**์ฌ์ฉ ๊ฐ๋ฅํ ํ๋์จ์ด:**\n{get_gpu_info()}") | |
# ๋ชจ๋ธ ์ ํ ์น์ ์ถ๊ฐ | |
gr.Markdown("""## ๋ชจ๋ธ ์ ํ""") | |
model_selector = gr.Radio( | |
choices=list(available_models.values()), | |
value=available_models["meta-llama/Llama-3.2-3B-Instruct"], # ์์ ๋ชจ๋ธ์ ๊ธฐ๋ณธ๊ฐ์ผ๋ก | |
label="์ฌ์ฉํ LLM ๋ชจ๋ธ ์ ํ", | |
) | |
# ๋ชจ๋ธ ๋ก๋ ๋ฒํผ | |
load_model_btn = gr.Button("๋ชจ๋ธ ๋ก๋", variant="primary") | |
model_status = gr.Textbox(label="๋ชจ๋ธ ์ํ", interactive=False) | |
# ๋ฉ๋ชจ๋ฆฌ ์ ๋ฆฌ ๋ฒํผ | |
clear_memory_btn = gr.Button("GPU ๋ฉ๋ชจ๋ฆฌ ์ ๋ฆฌ", variant="secondary") | |
gr.Markdown("""## ๋งค๊ฐ๋ณ์ ์กฐ์ """) | |
with gr.Accordion("๊ณ ๊ธ ์ค์ ", open=False): | |
num_tokens = gr.Slider( | |
50, | |
2000, | |
1000, # ๊ธฐ๋ณธ๊ฐ ์ถ์ | |
step=50, | |
label="์ถ๋ก ๋จ๊ณ๋น ์ต๋ ํ ํฐ ์", | |
interactive=True, | |
) | |
final_num_tokens = gr.Slider( | |
50, | |
3000, | |
1500, # ๊ธฐ๋ณธ๊ฐ ์ถ์ | |
step=50, | |
label="์ต์ข ๋ต๋ณ์ ์ต๋ ํ ํฐ ์", | |
interactive=True, | |
) | |
do_sample = gr.Checkbox(True, label="์ํ๋ง ์ฌ์ฉ") | |
temperature = gr.Slider(0.1, 1.0, 0.7, step=0.1, label="์จ๋") | |
# ์ ํ๋ ๋ชจ๋ธ ๋ก๋ ์ด๋ฒคํธ ์ฐ๊ฒฐ | |
def get_model_names(selected_model): | |
# ํ์ ์ด๋ฆ์์ ์๋ ๋ชจ๋ธ ์ด๋ฆ์ผ๋ก ๋ณํ | |
inverse_map = {v: k for k, v in available_models.items()} | |
return [inverse_map[selected_model]] if selected_model else [] | |
load_model_btn.click( | |
lambda selected: load_model(get_model_names(selected)), | |
inputs=[model_selector], | |
outputs=[model_status] | |
) | |
# GPU ๋ฉ๋ชจ๋ฆฌ ์ ๋ฆฌ ์ด๋ฒคํธ ์ฐ๊ฒฐ | |
clear_memory_btn.click( | |
lambda: (clear_gpu_memory(), "GPU ๋ฉ๋ชจ๋ฆฌ๊ฐ ์ ๋ฆฌ๋์์ต๋๋ค."), | |
inputs=[], | |
outputs=[model_status] | |
) | |
# ์ฌ์ฉ์๊ฐ ๋ฉ์์ง๋ฅผ ์ ์ถํ๋ฉด ๋ด์ด ์๋ตํฉ๋๋ค | |
msg.submit( | |
user_input, | |
[msg, chatbot], # ์ ๋ ฅ | |
[msg, chatbot], # ์ถ๋ ฅ | |
).then( | |
bot, | |
[ | |
chatbot, | |
num_tokens, | |
final_num_tokens, | |
do_sample, | |
temperature, | |
], # ์ค์ ๋ก๋ "history" ์ ๋ ฅ | |
chatbot, # ์ถ๋ ฅ์์ ์ ํ์คํ ๋ฆฌ ์ ์ฅ | |
) | |
if __name__ == "__main__": | |
# ๋๋ฒ๊น ์ ๋ณด ์ถ๋ ฅ | |
print(f"GPU ์ฌ์ฉ ๊ฐ๋ฅ: {torch.cuda.is_available()}") | |
if torch.cuda.is_available(): | |
print(f"์ฌ์ฉ ๊ฐ๋ฅํ GPU ๊ฐ์: {torch.cuda.device_count()}") | |
print(f"ํ์ฌ GPU: {torch.cuda.current_device()}") | |
print(f"GPU ์ด๋ฆ: {torch.cuda.get_device_name(0)}") | |
# HF_TOKEN ํ๊ฒฝ ๋ณ์ ํ์ธ | |
hf_token = os.getenv("HF_TOKEN") | |
if hf_token: | |
print("HF_TOKEN ํ๊ฒฝ ๋ณ์๊ฐ ์ค์ ๋์ด ์์ต๋๋ค.") | |
else: | |
print("๊ฒฝ๊ณ : HF_TOKEN ํ๊ฒฝ ๋ณ์๊ฐ ์ค์ ๋์ง ์์์ต๋๋ค. ์ ํ๋ ๋ชจ๋ธ์ ์ ๊ทผํ ์ ์์ต๋๋ค.") | |
# ํ ์ฌ์ฉ ๋ฐ ์ฑ ์คํ | |
demo.queue(max_size=10).launch() |