openfree commited on
Commit
a4ca8e9
·
verified ·
1 Parent(s): 04bc27d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +371 -736
app.py CHANGED
@@ -1,369 +1,21 @@
1
  import re
2
  import threading
3
- import time
4
- import os
5
- import logging
6
- from datetime import datetime
7
- import torch
8
- import numpy as np
9
- from typing import List, Optional, Tuple, Dict
10
- import networkx as nx
11
 
12
  import gradio as gr
 
13
  import transformers
14
- from transformers import (
15
- pipeline,
16
- AutoModelForCausalLM,
17
- AutoTokenizer,
18
- BartForConditionalGeneration,
19
- BartTokenizer,
20
- BitsAndBytesConfig
21
- )
22
-
23
- # 로깅 설정
24
- logging.basicConfig(level=logging.INFO)
25
- logger = logging.getLogger(__name__)
26
-
27
- # ===================== RLRetrievalPolicy =====================
28
- class RLRetrievalPolicy:
29
- def __init__(self):
30
- self.policy_data = {}
31
- self.alpha = 0.5 # 유사도 vs. RL 점수 간 가중치
32
-
33
- def update_policy(self, contexts: List[str], reward: float):
34
- for ctx in contexts:
35
- if ctx not in self.policy_data:
36
- self.policy_data[ctx] = 0.0
37
- self.policy_data[ctx] += reward
38
-
39
- def re_rank(self, candidates: List[Tuple[float, str]]) -> List[str]:
40
- reweighted = []
41
- for sim, txt in candidates:
42
- rl_score = self.policy_data.get(txt, 0.0)
43
- reweighted_score = sim * (1 - self.alpha) + rl_score * self.alpha
44
- reweighted.append((reweighted_score, txt))
45
- reweighted.sort(key=lambda x: x[0], reverse=True)
46
- return [t for _, t in reweighted]
47
-
48
- # ===================== GraphMemory =====================
49
- class GraphMemory:
50
- def __init__(self):
51
- self.graph = nx.DiGraph()
52
- # 수학 문제 해결에 도움이 되는 기본 노드 추가
53
- self.add_node("수학", "수학 문제 해결을 위한 일반적인 접근법")
54
- self.add_node("대수학", "방정식, 함수, 비례 관계 등을 다루는 수학의 한 분야")
55
- self.add_node("기하학", "공간, 도형, 각도 등을 다루는 수학의 한 분야")
56
- self.add_node("산술", "기본적인 수 연산, 비율, 백분율 등을 다루는 분야")
57
- self.add_node("확률", "사건의 발생 가능성을 측정하는 수학의 한 분야")
58
-
59
- # 관계 설정
60
- self.add_edge("대수학", "수학")
61
- self.add_edge("기하학", "수학")
62
- self.add_edge("산술", "수학")
63
- self.add_edge("확률", "수학")
64
-
65
- def add_node(self, node_id: str, text: str = ""):
66
- self.graph.add_node(node_id, text=text)
67
-
68
- def add_edge(self, src: str, dst: str):
69
- self.graph.add_edge(src, dst)
70
-
71
- def get_text_by_node(self, node_id: str) -> str:
72
- return self.graph.nodes[node_id].get('text', "")
73
-
74
- def has_node(self, node_id: str) -> bool:
75
- return node_id in self.graph.nodes
76
-
77
- def search_nodes(self, keyword: str, max_nodes: int = 3) -> List[str]:
78
- matches = []
79
- for n in self.graph.nodes():
80
- node_text = self.get_text_by_node(n).lower()
81
- n_lower = n.lower()
82
- if keyword.lower() in node_text or keyword.lower() in n_lower:
83
- score = node_text.count(keyword.lower()) + n_lower.count(keyword.lower())
84
- matches.append((score, n))
85
- matches.sort(key=lambda x: x[0], reverse=True)
86
- top_nodes = [m[1] for m in matches[:max_nodes]]
87
- return top_nodes
88
-
89
- def get_connected_context(self, start_node: str, steps: int = 1) -> List[str]:
90
- contexts = []
91
- visited = set()
92
- queue = [(start_node, 0)]
93
- while queue:
94
- current, depth = queue.pop(0)
95
- if current not in visited:
96
- visited.add(current)
97
- contexts.append(self.get_text_by_node(current))
98
- if depth < steps:
99
- for neighbor in self.graph.successors(current):
100
- queue.append((neighbor, depth + 1))
101
- for neighbor in self.graph.predecessors(current):
102
- queue.append((neighbor, depth + 1))
103
- return contexts
104
-
105
- # ===================== SimpleSummarizer =====================
106
- class SimpleSummarizer:
107
- def __init__(self, model_name="facebook/bart-large-cnn"):
108
- self.model_name = model_name
109
- self.model = None
110
- self.tokenizer = None
111
-
112
- def load_summarization_model(self):
113
- if self.model is None:
114
- try:
115
- self.tokenizer = BartTokenizer.from_pretrained(self.model_name)
116
- self.model = BartForConditionalGeneration.from_pretrained(self.model_name)
117
- if torch.cuda.is_available():
118
- self.model = self.model.cuda()
119
- except Exception as e:
120
- logger.error(f"Error loading summarization model: {str(e)}")
121
- raise
122
-
123
- def summarize_text(self, text: str, max_length: int = 100) -> str:
124
- try:
125
- self.load_summarization_model()
126
- inputs = self.tokenizer([text], max_length=1024, return_tensors='pt', truncation=True)
127
- if torch.cuda.is_available():
128
- inputs = {k: v.cuda() for k, v in inputs.items()}
129
-
130
- with torch.no_grad():
131
- summary_ids = self.model.generate(
132
- inputs["input_ids"],
133
- num_beams=4,
134
- max_length=max_length,
135
- early_stopping=True
136
- )
137
- summary = self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)
138
- return summary
139
- except Exception as e:
140
- logger.error(f"Error in summarization: {str(e)}")
141
- return "요약을 생성할 수 없습니다."
142
-
143
- # ===================== SemanticMemory =====================
144
- class SemanticMemory:
145
- def __init__(self, max_entries: int = 4000):
146
- self.memories: List[dict] = []
147
- self.max_entries = max_entries
148
- self.rl_policy = RLRetrievalPolicy()
149
-
150
- def add_memory(self, text: str, embedding: torch.Tensor):
151
- if len(self.memories) >= self.max_entries:
152
- self.memories.pop(0)
153
- self.memories.append({
154
- 'text': text,
155
- 'embedding': embedding,
156
- 'timestamp': time.time()
157
- })
158
-
159
- def get_candidates(self, query_embedding: torch.Tensor) -> List[Tuple[float, str]]:
160
- candidates = []
161
- for mem in self.memories:
162
- if mem['embedding'].shape == query_embedding.shape:
163
- sim = torch.cosine_similarity(
164
- query_embedding.float(),
165
- mem['embedding'].float(),
166
- dim=-1
167
- )
168
- candidates.append((sim.item(), mem['text']))
169
- candidates.sort(key=lambda x: x[0], reverse=True)
170
- return candidates
171
-
172
- def get_relevant_context(self, query_embedding: torch.Tensor, top_k: int = 3) -> List[str]:
173
- candidates = self.get_candidates(query_embedding)
174
- re_ranked = self.rl_policy.re_rank(candidates)
175
- return re_ranked[:top_k]
176
-
177
- def update_retrieval_reward(self, texts: List[str], reward: float):
178
- self.rl_policy.update_policy(texts, reward)
179
-
180
- def clear(self):
181
- self.memories = []
182
-
183
- # ===================== GenericInferenceBuffer =====================
184
- MAX_TOKEN_BUFFER = 1024
185
-
186
- class GenericInferenceBuffer:
187
- def __init__(self, layer_idx: int, compression_rank: int = 128):
188
- self.layer_idx = layer_idx
189
- self.key_buffer: Optional[torch.Tensor] = None
190
- self.value_buffer: Optional[torch.Tensor] = None
191
- self.semantic_context: Optional[torch.Tensor] = None
192
- self.last_update: float = 0
193
- self.compression_rank = compression_rank
194
-
195
- def update_buffer(
196
- self,
197
- key: torch.Tensor,
198
- value: torch.Tensor,
199
- semantic_context: Optional[torch.Tensor] = None
200
- ):
201
- try:
202
- if self.key_buffer is None:
203
- self.key_buffer = key.detach().clone()
204
- self.value_buffer = value.detach().clone()
205
- if semantic_context is not None:
206
- self.semantic_context = semantic_context.detach().clone()
207
- else:
208
- self.key_buffer = torch.cat([self.key_buffer, key.detach()], dim=2)
209
- self.value_buffer = torch.cat([self.value_buffer, value.detach()], dim=2)
210
- if semantic_context is not None and self.semantic_context is not None:
211
- self.semantic_context = torch.cat([self.semantic_context, semantic_context.detach()], dim=0)
212
-
213
- if self.key_buffer.shape[2] > MAX_TOKEN_BUFFER:
214
- excess = self.key_buffer.shape[2] - MAX_TOKEN_BUFFER
215
- self.key_buffer = self.key_buffer[:, :, excess:, :]
216
- self.value_buffer = self.value_buffer[:, :, excess:, :]
217
- if self.semantic_context is not None:
218
- self.semantic_context = self.semantic_context[excess:, :]
219
-
220
- self.last_update = time.time()
221
-
222
- except Exception as e:
223
- logger.error(f"Buffer update error in layer {self.layer_idx}: {str(e)}")
224
-
225
- def compress_buffer_svd(self):
226
- if self.key_buffer is None or self.value_buffer is None:
227
- return
228
-
229
- try:
230
- k_shape = self.key_buffer.shape
231
- v_shape = self.value_buffer.shape
232
-
233
- k_2d = self.key_buffer.reshape(k_shape[0]*k_shape[1], k_shape[2]*k_shape[3]).float()
234
- v_2d = self.value_buffer.reshape(v_shape[0]*v_shape[1], v_shape[2]*v_shape[3]).float()
235
-
236
- device = k_2d.device
237
- k_2d_cpu = k_2d.cpu()
238
- v_2d_cpu = v_2d.cpu()
239
-
240
- U_k, S_k, V_k = torch.linalg.svd(k_2d_cpu, full_matrices=False)
241
- U_v, S_v, V_v = torch.linalg.svd(v_2d_cpu, full_matrices=False)
242
- rank_k = min(self.compression_rank, S_k.shape[0])
243
- rank_v = min(self.compression_rank, S_v.shape[0])
244
- k_approx = (U_k[:, :rank_k] * S_k[:rank_k]) @ V_k[:rank_k, :]
245
- v_approx = (U_v[:, :rank_v] * S_v[:rank_v]) @ V_v[:rank_v, :]
246
-
247
- k_approx = k_approx.to(device)
248
- v_approx = v_approx.to(device)
249
-
250
- self.key_buffer = k_approx.reshape(k_shape).type(self.key_buffer.dtype)
251
- self.value_buffer = v_approx.reshape(v_shape).type(self.value_buffer.dtype)
252
-
253
- except Exception as e:
254
- logger.error(f"SVD compression error in layer {self.layer_idx}: {str(e)}")
255
-
256
- def get_buffer(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
257
- return self.key_buffer, self.value_buffer, self.semantic_context
258
-
259
- def clear(self):
260
- self.key_buffer = None
261
- self.value_buffer = None
262
- self.semantic_context = None
263
- self.last_update = 0
264
-
265
- # ===================== InferenceBufferManager =====================
266
- class InferenceBufferManager:
267
- def __init__(self, num_layers: int, hidden_size: int):
268
- self.num_layers = num_layers
269
- self.hidden_size = hidden_size
270
- self.layer_buffers = [
271
- GenericInferenceBuffer(i, compression_rank=128) for i in range(num_layers)
272
- ]
273
- self.semantic_memory = SemanticMemory()
274
- self.graph_memory = GraphMemory()
275
- self.summarizer = SimpleSummarizer()
276
- self.summarize_threshold = 1500
277
- self.generated_tokens_count = 0
278
- self.compression_interval = 512
279
- self.token_count_since_compress = 0
280
-
281
- def _compute_semantic_embedding(self, key: Optional[torch.Tensor], value: Optional[torch.Tensor]) -> torch.Tensor:
282
- device = "cuda" if torch.cuda.is_available() else "cpu"
283
- if key is None or value is None:
284
- return torch.zeros((1, self.hidden_size), dtype=torch.float32, device=device)
285
- combined = key * value
286
- combined = combined.mean(dim=2)
287
- combined = combined.reshape(combined.shape[0], -1)
288
- combined = torch.nn.functional.normalize(combined, dim=-1)
289
- return combined
290
-
291
- def update_buffer(self, layer_outputs, current_tokens: List[int], semantic_context: torch.Tensor, tokenizer):
292
- try:
293
- if hasattr(layer_outputs, 'past_key_values'):
294
- for layer_idx, past_kv in enumerate(layer_outputs.past_key_values):
295
- if isinstance(past_kv, tuple) and len(past_kv) == 2:
296
- key, value = past_kv
297
- if key is not None and value is not None:
298
- self.layer_buffers[layer_idx].update_buffer(
299
- key.detach(),
300
- value.detach(),
301
- semantic_context
302
- )
303
- self.generated_tokens_count += len(current_tokens)
304
- self.token_count_since_compress += len(current_tokens)
305
-
306
- if self.token_count_since_compress >= self.compression_interval:
307
- self.compress_all_buffers()
308
- self.token_count_since_compress = 0
309
- except Exception as e:
310
- logger.error(f"Buffer update error: {str(e)}")
311
-
312
- def compress_all_buffers(self):
313
- for buf in self.layer_buffers:
314
- buf.compress_buffer_svd()
315
-
316
- def finalize_semantic_memory(self, tokenizer, generated_tokens: List[int]):
317
- if self.layer_buffers and len(self.layer_buffers) > 0 and self.layer_buffers[-1].key_buffer is not None:
318
- text_chunk = tokenizer.decode(generated_tokens, skip_special_tokens=True)
319
- key_buffer = self.layer_buffers[-1].key_buffer
320
- value_buffer = self.layer_buffers[-1].value_buffer
321
- embedding = self._compute_semantic_embedding(key_buffer, value_buffer)
322
- self.semantic_memory.add_memory(text_chunk, embedding)
323
-
324
- def get_relevant_context(self, query_embedding: torch.Tensor, top_k: int = 3) -> List[str]:
325
- candidates_sem = self.semantic_memory.get_candidates(query_embedding)
326
-
327
- # 키워드 추출 (간단한 구현)
328
- possible_keywords = ["수학", "대수학", "기하학", "산술", "확률"]
329
- text_candidates = []
330
- for kw in possible_keywords:
331
- nodes = self.graph_memory.search_nodes(kw)
332
- for n in nodes:
333
- context_list = self.graph_memory.get_connected_context(n, steps=1)
334
- cscore = 1.0
335
- for ctxt in context_list:
336
- text_candidates.append((cscore, ctxt))
337
-
338
- merged_candidates = candidates_sem + text_candidates
339
- re_ranked = self.semantic_memory.rl_policy.re_rank(merged_candidates)
340
- return re_ranked[:top_k]
341
-
342
- def update_retrieval_reward(self, contexts: List[str], reward: float):
343
- self.semantic_memory.update_retrieval_reward(contexts, reward)
344
-
345
- def maybe_summarize_memory(self):
346
- if self.generated_tokens_count < self.summarize_threshold:
347
- return
348
-
349
- all_text = "\n".join([m['text'] for m in self.semantic_memory.memories])
350
- if len(all_text) < 300:
351
- return
352
-
353
- summary = self.summarizer.summarize_text(all_text, max_length=120)
354
- device = "cuda" if torch.cuda.is_available() else "cpu"
355
- summary_embedding = torch.zeros((1, self.hidden_size), dtype=torch.float32, device=device)
356
-
357
- self.semantic_memory.clear()
358
- self.semantic_memory.add_memory(summary, summary_embedding)
359
- self.generated_tokens_count = 0
360
-
361
- def clear(self):
362
- for layer in self.layer_buffers:
363
- layer.clear()
364
- self.semantic_memory.clear()
365
-
366
- # ===================== Enhanced ThinkFlow Implementation =====================
367
 
368
  # 최종 답변을 감지하기 위한 마커
369
  ANSWER_MARKER = "**답변**"
@@ -380,15 +32,190 @@ rethink_prepends = [
380
  "이제 충분히 이해했다고 생각합니다 ",
381
  ]
382
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
383
  # 최종 답변 생성을 위한 프롬프트 추가
384
  final_answer_prompt = """
385
- 지금까지의 추론 과정을 바탕으로, 원래 질문에 사용된 언어로 답변하겠습니다:
 
386
  {question}
387
 
388
- 아래는 내가 추론한 결론입니다:
389
  {reasoning_conclusion}
390
 
391
- 추론을 기반으로 최종 답변:
392
  {ANSWER_MARKER}
393
  """
394
 
@@ -400,50 +227,15 @@ latex_delimiters = [
400
 
401
 
402
  def reformat_math(text):
403
- """Gradio 구문(Katex)을 사용하도록 MathJax 구분 기호 수정."""
 
 
 
404
  text = re.sub(r"\\\[\s*(.*?)\s*\\\]", r"$$\1$$", text, flags=re.DOTALL)
405
  text = re.sub(r"\\\(\s*(.*?)\s*\\\)", r"$\1$", text, flags=re.DOTALL)
406
  return text
407
 
408
 
409
- def extract_keywords(text: str) -> List[str]:
410
- """텍스트에서 간단한 키워드 추출 함수"""
411
- # 간단한 구현 - 실제로는 더 복잡한 NLP 기법을 사용할 수 있음
412
- common_math_keywords = [
413
- "수학", "대수학", "기하학", "산술", "확률", "공식", "방정식",
414
- "함수", "적분", "미분", "기하", "삼각형", "원", "각도", "비율",
415
- "비례", "평균", "분산", "표준편차"
416
- ]
417
-
418
- keywords = []
419
- for kw in common_math_keywords:
420
- if kw in text:
421
- keywords.append(kw)
422
-
423
- return keywords[:5] # 최대 5개 키워드만 반환
424
-
425
-
426
- def get_embedding_for_text(text: str, hidden_size: int = 768) -> torch.Tensor:
427
- """
428
- 텍스트를 위한 임시 임베딩 생성 함수
429
- 실제 구현에서는 적절한 언어 모델을 사용해야 함
430
- """
431
- # 임시 구현: 텍스트의 해시 값을 기반으로 한 임베딩
432
- device = "cuda" if torch.cuda.is_available() else "cpu"
433
- hash_val = hash(text)
434
- np.random.seed(hash_val)
435
-
436
- # 임의의 임베딩 생성
437
- embedding = np.random.rand(1, hidden_size).astype(np.float32)
438
-
439
- # 정규화
440
- norm = np.linalg.norm(embedding)
441
- if norm > 0:
442
- embedding = embedding / norm
443
-
444
- return torch.tensor(embedding, device=device)
445
-
446
-
447
  def user_input(message, history_original, history_thinking):
448
  """사용자 입력을 히��토리에 추가하고 입력 텍스트 상자 비우기"""
449
  return "", history_original + [
@@ -468,59 +260,39 @@ def rebuild_messages(history: list):
468
  return messages
469
 
470
 
471
- # 모델과 버퍼 매니저 초기화 함수
472
- def initialize_model_and_manager(model_name):
473
- """모델과 버퍼 매니저 초기화 함수"""
474
- try:
475
- pipe = pipeline(
476
- "text-generation",
477
- model=model_name,
478
- device_map="auto",
479
- torch_dtype="auto",
480
- )
481
-
482
- # 모델 구성에서 레이어 및 은닉 크기 정보 추출
483
- config = pipe.model.config
484
- if hasattr(config, "n_layer"):
485
- num_layers = config.n_layer
486
- elif hasattr(config, "num_layers"):
487
- num_layers = config.num_layers
488
- elif hasattr(config, "num_hidden_layers"):
489
- num_layers = config.num_hidden_layers
490
- else:
491
- num_layers = 12 # 기본값
492
-
493
- if hasattr(config, "n_embd"):
494
- hidden_size = config.n_embd
495
- elif hasattr(config, "hidden_size"):
496
- hidden_size = config.hidden_size
497
- else:
498
- hidden_size = 768 # 기본값
499
-
500
- # 버퍼 매니저 초기화
501
- buffer_manager = InferenceBufferManager(num_layers, hidden_size)
502
-
503
- return pipe, buffer_manager
504
- except Exception as e:
505
- logger.error(f"모델 초기화 오류: {str(e)}")
506
- raise
507
 
508
 
 
509
  def bot_original(
510
  history: list,
511
  max_num_tokens: int,
512
  do_sample: bool,
513
  temperature: float,
514
- pipe=None
515
  ):
516
  """원본 모델이 질문에 답변하도록 하기 (추론 과정 없이)"""
517
- if pipe is None:
518
- # 이 부분은 실제 구현에서는 전역 변수나 세션 상태로 관리해야 함
519
- return history
520
 
521
  # 나중에 스레드에서 토큰을 스트림으로 가져오기 위함
522
  streamer = transformers.TextIteratorStreamer(
523
- pipe.tokenizer,
524
  skip_special_tokens=True,
525
  skip_prompt=True,
526
  )
@@ -558,23 +330,19 @@ def bot_original(
558
  yield history
559
 
560
 
561
- def bot_thinking_enhanced(
 
562
  history: list,
563
  max_num_tokens: int,
564
  final_num_tokens: int,
565
  do_sample: bool,
566
  temperature: float,
567
- pipe=None,
568
- buffer_manager=None
569
  ):
570
- """추론 과정을 포함하여 모델이 질문에 답변하도록 하기 - DeepSeek 기능 통합"""
571
- if pipe is None or buffer_manager is None:
572
- # 이 부분은 실제 구현에서는 전역 변수나 세션 상태로 관리해야 함
573
- return history
574
 
575
  # 나중에 스레드에서 토큰을 스트림으로 가져오기 위함
576
  streamer = transformers.TextIteratorStreamer(
577
- pipe.tokenizer,
578
  skip_special_tokens=True,
579
  skip_prompt=True,
580
  )
@@ -582,26 +350,9 @@ def bot_thinking_enhanced(
582
  # 필요한 경우 추론에 질문을 다시 삽입하기 위함
583
  question = history[-1]["content"]
584
 
585
- # 쿼리 임베딩 생성
586
- query_embedding = get_embedding_for_text(question, buffer_manager.hidden_size)
587
-
588
- # 관련 컨텍스트 검색
589
- relevant_contexts = buffer_manager.get_relevant_context(query_embedding, top_k=3)
590
-
591
- # 키워드 추출 및 그래프 메모리에서 컨텍스트 가져오기
592
- keywords = extract_keywords(question)
593
- graph_contexts = []
594
- for keyword in keywords:
595
- nodes = buffer_manager.graph_memory.search_nodes(keyword)
596
- for node in nodes:
597
- contexts = buffer_manager.graph_memory.get_connected_context(node)
598
- graph_contexts.extend(contexts)
599
-
600
- # 모든 컨텍스트 병합
601
- all_contexts = relevant_contexts + graph_contexts
602
- all_contexts = list(set(all_contexts)) # 중복 제거
603
- all_contexts = all_contexts[:5] # 최대 5개 컨텍스트로 제한
604
-
605
  # 보조자 메시지 준비
606
  history.append(
607
  gr.ChatMessage(
@@ -614,22 +365,32 @@ def bot_thinking_enhanced(
614
  # 현재 채팅에 표시될 추론 과정
615
  messages = rebuild_messages(history)
616
 
617
- # 관련 컨텍스트가 있다면 메시지에 추가
618
- if all_contexts:
619
- context_str = "\n\n관련 컨텍스트:\n" + "\n".join(all_contexts)
620
- messages[-1]["content"] += context_str
621
- history[-1].content += context_str
622
-
623
  # 전체 추론 과정을 저장할 변수
624
  full_reasoning = ""
625
 
626
- # 생성된 토큰 추적을 위한 변수
627
- generated_tokens = []
 
628
 
629
  # 추론 단계 실행
630
  for i, prepend in enumerate(rethink_prepends):
631
  if i > 0:
632
  messages[-1]["content"] += "\n\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
633
  messages[-1]["content"] += prepend.format(question=question)
634
 
635
  t = threading.Thread(
@@ -645,61 +406,91 @@ def bot_thinking_enhanced(
645
  t.start()
646
 
647
  # 새 내용으로 히스토리 재구성
 
 
 
 
648
  history[-1].content += prepend.format(question=question)
649
- step_tokens = []
650
 
651
  for token in streamer:
652
  history[-1].content += token
653
  history[-1].content = reformat_math(history[-1].content)
654
- step_tokens.append(token)
655
- generated_tokens.append(token)
656
  yield history
657
  t.join()
658
 
659
  # 각 추론 단계의 결과를 full_reasoning에 저장
660
  full_reasoning = history[-1].content
661
 
662
- # 추론이 길어지면 중간 요약 생성
663
- if i > 0 and i % 3 == 0 and len(generated_tokens) > 500:
664
- try:
665
- summary = buffer_manager.summarizer.summarize_text(full_reasoning, max_length=150)
666
- summary_text = f"\n\n**중간 요약:**\n{summary}\n\n"
667
- history[-1].content += summary_text
668
- messages[-1]["content"] += summary_text
669
- yield history
670
- except Exception as e:
671
- logger.error(f"요약 생성 오류: {str(e)}")
672
 
673
- # KV 캐시 압축
674
- if i > 0 and i % 2 == 0:
675
- buffer_manager.compress_all_buffers()
676
-
677
- # 시맨틱 컨텍스트 업데이트
678
- step_text = "".join(step_tokens)
679
- step_embedding = get_embedding_for_text(step_text, buffer_manager.hidden_size)
680
- buffer_manager.semantic_memory.add_memory(step_text, step_embedding)
681
-
682
 
683
-
684
- # 추론 완료, 이제 최종 답변을 생성
685
  history[-1].metadata = {"title": "💭 사고 과정", "status": "done"}
686
 
687
- # 추론 과정을 시맨틱 메모리와 그래프 메모리에 저장
688
- full_embedding = get_embedding_for_text(full_reasoning, buffer_manager.hidden_size)
689
- buffer_manager.semantic_memory.add_memory(full_reasoning, full_embedding)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
690
 
691
- # 키워드에 대한 그래프 메모리 업데이트
692
- for keyword in keywords:
693
- if not buffer_manager.graph_memory.has_node(keyword):
694
- buffer_manager.graph_memory.add_node(keyword, f"{keyword}에 관한 개념: 이 주제에 대한 추론을 수행했습니다.")
695
- # 관련 노드와 연결
696
- for related_kw in keywords:
697
- if related_kw != keyword and buffer_manager.graph_memory.has_node(related_kw):
698
- buffer_manager.graph_memory.add_edge(keyword, related_kw)
699
 
700
- # 추론 과정에서 결론 부분을 추출 (마지막 1-2 문단 정도)
701
- reasoning_parts = full_reasoning.split("\n\n")
702
- reasoning_conclusion = "\n\n".join(reasoning_parts[-2:]) if len(reasoning_parts) > 2 else full_reasoning
 
 
 
 
703
 
704
  # 최종 답변 메시지 추가
705
  history.append(gr.ChatMessage(role="assistant", content=""))
@@ -711,7 +502,7 @@ def bot_thinking_enhanced(
711
  reasoning_conclusion=reasoning_conclusion,
712
  ANSWER_MARKER=ANSWER_MARKER
713
  )
714
- final_messages[-1]["content"] += final_prompt
715
 
716
  # 최종 답변 생성
717
  t = threading.Thread(
@@ -721,271 +512,115 @@ def bot_thinking_enhanced(
721
  max_new_tokens=final_num_tokens,
722
  streamer=streamer,
723
  do_sample=do_sample,
724
- temperature=temperature,
725
  ),
726
  )
727
  t.start()
728
 
729
  # 최종 답변 스트리밍
730
- final_tokens = []
731
  for token in streamer:
732
  history[-1].content += token
733
  history[-1].content = reformat_math(history[-1].content)
734
- final_tokens.append(token)
735
  yield history
736
  t.join()
737
-
738
- # 최종 답변을 시맨틱 메모리에 저장
739
- final_text = "".join(final_tokens)
740
- final_embedding = get_embedding_for_text(final_text, buffer_manager.hidden_size)
741
- buffer_manager.semantic_memory.add_memory(final_text, final_embedding)
742
-
743
- # 주기적 메모리 요약 체크
744
- buffer_manager.maybe_summarize_memory()
745
 
746
  yield history
747
 
748
 
749
- with gr.Blocks(fill_height=True, title="Enhanced ThinkFlow") as demo:
750
  # 제목과 설명
751
- gr.Markdown("# Enhanced ThinkFlow with DeepSeek Features")
752
- gr.Markdown("### 시맨틱 메모리, 그래프 메모리, KV 캐시 압축을 통해 향상된 LLM 추론 생성 플랫폼")
753
-
754
- # 모델 및 버퍼 매니저 초기화 (실제 구현에서는 세션 상태로 관리)
755
- model_name = "CohereForAI/c4ai-command-r7b-arabic-02-2025"
756
-
757
- # 세션 변수 (실제 구현에서는 gr.State() 사용)
758
- pipe = None
759
- buffer_manager = None
760
- current_contexts = []
761
-
762
- # 탭 인터페이스
763
- with gr.Tabs() as tabs:
764
- # 채팅 탭
765
- with gr.TabItem("통합 추론 인터페이스"):
766
- with gr.Row(scale=1):
767
- with gr.Column(scale=2):
768
- gr.Markdown("## Before (Original)")
769
- chatbot_original = gr.Chatbot(
770
- scale=1,
771
- type="messages",
772
- latex_delimiters=latex_delimiters,
773
- label="Original Model (No Reasoning)"
774
- )
775
-
776
- with gr.Column(scale=2):
777
- gr.Markdown("## After (Enhanced Thinking)")
778
- chatbot_thinking = gr.Chatbot(
779
- scale=1,
780
- type="messages",
781
- latex_delimiters=latex_delimiters,
782
- label="Model with Enhanced Reasoning"
783
- )
784
-
785
- with gr.Row():
786
- # msg 텍스트박스를 먼저 정의
787
- msg = gr.Textbox(
788
- submit_btn=True,
789
- label="",
790
- show_label=False,
791
- placeholder="여기에 질문을 입력하세요.",
792
- autofocus=True,
793
- )
794
-
795
- # 피드백 버튼
796
- with gr.Row():
797
- with gr.Column(scale=1):
798
- feedback_btn_pos = gr.Button("👍 이 추론이 도움이 되었습니다")
799
- with gr.Column(scale=1):
800
- feedback_btn_neg = gr.Button("👎 이 추론은 개선이 필요합니다")
801
- with gr.Column(scale=1):
802
- clear_memory_btn = gr.Button("🧹 메모리 초기화")
803
-
804
- # 메모리 시각화 탭
805
- with gr.TabItem("메모리 시각화"):
806
- gr.Markdown("## 시맨틱 메모리 내용")
807
- semantic_memory_display = gr.Textbox(
808
- label="현재 시맨틱 메모리 내용",
809
- placeholder="아직 메모리가 없습니다.",
810
- lines=10,
811
- max_lines=20,
812
- interactive=False
813
  )
814
-
815
- gr.Markdown("## 그래프 지식베이스")
816
- graph_memory_display = gr.Textbox(
817
- label="현재 그래프 메모리 내용",
818
- placeholder="아직 그래프 노드가 없습니다.",
819
- lines=10,
820
- max_lines=20,
821
- interactive=False
822
  )
823
 
 
 
 
 
 
 
 
 
 
 
824
  # 예제 섹션 - msg 변수 정의 이후에 배치
825
  with gr.Accordion("EXAMPLES", open=False):
826
  examples = gr.Examples(
827
  examples=[
828
  "[출처: MATH-500)] 처음 100개의 양의 정수 중에서 3, 4, 5로 나누어 떨어지는 수는 몇 개입니까?",
829
- "[출처: MATH-500)] 잉크의 땅에서 돈 시스템은 독특합니다. 트링킷 1개는 블링킷 4개와 같고, 블링킷 3개는 드링크 7개와 같습니다. 트링킷에서 드링크 56개의 가치는 얼마입니까?",
830
  "[출처: MATH-500)] 에이미, 벤, 크리스의 평균 나이는 6살입니다. 4년 전 크리스는 지금 에이미와 같은 나이였습니다. 4년 후 벤의 나이는 그때 에이미의 나이의 $\\frac{3}{5}$가 될 것입니다. 크리스는 지금 몇 살입니까?",
831
- "[출처: MATH-500)] 노란색과 파란색 구슬이 들어 있는 가방이 있습니다. 현재 파란색 구슬과 노란색 구슬의 비율은 4:3입니다. 파란색 구슬 5개를 더하고 노란색 구슬 3개를 제거하면 비율은 7:3이 됩니다. 더 넣기 전에 가방에 파란색 구슬이 몇 개 있었습니까?"
 
832
  ],
833
  inputs=msg
834
  )
835
 
836
- with gr.Accordion("매개변수 조정", open=False):
837
- with gr.Row():
838
- with gr.Column():
839
- model_dropdown = gr.Dropdown(
840
- ["CohereForAI/c4ai-command-r7b-arabic-02-2025", "meta-llama/Meta-Llama-3-8B-Instruct"],
841
- label="모델 선택",
842
- value="CohereForAI/c4ai-command-r7b-arabic-02-2025"
843
- )
844
-
845
- num_tokens = gr.Slider(
846
- 50,
847
- 4000,
848
- 2000,
849
- step=1,
850
- label="추론 단계당 최대 토큰 수",
851
- interactive=True,
852
- )
853
- final_num_tokens = gr.Slider(
854
- 50,
855
- 4000,
856
- 2000,
857
- step=1,
858
- label="최종 답변의 최대 토큰 수",
859
- interactive=True,
860
- )
861
-
862
- with gr.Column():
863
- do_sample = gr.Checkbox(True, label="샘플링 사용")
864
- temperature = gr.Slider(0.1, 1.0, 0.7, step=0.1, label="온도")
865
- memory_weight = gr.Slider(0.0, 1.0, 0.5, step=0.1, label="메모리 반영 가중치")
866
-
867
- # 피드백 처리 함수
868
- def process_positive_feedback():
869
- global buffer_manager, current_contexts
870
- if buffer_manager:
871
- buffer_manager.update_retrieval_reward(current_contexts, reward=1.0)
872
- return "피드백 감사합니다! 이 접근 방식을 향후 유사한 질문에 더 자주 사용하겠습니다."
873
-
874
- def process_negative_feedback():
875
- global buffer_manager, current_contexts
876
- if buffer_manager:
877
- buffer_manager.update_retrieval_reward(current_contexts, reward=-0.5)
878
- return "피드백 감사합니다! 이 접근 방식을 개선하겠습니다."
879
-
880
- def clear_memory():
881
- global buffer_manager
882
- if buffer_manager:
883
- buffer_manager.clear()
884
- return "메모리가 초기화되었습니다."
885
-
886
- def update_memory_displays():
887
- global buffer_manager
888
- if not buffer_manager:
889
- return "메모리가 초기화되지 않았습니다.", "그래프가 초기화되지 않았습니다."
890
-
891
- semantic_text = "현재 저장된 메모리:\n\n"
892
- for i, mem in enumerate(buffer_manager.semantic_memory.memories[:5]): # 최대 5개만 표시
893
- semantic_text += f"{i+1}. {mem['text'][:100]}...\n\n"
894
-
895
- graph_text = "현재 그래프 노드:\n\n"
896
- for node in buffer_manager.graph_memory.graph.nodes():
897
- node_text = buffer_manager.graph_memory.get_text_by_node(node)
898
- neighbors = list(buffer_manager.graph_memory.graph.neighbors(node))
899
- graph_text += f"노드: {node}\n설명: {node_text[:50]}...\n연결: {', '.join(neighbors[:3])}\n\n"
900
-
901
- return semantic_text, graph_text
902
-
903
- # 초기화 함수
904
- def initialize_models():
905
- global pipe, buffer_manager, model_name
906
- try:
907
- pipe, buffer_manager = initialize_model_and_manager(model_name)
908
- semantic_text, graph_text = update_memory_displays()
909
- return "모델이 초기화되었습니다.", semantic_text, graph_text
910
- except Exception as e:
911
- return f"모델 초기화 오류: {str(e)}", "", ""
912
-
913
- # 모델 선택 변경 시 처리
914
- def change_model(new_model_name):
915
- global model_name
916
- model_name = new_model_name
917
- status, semantic_text, graph_text = initialize_models()
918
- return status, semantic_text, graph_text
919
-
920
-
921
-
922
- # 초기화 함수 실행
923
- model_dropdown.change(
924
- change_model,
925
- [model_dropdown],
926
- [gr.Textbox(visible=False), semantic_memory_display, graph_memory_display]
927
- )
928
-
929
- # 피드백 버튼에 함수 연결
930
- feedback_btn_pos.click(process_positive_feedback, [], gr.Textbox(visible=False))
931
- feedback_btn_neg.click(process_negative_feedback, [], gr.Textbox(visible=False))
932
- clear_memory_btn.click(clear_memory, [], gr.Textbox(visible=False))
933
-
934
- # 탭 변경 시 메모리 디스플레이 업데이트
935
- tabs.change(update_memory_displays, [], [semantic_memory_display, graph_memory_display])
936
-
937
  # 사용자가 메시지를 제출하면 두 봇이 동시에 응답합니다
938
  msg.submit(
939
  user_input,
940
  [msg, chatbot_original, chatbot_thinking], # 입력
941
  [msg, chatbot_original, chatbot_thinking], # 출력
942
  ).then(
943
- lambda h, n, d, t, p: bot_original(h, n, d, t, p), # pipe 매개변수 추가
944
  [
945
- chatbot_original,
946
  num_tokens,
947
  do_sample,
948
  temperature,
949
- gr.Textbox(value=lambda: pipe, visible=False), # pipe 전달
950
  ],
951
  chatbot_original, # 출력에서 새 히스토리 저장
952
  ).then(
953
- lambda h, n, f, d, t, p, b: bot_thinking_enhanced(h, n, f, d, t, p, b), # 매개변수 추가
954
  [
955
  chatbot_thinking,
956
  num_tokens,
957
- final_num_tokens,
958
  do_sample,
959
  temperature,
960
- gr.Textbox(value=lambda: pipe, visible=False), # pipe 전달
961
- gr.Textbox(value=lambda: buffer_manager, visible=False), # buffer_manager 전달
962
  ],
963
  chatbot_thinking, # 출력에서 새 히스토리 저장
964
- ).then(
965
- update_memory_displays,
966
- [],
967
- [semantic_memory_display, graph_memory_display]
968
  )
969
 
970
- # 시작 시 모델 초기화를 위한 코드
971
- def load_on_startup():
972
- global pipe, buffer_manager
973
- try:
974
- # 기본 모델 초기화
975
- pipe, buffer_manager = initialize_model_and_manager(
976
- "CohereForAI/c4ai-command-r7b-arabic-02-2025"
977
- )
978
- logger.info("모델 및 버퍼 매니저가 성공적으로 초기화되었습니다.")
979
- except Exception as e:
980
- logger.error(f"시작 시 모델 초기화 실패: {str(e)}")
981
-
982
  if __name__ == "__main__":
983
- # 응용 프로그램 시작 전에 모델 초기화
984
- load_on_startup()
985
-
986
- # 대기열 및 서버 시작
987
- demo.queue().launch(
988
- share=False,
989
- debug=True,
990
- title="Enhanced ThinkFlow with DeepSeek Features"
991
- )
 
1
  import re
2
  import threading
3
+ from collections import Counter
 
 
 
 
 
 
 
4
 
5
  import gradio as gr
6
+ import spaces
7
  import transformers
8
+ from transformers import pipeline
9
+
10
+ # 모델과 토크나이저 로딩
11
+ model_name = "CohereForAI/c4ai-command-r7b-arabic-02-2025"
12
+ if gr.NO_RELOAD:
13
+ pipe = pipeline(
14
+ "text-generation",
15
+ model=model_name,
16
+ device_map="auto",
17
+ torch_dtype="auto",
18
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  # 최종 답변을 감지하기 위한 마커
21
  ANSWER_MARKER = "**답변**"
 
32
  "이제 충분히 이해했다고 생각합니다 ",
33
  ]
34
 
35
+ # 일반적인 추론 가이드 프롬프트
36
+ general_reasoning_guide = """
37
+ 이 문제를 해결하기 위한 체계적인 접근 방법을 사용하겠습니다:
38
+
39
+ 1. 문제에서 제공된 모든 정보와 조건을 명확히 이해합니다.
40
+ 2. 각 변수와 관계를 식별하고 필요한 방정식을 세웁니다.
41
+ 3. 단계별로 계산을 수행하며, 각 단계의 결과를 확인합니다.
42
+ 4. 중간 결과가 합리적인지 검토하며 진행합니다.
43
+ 5. 최종 답안을 도출하고 문제의 요구사항을 충족하는지 확인합니다.
44
+
45
+ 이제 문제를 풀어보겠습니다:
46
+ """
47
+
48
+ # 결과 추출 및 검증을 위한 함수들
49
+ def extract_calculation_results(reasoning_text):
50
+ """추론 과정에서 도출된 가능한 답안 결과를 추출합니다."""
51
+ # 수치 결과 패턴 (다양한 표현 방식 고려)
52
+ numeric_patterns = [
53
+ r'결과는 (\d+[\.,]?\d*)',
54
+ r'답(은|는|이) (\d+[\.,]?\d*)',
55
+ r'정답(은|는|이) (\d+[\.,]?\d*)',
56
+ r'답안(은|는|이) (\d+[\.,]?\d*)',
57
+ r'수익(은|는|이) (\d+[\.,]?\d*)',
58
+ r'값(은|는|이) (\d+[\.,]?\d*)',
59
+ r'결론(은|는|이) (\d+[\.,]?\d*)',
60
+ r'개수(는|은|가) (\d+[\.,]?\d*)',
61
+ r'총 (\d+[\.,]?\d*)개',
62
+ r'총액(은|는|이) (\d+[\.,]?\d*)',
63
+ r'총합(은|는|이) (\d+[\.,]?\d*)',
64
+ r'합계(는|은|가) (\d+[\.,]?\d*)',
65
+ r'=\s*(\d+[\.,]?\d*)\s*$',
66
+ r':\s*(\d+[\.,]?\d*)\s*$',
67
+ r'총계:\s*(\d+[\.,]?\d*)',
68
+ r'최종 결과:\s*(\d+[\.,]?\d*)',
69
+ r'최종 값:\s*(\d+[\.,]?\d*)',
70
+ r'최종 답변:\s*(\d+[\.,]?\d*)',
71
+ ]
72
+
73
+ # 단위를 포함한 패턴 (달러, 개, 세트 등)
74
+ unit_patterns = [
75
+ r'(\d+[\.,]?\d*)\s*(달러|원|유로|파운드|엔)',
76
+ r'(\d+[\.,]?\d*)\s*(개|명|세트|쌍|팀|그룹)',
77
+ r'(\d+[\.,]?\d*)\s*(분|시간|초|일|주|개월|년)',
78
+ r'(\d+[\.,]?\d*)\s*(미터|킬로미터|센티미터|인치|피트)',
79
+ r'(\d+[\.,]?\d*)\s*(그램|킬로그램|파운드|온스)',
80
+ ]
81
+
82
+ results = []
83
+
84
+ # 숫자 결과 추출
85
+ for pattern in numeric_patterns:
86
+ matches = re.findall(pattern, reasoning_text, re.IGNORECASE)
87
+ for match in matches:
88
+ if isinstance(match, tuple):
89
+ # 그룹이 여러 개인 경우 (첫 번째는 조사 등)
90
+ value = match[-1] # 마지막 그룹이 숫자값
91
+ else:
92
+ value = match
93
+ # 콤마 제거 및 소수점 처리
94
+ value = value.replace(',', '')
95
+ try:
96
+ if '.' in value:
97
+ results.append(float(value))
98
+ else:
99
+ results.append(int(value))
100
+ except ValueError:
101
+ continue
102
+
103
+ # 단위가 포함된 결과 추출
104
+ for pattern in unit_patterns:
105
+ matches = re.findall(pattern, reasoning_text, re.IGNORECASE)
106
+ for match in matches:
107
+ value = match[0].replace(',', '')
108
+ try:
109
+ if '.' in value:
110
+ results.append(float(value))
111
+ else:
112
+ results.append(int(value))
113
+ except ValueError:
114
+ continue
115
+
116
+ # 마지막 문단에서 숫자만 추출 (최종 답변에 가까운 숫자)
117
+ last_paragraph = reasoning_text.split('\n\n')[-1]
118
+ numbers_in_last = re.findall(r'(\d+[\.,]?\d*)', last_paragraph)
119
+ for num in numbers_in_last:
120
+ num = num.replace(',', '')
121
+ try:
122
+ if '.' in num:
123
+ results.append(float(num))
124
+ else:
125
+ results.append(int(num))
126
+ except ValueError:
127
+ continue
128
+
129
+ return results
130
+
131
+ def determine_best_result(results, full_reasoning):
132
+ """가장 신뢰할 수 있는 결과를 결정합니다."""
133
+ if not results:
134
+ return None
135
+
136
+ # 결과가 하나밖에 없으면 그것을 반환
137
+ if len(set(results)) == 1:
138
+ return results[0]
139
+
140
+ # 빈도 기반 분석 (가장 자주 등장한 결��가 신뢰성이 높을 가능성)
141
+ counter = Counter(results)
142
+ most_common = counter.most_common()
143
+
144
+ # 빈도가 높은 상위 결과들
145
+ top_results = [result for result, count in most_common if count >= most_common[0][1] * 0.8]
146
+
147
+ if len(top_results) == 1:
148
+ return top_results[0]
149
+
150
+ # 최종 결론 근처에 있는 결과에 더 높은 가중치 부여
151
+ paragraphs = full_reasoning.split('\n\n')
152
+ last_paragraphs = '\n\n'.join(paragraphs[-2:]) # 마지막 두 단락
153
+
154
+ # 마지막 단락에서 등장하는 결과 확인
155
+ final_results = [result for result in top_results if str(result) in last_paragraphs]
156
+ if final_results:
157
+ # 마지막 단락에서 가장 자주 등장한 결과
158
+ final_counter = Counter([r for r in results if r in final_results])
159
+ if final_counter:
160
+ return final_counter.most_common(1)[0][0]
161
+
162
+ # 수식과 함께 등장하는 결과 (예: "= 78", "총합: 78")
163
+ for result in top_results:
164
+ result_str = str(result)
165
+ if re.search(r'=\s*' + result_str + r'(?!\d)', full_reasoning) or \
166
+ re.search(r'결과[:는은이가]\s*' + result_str, full_reasoning) or \
167
+ re.search(r'답[:는은이가]\s*' + result_str, full_reasoning) or \
168
+ re.search(r'정답[:는은이가]\s*' + result_str, full_reasoning):
169
+ return result
170
+
171
+ # 위의 방법으로 결정할 수 없을 경우 가장 빈도가 높은 결과 반환
172
+ return most_common[0][0]
173
+
174
+ # 중간 결과를 요약하기 위한 프롬프트
175
+ structured_reasoning_prompt = """
176
+ 지금까지의 추론을 단계별로 정리해보겠습니다:
177
+
178
+ 1. 문제 분석:
179
+ - 주어진 정보: {given_info}
180
+ - 구해야 할 것: {goal}
181
+
182
+ 2. 계산 과정:
183
+ {calculation_steps}
184
+
185
+ 3. 현재까지의 결론:
186
+ {current_conclusion}
187
+
188
+ 이제 다음 단계로 진행하겠습니다.
189
+ """
190
+
191
+ # 최종 결과 검증을 위한 프롬프트
192
+ verification_prompt = """
193
+ 지금까지의 추론 과정에서 여러 결과가 도출되었습니다:
194
+ {different_results}
195
+
196
+ 이 중에서 가장 정확한 답변을 찾기 위해 계산 과정을 처음부터 다시 검토하겠습니다:
197
+
198
+ 1. 문제 분석:
199
+ - 주어진 정보: {given_info}
200
+ - 구해야 할 것: {goal}
201
+
202
+ 2. 단계별 계산 과정:
203
+ {calculation_steps}
204
+
205
+ 3. 결론:
206
+ 위 계산 과정을 통해 정확한 답은 {result}입니다.
207
+ """
208
+
209
  # 최종 답변 생성을 위한 프롬프트 추가
210
  final_answer_prompt = """
211
+ 지금까지의 체계적인 추론 과정을 종합하여, 원래 질문에 답변하겠습니다:
212
+
213
  {question}
214
 
215
+ 추론 과정을 검토한 결과, 다음과 같은 결론에 도달했습니다:
216
  {reasoning_conclusion}
217
 
218
+ 따라서 최종 답변은:
219
  {ANSWER_MARKER}
220
  """
221
 
 
227
 
228
 
229
  def reformat_math(text):
230
+ """Gradio 구문(Katex)을 사용하도록 MathJax 구분 기호 수정.
231
+ 이것은 Gradio에서 수학 공식을 표시하기 위한 임시 해결책입니다. 현재로서는
232
+ 다른 latex_delimiters를 사용하여 예상대로 작동하게 하는 방법을 찾지 못했습니다...
233
+ """
234
  text = re.sub(r"\\\[\s*(.*?)\s*\\\]", r"$$\1$$", text, flags=re.DOTALL)
235
  text = re.sub(r"\\\(\s*(.*?)\s*\\\)", r"$\1$", text, flags=re.DOTALL)
236
  return text
237
 
238
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239
  def user_input(message, history_original, history_thinking):
240
  """사용자 입력을 히��토리에 추가하고 입력 텍스트 상자 비우기"""
241
  return "", history_original + [
 
260
  return messages
261
 
262
 
263
+ def extract_info_from_question(question):
264
+ """문제에서 주어진 정보와 목표를 추출합니다."""
265
+ # 기본
266
+ given_info = "문제에서 제공된 모든 조건과 수치"
267
+ goal = "문제에서 요구하는 값이나 결과"
268
+
269
+ # 일반적인 정보 추출 패턴
270
+ if "몇 개" in question or "개수" in question:
271
+ goal = "특정 조건을 만족하는 항목의 개수"
272
+ elif "얼마" in question:
273
+ goal = "특정 값 또는 금액"
274
+ elif "나이" in question:
275
+ goal = "사람의 나이"
276
+ elif "확률" in question:
277
+ goal = "특정 사건의 확률"
278
+ elif "평균" in question:
279
+ goal = "값들의 평균"
280
+
281
+ return given_info, goal
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
282
 
283
 
284
+ @spaces.GPU
285
  def bot_original(
286
  history: list,
287
  max_num_tokens: int,
288
  do_sample: bool,
289
  temperature: float,
 
290
  ):
291
  """원본 모델이 질문에 답변하도록 하기 (추론 과정 없이)"""
 
 
 
292
 
293
  # 나중에 스레드에서 토큰을 스트림으로 가져오기 위함
294
  streamer = transformers.TextIteratorStreamer(
295
+ pipe.tokenizer, # pyright: ignore
296
  skip_special_tokens=True,
297
  skip_prompt=True,
298
  )
 
330
  yield history
331
 
332
 
333
+ @spaces.GPU
334
+ def bot_thinking(
335
  history: list,
336
  max_num_tokens: int,
337
  final_num_tokens: int,
338
  do_sample: bool,
339
  temperature: float,
 
 
340
  ):
341
+ """추론 과정을 포함하여 모델이 질문에 답변하도록 하기"""
 
 
 
342
 
343
  # 나중에 스레드에서 토큰을 스트림으로 가져오기 위함
344
  streamer = transformers.TextIteratorStreamer(
345
+ pipe.tokenizer, # pyright: ignore
346
  skip_special_tokens=True,
347
  skip_prompt=True,
348
  )
 
350
  # 필요한 경우 추론에 질문을 다시 삽입하기 위함
351
  question = history[-1]["content"]
352
 
353
+ # 문제에서 주어진 정보와 목표 추출
354
+ given_info, goal = extract_info_from_question(question)
355
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
356
  # 보조자 메시지 준비
357
  history.append(
358
  gr.ChatMessage(
 
365
  # 현재 채팅에 표시될 추론 과정
366
  messages = rebuild_messages(history)
367
 
 
 
 
 
 
 
368
  # 전체 추론 과정을 저장할 변수
369
  full_reasoning = ""
370
 
371
+ # 추론 과정에서 수집된 계산 단계 저장
372
+ calculation_steps = ""
373
+ current_conclusion = "아직 최종 결론에 도달하지 않았습니다."
374
 
375
  # 추론 단계 실행
376
  for i, prepend in enumerate(rethink_prepends):
377
  if i > 0:
378
  messages[-1]["content"] += "\n\n"
379
+
380
+ # 첫 단계에서 일반적인 추론 가이드 추가
381
+ if i == 0:
382
+ messages[-1]["content"] += general_reasoning_guide + "\n\n"
383
+
384
+ # 중간 단계에서 구조화된 추론 요약 추가
385
+ if i > 1 and calculation_steps:
386
+ structured_summary = structured_reasoning_prompt.format(
387
+ given_info=given_info,
388
+ goal=goal,
389
+ calculation_steps=calculation_steps,
390
+ current_conclusion=current_conclusion
391
+ )
392
+ messages[-1]["content"] += structured_summary + "\n\n"
393
+
394
  messages[-1]["content"] += prepend.format(question=question)
395
 
396
  t = threading.Thread(
 
406
  t.start()
407
 
408
  # 새 내용으로 히스토리 재구성
409
+ if i == 0:
410
+ history[-1].content += general_reasoning_guide + "\n\n"
411
+ if i > 1 and calculation_steps:
412
+ history[-1].content += structured_summary + "\n\n"
413
  history[-1].content += prepend.format(question=question)
 
414
 
415
  for token in streamer:
416
  history[-1].content += token
417
  history[-1].content = reformat_math(history[-1].content)
 
 
418
  yield history
419
  t.join()
420
 
421
  # 각 추론 단계의 결과를 full_reasoning에 저장
422
  full_reasoning = history[-1].content
423
 
424
+ # 계산 단계 추출 업데이트
425
+ new_content = history[-1].content.split(prepend.format(question=question))[-1]
426
+ if "=" in new_content or ":" in new_content:
427
+ # 계산 단계가 있는 것으로 간주
428
+ calculation_steps += f"\n - {new_content.strip()}"
 
 
 
 
 
429
 
430
+ # 단계에서 가능한 결론 추출
431
+ results = extract_calculation_results(new_content)
432
+ if results:
433
+ current_conclusion = f"현재 계산된 값: {results[-1]}"
 
 
 
 
 
434
 
435
+ # 추론 완료, 이제 최종 답변을 생성
 
436
  history[-1].metadata = {"title": "💭 사고 과정", "status": "done"}
437
 
438
+ # 추론 과정에서 도출된 모든 결과 추출
439
+ all_results = extract_calculation_results(full_reasoning)
440
+
441
+ # 결과가 있는 경우 검증 단계 추가
442
+ if all_results and len(set(all_results)) > 1:
443
+ # 결과별 빈도 계산
444
+ result_counter = Counter(all_results)
445
+ different_results = "\n".join([f"{result} (빈도: {freq}회)" for result, freq in result_counter.most_common()])
446
+
447
+ # 최적의 결과 결정
448
+ best_result = determine_best_result(all_results, full_reasoning)
449
+
450
+ # 모델에게 가장 정확한 결과 선택 요청
451
+ verify_prompt = verification_prompt.format(
452
+ different_results=different_results,
453
+ given_info=given_info,
454
+ goal=goal,
455
+ calculation_steps=calculation_steps,
456
+ result=best_result
457
+ )
458
+ messages[-1]["content"] += "\n\n" + verify_prompt
459
+
460
+ # 검증 단계 실행
461
+ t = threading.Thread(
462
+ target=pipe,
463
+ args=(messages,),
464
+ kwargs=dict(
465
+ max_new_tokens=max_num_tokens // 2,
466
+ streamer=streamer,
467
+ do_sample=False, # 확정적인 결과를 위해 샘플링 비활성화
468
+ temperature=0.3, # 낮은 온도 사용
469
+ ),
470
+ )
471
+ t.start()
472
+
473
+ history[-1].content += "\n\n" + verify_prompt
474
+ for token in streamer:
475
+ history[-1].content += token
476
+ history[-1].content = reformat_math(history[-1].content)
477
+ yield history
478
+ t.join()
479
+
480
+ # 검증 단계 후 full_reasoning 업데이트
481
+ full_reasoning = history[-1].content
482
 
483
+ # 최종 결과 결정
484
+ final_results = extract_calculation_results(full_reasoning)
485
+ best_result = determine_best_result(final_results, full_reasoning) if final_results else None
 
 
 
 
 
486
 
487
+ # 최종 결론 생성
488
+ if best_result is not None:
489
+ reasoning_conclusion = f"추론 과정을 종합한 결과, 정확한 답변은 {best_result}입니다."
490
+ else:
491
+ # 결과를 추출할 수 없는 경우의 대비책
492
+ reasoning_parts = full_reasoning.split("\n\n")
493
+ reasoning_conclusion = "\n\n".join(reasoning_parts[-2:]) if len(reasoning_parts) > 2 else full_reasoning
494
 
495
  # 최종 답변 메시지 추가
496
  history.append(gr.ChatMessage(role="assistant", content=""))
 
502
  reasoning_conclusion=reasoning_conclusion,
503
  ANSWER_MARKER=ANSWER_MARKER
504
  )
505
+ final_messages[-1]["content"] += "\n\n" + final_prompt
506
 
507
  # 최종 답변 생성
508
  t = threading.Thread(
 
512
  max_new_tokens=final_num_tokens,
513
  streamer=streamer,
514
  do_sample=do_sample,
515
+ temperature=temperature * 0.8, # 최종 답변에 더 확신을 주기 위해 온도 약간 낮춤
516
  ),
517
  )
518
  t.start()
519
 
520
  # 최종 답변 스트리밍
 
521
  for token in streamer:
522
  history[-1].content += token
523
  history[-1].content = reformat_math(history[-1].content)
 
524
  yield history
525
  t.join()
 
 
 
 
 
 
 
 
526
 
527
  yield history
528
 
529
 
530
+ with gr.Blocks(fill_height=True, title="Vidraft ThinkFlow") as demo:
531
  # 제목과 설명
532
+ gr.Markdown("# Vidraft ThinkFlow")
533
+ gr.Markdown("### 추론 기능이 없는 LLM 모델의 수정 없이도 추론 기능을 자동으로 적용하는 LLM 추론 생성 플랫폼")
534
+
535
+ with gr.Row(scale=1):
536
+ with gr.Column(scale=2):
537
+ gr.Markdown("## Before (Original)")
538
+ chatbot_original = gr.Chatbot(
539
+ scale=1,
540
+ type="messages",
541
+ latex_delimiters=latex_delimiters,
542
+ label="Original Model (No Reasoning)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
543
  )
544
+
545
+ with gr.Column(scale=2):
546
+ gr.Markdown("## After (Thinking)")
547
+ chatbot_thinking = gr.Chatbot(
548
+ scale=1,
549
+ type="messages",
550
+ latex_delimiters=latex_delimiters,
551
+ label="Model with Reasoning"
552
  )
553
 
554
+ with gr.Row():
555
+ # msg 텍스트박스를 먼저 정의
556
+ msg = gr.Textbox(
557
+ submit_btn=True,
558
+ label="",
559
+ show_label=False,
560
+ placeholder="여기에 질문을 입력하세요.",
561
+ autofocus=True,
562
+ )
563
+
564
  # 예제 섹션 - msg 변수 정의 이후에 배치
565
  with gr.Accordion("EXAMPLES", open=False):
566
  examples = gr.Examples(
567
  examples=[
568
  "[출처: MATH-500)] 처음 100개의 양의 정수 중에서 3, 4, 5로 나누어 떨어지는 수는 몇 개입니까?",
569
+ "[출처: MATH-500)] 잉크의 땅에서 돈 시스템은 독특합니다. 트링킛 1개는 블링킷 4개와 같고, 블링킷 3개는 드링크 7개와 같습니다. 트링킷에서 드링크 56개의 가치는 얼마입니까?",
570
  "[출처: MATH-500)] 에이미, 벤, 크리스의 평균 나이는 6살입니다. 4년 전 크리스는 지금 에이미와 같은 나이였습니다. 4년 후 벤의 나이는 그때 에이미의 나이의 $\\frac{3}{5}$가 될 것입니다. 크리스는 지금 몇 살입니까?",
571
+ "[출처: MATH-500)] 노란색과 파란색 구슬이 들어 있는 가방이 있습니다. 현재 파란색 구슬과 노란색 구슬의 비율은 4:3입니다. 파란색 구슬 5개를 더하고 노란색 구슬 3개를 제거하면 비율은 7:3이 됩니다. 더 넣기 전에 가방에 파란색 구슬이 몇 개 있었습니까?",
572
+ "수학 동아리에서 다가올 여행을 위한 기금 모금을 위해 베이킹 세일을 열고 있습니다. 3개에 54달러짜리 쿠키를 1달러에 판매하고, 20개에 컵케이크를 각각 2달러에 판매하고, 35개에 브라우니를 각각 1달러에 판매합니다. 수학 동아리에서 이 제품을 굽는 데 15달러가 들었다면, 수익은 얼마였을까요?"
573
  ],
574
  inputs=msg
575
  )
576
 
577
+ with gr.Row():
578
+ with gr.Column():
579
+ gr.Markdown("""## 매개변수 조정""")
580
+ num_tokens = gr.Slider(
581
+ 50,
582
+ 4000,
583
+ 2000,
584
+ step=1,
585
+ label="추론 단계당 최대 토큰 수",
586
+ interactive=True,
587
+ )
588
+ final_num_tokens = gr.Slider(
589
+ 50,
590
+ 4000,
591
+ 2000,
592
+ step=1,
593
+ label="최종 답변의 최대 토큰 수",
594
+ interactive=True,
595
+ )
596
+ do_sample = gr.Checkbox(True, label="샘플링 사용")
597
+ temperature = gr.Slider(0.1, 1.0, 0.7, step=0.1, label="온도")
598
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
599
  # 사용자가 메시지를 제출하면 두 봇이 동시에 응답합니다
600
  msg.submit(
601
  user_input,
602
  [msg, chatbot_original, chatbot_thinking], # 입력
603
  [msg, chatbot_original, chatbot_thinking], # 출력
604
  ).then(
605
+ bot_original,
606
  [
607
+ chatbot_original,
608
  num_tokens,
609
  do_sample,
610
  temperature,
 
611
  ],
612
  chatbot_original, # 출력에서 새 히스토리 저장
613
  ).then(
614
+ bot_thinking,
615
  [
616
  chatbot_thinking,
617
  num_tokens,
618
+ final_num_tokens,
619
  do_sample,
620
  temperature,
 
 
621
  ],
622
  chatbot_thinking, # 출력에서 새 히스토리 저장
 
 
 
 
623
  )
624
 
 
 
 
 
 
 
 
 
 
 
 
 
625
  if __name__ == "__main__":
626
+ demo.queue().launch() # title 매개변수 제거