openfree commited on
Commit
03434f6
ยท
verified ยท
1 Parent(s): e94718a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +230 -80
app.py CHANGED
@@ -1,26 +1,32 @@
1
  import re
2
  import threading
 
 
3
 
4
  import gradio as gr
5
  import spaces
6
  import transformers
7
- from transformers import pipeline
8
 
9
- # ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ ๋ชจ๋ธ ๋ชฉ๋ก
 
 
 
 
 
10
  available_models = {
11
- "meta-llama/Llama-3.2-3B-Instruct": "Llama 3.2(3B)",
12
  "Hermes-3-Llama-3.1-8B": "Hermes 3 Llama 3.1 (8B)",
13
  "nvidia/Llama-3.1-Nemotron-Nano-8B-v1": "Nvidia Nemotron Nano (8B)",
14
  "mistralai/Mistral-Small-3.1-24B-Instruct-2503": "Mistral Small 3.1 (24B)",
15
- "bartowski/mistralai_Mistral-Small-3.1-24B-Instruct-2503-GGUF": "Mistral Small GGUF (24B)",
16
  "google/gemma-3-27b-it": "Google Gemma 3 (27B)",
17
- "gemma-3-27b-it-abliterated": "Gemma 3 Abliterated (27B)",
18
  "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen 2.5 Coder (32B)",
19
  "open-r1/OlympicCoder-32B": "Olympic Coder (32B)"
20
  }
21
 
22
- # ๋ชจ๋ธ๊ณผ ํ† ํฌ๋‚˜์ด์ € ๋กœ๋”ฉ์„ ์œ„ํ•œ ์ „์—ญ ๋ณ€์ˆ˜
23
  pipe = None
 
24
 
25
  # ์ตœ์ข… ๋‹ต๋ณ€์„ ๊ฐ์ง€ํ•˜๊ธฐ ์œ„ํ•œ ๋งˆ์ปค
26
  ANSWER_MARKER = "**๋‹ต๋ณ€**"
@@ -40,31 +46,69 @@ rethink_prepends = [
40
  f"\n{ANSWER_MARKER}\n",
41
  ]
42
 
43
-
44
  # ์ˆ˜์‹ ํ‘œ์‹œ ๋ฌธ์ œ ํ•ด๊ฒฐ์„ ์œ„ํ•œ ์„ค์ •
45
  latex_delimiters = [
46
  {"left": "$$", "right": "$$", "display": True},
47
  {"left": "$", "right": "$", "display": False},
48
  ]
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  def reformat_math(text):
52
- """Gradio ๊ตฌ๋ฌธ(Katex)์„ ์‚ฌ์šฉํ•˜๋„๋ก MathJax ๊ตฌ๋ถ„ ๊ธฐํ˜ธ ์ˆ˜์ •.
53
- ์ด๊ฒƒ์€ Gradio์—์„œ ์ˆ˜ํ•™ ๊ณต์‹์„ ํ‘œ์‹œํ•˜๊ธฐ ์œ„ํ•œ ์ž„์‹œ ํ•ด๊ฒฐ์ฑ…์ž…๋‹ˆ๋‹ค. ํ˜„์žฌ๋กœ์„œ๋Š”
54
- ๋‹ค๋ฅธ latex_delimiters๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์˜ˆ์ƒ๋Œ€๋กœ ์ž‘๋™ํ•˜๊ฒŒ ํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ์ฐพ์ง€ ๋ชปํ–ˆ์Šต๋‹ˆ๋‹ค...
55
- """
56
  text = re.sub(r"\\\[\s*(.*?)\s*\\\]", r"$$\1$$", text, flags=re.DOTALL)
57
  text = re.sub(r"\\\(\s*(.*?)\s*\\\)", r"$\1$", text, flags=re.DOTALL)
58
  return text
59
 
60
-
61
  def user_input(message, history: list):
62
  """์‚ฌ์šฉ์ž ์ž…๋ ฅ์„ ํžˆ์Šคํ† ๋ฆฌ์— ์ถ”๊ฐ€ํ•˜๊ณ  ์ž…๋ ฅ ํ…์ŠคํŠธ ์ƒ์ž ๋น„์šฐ๊ธฐ"""
63
  return "", history + [
64
  gr.ChatMessage(role="user", content=message.replace(ANSWER_MARKER, ""))
65
  ]
66
 
67
-
68
  def rebuild_messages(history: list):
