openfree commited on
Commit
aea4015
ยท
verified ยท
1 Parent(s): 1399380

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -444
app.py CHANGED
@@ -1,50 +1,20 @@
1
  import re
2
  import threading
3
- import gc
4
- import os
5
- import torch
6
- import time
7
- import signal
8
  import gradio as gr
9
  import spaces
10
  import transformers
11
- from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
12
- from huggingface_hub import login
13
-
14
- # ๋ชจ๋ธ ๋ฉ”๋ชจ๋ฆฌ ๊ด€๋ฆฌ ๋ฐ ์ตœ์ ํ™”๋ฅผ ์œ„ํ•œ ์„ค์ •
15
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
16
- DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
17
- MAX_GPU_MEMORY = 80 * 1024 * 1024 * 1024 # 80GB A100 ๊ธฐ์ค€
18
-
19
- # ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ ๋ชจ๋ธ ๋ชฉ๋ก - ๋” ์ž‘์€ ๋ชจ๋ธ๋ถ€ํ„ฐ ์‹œ์ž‘ํ•˜๋„๋ก ๋ณ€๊ฒฝ
20
- available_models = {
21
- "google/gemma-2b": "Google Gemma (2B)", # ๋” ์ž‘์€ ๋ชจ๋ธ์„ ๊ธฐ๋ณธ์œผ๋กœ ์„ค์ •
22
- "mistralai/Mistral-7B-Instruct-v0.2": "Mistral 7B Instruct v0.2",
23
- "mistralai/Mistral-Small-3.1-24B-Base-2503": "Mistral Small 3.1 (24B)",
24
- "google/gemma-3-27b-it": "Google Gemma 3 (27B)",
25
- "Qwen/Qwen2.5-Coder-32B-Instruct": "Qwen 2.5 Coder (32B)",
26
- "open-r1/OlympicCoder-32B": "Olympic Coder (32B)"
27
- }
28
-
29
- # ๊ธฐ๋ณธ ๋ชจ๋ธ - ๊ฐ€์žฅ ์ž‘์€ ๋ชจ๋ธ๋กœ ์„ค์ •
30
- DEFAULT_MODEL_KEY = list(available_models.keys())[0]
31
- DEFAULT_MODEL_VALUE = available_models[DEFAULT_MODEL_KEY]
32
-
33
- # ๋ชจ๋ธ ๋กœ๋“œ์— ์‚ฌ์šฉ๋˜๋Š” ์ „์—ญ ๋ณ€์ˆ˜
34
- pipe = None
35
- current_model_name = None
36
- loading_in_progress = False
37
-
38
- # Hugging Face ํ† ํฐ์œผ๋กœ ๋กœ๊ทธ์ธ ์‹œ๋„
39
- try:
40
- hf_token = os.getenv("HF_TOKEN")
41
- if hf_token:
42
- login(token=hf_token)
43
- print("Hugging Face์— ์„ฑ๊ณต์ ์œผ๋กœ ๋กœ๊ทธ์ธํ–ˆ์Šต๋‹ˆ๋‹ค.")
44
- else:
45
- print("๊ฒฝ๊ณ : HF_TOKEN ํ™˜๊ฒฝ ๋ณ€์ˆ˜๊ฐ€ ์„ค์ •๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค.")
46
- except Exception as e:
47
- print(f"Hugging Face ๋กœ๊ทธ์ธ ์—๋Ÿฌ: {str(e)}")
48
 
49
  # ์ตœ์ข… ๋‹ต๋ณ€์„ ๊ฐ์ง€ํ•˜๊ธฐ ์œ„ํ•œ ๋งˆ์ปค
50
  ANSWER_MARKER = "**๋‹ต๋ณ€**"
@@ -64,69 +34,31 @@ rethink_prepends = [
64
  f"\n{ANSWER_MARKER}\n",
65
  ]
66
 
 
67
  # ์ˆ˜์‹ ํ‘œ์‹œ ๋ฌธ์ œ ํ•ด๊ฒฐ์„ ์œ„ํ•œ ์„ค์ •
68
  latex_delimiters = [
69
  {"left": "$$", "right": "$$", "display": True},
70
  {"left": "$", "right": "$", "display": False},
71
  ]
72
 
73
- # ๋ชจ๋ธ ํฌ๊ธฐ ๊ธฐ๋ฐ˜ ๊ตฌ์„ฑ - ๋ชจ๋ธ ํฌ๊ธฐ์— ๋”ฐ๋ฅธ ์ตœ์  ์„ค์ • ์ •์˜
74
- MODEL_CONFIG = {
75
- "small": { # <10B
76
- "max_memory": {0: "10GiB"},
77
- "offload": False,
78
- "quantization": None
79
- },
80
- "medium": { # 10B-30B
81
- "max_memory": {0: "30GiB"},
82
- "offload": False,
83
- "quantization": None
84
- },
85
- "large": { # >30B
86
- "max_memory": {0: "60GiB"},
87
- "offload": True,
88
- "quantization": None
89
- }
90
- }
91
-
92
- def get_model_size_category(model_name):
93
- """๋ชจ๋ธ ํฌ๊ธฐ ์นดํ…Œ๊ณ ๋ฆฌ ๊ฒฐ์ •"""
94
- if "2B" in model_name or "3B" in model_name or "7B" in model_name or "8B" in model_name:
95
- return "small"
96
- elif "15B" in model_name or "24B" in model_name or "27B" in model_name:
97
- return "medium"
98
- elif "32B" in model_name or "70B" in model_name:
99
- return "large"
100
- else:
101
- # ๊ธฐ๋ณธ๊ฐ’์œผ๋กœ small ๋ฐ˜ํ™˜ (์•ˆ์ „์„ ์œ„ํ•ด)
102
- return "small"
103
-
104
- def clear_gpu_memory():
105
- """GPU ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ"""
106
- global pipe
107
-
108
- if pipe is not None:
109
- del pipe
110
- pipe = None
111
-
112
- # CUDA ์บ์‹œ ์ •๋ฆฌ
113
- gc.collect()
114
- if torch.cuda.is_available():
115
- torch.cuda.empty_cache()
116
- torch.cuda.synchronize()
117
 
