ThinkFlow-llama / app.py
openfree's picture
Update app.py
04bc27d verified
raw
history blame
39.2 kB
import re
import threading
import time
import os
import logging
from datetime import datetime
import torch
import numpy as np
from typing import List, Optional, Tuple, Dict
import networkx as nx
import gradio as gr
import transformers
from transformers import (
pipeline,
AutoModelForCausalLM,
AutoTokenizer,
BartForConditionalGeneration,
BartTokenizer,
BitsAndBytesConfig
)
# λ‘œκΉ… μ„€μ •
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# ===================== RLRetrievalPolicy =====================
class RLRetrievalPolicy:
def __init__(self):
self.policy_data = {}
self.alpha = 0.5 # μœ μ‚¬λ„ vs. RL 점수 κ°„ κ°€μ€‘μΉ˜
def update_policy(self, contexts: List[str], reward: float):
for ctx in contexts:
if ctx not in self.policy_data:
self.policy_data[ctx] = 0.0
self.policy_data[ctx] += reward
def re_rank(self, candidates: List[Tuple[float, str]]) -> List[str]:
reweighted = []
for sim, txt in candidates:
rl_score = self.policy_data.get(txt, 0.0)
reweighted_score = sim * (1 - self.alpha) + rl_score * self.alpha
reweighted.append((reweighted_score, txt))
reweighted.sort(key=lambda x: x[0], reverse=True)
return [t for _, t in reweighted]
# ===================== GraphMemory =====================
class GraphMemory:
def __init__(self):
self.graph = nx.DiGraph()
# μˆ˜ν•™ 문제 해결에 도움이 λ˜λŠ” κΈ°λ³Έ λ…Έλ“œ μΆ”κ°€
self.add_node("μˆ˜ν•™", "μˆ˜ν•™ 문제 해결을 μœ„ν•œ 일반적인 접근법")
self.add_node("λŒ€μˆ˜ν•™", "방정식, ν•¨μˆ˜, λΉ„λ‘€ 관계 등을 λ‹€λ£¨λŠ” μˆ˜ν•™μ˜ ν•œ λΆ„μ•Ό")
self.add_node("κΈ°ν•˜ν•™", "곡간, λ„ν˜•, 각도 등을 λ‹€λ£¨λŠ” μˆ˜ν•™μ˜ ν•œ λΆ„μ•Ό")
self.add_node("μ‚°μˆ ", "기본적인 수 μ—°μ‚°, λΉ„μœ¨, λ°±λΆ„μœ¨ 등을 λ‹€λ£¨λŠ” λΆ„μ•Ό")
self.add_node("ν™•λ₯ ", "μ‚¬κ±΄μ˜ λ°œμƒ κ°€λŠ₯성을 μΈ‘μ •ν•˜λŠ” μˆ˜ν•™μ˜ ν•œ λΆ„μ•Ό")
# 관계 μ„€μ •
self.add_edge("λŒ€μˆ˜ν•™", "μˆ˜ν•™")
self.add_edge("κΈ°ν•˜ν•™", "μˆ˜ν•™")
self.add_edge("μ‚°μˆ ", "μˆ˜ν•™")
self.add_edge("ν™•λ₯ ", "μˆ˜ν•™")
def add_node(self, node_id: str, text: str = ""):
self.graph.add_node(node_id, text=text)
def add_edge(self, src: str, dst: str):
self.graph.add_edge(src, dst)
def get_text_by_node(self, node_id: str) -> str:
return self.graph.nodes[node_id].get('text', "")
def has_node(self, node_id: str) -> bool:
return node_id in self.graph.nodes
def search_nodes(self, keyword: str, max_nodes: int = 3) -> List[str]:
matches = []
for n in self.graph.nodes():
node_text = self.get_text_by_node(n).lower()
n_lower = n.lower()
if keyword.lower() in node_text or keyword.lower() in n_lower:
score = node_text.count(keyword.lower()) + n_lower.count(keyword.lower())
matches.append((score, n))
matches.sort(key=lambda x: x[0], reverse=True)
top_nodes = [m[1] for m in matches[:max_nodes]]
return top_nodes
def get_connected_context(self, start_node: str, steps: int = 1) -> List[str]:
contexts = []
visited = set()
queue = [(start_node, 0)]
while queue:
current, depth = queue.pop(0)
if current not in visited:
visited.add(current)
contexts.append(self.get_text_by_node(current))
if depth < steps:
for neighbor in self.graph.successors(current):
queue.append((neighbor, depth + 1))
for neighbor in self.graph.predecessors(current):
queue.append((neighbor, depth + 1))
return contexts
# ===================== SimpleSummarizer =====================
class SimpleSummarizer:
def __init__(self, model_name="facebook/bart-large-cnn"):
self.model_name = model_name
self.model = None
self.tokenizer = None
def load_summarization_model(self):
if self.model is None:
try:
self.tokenizer = BartTokenizer.from_pretrained(self.model_name)
self.model = BartForConditionalGeneration.from_pretrained(self.model_name)
if torch.cuda.is_available():
self.model = self.model.cuda()
except Exception as e:
logger.error(f"Error loading summarization model: {str(e)}")
raise
def summarize_text(self, text: str, max_length: int = 100) -> str:
try:
self.load_summarization_model()
inputs = self.tokenizer([text], max_length=1024, return_tensors='pt', truncation=True)
if torch.cuda.is_available():
inputs = {k: v.cuda() for k, v in inputs.items()}
with torch.no_grad():
summary_ids = self.model.generate(
inputs["input_ids"],
num_beams=4,
max_length=max_length,
early_stopping=True
)
summary = self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)
return summary
except Exception as e:
logger.error(f"Error in summarization: {str(e)}")
return "μš”μ•½μ„ 생성할 수 μ—†μŠ΅λ‹ˆλ‹€."
# ===================== SemanticMemory =====================
class SemanticMemory:
def __init__(self, max_entries: int = 4000):
self.memories: List[dict] = []
self.max_entries = max_entries
self.rl_policy = RLRetrievalPolicy()
def add_memory(self, text: str, embedding: torch.Tensor):
if len(self.memories) >= self.max_entries:
self.memories.pop(0)
self.memories.append({
'text': text,
'embedding': embedding,
'timestamp': time.time()
})
def get_candidates(self, query_embedding: torch.Tensor) -> List[Tuple[float, str]]:
candidates = []
for mem in self.memories:
if mem['embedding'].shape == query_embedding.shape:
sim = torch.cosine_similarity(
query_embedding.float(),
mem['embedding'].float(),
dim=-1
)
candidates.append((sim.item(), mem['text']))
candidates.sort(key=lambda x: x[0], reverse=True)
return candidates
def get_relevant_context(self, query_embedding: torch.Tensor, top_k: int = 3) -> List[str]:
candidates = self.get_candidates(query_embedding)
re_ranked = self.rl_policy.re_rank(candidates)
return re_ranked[:top_k]
def update_retrieval_reward(self, texts: List[str], reward: float):
self.rl_policy.update_policy(texts, reward)
def clear(self):
self.memories = []
# ===================== GenericInferenceBuffer =====================
MAX_TOKEN_BUFFER = 1024
class GenericInferenceBuffer:
def __init__(self, layer_idx: int, compression_rank: int = 128):
self.layer_idx = layer_idx
self.key_buffer: Optional[torch.Tensor] = None
self.value_buffer: Optional[torch.Tensor] = None
self.semantic_context: Optional[torch.Tensor] = None
self.last_update: float = 0
self.compression_rank = compression_rank
def update_buffer(
self,
key: torch.Tensor,
value: torch.Tensor,
semantic_context: Optional[torch.Tensor] = None
):
try:
if self.key_buffer is None:
self.key_buffer = key.detach().clone()
self.value_buffer = value.detach().clone()
if semantic_context is not None:
self.semantic_context = semantic_context.detach().clone()
else:
self.key_buffer = torch.cat([self.key_buffer, key.detach()], dim=2)
self.value_buffer = torch.cat([self.value_buffer, value.detach()], dim=2)
if semantic_context is not None and self.semantic_context is not None:
self.semantic_context = torch.cat([self.semantic_context, semantic_context.detach()], dim=0)
if self.key_buffer.shape[2] > MAX_TOKEN_BUFFER:
excess = self.key_buffer.shape[2] - MAX_TOKEN_BUFFER
self.key_buffer = self.key_buffer[:, :, excess:, :]
self.value_buffer = self.value_buffer[:, :, excess:, :]
if self.semantic_context is not None:
self.semantic_context = self.semantic_context[excess:, :]
self.last_update = time.time()
except Exception as e:
logger.error(f"Buffer update error in layer {self.layer_idx}: {str(e)}")
def compress_buffer_svd(self):
if self.key_buffer is None or self.value_buffer is None:
return
try:
k_shape = self.key_buffer.shape
v_shape = self.value_buffer.shape
k_2d = self.key_buffer.reshape(k_shape[0]*k_shape[1], k_shape[2]*k_shape[3]).float()
v_2d = self.value_buffer.reshape(v_shape[0]*v_shape[1], v_shape[2]*v_shape[3]).float()
device = k_2d.device
k_2d_cpu = k_2d.cpu()
v_2d_cpu = v_2d.cpu()
U_k, S_k, V_k = torch.linalg.svd(k_2d_cpu, full_matrices=False)
U_v, S_v, V_v = torch.linalg.svd(v_2d_cpu, full_matrices=False)
rank_k = min(self.compression_rank, S_k.shape[0])
rank_v = min(self.compression_rank, S_v.shape[0])
k_approx = (U_k[:, :rank_k] * S_k[:rank_k]) @ V_k[:rank_k, :]
v_approx = (U_v[:, :rank_v] * S_v[:rank_v]) @ V_v[:rank_v, :]
k_approx = k_approx.to(device)
v_approx = v_approx.to(device)
self.key_buffer = k_approx.reshape(k_shape).type(self.key_buffer.dtype)
self.value_buffer = v_approx.reshape(v_shape).type(self.value_buffer.dtype)
except Exception as e:
logger.error(f"SVD compression error in layer {self.layer_idx}: {str(e)}")
def get_buffer(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
return self.key_buffer, self.value_buffer, self.semantic_context
def clear(self):
self.key_buffer = None
self.value_buffer = None
self.semantic_context = None
self.last_update = 0
# ===================== InferenceBufferManager =====================
class InferenceBufferManager:
def __init__(self, num_layers: int, hidden_size: int):
self.num_layers = num_layers
self.hidden_size = hidden_size
self.layer_buffers = [
GenericInferenceBuffer(i, compression_rank=128) for i in range(num_layers)
]
self.semantic_memory = SemanticMemory()
self.graph_memory = GraphMemory()
self.summarizer = SimpleSummarizer()
self.summarize_threshold = 1500
self.generated_tokens_count = 0
self.compression_interval = 512
self.token_count_since_compress = 0
def _compute_semantic_embedding(self, key: Optional[torch.Tensor], value: Optional[torch.Tensor]) -> torch.Tensor:
device = "cuda" if torch.cuda.is_available() else "cpu"
if key is None or value is None:
return torch.zeros((1, self.hidden_size), dtype=torch.float32, device=device)
combined = key * value
combined = combined.mean(dim=2)
combined = combined.reshape(combined.shape[0], -1)
combined = torch.nn.functional.normalize(combined, dim=-1)
return combined
def update_buffer(self, layer_outputs, current_tokens: List[int], semantic_context: torch.Tensor, tokenizer):
try:
if hasattr(layer_outputs, 'past_key_values'):
for layer_idx, past_kv in enumerate(layer_outputs.past_key_values):
if isinstance(past_kv, tuple) and len(past_kv) == 2:
key, value = past_kv
if key is not None and value is not None:
self.layer_buffers[layer_idx].update_buffer(
key.detach(),
value.detach(),
semantic_context
)
self.generated_tokens_count += len(current_tokens)
self.token_count_since_compress += len(current_tokens)
if self.token_count_since_compress >= self.compression_interval:
self.compress_all_buffers()
self.token_count_since_compress = 0
except Exception as e:
logger.error(f"Buffer update error: {str(e)}")
def compress_all_buffers(self):
for buf in self.layer_buffers:
buf.compress_buffer_svd()
def finalize_semantic_memory(self, tokenizer, generated_tokens: List[int]):
if self.layer_buffers and len(self.layer_buffers) > 0 and self.layer_buffers[-1].key_buffer is not None:
text_chunk = tokenizer.decode(generated_tokens, skip_special_tokens=True)
key_buffer = self.layer_buffers[-1].key_buffer
value_buffer = self.layer_buffers[-1].value_buffer
embedding = self._compute_semantic_embedding(key_buffer, value_buffer)
self.semantic_memory.add_memory(text_chunk, embedding)
def get_relevant_context(self, query_embedding: torch.Tensor, top_k: int = 3) -> List[str]:
candidates_sem = self.semantic_memory.get_candidates(query_embedding)
# ν‚€μ›Œλ“œ μΆ”μΆœ (κ°„λ‹¨ν•œ κ΅¬ν˜„)
possible_keywords = ["μˆ˜ν•™", "λŒ€μˆ˜ν•™", "κΈ°ν•˜ν•™", "μ‚°μˆ ", "ν™•λ₯ "]
text_candidates = []
for kw in possible_keywords:
nodes = self.graph_memory.search_nodes(kw)
for n in nodes:
context_list = self.graph_memory.get_connected_context(n, steps=1)
cscore = 1.0
for ctxt in context_list:
text_candidates.append((cscore, ctxt))
merged_candidates = candidates_sem + text_candidates
re_ranked = self.semantic_memory.rl_policy.re_rank(merged_candidates)
return re_ranked[:top_k]
def update_retrieval_reward(self, contexts: List[str], reward: float):
self.semantic_memory.update_retrieval_reward(contexts, reward)
def maybe_summarize_memory(self):
if self.generated_tokens_count < self.summarize_threshold:
return
all_text = "\n".join([m['text'] for m in self.semantic_memory.memories])
if len(all_text) < 300:
return
summary = self.summarizer.summarize_text(all_text, max_length=120)
device = "cuda" if torch.cuda.is_available() else "cpu"
summary_embedding = torch.zeros((1, self.hidden_size), dtype=torch.float32, device=device)
self.semantic_memory.clear()
self.semantic_memory.add_memory(summary, summary_embedding)
self.generated_tokens_count = 0
def clear(self):
for layer in self.layer_buffers:
layer.clear()
self.semantic_memory.clear()
# ===================== Enhanced ThinkFlow Implementation =====================
# μ΅œμ’… 닡변을 κ°μ§€ν•˜κΈ° μœ„ν•œ 마컀
ANSWER_MARKER = "**λ‹΅λ³€**"
# 단계별 좔둠을 μ‹œμž‘ν•˜λŠ” λ¬Έμž₯λ“€
rethink_prepends = [
"자, 이제 λ‹€μŒμ„ νŒŒμ•…ν•΄μ•Ό ν•©λ‹ˆλ‹€ ",
"제 μƒκ°μ—λŠ” ",
"μž μ‹œλ§Œμš”, 제 μƒκ°μ—λŠ” ",
"λ‹€μŒ 사항이 λ§žλŠ”μ§€ 확인해 λ³΄κ² μŠ΅λ‹ˆλ‹€ ",
"λ˜ν•œ κΈ°μ–΅ν•΄μ•Ό ν•  것은 ",
"또 λ‹€λ₯Έ μ£Όλͺ©ν•  점은 ",
"그리고 μ €λŠ” λ‹€μŒκ³Ό 같은 사싀도 κΈ°μ–΅ν•©λ‹ˆλ‹€ ",
"이제 μΆ©λΆ„νžˆ μ΄ν•΄ν–ˆλ‹€κ³  μƒκ°ν•©λ‹ˆλ‹€ ",
]
# μ΅œμ’… λ‹΅λ³€ 생성을 μœ„ν•œ ν”„λ‘¬ν”„νŠΈ μΆ”κ°€
final_answer_prompt = """
μ§€κΈˆκΉŒμ§€μ˜ μΆ”λ‘  과정을 λ°”νƒ•μœΌλ‘œ, μ›λž˜ μ§ˆλ¬Έμ— μ‚¬μš©λœ μ–Έμ–΄λ‘œ λ‹΅λ³€ν•˜κ² μŠ΅λ‹ˆλ‹€:
{question}
μ•„λž˜λŠ” λ‚΄κ°€ μΆ”λ‘ ν•œ κ²°λ‘ μž…λ‹ˆλ‹€:
{reasoning_conclusion}
μœ„ 좔둠을 기반으둜 μ΅œμ’… λ‹΅λ³€:
{ANSWER_MARKER}
"""
# μˆ˜μ‹ ν‘œμ‹œ 문제 해결을 μœ„ν•œ μ„€μ •
latex_delimiters = [
{"left": "$$", "right": "$$", "display": True},
{"left": "$", "right": "$", "display": False},
]
def reformat_math(text):
"""Gradio ꡬ문(Katex)을 μ‚¬μš©ν•˜λ„λ‘ MathJax ꡬ뢄 기호 μˆ˜μ •."""
text = re.sub(r"\\\[\s*(.*?)\s*\\\]", r"$$\1$$", text, flags=re.DOTALL)
text = re.sub(r"\\\(\s*(.*?)\s*\\\)", r"$\1$", text, flags=re.DOTALL)
return text
def extract_keywords(text: str) -> List[str]:
"""ν…μŠ€νŠΈμ—μ„œ κ°„λ‹¨ν•œ ν‚€μ›Œλ“œ μΆ”μΆœ ν•¨μˆ˜"""
# κ°„λ‹¨ν•œ κ΅¬ν˜„ - μ‹€μ œλ‘œλŠ” 더 λ³΅μž‘ν•œ NLP 기법을 μ‚¬μš©ν•  수 있음
common_math_keywords = [
"μˆ˜ν•™", "λŒ€μˆ˜ν•™", "κΈ°ν•˜ν•™", "μ‚°μˆ ", "ν™•λ₯ ", "곡식", "방정식",
"ν•¨μˆ˜", "적뢄", "λ―ΈλΆ„", "κΈ°ν•˜", "μ‚Όκ°ν˜•", "원", "각도", "λΉ„μœ¨",
"λΉ„λ‘€", "평균", "λΆ„μ‚°", "ν‘œμ€€νŽΈμ°¨"
]
keywords = []
for kw in common_math_keywords:
if kw in text:
keywords.append(kw)
return keywords[:5] # μ΅œλŒ€ 5개 ν‚€μ›Œλ“œλ§Œ λ°˜ν™˜
def get_embedding_for_text(text: str, hidden_size: int = 768) -> torch.Tensor:
"""
ν…μŠ€νŠΈλ₯Ό μœ„ν•œ μž„μ‹œ μž„λ² λ”© 생성 ν•¨μˆ˜
μ‹€μ œ κ΅¬ν˜„μ—μ„œλŠ” μ μ ˆν•œ μ–Έμ–΄ λͺ¨λΈμ„ μ‚¬μš©ν•΄μ•Ό 함
"""
# μž„μ‹œ κ΅¬ν˜„: ν…μŠ€νŠΈμ˜ ν•΄μ‹œ 값을 기반으둜 ν•œ μž„λ² λ”©
device = "cuda" if torch.cuda.is_available() else "cpu"
hash_val = hash(text)
np.random.seed(hash_val)
# μž„μ˜μ˜ μž„λ² λ”© 생성
embedding = np.random.rand(1, hidden_size).astype(np.float32)
# μ •κ·œν™”
norm = np.linalg.norm(embedding)
if norm > 0:
embedding = embedding / norm
return torch.tensor(embedding, device=device)
def user_input(message, history_original, history_thinking):
"""μ‚¬μš©μž μž…λ ₯을 νžˆμŠ€ν† λ¦¬μ— μΆ”κ°€ν•˜κ³  μž…λ ₯ ν…μŠ€νŠΈ μƒμž λΉ„μš°κΈ°"""
return "", history_original + [
gr.ChatMessage(role="user", content=message.replace(ANSWER_MARKER, ""))
], history_thinking + [
gr.ChatMessage(role="user", content=message.replace(ANSWER_MARKER, ""))
]
def rebuild_messages(history: list):
"""쀑간 생각 κ³Όμ • 없이 λͺ¨λΈμ΄ μ‚¬μš©ν•  νžˆμŠ€ν† λ¦¬μ—μ„œ λ©”μ‹œμ§€ μž¬κ΅¬μ„±"""
messages = []
for h in history:
if isinstance(h, dict) and not h.get("metadata", {}).get("title", False):
messages.append(h)
elif (
isinstance(h, gr.ChatMessage)
and h.metadata.get("title", None) is None
and isinstance(h.content, str)
):
messages.append({"role": h.role, "content": h.content})
return messages
# λͺ¨λΈκ³Ό 버퍼 λ§€λ‹ˆμ € μ΄ˆκΈ°ν™” ν•¨μˆ˜
def initialize_model_and_manager(model_name):
"""λͺ¨λΈκ³Ό 버퍼 λ§€λ‹ˆμ € μ΄ˆκΈ°ν™” ν•¨μˆ˜"""
try:
pipe = pipeline(
"text-generation",
model=model_name,
device_map="auto",
torch_dtype="auto",
)
# λͺ¨λΈ κ΅¬μ„±μ—μ„œ λ ˆμ΄μ–΄ 및 은닉 크기 정보 μΆ”μΆœ
config = pipe.model.config
if hasattr(config, "n_layer"):
num_layers = config.n_layer
elif hasattr(config, "num_layers"):
num_layers = config.num_layers
elif hasattr(config, "num_hidden_layers"):
num_layers = config.num_hidden_layers
else:
num_layers = 12 # κΈ°λ³Έκ°’
if hasattr(config, "n_embd"):
hidden_size = config.n_embd
elif hasattr(config, "hidden_size"):
hidden_size = config.hidden_size
else:
hidden_size = 768 # κΈ°λ³Έκ°’
# 버퍼 λ§€λ‹ˆμ € μ΄ˆκΈ°ν™”
buffer_manager = InferenceBufferManager(num_layers, hidden_size)
return pipe, buffer_manager
except Exception as e:
logger.error(f"λͺ¨λΈ μ΄ˆκΈ°ν™” 였λ₯˜: {str(e)}")
raise
def bot_original(
history: list,
max_num_tokens: int,
do_sample: bool,
temperature: float,
pipe=None
):
"""원본 λͺ¨λΈμ΄ μ§ˆλ¬Έμ— λ‹΅λ³€ν•˜λ„λ‘ ν•˜κΈ° (μΆ”λ‘  κ³Όμ • 없이)"""
if pipe is None:
# 이 뢀뢄은 μ‹€μ œ κ΅¬ν˜„μ—μ„œλŠ” μ „μ—­ λ³€μˆ˜λ‚˜ μ„Έμ…˜ μƒνƒœλ‘œ 관리해야 함
return history
# λ‚˜μ€‘μ— μŠ€λ ˆλ“œμ—μ„œ 토큰을 슀트림으둜 κ°€μ Έμ˜€κΈ° μœ„ν•¨
streamer = transformers.TextIteratorStreamer(
pipe.tokenizer,
skip_special_tokens=True,
skip_prompt=True,
)
# 보쑰자 λ©”μ‹œμ§€ μ€€λΉ„
history.append(
gr.ChatMessage(
role="assistant",
content=str(""),
)
)
# ν˜„μž¬ μ±„νŒ…μ— ν‘œμ‹œλ  λ©”μ‹œμ§€
messages = rebuild_messages(history[:-1]) # λ§ˆμ§€λ§‰ 빈 λ©”μ‹œμ§€ μ œμ™Έ
# 원본 λͺ¨λΈμ€ μΆ”λ‘  없이 λ°”λ‘œ λ‹΅λ³€
t = threading.Thread(
target=pipe,
args=(messages,),
kwargs=dict(
max_new_tokens=max_num_tokens,
streamer=streamer,
do_sample=do_sample,
temperature=temperature,
),
)
t.start()
for token in streamer:
history[-1].content += token
history[-1].content = reformat_math(history[-1].content)
yield history
t.join()
yield history
def bot_thinking_enhanced(
history: list,
max_num_tokens: int,
final_num_tokens: int,
do_sample: bool,
temperature: float,
pipe=None,
buffer_manager=None
):
"""μΆ”λ‘  과정을 ν¬ν•¨ν•˜μ—¬ λͺ¨λΈμ΄ μ§ˆλ¬Έμ— λ‹΅λ³€ν•˜λ„λ‘ ν•˜κΈ° - DeepSeek κΈ°λŠ₯ 톡합"""
if pipe is None or buffer_manager is None:
# 이 뢀뢄은 μ‹€μ œ κ΅¬ν˜„μ—μ„œλŠ” μ „μ—­ λ³€μˆ˜λ‚˜ μ„Έμ…˜ μƒνƒœλ‘œ 관리해야 함
return history
# λ‚˜μ€‘μ— μŠ€λ ˆλ“œμ—μ„œ 토큰을 슀트림으둜 κ°€μ Έμ˜€κΈ° μœ„ν•¨
streamer = transformers.TextIteratorStreamer(
pipe.tokenizer,
skip_special_tokens=True,
skip_prompt=True,
)
# ν•„μš”ν•œ 경우 좔둠에 μ§ˆλ¬Έμ„ λ‹€μ‹œ μ‚½μž…ν•˜κΈ° μœ„ν•¨
question = history[-1]["content"]
# 쿼리 μž„λ² λ”© 생성
query_embedding = get_embedding_for_text(question, buffer_manager.hidden_size)
# κ΄€λ ¨ μ»¨ν…μŠ€νŠΈ 검색
relevant_contexts = buffer_manager.get_relevant_context(query_embedding, top_k=3)
# ν‚€μ›Œλ“œ μΆ”μΆœ 및 κ·Έλž˜ν”„ λ©”λͺ¨λ¦¬μ—μ„œ μ»¨ν…μŠ€νŠΈ κ°€μ Έμ˜€κΈ°
keywords = extract_keywords(question)
graph_contexts = []
for keyword in keywords:
nodes = buffer_manager.graph_memory.search_nodes(keyword)
for node in nodes:
contexts = buffer_manager.graph_memory.get_connected_context(node)
graph_contexts.extend(contexts)
# λͺ¨λ“  μ»¨ν…μŠ€νŠΈ 병합
all_contexts = relevant_contexts + graph_contexts
all_contexts = list(set(all_contexts)) # 쀑볡 제거
all_contexts = all_contexts[:5] # μ΅œλŒ€ 5개 μ»¨ν…μŠ€νŠΈλ‘œ μ œν•œ
# 보쑰자 λ©”μ‹œμ§€ μ€€λΉ„
history.append(
gr.ChatMessage(
role="assistant",
content=str(""),
metadata={"title": "🧠 생각 쀑...", "status": "pending"},
)
)
# ν˜„μž¬ μ±„νŒ…μ— ν‘œμ‹œλ  μΆ”λ‘  κ³Όμ •
messages = rebuild_messages(history)
# κ΄€λ ¨ μ»¨ν…μŠ€νŠΈκ°€ μžˆλ‹€λ©΄ λ©”μ‹œμ§€μ— μΆ”κ°€
if all_contexts:
context_str = "\n\nκ΄€λ ¨ μ»¨ν…μŠ€νŠΈ:\n" + "\n".join(all_contexts)
messages[-1]["content"] += context_str
history[-1].content += context_str
# 전체 μΆ”λ‘  과정을 μ €μž₯ν•  λ³€μˆ˜
full_reasoning = ""
# μƒμ„±λœ 토큰 좔적을 μœ„ν•œ λ³€μˆ˜
generated_tokens = []
# μΆ”λ‘  단계 μ‹€ν–‰
for i, prepend in enumerate(rethink_prepends):
if i > 0:
messages[-1]["content"] += "\n\n"
messages[-1]["content"] += prepend.format(question=question)
t = threading.Thread(
target=pipe,
args=(messages,),
kwargs=dict(
max_new_tokens=max_num_tokens,
streamer=streamer,
do_sample=do_sample,
temperature=temperature,
),
)
t.start()
# μƒˆ λ‚΄μš©μœΌλ‘œ νžˆμŠ€ν† λ¦¬ μž¬κ΅¬μ„±
history[-1].content += prepend.format(question=question)
step_tokens = []
for token in streamer:
history[-1].content += token
history[-1].content = reformat_math(history[-1].content)
step_tokens.append(token)
generated_tokens.append(token)
yield history
t.join()
# 각 μΆ”λ‘  λ‹¨κ³„μ˜ κ²°κ³Όλ₯Ό full_reasoning에 μ €μž₯
full_reasoning = history[-1].content
# 좔둠이 κΈΈμ–΄μ§€λ©΄ 쀑간 μš”μ•½ 생성
if i > 0 and i % 3 == 0 and len(generated_tokens) > 500:
try:
summary = buffer_manager.summarizer.summarize_text(full_reasoning, max_length=150)
summary_text = f"\n\n**쀑간 μš”μ•½:**\n{summary}\n\n"
history[-1].content += summary_text
messages[-1]["content"] += summary_text
yield history
except Exception as e:
logger.error(f"μš”μ•½ 생성 였λ₯˜: {str(e)}")
# KV μΊμ‹œ μ••μΆ•
if i > 0 and i % 2 == 0:
buffer_manager.compress_all_buffers()
# μ‹œλ§¨ν‹± μ»¨ν…μŠ€νŠΈ μ—…λ°μ΄νŠΈ
step_text = "".join(step_tokens)
step_embedding = get_embedding_for_text(step_text, buffer_manager.hidden_size)
buffer_manager.semantic_memory.add_memory(step_text, step_embedding)
# μΆ”λ‘  μ™„λ£Œ, 이제 μ΅œμ’… 닡변을 생성
history[-1].metadata = {"title": "πŸ’­ 사고 κ³Όμ •", "status": "done"}
# μΆ”λ‘  과정을 μ‹œλ§¨ν‹± λ©”λͺ¨λ¦¬μ™€ κ·Έλž˜ν”„ λ©”λͺ¨λ¦¬μ— μ €μž₯
full_embedding = get_embedding_for_text(full_reasoning, buffer_manager.hidden_size)
buffer_manager.semantic_memory.add_memory(full_reasoning, full_embedding)
# ν‚€μ›Œλ“œμ— λŒ€ν•œ κ·Έλž˜ν”„ λ©”λͺ¨λ¦¬ μ—…λ°μ΄νŠΈ
for keyword in keywords:
if not buffer_manager.graph_memory.has_node(keyword):
buffer_manager.graph_memory.add_node(keyword, f"{keyword}에 κ΄€ν•œ κ°œλ…: 이 μ£Όμ œμ— λŒ€ν•œ 좔둠을 μˆ˜ν–‰ν–ˆμŠ΅λ‹ˆλ‹€.")
# κ΄€λ ¨ λ…Έλ“œμ™€ μ—°κ²°
for related_kw in keywords:
if related_kw != keyword and buffer_manager.graph_memory.has_node(related_kw):
buffer_manager.graph_memory.add_edge(keyword, related_kw)
# μΆ”λ‘  κ³Όμ •μ—μ„œ κ²°λ‘  뢀뢄을 μΆ”μΆœ (λ§ˆμ§€λ§‰ 1-2 문단 정도)
reasoning_parts = full_reasoning.split("\n\n")
reasoning_conclusion = "\n\n".join(reasoning_parts[-2:]) if len(reasoning_parts) > 2 else full_reasoning
# μ΅œμ’… λ‹΅λ³€ λ©”μ‹œμ§€ μΆ”κ°€
history.append(gr.ChatMessage(role="assistant", content=""))
# μ΅œμ’… 닡변을 μœ„ν•œ λ©”μ‹œμ§€ ꡬ성
final_messages = rebuild_messages(history[:-1]) # λ§ˆμ§€λ§‰ 빈 λ©”μ‹œμ§€ μ œμ™Έ
final_prompt = final_answer_prompt.format(
question=question,
reasoning_conclusion=reasoning_conclusion,
ANSWER_MARKER=ANSWER_MARKER
)
final_messages[-1]["content"] += final_prompt
# μ΅œμ’… λ‹΅λ³€ 생성
t = threading.Thread(
target=pipe,
args=(final_messages,),
kwargs=dict(
max_new_tokens=final_num_tokens,
streamer=streamer,
do_sample=do_sample,
temperature=temperature,
),
)
t.start()
# μ΅œμ’… λ‹΅λ³€ 슀트리밍
final_tokens = []
for token in streamer:
history[-1].content += token
history[-1].content = reformat_math(history[-1].content)
final_tokens.append(token)
yield history
t.join()
# μ΅œμ’… 닡변을 μ‹œλ§¨ν‹± λ©”λͺ¨λ¦¬μ— μ €μž₯
final_text = "".join(final_tokens)
final_embedding = get_embedding_for_text(final_text, buffer_manager.hidden_size)
buffer_manager.semantic_memory.add_memory(final_text, final_embedding)
# 주기적 λ©”λͺ¨λ¦¬ μš”μ•½ 체크
buffer_manager.maybe_summarize_memory()
yield history
with gr.Blocks(fill_height=True, title="Enhanced ThinkFlow") as demo:
# 제λͺ©κ³Ό μ„€λͺ…
gr.Markdown("# Enhanced ThinkFlow with DeepSeek Features")
gr.Markdown("### μ‹œλ§¨ν‹± λ©”λͺ¨λ¦¬, κ·Έλž˜ν”„ λ©”λͺ¨λ¦¬, 및 KV μΊμ‹œ 압좕을 톡해 ν–₯μƒλœ LLM μΆ”λ‘  생성 ν”Œλž«νΌ")
# λͺ¨λΈ 및 버퍼 λ§€λ‹ˆμ € μ΄ˆκΈ°ν™” (μ‹€μ œ κ΅¬ν˜„μ—μ„œλŠ” μ„Έμ…˜ μƒνƒœλ‘œ 관리)
model_name = "CohereForAI/c4ai-command-r7b-arabic-02-2025"
# μ„Έμ…˜ λ³€μˆ˜ (μ‹€μ œ κ΅¬ν˜„μ—μ„œλŠ” gr.State() μ‚¬μš©)
pipe = None
buffer_manager = None
current_contexts = []
# νƒ­ μΈν„°νŽ˜μ΄μŠ€
with gr.Tabs() as tabs:
# μ±„νŒ… νƒ­
with gr.TabItem("톡합 μΆ”λ‘  μΈν„°νŽ˜μ΄μŠ€"):
with gr.Row(scale=1):
with gr.Column(scale=2):
gr.Markdown("## Before (Original)")
chatbot_original = gr.Chatbot(
scale=1,
type="messages",
latex_delimiters=latex_delimiters,
label="Original Model (No Reasoning)"
)
with gr.Column(scale=2):
gr.Markdown("## After (Enhanced Thinking)")
chatbot_thinking = gr.Chatbot(
scale=1,
type="messages",
latex_delimiters=latex_delimiters,
label="Model with Enhanced Reasoning"
)
with gr.Row():
# msg ν…μŠ€νŠΈλ°•μŠ€λ₯Ό λ¨Όμ € μ •μ˜
msg = gr.Textbox(
submit_btn=True,
label="",
show_label=False,
placeholder="여기에 μ§ˆλ¬Έμ„ μž…λ ₯ν•˜μ„Έμš”.",
autofocus=True,
)
# ν”Όλ“œλ°± λ²„νŠΌ
with gr.Row():
with gr.Column(scale=1):
feedback_btn_pos = gr.Button("πŸ‘ 이 좔둠이 도움이 λ˜μ—ˆμŠ΅λ‹ˆλ‹€")
with gr.Column(scale=1):
feedback_btn_neg = gr.Button("πŸ‘Ž 이 좔둠은 κ°œμ„ μ΄ ν•„μš”ν•©λ‹ˆλ‹€")
with gr.Column(scale=1):
clear_memory_btn = gr.Button("🧹 λ©”λͺ¨λ¦¬ μ΄ˆκΈ°ν™”")
# λ©”λͺ¨λ¦¬ μ‹œκ°ν™” νƒ­
with gr.TabItem("λ©”λͺ¨λ¦¬ μ‹œκ°ν™”"):
gr.Markdown("## μ‹œλ§¨ν‹± λ©”λͺ¨λ¦¬ λ‚΄μš©")
semantic_memory_display = gr.Textbox(
label="ν˜„μž¬ μ‹œλ§¨ν‹± λ©”λͺ¨λ¦¬ λ‚΄μš©",
placeholder="아직 λ©”λͺ¨λ¦¬κ°€ μ—†μŠ΅λ‹ˆλ‹€.",
lines=10,
max_lines=20,
interactive=False
)
gr.Markdown("## κ·Έλž˜ν”„ μ§€μ‹λ² μ΄μŠ€")
graph_memory_display = gr.Textbox(
label="ν˜„μž¬ κ·Έλž˜ν”„ λ©”λͺ¨λ¦¬ λ‚΄μš©",
placeholder="아직 κ·Έλž˜ν”„ λ…Έλ“œκ°€ μ—†μŠ΅λ‹ˆλ‹€.",
lines=10,
max_lines=20,
interactive=False
)
# 예제 μ„Ήμ…˜ - msg λ³€μˆ˜ μ •μ˜ 이후에 배치
with gr.Accordion("EXAMPLES", open=False):
examples = gr.Examples(
examples=[
"[좜처: MATH-500)] 처음 100개의 μ–‘μ˜ μ •μˆ˜ μ€‘μ—μ„œ 3, 4, 5둜 λ‚˜λˆ„μ–΄ λ–¨μ–΄μ§€λŠ” μˆ˜λŠ” λͺ‡ κ°œμž…λ‹ˆκΉŒ?",
"[좜처: MATH-500)] μž‰ν¬μ˜ λ•…μ—μ„œ 돈 μ‹œμŠ€ν…œμ€ λ…νŠΉν•©λ‹ˆλ‹€. νŠΈλ§ν‚· 1κ°œλŠ” 블링킷 4κ°œμ™€ κ°™κ³ , 블링킷 3κ°œλŠ” λ“œλ§ν¬ 7κ°œμ™€ κ°™μŠ΅λ‹ˆλ‹€. νŠΈλ§ν‚·μ—μ„œ λ“œλ§ν¬ 56개의 κ°€μΉ˜λŠ” μ–Όλ§ˆμž…λ‹ˆκΉŒ?",
"[좜처: MATH-500)] 에이미, λ²€, 크리슀의 평균 λ‚˜μ΄λŠ” 6μ‚΄μž…λ‹ˆλ‹€. 4λ…„ μ „ ν¬λ¦¬μŠ€λŠ” μ§€κΈˆ 에이미와 같은 λ‚˜μ΄μ˜€μŠ΅λ‹ˆλ‹€. 4λ…„ ν›„ 벀의 λ‚˜μ΄λŠ” κ·Έλ•Œ μ—μ΄λ―Έμ˜ λ‚˜μ΄μ˜ $\\frac{3}{5}$κ°€ 될 κ²ƒμž…λ‹ˆλ‹€. ν¬λ¦¬μŠ€λŠ” μ§€κΈˆ λͺ‡ μ‚΄μž…λ‹ˆκΉŒ?",
"[좜처: MATH-500)] λ…Έλž€μƒ‰κ³Ό νŒŒλž€μƒ‰ ꡬ슬이 λ“€μ–΄ μžˆλŠ” 가방이 μžˆμŠ΅λ‹ˆλ‹€. ν˜„μž¬ νŒŒλž€μƒ‰ ꡬ슬과 λ…Έλž€μƒ‰ ꡬ슬의 λΉ„μœ¨μ€ 4:3μž…λ‹ˆλ‹€. νŒŒλž€μƒ‰ ꡬ슬 5개λ₯Ό λ”ν•˜κ³  λ…Έλž€μƒ‰ ꡬ슬 3개λ₯Ό μ œκ±°ν•˜λ©΄ λΉ„μœ¨μ€ 7:3이 λ©λ‹ˆλ‹€. 더 λ„£κΈ° 전에 가방에 νŒŒλž€μƒ‰ ꡬ슬이 λͺ‡ 개 μžˆμ—ˆμŠ΅λ‹ˆκΉŒ?"
],
inputs=msg
)
with gr.Accordion("λ§€κ°œλ³€μˆ˜ μ‘°μ •", open=False):
with gr.Row():
with gr.Column():
model_dropdown = gr.Dropdown(
["CohereForAI/c4ai-command-r7b-arabic-02-2025", "meta-llama/Meta-Llama-3-8B-Instruct"],
label="λͺ¨λΈ 선택",
value="CohereForAI/c4ai-command-r7b-arabic-02-2025"
)
num_tokens = gr.Slider(
50,
4000,
2000,
step=1,
label="μΆ”λ‘  단계당 μ΅œλŒ€ 토큰 수",
interactive=True,
)
final_num_tokens = gr.Slider(
50,
4000,
2000,
step=1,
label="μ΅œμ’… λ‹΅λ³€μ˜ μ΅œλŒ€ 토큰 수",
interactive=True,
)
with gr.Column():
do_sample = gr.Checkbox(True, label="μƒ˜ν”Œλ§ μ‚¬μš©")
temperature = gr.Slider(0.1, 1.0, 0.7, step=0.1, label="μ˜¨λ„")
memory_weight = gr.Slider(0.0, 1.0, 0.5, step=0.1, label="λ©”λͺ¨λ¦¬ 반영 κ°€μ€‘μΉ˜")
# ν”Όλ“œλ°± 처리 ν•¨μˆ˜
def process_positive_feedback():
global buffer_manager, current_contexts
if buffer_manager:
buffer_manager.update_retrieval_reward(current_contexts, reward=1.0)
return "ν”Όλ“œλ°± κ°μ‚¬ν•©λ‹ˆλ‹€! 이 μ ‘κ·Ό 방식을 ν–₯ν›„ μœ μ‚¬ν•œ μ§ˆλ¬Έμ— 더 자주 μ‚¬μš©ν•˜κ² μŠ΅λ‹ˆλ‹€."
def process_negative_feedback():
global buffer_manager, current_contexts
if buffer_manager:
buffer_manager.update_retrieval_reward(current_contexts, reward=-0.5)
return "ν”Όλ“œλ°± κ°μ‚¬ν•©λ‹ˆλ‹€! 이 μ ‘κ·Ό 방식을 κ°œμ„ ν•˜κ² μŠ΅λ‹ˆλ‹€."
def clear_memory():
global buffer_manager
if buffer_manager:
buffer_manager.clear()
return "λ©”λͺ¨λ¦¬κ°€ μ΄ˆκΈ°ν™”λ˜μ—ˆμŠ΅λ‹ˆλ‹€."
def update_memory_displays():
global buffer_manager
if not buffer_manager:
return "λ©”λͺ¨λ¦¬κ°€ μ΄ˆκΈ°ν™”λ˜μ§€ μ•Šμ•˜μŠ΅λ‹ˆλ‹€.", "κ·Έλž˜ν”„κ°€ μ΄ˆκΈ°ν™”λ˜μ§€ μ•Šμ•˜μŠ΅λ‹ˆλ‹€."
semantic_text = "ν˜„μž¬ μ €μž₯된 λ©”λͺ¨λ¦¬:\n\n"
for i, mem in enumerate(buffer_manager.semantic_memory.memories[:5]): # μ΅œλŒ€ 5개만 ν‘œμ‹œ
semantic_text += f"{i+1}. {mem['text'][:100]}...\n\n"
graph_text = "ν˜„μž¬ κ·Έλž˜ν”„ λ…Έλ“œ:\n\n"
for node in buffer_manager.graph_memory.graph.nodes():
node_text = buffer_manager.graph_memory.get_text_by_node(node)
neighbors = list(buffer_manager.graph_memory.graph.neighbors(node))
graph_text += f"λ…Έλ“œ: {node}\nμ„€λͺ…: {node_text[:50]}...\nμ—°κ²°: {', '.join(neighbors[:3])}\n\n"
return semantic_text, graph_text
# μ΄ˆκΈ°ν™” ν•¨μˆ˜
def initialize_models():
global pipe, buffer_manager, model_name
try:
pipe, buffer_manager = initialize_model_and_manager(model_name)
semantic_text, graph_text = update_memory_displays()
return "λͺ¨λΈμ΄ μ΄ˆκΈ°ν™”λ˜μ—ˆμŠ΅λ‹ˆλ‹€.", semantic_text, graph_text
except Exception as e:
return f"λͺ¨λΈ μ΄ˆκΈ°ν™” 였λ₯˜: {str(e)}", "", ""
# λͺ¨λΈ 선택 λ³€κ²½ μ‹œ 처리
def change_model(new_model_name):
global model_name
model_name = new_model_name
status, semantic_text, graph_text = initialize_models()
return status, semantic_text, graph_text
# μ΄ˆκΈ°ν™” ν•¨μˆ˜ μ‹€ν–‰
model_dropdown.change(
change_model,
[model_dropdown],
[gr.Textbox(visible=False), semantic_memory_display, graph_memory_display]
)
# ν”Όλ“œλ°± λ²„νŠΌμ— ν•¨μˆ˜ μ—°κ²°
feedback_btn_pos.click(process_positive_feedback, [], gr.Textbox(visible=False))
feedback_btn_neg.click(process_negative_feedback, [], gr.Textbox(visible=False))
clear_memory_btn.click(clear_memory, [], gr.Textbox(visible=False))
# νƒ­ λ³€κ²½ μ‹œ λ©”λͺ¨λ¦¬ λ””μŠ€ν”Œλ ˆμ΄ μ—…λ°μ΄νŠΈ
tabs.change(update_memory_displays, [], [semantic_memory_display, graph_memory_display])
# μ‚¬μš©μžκ°€ λ©”μ‹œμ§€λ₯Ό μ œμΆœν•˜λ©΄ 두 봇이 λ™μ‹œμ— μ‘λ‹΅ν•©λ‹ˆλ‹€
msg.submit(
user_input,
[msg, chatbot_original, chatbot_thinking], # μž…λ ₯
[msg, chatbot_original, chatbot_thinking], # 좜λ ₯
).then(
lambda h, n, d, t, p: bot_original(h, n, d, t, p), # pipe λ§€κ°œλ³€μˆ˜ μΆ”κ°€
[
chatbot_original,
num_tokens,
do_sample,
temperature,
gr.Textbox(value=lambda: pipe, visible=False), # pipe 전달
],
chatbot_original, # 좜λ ₯μ—μ„œ μƒˆ νžˆμŠ€ν† λ¦¬ μ €μž₯
).then(
lambda h, n, f, d, t, p, b: bot_thinking_enhanced(h, n, f, d, t, p, b), # λ§€κ°œλ³€μˆ˜ μΆ”κ°€
[
chatbot_thinking,
num_tokens,
final_num_tokens,
do_sample,
temperature,
gr.Textbox(value=lambda: pipe, visible=False), # pipe 전달
gr.Textbox(value=lambda: buffer_manager, visible=False), # buffer_manager 전달
],
chatbot_thinking, # 좜λ ₯μ—μ„œ μƒˆ νžˆμŠ€ν† λ¦¬ μ €μž₯
).then(
update_memory_displays,
[],
[semantic_memory_display, graph_memory_display]
)
# μ‹œμž‘ μ‹œ λͺ¨λΈ μ΄ˆκΈ°ν™”λ₯Ό μœ„ν•œ μ½”λ“œ
def load_on_startup():
global pipe, buffer_manager
try:
# κΈ°λ³Έ λͺ¨λΈ μ΄ˆκΈ°ν™”
pipe, buffer_manager = initialize_model_and_manager(
"CohereForAI/c4ai-command-r7b-arabic-02-2025"
)
logger.info("λͺ¨λΈ 및 버퍼 λ§€λ‹ˆμ €κ°€ μ„±κ³΅μ μœΌλ‘œ μ΄ˆκΈ°ν™”λ˜μ—ˆμŠ΅λ‹ˆλ‹€.")
except Exception as e:
logger.error(f"μ‹œμž‘ μ‹œ λͺ¨λΈ μ΄ˆκΈ°ν™” μ‹€νŒ¨: {str(e)}")
if __name__ == "__main__":
# μ‘μš© ν”„λ‘œκ·Έλž¨ μ‹œμž‘ 전에 λͺ¨λΈ μ΄ˆκΈ°ν™”
load_on_startup()
# λŒ€κΈ°μ—΄ 및 μ„œλ²„ μ‹œμž‘
demo.queue().launch(
share=False,
debug=True,
title="Enhanced ThinkFlow with DeepSeek Features"
)