Spaces:
Running
on
Zero
Running
on
Zero
File size: 39,171 Bytes
42f4126 7402b8f aea4015 42f4126 7402b8f 89817e2 42f4126 9e7af9a aea4015 42f4126 7402b8f 42f4126 aea4015 7402b8f 46ef1e4 42f4126 46ef1e4 42f4126 aea4015 42f4126 9e7af9a 42f4126 070daeb 7402b8f 46ef1e4 7402b8f 46ef1e4 7402b8f 46ef1e4 7402b8f 46ef1e4 7402b8f 42f4126 7402b8f 42f4126 7402b8f 42f4126 aea4015 7402b8f aea4015 42f4126 aea4015 7402b8f aea4015 42f4126 aea4015 42f4126 aea4015 9e7af9a 7402b8f 9e7af9a 7402b8f 9e7af9a aea4015 42f4126 aea4015 9e7af9a aea4015 7402b8f aea4015 7402b8f aea4015 9e7af9a 7402b8f 9e7af9a 7402b8f 9e7af9a 7402b8f 9e7af9a 7402b8f 9e7af9a 7402b8f 9e7af9a 7402b8f 42f4126 7402b8f 9e7af9a 7402b8f 46ef1e4 7402b8f 42f4126 46ef1e4 cd198e1 7402b8f 25c851c 04bc27d 7402b8f 04bc27d 7402b8f 04bc27d 7402b8f 04bc27d 25c851c 7402b8f 46ef1e4 42f4126 46ef1e4 7402b8f 46ef1e4 7402b8f 46ef1e4 7402b8f 46ef1e4 42f4126 7402b8f 42f4126 46ef1e4 42f4126 7402b8f 42f4126 7402b8f 46ef1e4 7402b8f 42f4126 7402b8f 42f4126 7402b8f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 |
import re
import threading
import time
import os
import logging
from datetime import datetime
import torch
import numpy as np
from typing import List, Optional, Tuple, Dict
import networkx as nx
import gradio as gr
import transformers
from transformers import (
pipeline,
AutoModelForCausalLM,
AutoTokenizer,
BartForConditionalGeneration,
BartTokenizer,
BitsAndBytesConfig
)
# λ‘κΉ
μ€μ
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# ===================== RLRetrievalPolicy =====================
class RLRetrievalPolicy:
def __init__(self):
self.policy_data = {}
self.alpha = 0.5 # μ μ¬λ vs. RL μ μ κ° κ°μ€μΉ
def update_policy(self, contexts: List[str], reward: float):
for ctx in contexts:
if ctx not in self.policy_data:
self.policy_data[ctx] = 0.0
self.policy_data[ctx] += reward
def re_rank(self, candidates: List[Tuple[float, str]]) -> List[str]:
reweighted = []
for sim, txt in candidates:
rl_score = self.policy_data.get(txt, 0.0)
reweighted_score = sim * (1 - self.alpha) + rl_score * self.alpha
reweighted.append((reweighted_score, txt))
reweighted.sort(key=lambda x: x[0], reverse=True)
return [t for _, t in reweighted]
# ===================== GraphMemory =====================
class GraphMemory:
def __init__(self):
self.graph = nx.DiGraph()
# μν λ¬Έμ ν΄κ²°μ λμμ΄ λλ κΈ°λ³Έ λ
Έλ μΆκ°
self.add_node("μν", "μν λ¬Έμ ν΄κ²°μ μν μΌλ°μ μΈ μ κ·Όλ²")
self.add_node("λμν", "λ°©μ μ, ν¨μ, λΉλ‘ κ΄κ³ λ±μ λ€λ£¨λ μνμ ν λΆμΌ")
self.add_node("κΈ°νν", "곡κ°, λν, κ°λ λ±μ λ€λ£¨λ μνμ ν λΆμΌ")
self.add_node("μ°μ ", "κΈ°λ³Έμ μΈ μ μ°μ°, λΉμ¨, λ°±λΆμ¨ λ±μ λ€λ£¨λ λΆμΌ")
self.add_node("νλ₯ ", "μ¬κ±΄μ λ°μ κ°λ₯μ±μ μΈ‘μ νλ μνμ ν λΆμΌ")
# κ΄κ³ μ€μ
self.add_edge("λμν", "μν")
self.add_edge("κΈ°νν", "μν")
self.add_edge("μ°μ ", "μν")
self.add_edge("νλ₯ ", "μν")
def add_node(self, node_id: str, text: str = ""):
self.graph.add_node(node_id, text=text)
def add_edge(self, src: str, dst: str):
self.graph.add_edge(src, dst)
def get_text_by_node(self, node_id: str) -> str:
return self.graph.nodes[node_id].get('text', "")
def has_node(self, node_id: str) -> bool:
return node_id in self.graph.nodes
def search_nodes(self, keyword: str, max_nodes: int = 3) -> List[str]:
matches = []
for n in self.graph.nodes():
node_text = self.get_text_by_node(n).lower()
n_lower = n.lower()
if keyword.lower() in node_text or keyword.lower() in n_lower:
score = node_text.count(keyword.lower()) + n_lower.count(keyword.lower())
matches.append((score, n))
matches.sort(key=lambda x: x[0], reverse=True)
top_nodes = [m[1] for m in matches[:max_nodes]]
return top_nodes
def get_connected_context(self, start_node: str, steps: int = 1) -> List[str]:
contexts = []
visited = set()
queue = [(start_node, 0)]
while queue:
current, depth = queue.pop(0)
if current not in visited:
visited.add(current)
contexts.append(self.get_text_by_node(current))
if depth < steps:
for neighbor in self.graph.successors(current):
queue.append((neighbor, depth + 1))
for neighbor in self.graph.predecessors(current):
queue.append((neighbor, depth + 1))
return contexts
# ===================== SimpleSummarizer =====================
class SimpleSummarizer:
def __init__(self, model_name="facebook/bart-large-cnn"):
self.model_name = model_name
self.model = None
self.tokenizer = None
def load_summarization_model(self):
if self.model is None:
try:
self.tokenizer = BartTokenizer.from_pretrained(self.model_name)
self.model = BartForConditionalGeneration.from_pretrained(self.model_name)
if torch.cuda.is_available():
self.model = self.model.cuda()
except Exception as e:
logger.error(f"Error loading summarization model: {str(e)}")
raise
def summarize_text(self, text: str, max_length: int = 100) -> str:
try:
self.load_summarization_model()
inputs = self.tokenizer([text], max_length=1024, return_tensors='pt', truncation=True)
if torch.cuda.is_available():
inputs = {k: v.cuda() for k, v in inputs.items()}
with torch.no_grad():
summary_ids = self.model.generate(
inputs["input_ids"],
num_beams=4,
max_length=max_length,
early_stopping=True
)
summary = self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)
return summary
except Exception as e:
logger.error(f"Error in summarization: {str(e)}")
return "μμ½μ μμ±ν μ μμ΅λλ€."
# ===================== SemanticMemory =====================
class SemanticMemory:
def __init__(self, max_entries: int = 4000):
self.memories: List[dict] = []
self.max_entries = max_entries
self.rl_policy = RLRetrievalPolicy()
def add_memory(self, text: str, embedding: torch.Tensor):
if len(self.memories) >= self.max_entries:
self.memories.pop(0)
self.memories.append({
'text': text,
'embedding': embedding,
'timestamp': time.time()
})
def get_candidates(self, query_embedding: torch.Tensor) -> List[Tuple[float, str]]:
candidates = []
for mem in self.memories:
if mem['embedding'].shape == query_embedding.shape:
sim = torch.cosine_similarity(
query_embedding.float(),
mem['embedding'].float(),
dim=-1
)
candidates.append((sim.item(), mem['text']))
candidates.sort(key=lambda x: x[0], reverse=True)
return candidates
def get_relevant_context(self, query_embedding: torch.Tensor, top_k: int = 3) -> List[str]:
candidates = self.get_candidates(query_embedding)
re_ranked = self.rl_policy.re_rank(candidates)
return re_ranked[:top_k]
def update_retrieval_reward(self, texts: List[str], reward: float):
self.rl_policy.update_policy(texts, reward)
def clear(self):
self.memories = []
# ===================== GenericInferenceBuffer =====================
MAX_TOKEN_BUFFER = 1024
class GenericInferenceBuffer:
def __init__(self, layer_idx: int, compression_rank: int = 128):
self.layer_idx = layer_idx
self.key_buffer: Optional[torch.Tensor] = None
self.value_buffer: Optional[torch.Tensor] = None
self.semantic_context: Optional[torch.Tensor] = None
self.last_update: float = 0
self.compression_rank = compression_rank
def update_buffer(
self,
key: torch.Tensor,
value: torch.Tensor,
semantic_context: Optional[torch.Tensor] = None
):
try:
if self.key_buffer is None:
self.key_buffer = key.detach().clone()
self.value_buffer = value.detach().clone()
if semantic_context is not None:
self.semantic_context = semantic_context.detach().clone()
else:
self.key_buffer = torch.cat([self.key_buffer, key.detach()], dim=2)
self.value_buffer = torch.cat([self.value_buffer, value.detach()], dim=2)
if semantic_context is not None and self.semantic_context is not None:
self.semantic_context = torch.cat([self.semantic_context, semantic_context.detach()], dim=0)
if self.key_buffer.shape[2] > MAX_TOKEN_BUFFER:
excess = self.key_buffer.shape[2] - MAX_TOKEN_BUFFER
self.key_buffer = self.key_buffer[:, :, excess:, :]
self.value_buffer = self.value_buffer[:, :, excess:, :]
if self.semantic_context is not None:
self.semantic_context = self.semantic_context[excess:, :]
self.last_update = time.time()
except Exception as e:
logger.error(f"Buffer update error in layer {self.layer_idx}: {str(e)}")
def compress_buffer_svd(self):
if self.key_buffer is None or self.value_buffer is None:
return
try:
k_shape = self.key_buffer.shape
v_shape = self.value_buffer.shape
k_2d = self.key_buffer.reshape(k_shape[0]*k_shape[1], k_shape[2]*k_shape[3]).float()
v_2d = self.value_buffer.reshape(v_shape[0]*v_shape[1], v_shape[2]*v_shape[3]).float()
device = k_2d.device
k_2d_cpu = k_2d.cpu()
v_2d_cpu = v_2d.cpu()
U_k, S_k, V_k = torch.linalg.svd(k_2d_cpu, full_matrices=False)
U_v, S_v, V_v = torch.linalg.svd(v_2d_cpu, full_matrices=False)
rank_k = min(self.compression_rank, S_k.shape[0])
rank_v = min(self.compression_rank, S_v.shape[0])
k_approx = (U_k[:, :rank_k] * S_k[:rank_k]) @ V_k[:rank_k, :]
v_approx = (U_v[:, :rank_v] * S_v[:rank_v]) @ V_v[:rank_v, :]
k_approx = k_approx.to(device)
v_approx = v_approx.to(device)
self.key_buffer = k_approx.reshape(k_shape).type(self.key_buffer.dtype)
self.value_buffer = v_approx.reshape(v_shape).type(self.value_buffer.dtype)
except Exception as e:
logger.error(f"SVD compression error in layer {self.layer_idx}: {str(e)}")
def get_buffer(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
return self.key_buffer, self.value_buffer, self.semantic_context
def clear(self):
self.key_buffer = None
self.value_buffer = None
self.semantic_context = None
self.last_update = 0
# ===================== InferenceBufferManager =====================
class InferenceBufferManager:
def __init__(self, num_layers: int, hidden_size: int):
self.num_layers = num_layers
self.hidden_size = hidden_size
self.layer_buffers = [
GenericInferenceBuffer(i, compression_rank=128) for i in range(num_layers)
]
self.semantic_memory = SemanticMemory()
self.graph_memory = GraphMemory()
self.summarizer = SimpleSummarizer()
self.summarize_threshold = 1500
self.generated_tokens_count = 0
self.compression_interval = 512
self.token_count_since_compress = 0
def _compute_semantic_embedding(self, key: Optional[torch.Tensor], value: Optional[torch.Tensor]) -> torch.Tensor:
device = "cuda" if torch.cuda.is_available() else "cpu"
if key is None or value is None:
return torch.zeros((1, self.hidden_size), dtype=torch.float32, device=device)
combined = key * value
combined = combined.mean(dim=2)
combined = combined.reshape(combined.shape[0], -1)
combined = torch.nn.functional.normalize(combined, dim=-1)
return combined
def update_buffer(self, layer_outputs, current_tokens: List[int], semantic_context: torch.Tensor, tokenizer):
try:
if hasattr(layer_outputs, 'past_key_values'):
for layer_idx, past_kv in enumerate(layer_outputs.past_key_values):
if isinstance(past_kv, tuple) and len(past_kv) == 2:
key, value = past_kv
if key is not None and value is not None:
self.layer_buffers[layer_idx].update_buffer(
key.detach(),
value.detach(),
semantic_context
)
self.generated_tokens_count += len(current_tokens)
self.token_count_since_compress += len(current_tokens)
if self.token_count_since_compress >= self.compression_interval:
self.compress_all_buffers()
self.token_count_since_compress = 0
except Exception as e:
logger.error(f"Buffer update error: {str(e)}")
def compress_all_buffers(self):
for buf in self.layer_buffers:
buf.compress_buffer_svd()
def finalize_semantic_memory(self, tokenizer, generated_tokens: List[int]):
if self.layer_buffers and len(self.layer_buffers) > 0 and self.layer_buffers[-1].key_buffer is not None:
text_chunk = tokenizer.decode(generated_tokens, skip_special_tokens=True)
key_buffer = self.layer_buffers[-1].key_buffer
value_buffer = self.layer_buffers[-1].value_buffer
embedding = self._compute_semantic_embedding(key_buffer, value_buffer)
self.semantic_memory.add_memory(text_chunk, embedding)
def get_relevant_context(self, query_embedding: torch.Tensor, top_k: int = 3) -> List[str]:
candidates_sem = self.semantic_memory.get_candidates(query_embedding)
# ν€μλ μΆμΆ (κ°λ¨ν ꡬν)
possible_keywords = ["μν", "λμν", "κΈ°νν", "μ°μ ", "νλ₯ "]
text_candidates = []
for kw in possible_keywords:
nodes = self.graph_memory.search_nodes(kw)
for n in nodes:
context_list = self.graph_memory.get_connected_context(n, steps=1)
cscore = 1.0
for ctxt in context_list:
text_candidates.append((cscore, ctxt))
merged_candidates = candidates_sem + text_candidates
re_ranked = self.semantic_memory.rl_policy.re_rank(merged_candidates)
return re_ranked[:top_k]
def update_retrieval_reward(self, contexts: List[str], reward: float):
self.semantic_memory.update_retrieval_reward(contexts, reward)
def maybe_summarize_memory(self):
if self.generated_tokens_count < self.summarize_threshold:
return
all_text = "\n".join([m['text'] for m in self.semantic_memory.memories])
if len(all_text) < 300:
return
summary = self.summarizer.summarize_text(all_text, max_length=120)
device = "cuda" if torch.cuda.is_available() else "cpu"
summary_embedding = torch.zeros((1, self.hidden_size), dtype=torch.float32, device=device)
self.semantic_memory.clear()
self.semantic_memory.add_memory(summary, summary_embedding)
self.generated_tokens_count = 0
def clear(self):
for layer in self.layer_buffers:
layer.clear()
self.semantic_memory.clear()
# ===================== Enhanced ThinkFlow Implementation =====================
# μ΅μ’
λ΅λ³μ κ°μ§νκΈ° μν λ§μ»€
ANSWER_MARKER = "**λ΅λ³**"
# λ¨κ³λ³ μΆλ‘ μ μμνλ λ¬Έμ₯λ€
rethink_prepends = [
"μ, μ΄μ λ€μμ νμ
ν΄μΌ ν©λλ€ ",
"μ μκ°μλ ",
"μ μλ§μ, μ μκ°μλ ",
"λ€μ μ¬νμ΄ λ§λμ§ νμΈν΄ λ³΄κ² μ΅λλ€ ",
"λν κΈ°μ΅ν΄μΌ ν κ²μ ",
"λ λ€λ₯Έ μ£Όλͺ©ν μ μ ",
"κ·Έλ¦¬κ³ μ λ λ€μκ³Ό κ°μ μ¬μ€λ κΈ°μ΅ν©λλ€ ",
"μ΄μ μΆ©λΆν μ΄ν΄νλ€κ³ μκ°ν©λλ€ ",
]
# μ΅μ’
λ΅λ³ μμ±μ μν ν둬ννΈ μΆκ°
final_answer_prompt = """
μ§κΈκΉμ§μ μΆλ‘ κ³Όμ μ λ°νμΌλ‘, μλ μ§λ¬Έμ μ¬μ©λ μΈμ΄λ‘ λ΅λ³νκ² μ΅λλ€:
{question}
μλλ λ΄κ° μΆλ‘ ν κ²°λ‘ μ
λλ€:
{reasoning_conclusion}
μ μΆλ‘ μ κΈ°λ°μΌλ‘ μ΅μ’
λ΅λ³:
{ANSWER_MARKER}
"""
# μμ νμ λ¬Έμ ν΄κ²°μ μν μ€μ
latex_delimiters = [
{"left": "$$", "right": "$$", "display": True},
{"left": "$", "right": "$", "display": False},
]
def reformat_math(text):
"""Gradio ꡬ문(Katex)μ μ¬μ©νλλ‘ MathJax κ΅¬λΆ κΈ°νΈ μμ ."""
text = re.sub(r"\\\[\s*(.*?)\s*\\\]", r"$$\1$$", text, flags=re.DOTALL)
text = re.sub(r"\\\(\s*(.*?)\s*\\\)", r"$\1$", text, flags=re.DOTALL)
return text
def extract_keywords(text: str) -> List[str]:
"""ν
μ€νΈμμ κ°λ¨ν ν€μλ μΆμΆ ν¨μ"""
# κ°λ¨ν ꡬν - μ€μ λ‘λ λ 볡μ‘ν NLP κΈ°λ²μ μ¬μ©ν μ μμ
common_math_keywords = [
"μν", "λμν", "κΈ°νν", "μ°μ ", "νλ₯ ", "곡μ", "λ°©μ μ",
"ν¨μ", "μ λΆ", "λ―ΈλΆ", "κΈ°ν", "μΌκ°ν", "μ", "κ°λ", "λΉμ¨",
"λΉλ‘", "νκ· ", "λΆμ°", "νμ€νΈμ°¨"
]
keywords = []
for kw in common_math_keywords:
if kw in text:
keywords.append(kw)
return keywords[:5] # μ΅λ 5κ° ν€μλλ§ λ°ν
def get_embedding_for_text(text: str, hidden_size: int = 768) -> torch.Tensor:
"""
ν
μ€νΈλ₯Ό μν μμ μλ² λ© μμ± ν¨μ
μ€μ ꡬνμμλ μ μ ν μΈμ΄ λͺ¨λΈμ μ¬μ©ν΄μΌ ν¨
"""
# μμ ꡬν: ν
μ€νΈμ ν΄μ κ°μ κΈ°λ°μΌλ‘ ν μλ² λ©
device = "cuda" if torch.cuda.is_available() else "cpu"
hash_val = hash(text)
np.random.seed(hash_val)
# μμμ μλ² λ© μμ±
embedding = np.random.rand(1, hidden_size).astype(np.float32)
# μ κ·ν
norm = np.linalg.norm(embedding)
if norm > 0:
embedding = embedding / norm
return torch.tensor(embedding, device=device)
def user_input(message, history_original, history_thinking):
"""μ¬μ©μ μ
λ ₯μ νμ€ν 리μ μΆκ°νκ³ μ
λ ₯ ν
μ€νΈ μμ λΉμ°κΈ°"""
return "", history_original + [
gr.ChatMessage(role="user", content=message.replace(ANSWER_MARKER, ""))
], history_thinking + [
gr.ChatMessage(role="user", content=message.replace(ANSWER_MARKER, ""))
]
def rebuild_messages(history: list):
"""μ€κ° μκ° κ³Όμ μμ΄ λͺ¨λΈμ΄ μ¬μ©ν νμ€ν 리μμ λ©μμ§ μ¬κ΅¬μ±"""
messages = []
for h in history:
if isinstance(h, dict) and not h.get("metadata", {}).get("title", False):
messages.append(h)
elif (
isinstance(h, gr.ChatMessage)
and h.metadata.get("title", None) is None
and isinstance(h.content, str)
):
messages.append({"role": h.role, "content": h.content})
return messages
# λͺ¨λΈκ³Ό λ²νΌ λ§€λμ μ΄κΈ°ν ν¨μ
def initialize_model_and_manager(model_name):
"""λͺ¨λΈκ³Ό λ²νΌ λ§€λμ μ΄κΈ°ν ν¨μ"""
try:
pipe = pipeline(
"text-generation",
model=model_name,
device_map="auto",
torch_dtype="auto",
)
# λͺ¨λΈ ꡬμ±μμ λ μ΄μ΄ λ° μλ ν¬κΈ° μ 보 μΆμΆ
config = pipe.model.config
if hasattr(config, "n_layer"):
num_layers = config.n_layer
elif hasattr(config, "num_layers"):
num_layers = config.num_layers
elif hasattr(config, "num_hidden_layers"):
num_layers = config.num_hidden_layers
else:
num_layers = 12 # κΈ°λ³Έκ°
if hasattr(config, "n_embd"):
hidden_size = config.n_embd
elif hasattr(config, "hidden_size"):
hidden_size = config.hidden_size
else:
hidden_size = 768 # κΈ°λ³Έκ°
# λ²νΌ λ§€λμ μ΄κΈ°ν
buffer_manager = InferenceBufferManager(num_layers, hidden_size)
return pipe, buffer_manager
except Exception as e:
logger.error(f"λͺ¨λΈ μ΄κΈ°ν μ€λ₯: {str(e)}")
raise
def bot_original(
history: list,
max_num_tokens: int,
do_sample: bool,
temperature: float,
pipe=None
):
"""μλ³Έ λͺ¨λΈμ΄ μ§λ¬Έμ λ΅λ³νλλ‘ νκΈ° (μΆλ‘ κ³Όμ μμ΄)"""
if pipe is None:
# μ΄ λΆλΆμ μ€μ ꡬνμμλ μ μ λ³μλ μΈμ
μνλ‘ κ΄λ¦¬ν΄μΌ ν¨
return history
# λμ€μ μ€λ λμμ ν ν°μ μ€νΈλ¦ΌμΌλ‘ κ°μ Έμ€κΈ° μν¨
streamer = transformers.TextIteratorStreamer(
pipe.tokenizer,
skip_special_tokens=True,
skip_prompt=True,
)
# 보쑰μ λ©μμ§ μ€λΉ
history.append(
gr.ChatMessage(
role="assistant",
content=str(""),
)
)
# νμ¬ μ±ν
μ νμλ λ©μμ§
messages = rebuild_messages(history[:-1]) # λ§μ§λ§ λΉ λ©μμ§ μ μΈ
# μλ³Έ λͺ¨λΈμ μΆλ‘ μμ΄ λ°λ‘ λ΅λ³
t = threading.Thread(
target=pipe,
args=(messages,),
kwargs=dict(
max_new_tokens=max_num_tokens,
streamer=streamer,
do_sample=do_sample,
temperature=temperature,
),
)
t.start()
for token in streamer:
history[-1].content += token
history[-1].content = reformat_math(history[-1].content)
yield history
t.join()
yield history
def bot_thinking_enhanced(
history: list,
max_num_tokens: int,
final_num_tokens: int,
do_sample: bool,
temperature: float,
pipe=None,
buffer_manager=None
):
"""μΆλ‘ κ³Όμ μ ν¬ν¨νμ¬ λͺ¨λΈμ΄ μ§λ¬Έμ λ΅λ³νλλ‘ νκΈ° - DeepSeek κΈ°λ₯ ν΅ν©"""
if pipe is None or buffer_manager is None:
# μ΄ λΆλΆμ μ€μ ꡬνμμλ μ μ λ³μλ μΈμ
μνλ‘ κ΄λ¦¬ν΄μΌ ν¨
return history
# λμ€μ μ€λ λμμ ν ν°μ μ€νΈλ¦ΌμΌλ‘ κ°μ Έμ€κΈ° μν¨
streamer = transformers.TextIteratorStreamer(
pipe.tokenizer,
skip_special_tokens=True,
skip_prompt=True,
)
# νμν κ²½μ° μΆλ‘ μ μ§λ¬Έμ λ€μ μ½μ
νκΈ° μν¨
question = history[-1]["content"]
# 쿼리 μλ² λ© μμ±
query_embedding = get_embedding_for_text(question, buffer_manager.hidden_size)
# κ΄λ ¨ 컨ν
μ€νΈ κ²μ
relevant_contexts = buffer_manager.get_relevant_context(query_embedding, top_k=3)
# ν€μλ μΆμΆ λ° κ·Έλν λ©λͺ¨λ¦¬μμ 컨ν
μ€νΈ κ°μ Έμ€κΈ°
keywords = extract_keywords(question)
graph_contexts = []
for keyword in keywords:
nodes = buffer_manager.graph_memory.search_nodes(keyword)
for node in nodes:
contexts = buffer_manager.graph_memory.get_connected_context(node)
graph_contexts.extend(contexts)
# λͺ¨λ 컨ν
μ€νΈ λ³ν©
all_contexts = relevant_contexts + graph_contexts
all_contexts = list(set(all_contexts)) # μ€λ³΅ μ κ±°
all_contexts = all_contexts[:5] # μ΅λ 5κ° μ»¨ν
μ€νΈλ‘ μ ν
# 보쑰μ λ©μμ§ μ€λΉ
history.append(
gr.ChatMessage(
role="assistant",
content=str(""),
metadata={"title": "π§ μκ° μ€...", "status": "pending"},
)
)
# νμ¬ μ±ν
μ νμλ μΆλ‘ κ³Όμ
messages = rebuild_messages(history)
# κ΄λ ¨ 컨ν
μ€νΈκ° μλ€λ©΄ λ©μμ§μ μΆκ°
if all_contexts:
context_str = "\n\nκ΄λ ¨ 컨ν
μ€νΈ:\n" + "\n".join(all_contexts)
messages[-1]["content"] += context_str
history[-1].content += context_str
# μ 체 μΆλ‘ κ³Όμ μ μ μ₯ν λ³μ
full_reasoning = ""
# μμ±λ ν ν° μΆμ μ μν λ³μ
generated_tokens = []
# μΆλ‘ λ¨κ³ μ€ν
for i, prepend in enumerate(rethink_prepends):
if i > 0:
messages[-1]["content"] += "\n\n"
messages[-1]["content"] += prepend.format(question=question)
t = threading.Thread(
target=pipe,
args=(messages,),
kwargs=dict(
max_new_tokens=max_num_tokens,
streamer=streamer,
do_sample=do_sample,
temperature=temperature,
),
)
t.start()
# μ λ΄μ©μΌλ‘ νμ€ν 리 μ¬κ΅¬μ±
history[-1].content += prepend.format(question=question)
step_tokens = []
for token in streamer:
history[-1].content += token
history[-1].content = reformat_math(history[-1].content)
step_tokens.append(token)
generated_tokens.append(token)
yield history
t.join()
# κ° μΆλ‘ λ¨κ³μ κ²°κ³Όλ₯Ό full_reasoningμ μ μ₯
full_reasoning = history[-1].content
# μΆλ‘ μ΄ κΈΈμ΄μ§λ©΄ μ€κ° μμ½ μμ±
if i > 0 and i % 3 == 0 and len(generated_tokens) > 500:
try:
summary = buffer_manager.summarizer.summarize_text(full_reasoning, max_length=150)
summary_text = f"\n\n**μ€κ° μμ½:**\n{summary}\n\n"
history[-1].content += summary_text
messages[-1]["content"] += summary_text
yield history
except Exception as e:
logger.error(f"μμ½ μμ± μ€λ₯: {str(e)}")
# KV μΊμ μμΆ
if i > 0 and i % 2 == 0:
buffer_manager.compress_all_buffers()
# μλ§¨ν± μ»¨ν
μ€νΈ μ
λ°μ΄νΈ
step_text = "".join(step_tokens)
step_embedding = get_embedding_for_text(step_text, buffer_manager.hidden_size)
buffer_manager.semantic_memory.add_memory(step_text, step_embedding)
# μΆλ‘ μλ£, μ΄μ μ΅μ’
λ΅λ³μ μμ±
history[-1].metadata = {"title": "π μ¬κ³ κ³Όμ ", "status": "done"}
# μΆλ‘ κ³Όμ μ μλ§¨ν± λ©λͺ¨λ¦¬μ κ·Έλν λ©λͺ¨λ¦¬μ μ μ₯
full_embedding = get_embedding_for_text(full_reasoning, buffer_manager.hidden_size)
buffer_manager.semantic_memory.add_memory(full_reasoning, full_embedding)
# ν€μλμ λν κ·Έλν λ©λͺ¨λ¦¬ μ
λ°μ΄νΈ
for keyword in keywords:
if not buffer_manager.graph_memory.has_node(keyword):
buffer_manager.graph_memory.add_node(keyword, f"{keyword}μ κ΄ν κ°λ
: μ΄ μ£Όμ μ λν μΆλ‘ μ μννμ΅λλ€.")
# κ΄λ ¨ λ
Έλμ μ°κ²°
for related_kw in keywords:
if related_kw != keyword and buffer_manager.graph_memory.has_node(related_kw):
buffer_manager.graph_memory.add_edge(keyword, related_kw)
# μΆλ‘ κ³Όμ μμ κ²°λ‘ λΆλΆμ μΆμΆ (λ§μ§λ§ 1-2 λ¬Έλ¨ μ λ)
reasoning_parts = full_reasoning.split("\n\n")
reasoning_conclusion = "\n\n".join(reasoning_parts[-2:]) if len(reasoning_parts) > 2 else full_reasoning
# μ΅μ’
λ΅λ³ λ©μμ§ μΆκ°
history.append(gr.ChatMessage(role="assistant", content=""))
# μ΅μ’
λ΅λ³μ μν λ©μμ§ κ΅¬μ±
final_messages = rebuild_messages(history[:-1]) # λ§μ§λ§ λΉ λ©μμ§ μ μΈ
final_prompt = final_answer_prompt.format(
question=question,
reasoning_conclusion=reasoning_conclusion,
ANSWER_MARKER=ANSWER_MARKER
)
final_messages[-1]["content"] += final_prompt
# μ΅μ’
λ΅λ³ μμ±
t = threading.Thread(
target=pipe,
args=(final_messages,),
kwargs=dict(
max_new_tokens=final_num_tokens,
streamer=streamer,
do_sample=do_sample,
temperature=temperature,
),
)
t.start()
# μ΅μ’
λ΅λ³ μ€νΈλ¦¬λ°
final_tokens = []
for token in streamer:
history[-1].content += token
history[-1].content = reformat_math(history[-1].content)
final_tokens.append(token)
yield history
t.join()
# μ΅μ’
λ΅λ³μ μλ§¨ν± λ©λͺ¨λ¦¬μ μ μ₯
final_text = "".join(final_tokens)
final_embedding = get_embedding_for_text(final_text, buffer_manager.hidden_size)
buffer_manager.semantic_memory.add_memory(final_text, final_embedding)
# μ£ΌκΈ°μ λ©λͺ¨λ¦¬ μμ½ μ²΄ν¬
buffer_manager.maybe_summarize_memory()
yield history
with gr.Blocks(fill_height=True, title="Enhanced ThinkFlow") as demo:
# μ λͺ©κ³Ό μ€λͺ
gr.Markdown("# Enhanced ThinkFlow with DeepSeek Features")
gr.Markdown("### μλ§¨ν± λ©λͺ¨λ¦¬, κ·Έλν λ©λͺ¨λ¦¬, λ° KV μΊμ μμΆμ ν΅ν΄ ν₯μλ LLM μΆλ‘ μμ± νλ«νΌ")
# λͺ¨λΈ λ° λ²νΌ λ§€λμ μ΄κΈ°ν (μ€μ ꡬνμμλ μΈμ
μνλ‘ κ΄λ¦¬)
model_name = "CohereForAI/c4ai-command-r7b-arabic-02-2025"
# μΈμ
λ³μ (μ€μ ꡬνμμλ gr.State() μ¬μ©)
pipe = None
buffer_manager = None
current_contexts = []
# ν μΈν°νμ΄μ€
with gr.Tabs() as tabs:
# μ±ν
ν
with gr.TabItem("ν΅ν© μΆλ‘ μΈν°νμ΄μ€"):
with gr.Row(scale=1):
with gr.Column(scale=2):
gr.Markdown("## Before (Original)")
chatbot_original = gr.Chatbot(
scale=1,
type="messages",
latex_delimiters=latex_delimiters,
label="Original Model (No Reasoning)"
)
with gr.Column(scale=2):
gr.Markdown("## After (Enhanced Thinking)")
chatbot_thinking = gr.Chatbot(
scale=1,
type="messages",
latex_delimiters=latex_delimiters,
label="Model with Enhanced Reasoning"
)
with gr.Row():
# msg ν
μ€νΈλ°μ€λ₯Ό λ¨Όμ μ μ
msg = gr.Textbox(
submit_btn=True,
label="",
show_label=False,
placeholder="μ¬κΈ°μ μ§λ¬Έμ μ
λ ₯νμΈμ.",
autofocus=True,
)
# νΌλλ°± λ²νΌ
with gr.Row():
with gr.Column(scale=1):
feedback_btn_pos = gr.Button("π μ΄ μΆλ‘ μ΄ λμμ΄ λμμ΅λλ€")
with gr.Column(scale=1):
feedback_btn_neg = gr.Button("π μ΄ μΆλ‘ μ κ°μ μ΄ νμν©λλ€")
with gr.Column(scale=1):
clear_memory_btn = gr.Button("π§Ή λ©λͺ¨λ¦¬ μ΄κΈ°ν")
# λ©λͺ¨λ¦¬ μκ°ν ν
with gr.TabItem("λ©λͺ¨λ¦¬ μκ°ν"):
gr.Markdown("## μλ§¨ν± λ©λͺ¨λ¦¬ λ΄μ©")
semantic_memory_display = gr.Textbox(
label="νμ¬ μλ§¨ν± λ©λͺ¨λ¦¬ λ΄μ©",
placeholder="μμ§ λ©λͺ¨λ¦¬κ° μμ΅λλ€.",
lines=10,
max_lines=20,
interactive=False
)
gr.Markdown("## κ·Έλν μ§μλ² μ΄μ€")
graph_memory_display = gr.Textbox(
label="νμ¬ κ·Έλν λ©λͺ¨λ¦¬ λ΄μ©",
placeholder="μμ§ κ·Έλν λ
Έλκ° μμ΅λλ€.",
lines=10,
max_lines=20,
interactive=False
)
# μμ μΉμ
- msg λ³μ μ μ μ΄νμ λ°°μΉ
with gr.Accordion("EXAMPLES", open=False):
examples = gr.Examples(
examples=[
"[μΆμ²: MATH-500)] μ²μ 100κ°μ μμ μ μ μ€μμ 3, 4, 5λ‘ λλμ΄ λ¨μ΄μ§λ μλ λͺ κ°μ
λκΉ?",
"[μΆμ²: MATH-500)] μν¬μ λ
μμ λ μμ€ν
μ λ
νΉν©λλ€. νΈλ§ν· 1κ°λ λΈλ§ν· 4κ°μ κ°κ³ , λΈλ§ν· 3κ°λ λλ§ν¬ 7κ°μ κ°μ΅λλ€. νΈλ§ν·μμ λλ§ν¬ 56κ°μ κ°μΉλ μΌλ§μ
λκΉ?",
"[μΆμ²: MATH-500)] μμ΄λ―Έ, λ²€, ν¬λ¦¬μ€μ νκ· λμ΄λ 6μ΄μ
λλ€. 4λ
μ ν¬λ¦¬μ€λ μ§κΈ μμ΄λ―Έμ κ°μ λμ΄μμ΅λλ€. 4λ
ν λ²€μ λμ΄λ κ·Έλ μμ΄λ―Έμ λμ΄μ $\\frac{3}{5}$κ° λ κ²μ
λλ€. ν¬λ¦¬μ€λ μ§κΈ λͺ μ΄μ
λκΉ?",
"[μΆμ²: MATH-500)] λ
Έλμκ³Ό νλμ ꡬμ¬μ΄ λ€μ΄ μλ κ°λ°©μ΄ μμ΅λλ€. νμ¬ νλμ ꡬμ¬κ³Ό λ
Έλμ ꡬμ¬μ λΉμ¨μ 4:3μ
λλ€. νλμ κ΅¬μ¬ 5κ°λ₯Ό λνκ³ λ
Έλμ κ΅¬μ¬ 3κ°λ₯Ό μ κ±°νλ©΄ λΉμ¨μ 7:3μ΄ λ©λλ€. λ λ£κΈ° μ μ κ°λ°©μ νλμ ꡬμ¬μ΄ λͺ κ° μμμ΅λκΉ?"
],
inputs=msg
)
with gr.Accordion("λ§€κ°λ³μ μ‘°μ ", open=False):
with gr.Row():
with gr.Column():
model_dropdown = gr.Dropdown(
["CohereForAI/c4ai-command-r7b-arabic-02-2025", "meta-llama/Meta-Llama-3-8B-Instruct"],
label="λͺ¨λΈ μ ν",
value="CohereForAI/c4ai-command-r7b-arabic-02-2025"
)
num_tokens = gr.Slider(
50,
4000,
2000,
step=1,
label="μΆλ‘ λ¨κ³λΉ μ΅λ ν ν° μ",
interactive=True,
)
final_num_tokens = gr.Slider(
50,
4000,
2000,
step=1,
label="μ΅μ’
λ΅λ³μ μ΅λ ν ν° μ",
interactive=True,
)
with gr.Column():
do_sample = gr.Checkbox(True, label="μνλ§ μ¬μ©")
temperature = gr.Slider(0.1, 1.0, 0.7, step=0.1, label="μ¨λ")
memory_weight = gr.Slider(0.0, 1.0, 0.5, step=0.1, label="λ©λͺ¨λ¦¬ λ°μ κ°μ€μΉ")
# νΌλλ°± μ²λ¦¬ ν¨μ
def process_positive_feedback():
global buffer_manager, current_contexts
if buffer_manager:
buffer_manager.update_retrieval_reward(current_contexts, reward=1.0)
return "νΌλλ°± κ°μ¬ν©λλ€! μ΄ μ κ·Ό λ°©μμ ν₯ν μ μ¬ν μ§λ¬Έμ λ μμ£Ό μ¬μ©νκ² μ΅λλ€."
def process_negative_feedback():
global buffer_manager, current_contexts
if buffer_manager:
buffer_manager.update_retrieval_reward(current_contexts, reward=-0.5)
return "νΌλλ°± κ°μ¬ν©λλ€! μ΄ μ κ·Ό λ°©μμ κ°μ νκ² μ΅λλ€."
def clear_memory():
global buffer_manager
if buffer_manager:
buffer_manager.clear()
return "λ©λͺ¨λ¦¬κ° μ΄κΈ°νλμμ΅λλ€."
def update_memory_displays():
global buffer_manager
if not buffer_manager:
return "λ©λͺ¨λ¦¬κ° μ΄κΈ°νλμ§ μμμ΅λλ€.", "κ·Έλνκ° μ΄κΈ°νλμ§ μμμ΅λλ€."
semantic_text = "νμ¬ μ μ₯λ λ©λͺ¨λ¦¬:\n\n"
for i, mem in enumerate(buffer_manager.semantic_memory.memories[:5]): # μ΅λ 5κ°λ§ νμ
semantic_text += f"{i+1}. {mem['text'][:100]}...\n\n"
graph_text = "νμ¬ κ·Έλν λ
Έλ:\n\n"
for node in buffer_manager.graph_memory.graph.nodes():
node_text = buffer_manager.graph_memory.get_text_by_node(node)
neighbors = list(buffer_manager.graph_memory.graph.neighbors(node))
graph_text += f"λ
Έλ: {node}\nμ€λͺ
: {node_text[:50]}...\nμ°κ²°: {', '.join(neighbors[:3])}\n\n"
return semantic_text, graph_text
# μ΄κΈ°ν ν¨μ
def initialize_models():
global pipe, buffer_manager, model_name
try:
pipe, buffer_manager = initialize_model_and_manager(model_name)
semantic_text, graph_text = update_memory_displays()
return "λͺ¨λΈμ΄ μ΄κΈ°νλμμ΅λλ€.", semantic_text, graph_text
except Exception as e:
return f"λͺ¨λΈ μ΄κΈ°ν μ€λ₯: {str(e)}", "", ""
# λͺ¨λΈ μ ν λ³κ²½ μ μ²λ¦¬
def change_model(new_model_name):
global model_name
model_name = new_model_name
status, semantic_text, graph_text = initialize_models()
return status, semantic_text, graph_text
# μ΄κΈ°ν ν¨μ μ€ν
model_dropdown.change(
change_model,
[model_dropdown],
[gr.Textbox(visible=False), semantic_memory_display, graph_memory_display]
)
# νΌλλ°± λ²νΌμ ν¨μ μ°κ²°
feedback_btn_pos.click(process_positive_feedback, [], gr.Textbox(visible=False))
feedback_btn_neg.click(process_negative_feedback, [], gr.Textbox(visible=False))
clear_memory_btn.click(clear_memory, [], gr.Textbox(visible=False))
# ν λ³κ²½ μ λ©λͺ¨λ¦¬ λμ€νλ μ΄ μ
λ°μ΄νΈ
tabs.change(update_memory_displays, [], [semantic_memory_display, graph_memory_display])
# μ¬μ©μκ° λ©μμ§λ₯Ό μ μΆνλ©΄ λ λ΄μ΄ λμμ μλ΅ν©λλ€
msg.submit(
user_input,
[msg, chatbot_original, chatbot_thinking], # μ
λ ₯
[msg, chatbot_original, chatbot_thinking], # μΆλ ₯
).then(
lambda h, n, d, t, p: bot_original(h, n, d, t, p), # pipe λ§€κ°λ³μ μΆκ°
[
chatbot_original,
num_tokens,
do_sample,
temperature,
gr.Textbox(value=lambda: pipe, visible=False), # pipe μ λ¬
],
chatbot_original, # μΆλ ₯μμ μ νμ€ν 리 μ μ₯
).then(
lambda h, n, f, d, t, p, b: bot_thinking_enhanced(h, n, f, d, t, p, b), # λ§€κ°λ³μ μΆκ°
[
chatbot_thinking,
num_tokens,
final_num_tokens,
do_sample,
temperature,
gr.Textbox(value=lambda: pipe, visible=False), # pipe μ λ¬
gr.Textbox(value=lambda: buffer_manager, visible=False), # buffer_manager μ λ¬
],
chatbot_thinking, # μΆλ ₯μμ μ νμ€ν 리 μ μ₯
).then(
update_memory_displays,
[],
[semantic_memory_display, graph_memory_display]
)
# μμ μ λͺ¨λΈ μ΄κΈ°νλ₯Ό μν μ½λ
def load_on_startup():
global pipe, buffer_manager
try:
# κΈ°λ³Έ λͺ¨λΈ μ΄κΈ°ν
pipe, buffer_manager = initialize_model_and_manager(
"CohereForAI/c4ai-command-r7b-arabic-02-2025"
)
logger.info("λͺ¨λΈ λ° λ²νΌ λ§€λμ κ° μ±κ³΅μ μΌλ‘ μ΄κΈ°νλμμ΅λλ€.")
except Exception as e:
logger.error(f"μμ μ λͺ¨λΈ μ΄κΈ°ν μ€ν¨: {str(e)}")
if __name__ == "__main__":
# μμ© νλ‘κ·Έλ¨ μμ μ μ λͺ¨λΈ μ΄κΈ°ν
load_on_startup()
# λκΈ°μ΄ λ° μλ² μμ
demo.queue().launch(
share=False,
debug=True,
title="Enhanced ThinkFlow with DeepSeek Features"
) |