118
  def reformat_math(text):
119
- """Gradio ๊ตฌ๋ฌธ(Katex)์„ ์‚ฌ์šฉํ•˜๋„๋ก MathJax ๊ตฌ๋ถ„ ๊ธฐํ˜ธ ์ˆ˜์ •."""
 
 
 
120
  text = re.sub(r"\\\[\s*(.*?)\s*\\\]", r"$$\1$$", text, flags=re.DOTALL)
121
  text = re.sub(r"\\\(\s*(.*?)\s*\\\)", r"$\1$", text, flags=re.DOTALL)
122
  return text
123
 
 
124
  def user_input(message, history: list):
125
  """์‚ฌ์šฉ์ž ์ž…๋ ฅ์„ ํžˆ์Šคํ† ๋ฆฌ์— ์ถ”๊ฐ€ํ•˜๊ณ  ์ž…๋ ฅ ํ…์ŠคํŠธ ์ƒ์ž ๋น„์šฐ๊ธฐ"""
126
  return "", history + [
127
  gr.ChatMessage(role="user", content=message.replace(ANSWER_MARKER, ""))
128
  ]
129
 
 
130
  def rebuild_messages(history: list):
131
  """์ค‘๊ฐ„ ์ƒ๊ฐ ๊ณผ์ • ์—†์ด ๋ชจ๋ธ์ด ์‚ฌ์šฉํ•  ํžˆ์Šคํ† ๋ฆฌ์—์„œ ๋ฉ”์‹œ์ง€ ์žฌ๊ตฌ์„ฑ"""
132
  messages = []
@@ -141,122 +73,6 @@ def rebuild_messages(history: list):
141
  messages.append({"role": h.role, "content": h.content})
142
  return messages
143
 