69
  """์ค‘๊ฐ„ ์ƒ๊ฐ ๊ณผ์ • ์—†์ด ๋ชจ๋ธ์ด ์‚ฌ์šฉํ•  ํžˆ์Šคํ† ๋ฆฌ์—์„œ ๋ฉ”์‹œ์ง€ ์žฌ๊ตฌ์„ฑ"""
70
  messages = []
@@ -79,27 +123,68 @@ def rebuild_messages(history: list):
79
  messages.append({"role": h.role, "content": h.content})
80
  return messages
81
 
82
-
83
  def load_model(model_names):
84
- """์„ ํƒ๋œ ๋ชจ๋ธ ์ด๋ฆ„์— ๋”ฐ๋ผ ๋ชจ๋ธ ๋กœ๋“œ"""
85
- global pipe
 
 
 
86
 
87
  # ๋ชจ๋ธ์ด ์„ ํƒ๋˜์ง€ ์•Š์•˜์„ ๊ฒฝ์šฐ ๊ธฐ๋ณธ๊ฐ’ ์ง€์ •
88
  if not model_names:
89
- model_name = "Qwen/Qwen2-1.5B-Instruct"
90
  else:
91
- # ์ฒซ ๋ฒˆ์งธ ์„ ํƒ๋œ ๋ชจ๋ธ ์‚ฌ์šฉ (๋‚˜์ค‘์— ์—ฌ๋Ÿฌ ๋ชจ๋ธ ์•™์ƒ๋ธ”๋กœ ํ™•์žฅ ๊ฐ€๋Šฅ)
92
  model_name = model_names[0]
93
 
94
- pipe = pipeline(
95
- "text-generation",
96
- model=model_name,
97
- device_map="auto",
98
- torch_dtype="auto",
99
- )
100
 
101
- return f"๋ชจ๋ธ '{model_name}'์ด(๊ฐ€) ๋กœ๋“œ๋˜์—ˆ์Šต๋‹ˆ๋‹ค."
102
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
  @spaces.GPU