144
- def load_model(model_names):
145
- """์„ ํƒ๋œ ๋ชจ๋ธ ์ด๋ฆ„์— ๋”ฐ๋ผ ๋ชจ๋ธ ๋กœ๋“œ (A100์— ์ตœ์ ํ™”๋œ ์„ค์ • ์‚ฌ์šฉ)"""
146
- global pipe, current_model_name, loading_in_progress
147
-
148
- # ์ด๋ฏธ ๋กœ๋”ฉ ์ค‘์ธ ๊ฒฝ์šฐ
149
- if loading_in_progress:
150
- return "๋‹ค๋ฅธ ๋ชจ๋ธ์ด ์ด๋ฏธ ๋กœ๋“œ ์ค‘์ž…๋‹ˆ๋‹ค. ์ž ์‹œ ๊ธฐ๋‹ค๋ ค์ฃผ์„ธ์š”."
151
-
152
- loading_in_progress = True
153
- status_messages = []
154
-
155
- try:
156
- # ๊ธฐ์กด ๋ชจ๋ธ ์ •๋ฆฌ
157
- clear_gpu_memory()
158
-
159
- # ๋ชจ๋ธ์ด ์„ ํƒ๋˜์ง€ ์•Š์•˜์„ ๊ฒฝ์šฐ ๊ธฐ๋ณธ๊ฐ’ ์ง€์ •
160
- if not model_names:
161
- model_name = DEFAULT_MODEL_KEY
162
- else:
163
- # ์ฒซ ๋ฒˆ์งธ ์„ ํƒ๋œ ๋ชจ๋ธ ์‚ฌ์šฉ
164
- model_name = model_names[0]
165
-
166
- # ๋ชจ๋ธ ํฌ๊ธฐ ์นดํ…Œ๊ณ ๋ฆฌ ํ™•์ธ
167
- size_category = get_model_size_category(model_name)
168
- config = MODEL_CONFIG[size_category]
169
-
170
- # ๋กœ๋”ฉ ์ƒํƒœ ์—…๋ฐ์ดํŠธ
171
- status_messages.append(f"๋ชจ๋ธ '{model_name}' ๋กœ๋“œ ์ค‘... (ํฌ๊ธฐ: {size_category})")
172
-
173
- # ๋ชจ๋ธ ๋กœ๋“œ (ํฌ๊ธฐ์— ๋”ฐ๋ผ ์ตœ์ ํ™”๋œ ์„ค์ • ์ ์šฉ)
174
- # HF_TOKEN ํ™˜๊ฒฝ ๋ณ€์ˆ˜ ํ™•์ธ
175
- hf_token = os.getenv("HF_TOKEN")
176
- # ๊ณตํ†ต ๋งค๊ฐœ๋ณ€์ˆ˜
177
- common_params = {
178
- "token": hf_token, # ์ ‘๊ทผ ์ œํ•œ ๋ชจ๋ธ์„ ์œ„ํ•œ ํ† ํฐ
179
- "trust_remote_code": True,
180
- }
181
-
182
- # BitsAndBytes ์‚ฌ์šฉ ์—ฌ๋ถ€ ํ™•์ธ
183
- try:
184
- import bitsandbytes
185
- has_bitsandbytes = True
186
- except ImportError:
187
- has_bitsandbytes = False
188
- status_messages.append("BitsAndBytes ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค. ์–‘์žํ™” ์—†์ด ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค.")
189
-
190
- # ์‹œ๊ฐ„ ์ œํ•œ ์„ค์ • (๋ชจ๋ธ ํฌ๊ธฐ์— ๋”ฐ๋ผ ๋‹ค๋ฅด๊ฒŒ)
191
- if size_category == "small":
192
- load_timeout = 180 # 3๋ถ„
193
- elif size_category == "medium":
194
- load_timeout = 300 # 5๋ถ„
195
- else:
196
- load_timeout = 600 # 10๋ถ„
197
-
198
- # ๋กœ๋”ฉ ์‹œ์ž‘ ์‹œ๊ฐ„
199
- start_time = time.time()
200
-
201
- # ์–‘์žํ™” ์„ค์ •์ด ํ•„์š”ํ•˜๊ณ  BitsAndBytes๋ฅผ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋Š” ๊ฒฝ์šฐ
202
- if config["quantization"] and has_bitsandbytes:
203
- # ์–‘์žํ™” ์ ์šฉ
204
- from transformers import BitsAndBytesConfig
205
- quantization_config = BitsAndBytesConfig(
206
- load_in_4bit=config["quantization"] == "4bit",
207
- bnb_4bit_compute_dtype=DTYPE
208
- )
209
-
210
- status_messages.append(f"๋ชจ๋ธ '{model_name}' ๋กœ๋“œ ์ค‘... (์–‘์žํ™” ์ ์šฉ)")
211
-
212
- model = AutoModelForCausalLM.from_pretrained(
213
- model_name,
214
- device_map="auto",
215
- max_memory=config["max_memory"],
216
- torch_dtype=DTYPE,
217
- quantization_config=quantization_config,
218
- offload_folder="offload" if config["offload"] else None,
219
- **common_params
220
- )
221
- tokenizer = AutoTokenizer.from_pretrained(model_name, **common_params)
222
-
223
- pipe = pipeline(
224
- "text-generation",
225
- model=model,
226
- tokenizer=tokenizer,
227
- torch_dtype=DTYPE,
228
- device_map="auto"
229
- )
230
- else:
231
- # ์–‘์žํ™” ์—†์ด ๋กœ๋“œ
232
- status_messages.append(f"๋ชจ๋ธ '{model_name}' ๋กœ๋“œ ์ค‘... (ํ‘œ์ค€ ๋ฐฉ์‹)")
233
-
234
- pipe = pipeline(
235
- "text-generation",
236
- model=model_name,
237
- device_map="auto",
238
- torch_dtype=DTYPE,
239
- **common_params
240
- )
241
-
242
- # ์‹œ๊ฐ„ ์ œํ•œ ์ดˆ๊ณผ ํ™•์ธ
243
- elapsed_time = time.time() - start_time
244
- if elapsed_time > load_timeout:
245
- clear_gpu_memory()
246
- loading_in_progress = False
247
- return f"๋ชจ๋ธ ๋กœ๋“œ ์‹œ๊ฐ„ ์ดˆ๊ณผ: {load_timeout}์ดˆ๊ฐ€ ์ง€๋‚ฌ์Šต๋‹ˆ๋‹ค. ๋‹ค์‹œ ์‹œ๋„ํ•˜์„ธ์š”."
248
-
249
- current_model_name = model_name
250
- loading_in_progress = False
251
- return f"๋ชจ๋ธ '{model_name}'์ด(๊ฐ€) ์„ฑ๊ณต์ ์œผ๋กœ ๋กœ๋“œ๋˜์—ˆ์Šต๋‹ˆ๋‹ค. (์ตœ์ ํ™”: {size_category}, ์†Œ์š”์‹œ๊ฐ„: {elapsed_time:.1f}์ดˆ)"
252
-
253
- except Exception as e:
254
- loading_in_progress = False
255
- error_msg = f"๋ชจ๋ธ ๋กœ๋“œ ์‹คํŒจ: {str(e)}"
256
- print(f"์˜ค๋ฅ˜: {error_msg}")
257
- return error_msg
258
- finally:
259
- loading_in_progress = False
260
 
261
  @spaces.GPU