105
  def bot(
@@ -123,9 +208,17 @@ def bot(
123
  yield history
124
  return
125
 
 
 
 
 
 
 
 
 
126
  # ๋‚˜์ค‘์— ์Šค๋ ˆ๋“œ์—์„œ ํ† ํฐ์„ ์ŠคํŠธ๋ฆผ์œผ๋กœ ๊ฐ€์ ธ์˜ค๊ธฐ ์œ„ํ•จ
127
  streamer = transformers.TextIteratorStreamer(
128
- pipe.tokenizer, # pyright: ignore
129
  skip_special_tokens=True,
130
  skip_prompt=True,
131
  )
@@ -144,41 +237,75 @@ def bot(
144
 
145
  # ํ˜„์žฌ ์ฑ„ํŒ…์— ํ‘œ์‹œ๋  ์ถ”๋ก  ๊ณผ์ •
146
  messages = rebuild_messages(history)
147
- for i, prepend in enumerate(rethink_prepends):
148
- if i > 0:
149
- messages[-1]["content"] += "\n\n"
150
- messages[-1]["content"] += prepend.format(question=question)
 
 
151
 
152
- num_tokens = int(
153
- max_num_tokens if ANSWER_MARKER not in prepend else final_num_tokens
154
- )
155
- t = threading.Thread(
156
- target=pipe,
157
- args=(messages,),
158
- kwargs=dict(
159
- max_new_tokens=num_tokens,
160
- streamer=streamer,
161
- do_sample=do_sample,
162
- temperature=temperature,
163
- ),
164
- )
165
- t.start()
 
 
 
 
 
166
 
167
- # ์ƒˆ ๋‚ด์šฉ์œผ๋กœ ํžˆ์Šคํ† ๋ฆฌ ์žฌ๊ตฌ์„ฑ
168
- history[-1].content += prepend.format(question=question)
169
- if ANSWER_MARKER in prepend:
170
- history[-1].metadata = {"title": "๐Ÿ’ญ ์‚ฌ๊ณ  ๊ณผ์ •", "status": "done"}
171
- # ์ƒ๊ฐ ์ข…๋ฃŒ, ์ด์ œ ๋‹ต๋ณ€์ž…๋‹ˆ๋‹ค (์ค‘๊ฐ„ ๋‹จ๊ณ„์— ๋Œ€ํ•œ ๋ฉ”ํƒ€๋ฐ์ดํ„ฐ ์—†์Œ)
172
- history.append(gr.ChatMessage(role="assistant", content=""))
173
- for token in streamer:
174
- history[-1].content += token
175
- history[-1].content = reformat_math(history[-1].content)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
  yield history
177
- t.join()
178
 
179
  yield history
180
 
181
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  with gr.Blocks(fill_height=True, title="ThinkFlow - Step-by-step Reasoning Service") as demo:
183
  # ์ƒ๋‹จ์— ํƒ€์ดํ‹€๊ณผ ์„ค๋ช… ์ถ”๊ฐ€
184
  gr.Markdown("""
@@ -193,6 +320,7 @@ with gr.Blocks(fill_height=True, title="ThinkFlow - Step-by-step Reasoning Servi
193
  scale=1,
194
  type="messages",
195
  latex_delimiters=latex_delimiters,
 
196
  )
197
  msg = gr.Textbox(
198
  submit_btn=True,
@@ -203,49 +331,63 @@ with gr.Blocks(fill_height=True, title="ThinkFlow - Step-by-step Reasoning Servi
203
  )
204
 
205
  with gr.Column(scale=1):
 
 
 
206
  # ๋ชจ๋ธ ์„ ํƒ ์„น์…˜ ์ถ”๊ฐ€
207
  gr.Markdown("""## ๋ชจ๋ธ ์„ ํƒ""")
208
- model_selector = gr.CheckboxGroup(
209
  choices=list(available_models.values()),
210
- value=[available_models["Qwen/Qwen2-1.5B-Instruct"]], # ๊ธฐ๋ณธ๊ฐ’
211
- label="์‚ฌ์šฉํ•  LLM ๋ชจ๋ธ ์„ ํƒ (๋ณต์ˆ˜ ์„ ํƒ ๊ฐ€๋Šฅ)",
212
  )
213
 
214
  # ๋ชจ๋ธ ๋กœ๋“œ ๋ฒ„ํŠผ
215
- load_model_btn = gr.Button("๋ชจ๋ธ ๋กœ๋“œ")
216
  model_status = gr.Textbox(label="๋ชจ๋ธ ์ƒํƒœ", interactive=False)
217
 
 
 
 
218
  gr.Markdown("""## ๋งค๊ฐœ๋ณ€์ˆ˜ ์กฐ์ •""")
219
- num_tokens = gr.Slider(
220
- 50,
221
- 4000,
222
- 2000,
223
- step=1,
224
- label="์ถ”๋ก  ๋‹จ๊ณ„๋‹น ์ตœ๋Œ€ ํ† ํฐ ์ˆ˜",
225
- interactive=True,
226
- )
227
- final_num_tokens = gr.Slider(
228
- 50,
229
- 4000,
230
- 2000,
231
- step=1,
232
- label="์ตœ์ข… ๋‹ต๋ณ€์˜ ์ตœ๋Œ€ ํ† ํฐ ์ˆ˜",
233
- interactive=True,
234
- )
235
- do_sample = gr.Checkbox(True, label="์ƒ˜ํ”Œ๋ง ์‚ฌ์šฉ")
236
- temperature = gr.Slider(0.1, 1.0, 0.7, step=0.1, label="์˜จ๋„")
 
237
 
238
  # ์„ ํƒ๋œ ๋ชจ๋ธ ๋กœ๋“œ ์ด๋ฒคํŠธ ์—ฐ๊ฒฐ
239
- def get_model_names(selected_models):
240
  # ํ‘œ์‹œ ์ด๋ฆ„์—์„œ ์›๋ž˜ ๋ชจ๋ธ ์ด๋ฆ„์œผ๋กœ ๋ณ€ํ™˜
241
  inverse_map = {v: k for k, v in available_models.items()}
242
- return [inverse_map[model] for model in selected_models]
243
 
244
  load_model_btn.click(
245
  lambda selected: load_model(get_model_names(selected)),
246
  inputs=[model_selector],
247
  outputs=[model_status]
248
  )
 
 
 
 
 
 
 
249
 
250
  # ์‚ฌ์šฉ์ž๊ฐ€ ๋ฉ”์‹œ์ง€๋ฅผ ์ œ์ถœํ•˜๋ฉด ๋ด‡์ด ์‘๋‹ตํ•ฉ๋‹ˆ๋‹ค
251
  msg.submit(
@@ -265,4 +407,12 @@ with gr.Blocks(fill_height=True, title="ThinkFlow - Step-by-step Reasoning Servi
265
  )
266
 
267
  if __name__ == "__main__":
268
- demo.queue().launch()
 
 
 
 
 
 
 
 
 
1
  import re
2
  import threading
3
+ import gc
4
+ import torch
5
 
6
  import gradio as gr
7
  import spaces
8
  import transformers
9
+ from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
10
 
11
+ # ๋ชจ๋ธ ๋ฉ”๋ชจ๋ฆฌ ๊ด€๋ฆฌ ๋ฐ ์ตœ์ ํ™”๋ฅผ ์œ„ํ•œ ์„ค์ •
12
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
13
+ DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
14
+ MAX_GPU_MEMORY = 80 * 1024 * 1024 * 1024 # 80GB A100 ๊ธฐ์ค€ (์‹ค์ œ ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ ๋ฉ”๋ชจ๋ฆฌ๋Š” ์ด๋ณด๋‹ค ์ ์Œ)
15
+
16
+ # ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ ๋ชจ๋ธ ๋ชฉ๋ก - A100์—์„œ ํšจ์œจ์ ์œผ๋กœ ์‹คํ–‰ ๊ฐ€๋Šฅํ•œ ๋ชจ๋ธ๋กœ ํ•„ํ„ฐ๋ง
17
  available_models = {
18
+ "meta-llama/Llama-3.2-3B-Instruct": "Llama 3.2 (3B)",
19
  "Hermes-3-Llama-3.1-8B": "Hermes 3 Llama 3.1 (8B)",
20
  "nvidia/Llama-3.1-Nemotron-Nano-8B-v1": "Nvidia Nemotron Nano (8B)",
21
  "mistralai/Mistral-Small-3.1-24B-Instruct-2503": "Mistral Small 3.1 (24B)",
 
22
  "google/gemma-3-27b-it": "Google Gemma 3 (27B)",
 
23
  "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen 2.5 Coder (32B)",
24
  "open-r1/OlympicCoder-32B": "Olympic Coder (32B)"
25
  }
26
 
27
+ # ๋ชจ๋ธ ๋กœ๋“œ์— ์‚ฌ์šฉ๋˜๋Š” ์ „์—ญ ๋ณ€์ˆ˜
28
  pipe = None
29
+ current_model_name = None
30
 
31
  # ์ตœ์ข… ๋‹ต๋ณ€์„ ๊ฐ์ง€ํ•˜๊ธฐ ์œ„ํ•œ ๋งˆ์ปค
32
  ANSWER_MARKER = "**๋‹ต๋ณ€**"
 
46
  f"\n{ANSWER_MARKER}\n",
47
  ]
48
 
 
49
  # ์ˆ˜์‹ ํ‘œ์‹œ ๋ฌธ์ œ ํ•ด๊ฒฐ์„ ์œ„ํ•œ ์„ค์ •
50
  latex_delimiters = [
51
  {"left": "$$", "right": "$$", "display": True},
52
  {"left": "$", "right": "$", "display": False},
53
  ]
54
 
55
+ # ๋ชจ๋ธ ํฌ๊ธฐ ๊ธฐ๋ฐ˜ ๊ตฌ์„ฑ - ๋ชจ๋ธ ํฌ๊ธฐ์— ๋”ฐ๋ฅธ ์ตœ์  ์„ค์ • ์ •์˜
56
+ MODEL_CONFIG = {
57
+ "small": { # <10B
58
+ "max_memory": {0: "20GiB"},
59
+ "offload": False,
60
+ "quantization": None
61
+ },
62
+ "medium": { # 10B-30B
63
+ "max_memory": {0: "40GiB"},
64
+ "offload": False,
65
+ "quantization": "4bit"
66
+ },
67
+ "large": { # >30B
68
+ "max_memory": {0: "70GiB"},
69
+ "offload": True,
70
+ "quantization": "4bit"
71
+ }
72
+ }
73
+
74
+ def get_model_size_category(model_name):
75
+ """๋ชจ๋ธ ํฌ๊ธฐ ์นดํ…Œ๊ณ ๋ฆฌ ๊ฒฐ์ •"""
76
+ if "3B" in model_name or "8B" in model_name:
77
+ return "small"
78
+ elif "24B" in model_name or "27B" in model_name:
79
+ return "medium"
80
+ elif "32B" in model_name or "70B" in model_name:
81
+ return "large"
82
+ else:
83
+ # ๊ธฐ๋ณธ๊ฐ’์œผ๋กœ medium ๋ฐ˜ํ™˜
84
+ return "medium"
85
+
86
+ def clear_gpu_memory():
87
+ """GPU ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ"""
88
+ global pipe
89
+
90
+ if pipe is not None:
91
+ del pipe
92
+ pipe = None
93
+
94
+ # CUDA ์บ์‹œ ์ •๋ฆฌ
95
+ gc.collect()
96
+ if torch.cuda.is_available():
97
+ torch.cuda.empty_cache()
98
+ torch.cuda.synchronize()
99
 
100
  def reformat_math(text):
101
+ """Gradio ๊ตฌ๋ฌธ(Katex)์„ ์‚ฌ์šฉํ•˜๋„๋ก MathJax ๊ตฌ๋ถ„ ๊ธฐํ˜ธ ์ˆ˜์ •."""
 
 
 
102
  text = re.sub(r"\\\[\s*(.*?)\s*\\\]", r"$$\1$$", text, flags=re.DOTALL)
103
  text = re.sub(r"\\\(\s*(.*?)\s*\\\)", r"$\1$", text, flags=re.DOTALL)
104
  return text
105
 
 
106
  def user_input(message, history: list):
107
  """์‚ฌ์šฉ์ž ์ž…๋ ฅ์„ ํžˆ์Šคํ† ๋ฆฌ์— ์ถ”๊ฐ€ํ•˜๊ณ  ์ž…๋ ฅ ํ…์ŠคํŠธ ์ƒ์ž ๋น„์šฐ๊ธฐ"""
108
  return "", history + [
109
  gr.ChatMessage(role="user", content=message.replace(ANSWER_MARKER, ""))
110
  ]
111
 
 
112
  def rebuild_messages(history: list):
113
  """์ค‘๊ฐ„ ์ƒ๊ฐ ๊ณผ์ • ์—†์ด ๋ชจ๋ธ์ด ์‚ฌ์šฉํ•  ํžˆ์Šคํ† ๋ฆฌ์—์„œ ๋ฉ”์‹œ์ง€ ์žฌ๊ตฌ์„ฑ"""
114
  messages = []
 
123
  messages.append({"role": h.role, "content": h.content})
124
  return messages
125
 
 
126
  def load_model(model_names):
127
+ """์„ ํƒ๋œ ๋ชจ๋ธ ์ด๋ฆ„์— ๋”ฐ๋ผ ๋ชจ๋ธ ๋กœ๋“œ (A100์— ์ตœ์ ํ™”๋œ ์„ค์ • ์‚ฌ์šฉ)"""
128
+ global pipe, current_model_name
129
+
130
+ # ๊ธฐ์กด ๋ชจ๋ธ ์ •๋ฆฌ
131
+ clear_gpu_memory()
132
 
133
  # ๋ชจ๋ธ์ด ์„ ํƒ๋˜์ง€ ์•Š์•˜์„ ๊ฒฝ์šฐ ๊ธฐ๋ณธ๊ฐ’ ์ง€์ •
134
  if not model_names:
135
+ model_name = "meta-llama/Llama-3.2-3B-Instruct" # ๋” ์ž‘์€ ๋ชจ๋ธ์„ ๊ธฐ๋ณธ๊ฐ’์œผ๋กœ ์‚ฌ์šฉ
136
  else:
137
+ # ์ฒซ ๋ฒˆ์งธ ์„ ํƒ๋œ ๋ชจ๋ธ ์‚ฌ์šฉ
138
  model_name = model_names[0]
139
 
140
+ # ๋ชจ๋ธ ํฌ๊ธฐ ์นดํ…Œ๊ณ ๋ฆฌ ํ™•์ธ
141
+ size_category = get_model_size_category(model_name)
142
+ config = MODEL_CONFIG[size_category]
 
 
 
143
 
144
+ # ๋ชจ๋ธ ๋กœ๋“œ (ํฌ๊ธฐ์— ๋”ฐ๋ผ ์ตœ์ ํ™”๋œ ์„ค์ • ์ ์šฉ)
145
+ try:
146
+ # BF16 ์ •๋ฐ€๋„ ์‚ฌ์šฉ (A100์— ์ตœ์ ํ™”)
147
+ if config["quantization"]:
148
+ # ์–‘์žํ™” ์ ์šฉ
149
+ from transformers import BitsAndBytesConfig
150
+ quantization_config = BitsAndBytesConfig(
151
+ load_in_4bit=config["quantization"] == "4bit",
152
+ bnb_4bit_compute_dtype=DTYPE
153
+ )
154
+
155
+ model = AutoModelForCausalLM.from_pretrained(
156
+ model_name,
157
+ device_map="auto",
158
+ max_memory=config["max_memory"],
159
+ torch_dtype=DTYPE,
160
+ quantization_config=quantization_config if config["quantization"] else None,
161
+ offload_folder="offload" if config["offload"] else None,
162
+ trust_remote_code=True
163
+ )
164
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
165
+
166
+ pipe = pipeline(
167
+ "text-generation",
168
+ model=model,
169
+ tokenizer=tokenizer,
170
+ torch_dtype=DTYPE,
171
+ device_map="auto"
172
+ )
173
+ else:
174
+ # ์–‘์žํ™” ์—†์ด ๋กœ๋“œ
175
+ pipe = pipeline(
176
+ "text-generation",
177
+ model=model_name,
178
+ device_map="auto",
179
+ torch_dtype=DTYPE,
180
+ trust_remote_code=True
181
+ )
182
+
183
+ current_model_name = model_name
184
+ return f"๋ชจ๋ธ '{model_name}'์ด(๊ฐ€) ์„ฑ๊ณต์ ์œผ๋กœ ๋กœ๋“œ๋˜์—ˆ์Šต๋‹ˆ๋‹ค. (์ตœ์ ํ™”: {size_category} ์นดํ…Œ๊ณ ๋ฆฌ)"
185
+
186
+ except Exception as e:
187
+ return f"๋ชจ๋ธ ๋กœ๋“œ ์‹คํŒจ: {str(e)}"
188
 
189
  @spaces.GPU
190
  def bot(
 
208
  yield history
209
  return
210
 
211
+ # ํ† ํฐ ๊ธธ์ด ์ž๋™ ์กฐ์ • (๋ชจ๋ธ ํฌ๊ธฐ์— ๋”ฐ๋ผ)
212
+ size_category = get_model_size_category(current_model_name)
213
+
214
+ # ๋Œ€ํ˜• ๋ชจ๋ธ์€ ํ† ํฐ ์ˆ˜๋ฅผ ์ค„์—ฌ ๋ฉ”๋ชจ๋ฆฌ ํšจ์œจ์„ฑ ํ–ฅ์ƒ
215
+ if size_category == "large":
216
+ max_num_tokens = min(max_num_tokens, 1000)
217
+ final_num_tokens = min(final_num_tokens, 1500)
218
+
219
  # ๋‚˜์ค‘์— ์Šค๋ ˆ๋“œ์—์„œ ํ† ํฐ์„ ์ŠคํŠธ๋ฆผ์œผ๋กœ ๊ฐ€์ ธ์˜ค๊ธฐ ์œ„ํ•จ
220
  streamer = transformers.TextIteratorStreamer(
221
+ pipe.tokenizer,
222
  skip_special_tokens=True,
223
  skip_prompt=True,
224
  )
 
237
 
238
  # ํ˜„์žฌ ์ฑ„ํŒ…์— ํ‘œ์‹œ๋  ์ถ”๋ก  ๊ณผ์ •
239
  messages = rebuild_messages(history)
240
+
241
+ try:
242
+ for i, prepend in enumerate(rethink_prepends):
243
+ if i > 0:
244
+ messages[-1]["content"] += "\n\n"
245
+ messages[-1]["content"] += prepend.format(question=question)
246
 
247
+ num_tokens = int(
248
+ max_num_tokens if ANSWER_MARKER not in prepend else final_num_tokens
249
+ )
250
+
251
+ # ์Šค๋ ˆ๋“œ์—์„œ ๋ชจ๋ธ ์‹คํ–‰
252
+ t = threading.Thread(
253
+ target=pipe,
254
+ args=(messages,),
255
+ kwargs=dict(
256
+ max_new_tokens=num_tokens,
257
+ streamer=streamer,
258
+ do_sample=do_sample,
259
+ temperature=temperature,
260
+ # ๋ฉ”๋ชจ๋ฆฌ ํšจ์œจ์„ฑ์„ ์œ„ํ•œ ์ถ”๊ฐ€ ํŒŒ๋ผ๋ฏธํ„ฐ
261
+ repetition_penalty=1.2, # ๋ฐ˜๋ณต ๋ฐฉ์ง€
262
+ use_cache=True, # KV ์บ์‹œ ์‚ฌ์šฉ
263
+ ),
264
+ )
265
+ t.start()
266
 
267
+ # ์ƒˆ ๋‚ด์šฉ์œผ๋กœ ํžˆ์Šคํ† ๋ฆฌ ์žฌ๊ตฌ์„ฑ
268
+ history[-1].content += prepend.format(question=question)
269
+ if ANSWER_MARKER in prepend:
270
+ history[-1].metadata = {"title": "๐Ÿ’ญ ์‚ฌ๊ณ  ๊ณผ์ •", "status": "done"}
271
+ # ์ƒ๊ฐ ์ข…๋ฃŒ, ์ด์ œ ๋‹ต๋ณ€์ž…๋‹ˆ๋‹ค (์ค‘๊ฐ„ ๋‹จ๊ณ„์— ๋Œ€ํ•œ ๋ฉ”ํƒ€๋ฐ์ดํ„ฐ ์—†์Œ)
272
+ history.append(gr.ChatMessage(role="assistant", content=""))
273
+
274
+ # ํ† ํฐ ์ŠคํŠธ๋ฆฌ๋ฐ
275
+ for token in streamer:
276
+ history[-1].content += token
277
+ history[-1].content = reformat_math(history[-1].content)
278
+ yield history
279
+
280
+ t.join()
281
+
282
+ # ๋Œ€ํ˜• ๋ชจ๋ธ์ธ ๊ฒฝ์šฐ ๊ฐ ๋‹จ๊ณ„ ํ›„ ๋ถ€๋ถ„์  ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ
283
+ if size_category == "large" and torch.cuda.is_available():
284
+ torch.cuda.empty_cache()
285
+
286
+ except Exception as e:
287
+ # ์˜ค๋ฅ˜ ๋ฐœ์ƒ์‹œ ์‚ฌ์šฉ์ž์—๊ฒŒ ์•Œ๋ฆผ
288
+ if len(history) > 0 and history[-1].role == "assistant":
289
+ history[-1].content += f"\n\nโš ๏ธ ์ฒ˜๋ฆฌ ์ค‘ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค: {str(e)}"
290
  yield history
 
291
 
292
  yield history
293
 
294
 
295
+ # ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ GPU ์ •๋ณด ํ‘œ์‹œ ํ•จ์ˆ˜
296
+ def get_gpu_info():
297
+ if not torch.cuda.is_available():
298
+ return "GPU๋ฅผ ์‚ฌ์šฉํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค."
299
+
300
+ gpu_info = []
301
+ for i in range(torch.cuda.device_count()):
302
+ gpu_name = torch.cuda.get_device_name(i)
303
+ total_memory = torch.cuda.get_device_properties(i).total_memory / 1024**3
304
+ gpu_info.append(f"GPU {i}: {gpu_name} ({total_memory:.1f} GB)")
305
+
306
+ return "\n".join(gpu_info)
307
+
308
+ # Gradio ์ธํ„ฐํŽ˜์ด์Šค
309
  with gr.Blocks(fill_height=True, title="ThinkFlow - Step-by-step Reasoning Service") as demo:
310
  # ์ƒ๋‹จ์— ํƒ€์ดํ‹€๊ณผ ์„ค๋ช… ์ถ”๊ฐ€
311
  gr.Markdown("""
 
320
  scale=1,
321
  type="messages",
322
  latex_delimiters=latex_delimiters,
323
+ height=600,
324
  )
325
  msg = gr.Textbox(
326
  submit_btn=True,
 
331
  )
332
 
333
  with gr.Column(scale=1):
334
+ # ํ•˜๋“œ์›จ์–ด ์ •๋ณด ํ‘œ์‹œ
335
+ gpu_info = gr.Markdown(f"**์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ ํ•˜๋“œ์›จ์–ด:**\n{get_gpu_info()}")
336
+
337
  # ๋ชจ๋ธ ์„ ํƒ ์„น์…˜ ์ถ”๊ฐ€
338
  gr.Markdown("""## ๋ชจ๋ธ ์„ ํƒ""")
339
+ model_selector = gr.Radio(
340
  choices=list(available_models.values()),
341
+ value=available_models["meta-llama/Llama-3.2-3B-Instruct"], # ์ž‘์€ ๋ชจ๋ธ์„ ๊ธฐ๋ณธ๊ฐ’์œผ๋กœ
342
+ label="์‚ฌ์šฉํ•  LLM ๋ชจ๋ธ ์„ ํƒ",
343
  )
344
 
345
  # ๋ชจ๋ธ ๋กœ๋“œ ๋ฒ„ํŠผ
346
+ load_model_btn = gr.Button("๋ชจ๋ธ ๋กœ๋“œ", variant="primary")
347
  model_status = gr.Textbox(label="๋ชจ๋ธ ์ƒํƒœ", interactive=False)
348
 
349
+ # ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ ๋ฒ„ํŠผ
350
+ clear_memory_btn = gr.Button("GPU ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ", variant="secondary")
351
+
352
  gr.Markdown("""## ๋งค๊ฐœ๋ณ€์ˆ˜ ์กฐ์ •""")
353
+ with gr.Accordion("๊ณ ๊ธ‰ ์„ค์ •", open=False):
354
+ num_tokens = gr.Slider(
355
+ 50,
356
+ 2000,
357
+ 1000, # ๊ธฐ๋ณธ๊ฐ’ ์ถ•์†Œ
358
+ step=50,
359
+ label="์ถ”๋ก  ๋‹จ๊ณ„๋‹น ์ตœ๋Œ€ ํ† ํฐ ์ˆ˜",
360
+ interactive=True,
361
+ )
362
+ final_num_tokens = gr.Slider(
363
+ 50,
364
+ 3000,
365
+ 1500, # ๊ธฐ๋ณธ๊ฐ’ ์ถ•์†Œ
366
+ step=50,
367
+ label="์ตœ์ข… ๋‹ต๋ณ€์˜ ์ตœ๋Œ€ ํ† ํฐ ์ˆ˜",
368
+ interactive=True,
369
+ )
370
+ do_sample = gr.Checkbox(True, label="์ƒ˜ํ”Œ๋ง ์‚ฌ์šฉ")
371
+ temperature = gr.Slider(0.1, 1.0, 0.7, step=0.1, label="์˜จ๋„")
372
 
373
  # ์„ ํƒ๋œ ๋ชจ๋ธ ๋กœ๋“œ ์ด๋ฒคํŠธ ์—ฐ๊ฒฐ
374
+ def get_model_names(selected_model):
375
  # ํ‘œ์‹œ ์ด๋ฆ„์—์„œ ์›๋ž˜ ๋ชจ๋ธ ์ด๋ฆ„์œผ๋กœ ๋ณ€ํ™˜
376
  inverse_map = {v: k for k, v in available_models.items()}
377
+ return [inverse_map[selected_model]] if selected_model else []
378
 
379
  load_model_btn.click(
380
  lambda selected: load_model(get_model_names(selected)),
381
  inputs=[model_selector],
382
  outputs=[model_status]
383
  )
384
+
385
+ # GPU ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ ์ด๋ฒคํŠธ ์—ฐ๊ฒฐ
386
+ clear_memory_btn.click(
387
+ lambda: (clear_gpu_memory(), "GPU ๋ฉ”๋ชจ๋ฆฌ๊ฐ€ ์ •๋ฆฌ๋˜์—ˆ์Šต๋‹ˆ๋‹ค."),
388
+ inputs=[],
389
+ outputs=[model_status]
390
+ )
391
 
392
  # ์‚ฌ์šฉ์ž๊ฐ€ ๋ฉ”์‹œ์ง€๋ฅผ ์ œ์ถœํ•˜๋ฉด ๋ด‡์ด ์‘๋‹ตํ•ฉ๋‹ˆ๋‹ค
393
  msg.submit(
 
407
  )
408
 
409
  if __name__ == "__main__":
410
+ # ๋””๋ฒ„๊น… ์ •๋ณด ์ถœ๋ ฅ
411
+ print(f"GPU ์‚ฌ์šฉ ๊ฐ€๋Šฅ: {torch.cuda.is_available()}")
412
+ if torch.cuda.is_available():
413
+ print(f"์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ GPU ๊ฐœ์ˆ˜: {torch.cuda.device_count()}")
414
+ print(f"ํ˜„์žฌ GPU: {torch.cuda.current_device()}")
415
+ print(f"GPU ์ด๋ฆ„: {torch.cuda.get_device_name(0)}")
416
+
417
+ # ํ ์‚ฌ์šฉ ๋ฐ ์•ฑ ์‹คํ–‰
418
+ demo.queue(max_size=10).launch()