262
  def bot(
@@ -267,187 +83,71 @@ def bot(
267
  temperature: float,
268
  ):
269
  """๋ชจ๋ธ์ด ์งˆ๋ฌธ์— ๋‹ต๋ณ€ํ•˜๋„๋ก ํ•˜๊ธฐ"""
270
- global pipe, current_model_name
271
-
272
- # ๋ชจ๋ธ์ด ๋กœ๋“œ๋˜์ง€ ์•Š์•˜๋‹ค๋ฉด ์˜ค๋ฅ˜ ๋ฉ”์‹œ์ง€ ํ‘œ์‹œ
273
- if pipe is None:
274
- history.append(
275
- gr.ChatMessage(
276
- role="assistant",
277
- content="๋ชจ๋ธ์ด ๋กœ๋“œ๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค. ํ•˜๋‚˜ ์ด์ƒ์˜ ๋ชจ๋ธ์„ ์„ ํƒํ•˜๊ณ  '๋ชจ๋ธ ๋กœ๋“œ' ๋ฒ„ํŠผ์„ ํด๋ฆญํ•ด ์ฃผ์„ธ์š”.",
278
- )
279
- )
280
- yield history
281
- return
282
 
283
- try:
284
- # ํ† ํฐ ๊ธธ์ด ์ž๋™ ์กฐ์ • (๋ชจ๋ธ ํฌ๊ธฐ์— ๋”ฐ๋ผ)
285
- size_category = get_model_size_category(current_model_name)
286
-
287
- # ๋Œ€ํ˜• ๋ชจ๋ธ์€ ํ† ํฐ ์ˆ˜๋ฅผ ์ค„์—ฌ ๋ฉ”๋ชจ๋ฆฌ ํšจ์œจ์„ฑ ํ–ฅ์ƒ
288
- if size_category == "large":
289
- max_num_tokens = min(max_num_tokens, 1000)
290
- final_num_tokens = min(final_num_tokens, 1500)
291
-
292
- # ๋‚˜์ค‘์— ์Šค๋ ˆ๋“œ์—์„œ ํ† ํฐ์„ ์ŠคํŠธ๋ฆผ์œผ๋กœ ๊ฐ€์ ธ์˜ค๊ธฐ ์œ„ํ•จ
293
- streamer = transformers.TextIteratorStreamer(
294
- pipe.tokenizer,
295
- skip_special_tokens=True,
296
- skip_prompt=True,
297
- )
298
 
299
- # ํ•„์š”ํ•œ ๊ฒฝ์šฐ ์ถ”๋ก ์— ์งˆ๋ฌธ์„ ๋‹ค์‹œ ์‚ฝ์ž…ํ•˜๊ธฐ ์œ„ํ•จ
300
- question = history[-1]["content"]
301
 
302
- # ๋ณด์กฐ์ž ๋ฉ”์‹œ์ง€ ์ค€๋น„
303
- history.append(
304
- gr.ChatMessage(
305
- role="assistant",
306
- content=str(""),
307
- metadata={"title": "๐Ÿง  ์ƒ๊ฐ ์ค‘...", "status": "pending"},
308
- )
309
  )
 
310
 
311
- # ํ˜„์žฌ ์ฑ„ํŒ…์— ํ‘œ์‹œ๋  ์ถ”๋ก  ๊ณผ์ •
312
- messages = rebuild_messages(history)
313
-
314
- # ํƒ€์ž„์•„์›ƒ ์„ค์ •
315
- class TimeoutError(Exception):
316
- pass
317
-
318
- def timeout_handler(signum, frame):
319
- raise TimeoutError("์š”์ฒญ ์ฒ˜๋ฆฌ ์‹œ๊ฐ„์ด ์ดˆ๊ณผ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.")
320
-
321
- # ๊ฐ ๋‹จ๊ณ„๋งˆ๋‹ค ์ตœ๋Œ€ 120์ดˆ ํƒ€์ž„์•„์›ƒ ์„ค์ •
322
- timeout_seconds = 120
323
-
324
- for i, prepend in enumerate(rethink_prepends):
325
- if i > 0:
326
- messages[-1]["content"] += "\n\n"
327
- messages[-1]["content"] += prepend.format(question=question)
328
 
329
- num_tokens = int(
330
- max_num_tokens if ANSWER_MARKER not in prepend else final_num_tokens
331
- )
332
-
333
- # ์Šค๋ ˆ๋“œ์—์„œ ๋ชจ๋ธ ์‹คํ–‰
334
- t = threading.Thread(
335
- target=pipe,
336
- args=(messages,),
337
- kwargs=dict(
338
- max_new_tokens=num_tokens,
339
- streamer=streamer,
340
- do_sample=do_sample,
341
- temperature=temperature,
342
- # ๋ฉ”๋ชจ๋ฆฌ ํšจ์œจ์„ฑ์„ ์œ„ํ•œ ์ถ”๊ฐ€ ํŒŒ๋ผ๋ฏธํ„ฐ
343
- repetition_penalty=1.2, # ๋ฐ˜๋ณต ๋ฐฉ์ง€
344
- use_cache=True, # KV ์บ์‹œ ์‚ฌ์šฉ
345
- ),
346
- )
347
- t.daemon = True # ๋ฐ๋ชฌ ์Šค๋ ˆ๋“œ๋กœ ์„ค์ •ํ•˜์—ฌ ๋ฉ”์ธ ์Šค๋ ˆ๋“œ๊ฐ€ ์ข…๋ฃŒ๋˜๋ฉด ํ•จ๊ป˜ ์ข…๋ฃŒ
348
- t.start()
349
-
350
- # ์ƒˆ ๋‚ด์šฉ์œผ๋กœ ํžˆ์Šคํ† ๋ฆฌ ์žฌ๊ตฌ์„ฑ
351
- history[-1].content += prepend.format(question=question)
352
- if ANSWER_MARKER in prepend:
353
- history[-1].metadata = {"title": "๐Ÿ’ญ ์‚ฌ๊ณ  ๊ณผ์ •", "status": "done"}
354
- # ์ƒ๊ฐ ์ข…๋ฃŒ, ์ด์ œ ๋‹ต๋ณ€์ž…๋‹ˆ๋‹ค (์ค‘๊ฐ„ ๋‹จ๊ณ„์— ๋Œ€ํ•œ ๋ฉ”ํƒ€๋ฐ์ดํ„ฐ ์—†์Œ)
355
- history.append(gr.ChatMessage(role="assistant", content=""))
356
-
357
- # ํƒ€์ž„์•„์›ƒ ์„ค์ • (Unix ์‹œ์Šคํ…œ์—์„œ๋งŒ ์ž‘๋™)
358
- try:
359
- if hasattr(signal, 'SIGALRM'):
360
- signal.signal(signal.SIGALRM, timeout_handler)
361
- signal.alarm(timeout_seconds)
362
-
363
- # ํ† ํฐ ์ŠคํŠธ๋ฆฌ๋ฐ
364
- token_count = 0
365
- for token in streamer:
366
- history[-1].content += token
367
- history[-1].content = reformat_math(history[-1].content)
368
- token_count += 1
369
-
370
- # 10๊ฐœ ํ† ํฐ๋งˆ๋‹ค yield (UI ์‘๋‹ต์„ฑ ํ–ฅ์ƒ)
371
- if token_count % 10 == 0:
372
- yield history
373
-
374
- # ๋‚จ์€ ๋‚ด์šฉ yield
375
- yield history
376
-
377
- # ํƒ€์ž„์•„์›ƒ ํ•ด์ œ
378
- if hasattr(signal, 'SIGALRM'):
379
- signal.alarm(0)
380
-
381
- except TimeoutError:
382
- if hasattr(signal, 'SIGALRM'):
383
- signal.alarm(0)
384
- history[-1].content += "\n\nโš ๏ธ ์‘๋‹ต ์ƒ์„ฑ ์‹œ๊ฐ„์ด ์ดˆ๊ณผ๋˜์—ˆ์Šต๋‹ˆ๋‹ค. ๋‹ค์Œ ๋‹จ๊ณ„๋กœ ์ง„ํ–‰ํ•ฉ๋‹ˆ๋‹ค."
385
- yield history
386
- continue
387
-
388
- # ์ตœ๋Œ€ 30์ดˆ ๋Œ€๊ธฐ ํ›„ ๋‹ค์Œ ๋‹จ๊ณ„๋กœ ์ง„ํ–‰
389
- join_start_time = time.time()
390
- while t.is_alive() and (time.time() - join_start_time) < 30:
391
- t.join(1) # 1์ดˆ๋งˆ๋‹ค ํ™•์ธ
392
-
393
- # ์Šค๋ ˆ๋“œ๊ฐ€ ์—ฌ์ „ํžˆ ์‹คํ–‰ ์ค‘์ด๋ฉด ๊ฐ•์ œ ์ง„ํ–‰
394
- if t.is_alive():
395
- history[-1].content += "\n\nโš ๏ธ ์‘๋‹ต ์ƒ์„ฑ์ด ์˜ˆ์ƒ๋ณด๋‹ค ์˜ค๋ž˜ ๊ฑธ๋ฆฝ๋‹ˆ๋‹ค. ๋‹ค์Œ ๋‹จ๊ณ„๋กœ ์ง„ํ–‰ํ•ฉ๋‹ˆ๋‹ค."
396
- yield history
397
-
398
- # ๋Œ€ํ˜• ๋ชจ๋ธ์ธ ๊ฒฝ์šฐ ๊ฐ ๋‹จ๊ณ„ ํ›„ ๋ถ€๋ถ„์  ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ
399
- if size_category == "large" and torch.cuda.is_available():
400
- torch.cuda.empty_cache()
401
-
402
- except Exception as e:
403
- # ์˜ค๋ฅ˜ ๋ฐœ์ƒ์‹œ ์‚ฌ์šฉ์ž์—๊ฒŒ ์•Œ๋ฆผ
404
- import traceback
405
- error_msg = f"\n\nโš ๏ธ ์ฒ˜๋ฆฌ ์ค‘ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ–ˆ์Šต๋‹ˆ๋‹ค: {str(e)}\n{traceback.format_exc()}"
406
-
407
- if len(history) > 0 and isinstance(history[-1], gr.ChatMessage) and history[-1].role == "assistant":
408
- history[-1].content += error_msg
409
- else:
410
- history.append(gr.ChatMessage(role="assistant", content=error_msg))
411
-
412
- yield history
413
 
414
  yield history
415
 
416
 
417
- # ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ GPU ์ •๋ณด ํ‘œ์‹œ ํ•จ์ˆ˜
418
- def get_gpu_info():
419
- if not torch.cuda.is_available():
420
- return "GPU๋ฅผ ์‚ฌ์šฉํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค."
421
-
422
- gpu_info = []
423
- for i in range(torch.cuda.device_count()):
424
- gpu_name = torch.cuda.get_device_name(i)
425
- total_memory = torch.cuda.get_device_properties(i).total_memory / 1024**3
426
- gpu_info.append(f"GPU {i}: {gpu_name} ({total_memory:.1f} GB)")
427
-
428
- return "\n".join(gpu_info)
429
-
430
- # ๋น„๋™๊ธฐ ๋Œ€์‹  ๋™๊ธฐ ๋ฐฉ์‹์œผ๋กœ ๋ชจ๋ธ ์ž๋™ ๋กœ๋“œ (๊ฐ„์†Œํ™”)
431
- def load_default_model():
432
- model_key = DEFAULT_MODEL_KEY
433
- return load_model([model_key])
434
-
435
- # Gradio ์ธํ„ฐํŽ˜์ด์Šค
436
- with gr.Blocks(fill_height=True, title="ThinkFlow - Step-by-step Reasoning Service") as demo:
437
- # ์ƒ๋‹จ์— ํƒ€์ดํ‹€๊ณผ ์„ค๋ช… ์ถ”๊ฐ€
438
- gr.Markdown("""
439
- # ThinkFlow
440
- ## A thought amplification service that implants step-by-step reasoning abilities into LLMs without model modification
441
- """)
442
-
443
  with gr.Row(scale=1):
444
  with gr.Column(scale=5):
445
- # ์ฑ„ํŒ… ์ธํ„ฐํŽ˜์ด์Šค
446
  chatbot = gr.Chatbot(
447
  scale=1,
448
  type="messages",
449
  latex_delimiters=latex_delimiters,
450
- height=600,
451
  )
452
  msg = gr.Textbox(
453
  submit_btn=True,
@@ -456,68 +156,27 @@ with gr.Blocks(fill_height=True, title="ThinkFlow - Step-by-step Reasoning Servi
456
  placeholder="์—ฌ๊ธฐ์— ์งˆ๋ฌธ์„ ์ž…๋ ฅํ•˜์„ธ์š”.",
457
  autofocus=True,
458
  )
459
-
460
  with gr.Column(scale=1):
461
- # ํ•˜๋“œ์›จ์–ด ์ •๋ณด ํ‘œ์‹œ
462
- gpu_info = gr.Markdown(f"**์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ ํ•˜๋“œ์›จ์–ด:**\n{get_gpu_info()}")
463
-
464
- # ๋ชจ๋ธ ์„ ํƒ ์„น์…˜ ์ถ”๊ฐ€
465
- gr.Markdown("""## ๋ชจ๋ธ ์„ ํƒ""")
466
- model_selector = gr.Radio(
467
- choices=list(available_models.values()),
468
- value=DEFAULT_MODEL_VALUE,
469
- label="์‚ฌ์šฉํ•  LLM ๋ชจ๋ธ ์„ ํƒ",
470
- )
471
-
472
- # ๋ชจ๋ธ ๋กœ๋“œ ๋ฒ„ํŠผ
473
- load_model_btn = gr.Button("๋ชจ๋ธ ๋กœ๋“œ", variant="primary")
474
- model_status = gr.Textbox(label="๋ชจ๋ธ ์ƒํƒœ", interactive=False, value="์‹œ์ž‘ ์‹œ ์ž‘์€ ๋ชจ๋ธ์„ ์ž๋™์œผ๋กœ ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค...")
475
-
476
- # ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ ๋ฒ„ํŠผ
477
- clear_memory_btn = gr.Button("GPU ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ", variant="secondary")
478
-
479
  gr.Markdown("""## ๋งค๊ฐœ๋ณ€์ˆ˜ ์กฐ์ •""")
480
- with gr.Accordion("๊ณ ๊ธ‰ ์„ค์ •", open=False):
481
- num_tokens = gr.Slider(
482
- 50,
483
- 2000,
484
- 1000,
485
- step=50,
486
- label="์ถ”๋ก  ๋‹จ๊ณ„๋‹น ์ตœ๋Œ€ ํ† ํฐ ์ˆ˜",
487
- interactive=True,
488
- )
489
- final_num_tokens = gr.Slider(
490
- 50,
491
- 3000,
492
- 1500,
493
- step=50,
494
- label="์ตœ์ข… ๋‹ต๋ณ€์˜ ์ตœ๋Œ€ ํ† ํฐ ์ˆ˜",
495
- interactive=True,
496
- )
497
- do_sample = gr.Checkbox(True, label="์ƒ˜ํ”Œ๋ง ์‚ฌ์šฉ")
498
- temperature = gr.Slider(0.1, 1.0, 0.7, step=0.1, label="์˜จ๋„")
499
-
500
- # ์‹œ์ž‘ ์‹œ ์ž๋™์œผ๋กœ ๋ชจ๋ธ ๋กœ๋“œ - ์ด์ œ ๋™๊ธฐ์ ์œผ๋กœ ์ฒ˜๋ฆฌ
501
- demo.load(load_default_model, [], [model_status])
502
-
503
- # ์„ ํƒ๋œ ๋ชจ๋ธ ๋กœ๋“œ ์ด๋ฒคํŠธ ์—ฐ๊ฒฐ
504
- def get_model_names(selected_model):
505
- # ํ‘œ์‹œ ์ด๋ฆ„์—์„œ ์›๋ž˜ ๋ชจ๋ธ ์ด๋ฆ„์œผ๋กœ ๋ณ€ํ™˜
506
- inverse_map = {v: k for k, v in available_models.items()}
507
- return [inverse_map[selected_model]] if selected_model else []
508
-
509
- load_model_btn.click(
510
- lambda selected: load_model(get_model_names(selected)),
511
- inputs=[model_selector],
512
- outputs=[model_status]
513
- )
514
-
515
- # GPU ๋ฉ”๋ชจ๋ฆฌ ์ •๋ฆฌ ์ด๋ฒคํŠธ ์—ฐ๊ฒฐ
516
- clear_memory_btn.click(
517
- lambda: (clear_gpu_memory(), "GPU ๋ฉ”๋ชจ๋ฆฌ๊ฐ€ ์ •๋ฆฌ๋˜์—ˆ์Šต๋‹ˆ๋‹ค."),
518
- inputs=[],
519
- outputs=[model_status]
520
- )
521
 
522
  # ์‚ฌ์šฉ์ž๊ฐ€ ๋ฉ”์‹œ์ง€๋ฅผ ์ œ์ถœํ•˜๋ฉด ๋ด‡์ด ์‘๋‹ตํ•ฉ๋‹ˆ๋‹ค
523
  msg.submit(
@@ -537,19 +196,4 @@ with gr.Blocks(fill_height=True, title="ThinkFlow - Step-by-step Reasoning Servi
537
  )
538
 
539
  if __name__ == "__main__":
540
- # ๋””๋ฒ„๊น… ์ •๋ณด ์ถœ๋ ฅ
541
- print(f"GPU ์‚ฌ์šฉ ๊ฐ€๋Šฅ: {torch.cuda.is_available()}")
542
- if torch.cuda.is_available():
543
- print(f"์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ GPU ๊ฐœ์ˆ˜: {torch.cuda.device_count()}")
544
- print(f"ํ˜„์žฌ GPU: {torch.cuda.current_device()}")
545
- print(f"GPU ์ด๋ฆ„: {torch.cuda.get_device_name(0)}")
546
-
547
- # HF_TOKEN ํ™˜๊ฒฝ ๋ณ€์ˆ˜ ํ™•์ธ
548
- hf_token = os.getenv("HF_TOKEN")
549
- if hf_token:
550
- print("HF_TOKEN ํ™˜๊ฒฝ ๋ณ€์ˆ˜๊ฐ€ ์„ค์ •๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค.")
551
- else:
552
- print("๊ฒฝ๊ณ : HF_TOKEN ํ™˜๊ฒฝ ๋ณ€์ˆ˜๊ฐ€ ์„ค์ •๋˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค. ์ œํ•œ๋œ ๋ชจ๋ธ์— ์ ‘๊ทผํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.")
553
-
554
- # ํ ์‚ฌ์šฉ ๋ฐ ์•ฑ ์‹คํ–‰
555
- demo.queue(max_size=10).launch()
 
1
  import re
2
  import threading
3
+
 
 
 
 
4
  import gradio as gr
5
  import spaces
6
  import transformers
7
+ from transformers import pipeline
8
+
9
+ # ๋ชจ๋ธ๊ณผ ํ† ํฌ๋‚˜์ด์ € ๋กœ๋”ฉ
10
+ model_name = "Qwen/Qwen2-1.5B-Instruct"
11
+ if gr.NO_RELOAD:
12
+ pipe = pipeline(
13
+ "text-generation",
14
+ model=model_name,
15
+ device_map="auto",
16
+ torch_dtype="auto",
17
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  # ์ตœ์ข… ๋‹ต๋ณ€์„ ๊ฐ์ง€ํ•˜๊ธฐ ์œ„ํ•œ ๋งˆ์ปค
20
  ANSWER_MARKER = "**๋‹ต๋ณ€**"
 
34
  f"\n{ANSWER_MARKER}\n",
35
  ]
36
 
37
+
38
  # ์ˆ˜์‹ ํ‘œ์‹œ ๋ฌธ์ œ ํ•ด๊ฒฐ์„ ์œ„ํ•œ ์„ค์ •
39
  latex_delimiters = [
40
  {"left": "$$", "right": "$$", "display": True},
41
  {"left": "$", "right": "$", "display": False},
42
  ]
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
  def reformat_math(text):
46
+ """Gradio ๊ตฌ๋ฌธ(Katex)์„ ์‚ฌ์šฉํ•˜๋„๋ก MathJax ๊ตฌ๋ถ„ ๊ธฐํ˜ธ ์ˆ˜์ •.
47
+ ์ด๊ฒƒ์€ Gradio์—์„œ ์ˆ˜ํ•™ ๊ณต์‹์„ ํ‘œ์‹œํ•˜๊ธฐ ์œ„ํ•œ ์ž„์‹œ ํ•ด๊ฒฐ์ฑ…์ž…๋‹ˆ๋‹ค. ํ˜„์žฌ๋กœ์„œ๋Š”
48
+ ๋‹ค๋ฅธ latex_delimiters๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์˜ˆ์ƒ๋Œ€๋กœ ์ž‘๋™ํ•˜๊ฒŒ ํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ์ฐพ์ง€ ๋ชปํ–ˆ์Šต๋‹ˆ๋‹ค...
49
+ """
50
  text = re.sub(r"\\\[\s*(.*?)\s*\\\]", r"$$\1$$", text, flags=re.DOTALL)
51
  text = re.sub(r"\\\(\s*(.*?)\s*\\\)", r"$\1$", text, flags=re.DOTALL)
52
  return text
53
 
54
+
55
  def user_input(message, history: list):
56
  """์‚ฌ์šฉ์ž ์ž…๋ ฅ์„ ํžˆ์Šคํ† ๋ฆฌ์— ์ถ”๊ฐ€ํ•˜๊ณ  ์ž…๋ ฅ ํ…์ŠคํŠธ ์ƒ์ž ๋น„์šฐ๊ธฐ"""
57
  return "", history + [
58
  gr.ChatMessage(role="user", content=message.replace(ANSWER_MARKER, ""))
59
  ]
60
 
61
+
62
  def rebuild_messages(history: list):
63
  """์ค‘๊ฐ„ ์ƒ๊ฐ ๊ณผ์ • ์—†์ด ๋ชจ๋ธ์ด ์‚ฌ์šฉํ•  ํžˆ์Šคํ† ๋ฆฌ์—์„œ ๋ฉ”์‹œ์ง€ ์žฌ๊ตฌ์„ฑ"""
64
  messages = []
 
73
  messages.append({"role": h.role, "content": h.content})
74
  return messages
75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  @spaces.GPU
78
  def bot(
 
83
  temperature: float,
84
  ):
85
  """๋ชจ๋ธ์ด ์งˆ๋ฌธ์— ๋‹ต๋ณ€ํ•˜๋„๋ก ํ•˜๊ธฐ"""
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
+ # ๋‚˜์ค‘์— ์Šค๋ ˆ๋“œ์—์„œ ํ† ํฐ์„ ์ŠคํŠธ๋ฆผ์œผ๋กœ ๊ฐ€์ ธ์˜ค๊ธฐ ์œ„ํ•จ
88
+ streamer = transformers.TextIteratorStreamer(
89
+ pipe.tokenizer, # pyright: ignore
90
+ skip_special_tokens=True,
91
+ skip_prompt=True,
92
+ )
 
 
 
 
 
 
 
 
 
93
 
94
+ # ํ•„์š”ํ•œ ๊ฒฝ์šฐ ์ถ”๋ก ์— ์งˆ๋ฌธ์„ ๋‹ค์‹œ ์‚ฝ์ž…ํ•˜๊ธฐ ์œ„ํ•จ
95
+ question = history[-1]["content"]
96
 
97
+ # ๋ณด์กฐ์ž ๋ฉ”์‹œ์ง€ ์ค€๋น„
98
+ history.append(
99
+ gr.ChatMessage(
100
+ role="assistant",
101
+ content=str(""),
102
+ metadata={"title": "๐Ÿง  ์ƒ๊ฐ ์ค‘...", "status": "pending"},
 
103
  )
104
+ )
105
 
106
+ # ํ˜„์žฌ ์ฑ„ํŒ…์— ํ‘œ์‹œ๋  ์ถ”๋ก  ๊ณผ์ •
107
+ messages = rebuild_messages(history)
108
+ for i, prepend in enumerate(rethink_prepends):
109
+ if i > 0:
110
+ messages[-1]["content"] += "\n\n"
111
+ messages[-1]["content"] += prepend.format(question=question)
 
 
 
 
 
 
 
 
 
 
 
112
 
113
+ num_tokens = int(
114
+ max_num_tokens if ANSWER_MARKER not in prepend else final_num_tokens
115
+ )
116
+ t = threading.Thread(
117
+ target=pipe,
118
+ args=(messages,),
119
+ kwargs=dict(
120
+ max_new_tokens=num_tokens,
121
+ streamer=streamer,
122
+ do_sample=do_sample,
123
+ temperature=temperature,
124
+ ),
125
+ )
126
+ t.start()
127
+
128
+ # ์ƒˆ ๋‚ด์šฉ์œผ๋กœ ํžˆ์Šคํ† ๋ฆฌ ์žฌ๊ตฌ์„ฑ
129
+ history[-1].content += prepend.format(question=question)
130
+ if ANSWER_MARKER in prepend:
131
+ history[-1].metadata = {"title": "๐Ÿ’ญ ์‚ฌ๊ณ  ๊ณผ์ •", "status": "done"}
132
+ # ์ƒ๊ฐ ์ข…๋ฃŒ, ์ด์ œ ๋‹ต๋ณ€์ž…๋‹ˆ๋‹ค (์ค‘๊ฐ„ ๋‹จ๊ณ„์— ๋Œ€ํ•œ ๋ฉ”ํƒ€๋ฐ์ดํ„ฐ ์—†์Œ)
133
+ history.append(gr.ChatMessage(role="assistant", content=""))
134
+ for token in streamer:
135
+ history[-1].content += token
136
+ history[-1].content = reformat_math(history[-1].content)
137
+ yield history
138
+ t.join()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
  yield history
141
 
142
 
143
+ with gr.Blocks(fill_height=True, title="๋ชจ๋“  LLM ๋ชจ๋ธ์— ์ถ”๋ก  ๋Šฅ๋ ฅ ๋ถ€์—ฌํ•˜๊ธฐ") as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  with gr.Row(scale=1):
145
  with gr.Column(scale=5):
146
+
147
  chatbot = gr.Chatbot(
148
  scale=1,
149
  type="messages",
150
  latex_delimiters=latex_delimiters,
 
151
  )
152
  msg = gr.Textbox(
153
  submit_btn=True,
 
156
  placeholder="์—ฌ๊ธฐ์— ์งˆ๋ฌธ์„ ์ž…๋ ฅํ•˜์„ธ์š”.",
157
  autofocus=True,
158
  )
 
159
  with gr.Column(scale=1):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
  gr.Markdown("""## ๋งค๊ฐœ๋ณ€์ˆ˜ ์กฐ์ •""")
161
+ num_tokens = gr.Slider(
162
+ 50,
163
+ 4000,
164
+ 2000,
165
+ step=1,
166
+ label="์ถ”๋ก  ๋‹จ๊ณ„๋‹น ์ตœ๋Œ€ ํ† ํฐ ์ˆ˜",
167
+ interactive=True,
168
+ )
169
+ final_num_tokens = gr.Slider(
170
+ 50,
171
+ 4000,
172
+ 2000,
173
+ step=1,
174
+ label="์ตœ์ข… ๋‹ต๋ณ€์˜ ์ตœ๋Œ€ ํ† ํฐ ์ˆ˜",
175
+ interactive=True,
176
+ )
177
+ do_sample = gr.Checkbox(True, label="์ƒ˜ํ”Œ๋ง ์‚ฌ์šฉ")
178
+ temperature = gr.Slider(0.1, 1.0, 0.7, step=0.1, label="์˜จ๋„")
179
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
 
181
  # ์‚ฌ์šฉ์ž๊ฐ€ ๋ฉ”์‹œ์ง€๋ฅผ ์ œ์ถœํ•˜๋ฉด ๋ด‡์ด ์‘๋‹ตํ•ฉ๋‹ˆ๋‹ค
182
  msg.submit(
 
196
  )
197
 
198
  if __name__ == "__main__":
199
+ demo.queue().launch()