ThinkFlow-llama / app.py
openfree's picture
Update app.py
a4ca8e9 verified
raw
history blame
22.8 kB
import re
import threading
from collections import Counter
import gradio as gr
import spaces
import transformers
from transformers import pipeline
# ๋ชจ๋ธ๊ณผ ํ† ํฌ๋‚˜์ด์ € ๋กœ๋”ฉ
model_name = "CohereForAI/c4ai-command-r7b-arabic-02-2025"
if gr.NO_RELOAD:
pipe = pipeline(
"text-generation",
model=model_name,
device_map="auto",
torch_dtype="auto",
)
# ์ตœ์ข… ๋‹ต๋ณ€์„ ๊ฐ์ง€ํ•˜๊ธฐ ์œ„ํ•œ ๋งˆ์ปค
ANSWER_MARKER = "**๋‹ต๋ณ€**"
# ๋‹จ๊ณ„๋ณ„ ์ถ”๋ก ์„ ์‹œ์ž‘ํ•˜๋Š” ๋ฌธ์žฅ๋“ค
rethink_prepends = [
"์ž, ์ด์ œ ๋‹ค์Œ์„ ํŒŒ์•…ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค ",
"์ œ ์ƒ๊ฐ์—๋Š” ",
"์ž ์‹œ๋งŒ์š”, ์ œ ์ƒ๊ฐ์—๋Š” ",
"๋‹ค์Œ ์‚ฌํ•ญ์ด ๋งž๋Š”์ง€ ํ™•์ธํ•ด ๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค ",
"๋˜ํ•œ ๊ธฐ์–ตํ•ด์•ผ ํ•  ๊ฒƒ์€ ",
"๋˜ ๋‹ค๋ฅธ ์ฃผ๋ชฉํ•  ์ ์€ ",
"๊ทธ๋ฆฌ๊ณ  ์ €๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์€ ์‚ฌ์‹ค๋„ ๊ธฐ์–ตํ•ฉ๋‹ˆ๋‹ค ",
"์ด์ œ ์ถฉ๋ถ„ํžˆ ์ดํ•ดํ–ˆ๋‹ค๊ณ  ์ƒ๊ฐํ•ฉ๋‹ˆ๋‹ค ",
]
# ์ผ๋ฐ˜์ ์ธ ์ถ”๋ก  ๊ฐ€์ด๋“œ ํ”„๋กฌํ”„ํŠธ
general_reasoning_guide = """
์ด ๋ฌธ์ œ๋ฅผ ํ•ด๊ฒฐํ•˜๊ธฐ ์œ„ํ•œ ์ฒด๊ณ„์ ์ธ ์ ‘๊ทผ ๋ฐฉ๋ฒ•์„ ์‚ฌ์šฉํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค:
1. ๋ฌธ์ œ์—์„œ ์ œ๊ณต๋œ ๋ชจ๋“  ์ •๋ณด์™€ ์กฐ๊ฑด์„ ๋ช…ํ™•ํžˆ ์ดํ•ดํ•ฉ๋‹ˆ๋‹ค.
2. ๊ฐ ๋ณ€์ˆ˜์™€ ๊ด€๊ณ„๋ฅผ ์‹๋ณ„ํ•˜๊ณ  ํ•„์š”ํ•œ ๋ฐฉ์ •์‹์„ ์„ธ์›๋‹ˆ๋‹ค.
3. ๋‹จ๊ณ„๋ณ„๋กœ ๊ณ„์‚ฐ์„ ์ˆ˜ํ–‰ํ•˜๋ฉฐ, ๊ฐ ๋‹จ๊ณ„์˜ ๊ฒฐ๊ณผ๋ฅผ ํ™•์ธํ•ฉ๋‹ˆ๋‹ค.
4. ์ค‘๊ฐ„ ๊ฒฐ๊ณผ๊ฐ€ ํ•ฉ๋ฆฌ์ ์ธ์ง€ ๊ฒ€ํ† ํ•˜๋ฉฐ ์ง„ํ–‰ํ•ฉ๋‹ˆ๋‹ค.
5. ์ตœ์ข… ๋‹ต์•ˆ์„ ๋„์ถœํ•˜๊ณ  ๋ฌธ์ œ์˜ ์š”๊ตฌ์‚ฌํ•ญ์„ ์ถฉ์กฑํ•˜๋Š”์ง€ ํ™•์ธํ•ฉ๋‹ˆ๋‹ค.
์ด์ œ ๋ฌธ์ œ๋ฅผ ํ’€์–ด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค:
"""
# ๊ฒฐ๊ณผ ์ถ”์ถœ ๋ฐ ๊ฒ€์ฆ์„ ์œ„ํ•œ ํ•จ์ˆ˜๋“ค
def extract_calculation_results(reasoning_text):
"""์ถ”๋ก  ๊ณผ์ •์—์„œ ๋„์ถœ๋œ ๊ฐ€๋Šฅํ•œ ๋‹ต์•ˆ ๊ฒฐ๊ณผ๋ฅผ ์ถ”์ถœํ•ฉ๋‹ˆ๋‹ค."""
# ์ˆ˜์น˜ ๊ฒฐ๊ณผ ํŒจํ„ด (๋‹ค์–‘ํ•œ ํ‘œํ˜„ ๋ฐฉ์‹ ๊ณ ๋ ค)
numeric_patterns = [
r'๊ฒฐ๊ณผ๋Š” (\d+[\.,]?\d*)',
r'๋‹ต(์€|๋Š”|์ด) (\d+[\.,]?\d*)',
r'์ •๋‹ต(์€|๋Š”|์ด) (\d+[\.,]?\d*)',
r'๋‹ต์•ˆ(์€|๋Š”|์ด) (\d+[\.,]?\d*)',
r'์ˆ˜์ต(์€|๋Š”|์ด) (\d+[\.,]?\d*)',
r'๊ฐ’(์€|๋Š”|์ด) (\d+[\.,]?\d*)',
r'๊ฒฐ๋ก (์€|๋Š”|์ด) (\d+[\.,]?\d*)',
r'๊ฐœ์ˆ˜(๋Š”|์€|๊ฐ€) (\d+[\.,]?\d*)',
r'์ด (\d+[\.,]?\d*)๊ฐœ',
r'์ด์•ก(์€|๋Š”|์ด) (\d+[\.,]?\d*)',
r'์ดํ•ฉ(์€|๋Š”|์ด) (\d+[\.,]?\d*)',
r'ํ•ฉ๊ณ„(๋Š”|์€|๊ฐ€) (\d+[\.,]?\d*)',
r'=\s*(\d+[\.,]?\d*)\s*$',
r':\s*(\d+[\.,]?\d*)\s*$',
r'์ด๊ณ„:\s*(\d+[\.,]?\d*)',
r'์ตœ์ข… ๊ฒฐ๊ณผ:\s*(\d+[\.,]?\d*)',
r'์ตœ์ข… ๊ฐ’:\s*(\d+[\.,]?\d*)',
r'์ตœ์ข… ๋‹ต๋ณ€:\s*(\d+[\.,]?\d*)',
]
# ๋‹จ์œ„๋ฅผ ํฌํ•จํ•œ ํŒจํ„ด (๋‹ฌ๋Ÿฌ, ๊ฐœ, ์„ธํŠธ ๋“ฑ)
unit_patterns = [
r'(\d+[\.,]?\d*)\s*(๋‹ฌ๋Ÿฌ|์›|์œ ๋กœ|ํŒŒ์šด๋“œ|์—”)',
r'(\d+[\.,]?\d*)\s*(๊ฐœ|๋ช…|์„ธํŠธ|์Œ|ํŒ€|๊ทธ๋ฃน)',
r'(\d+[\.,]?\d*)\s*(๋ถ„|์‹œ๊ฐ„|์ดˆ|์ผ|์ฃผ|๊ฐœ์›”|๋…„)',
r'(\d+[\.,]?\d*)\s*(๋ฏธํ„ฐ|ํ‚ฌ๋กœ๋ฏธํ„ฐ|์„ผํ‹ฐ๋ฏธํ„ฐ|์ธ์น˜|ํ”ผํŠธ)',
r'(\d+[\.,]?\d*)\s*(๊ทธ๋žจ|ํ‚ฌ๋กœ๊ทธ๋žจ|ํŒŒ์šด๋“œ|์˜จ์Šค)',
]
results = []
# ์ˆซ์ž ๊ฒฐ๊ณผ ์ถ”์ถœ
for pattern in numeric_patterns:
matches = re.findall(pattern, reasoning_text, re.IGNORECASE)
for match in matches:
if isinstance(match, tuple):
# ๊ทธ๋ฃน์ด ์—ฌ๋Ÿฌ ๊ฐœ์ธ ๊ฒฝ์šฐ (์ฒซ ๋ฒˆ์งธ๋Š” ์กฐ์‚ฌ ๋“ฑ)
value = match[-1] # ๋งˆ์ง€๋ง‰ ๊ทธ๋ฃน์ด ์ˆซ์ž๊ฐ’
else:
value = match
# ์ฝค๋งˆ ์ œ๊ฑฐ ๋ฐ ์†Œ์ˆ˜์  ์ฒ˜๋ฆฌ
value = value.replace(',', '')
try:
if '.' in value:
results.append(float(value))
else:
results.append(int(value))
except ValueError:
continue
# ๋‹จ์œ„๊ฐ€ ํฌํ•จ๋œ ๊ฒฐ๊ณผ ์ถ”์ถœ
for pattern in unit_patterns:
matches = re.findall(pattern, reasoning_text, re.IGNORECASE)
for match in matches:
value = match[0].replace(',', '')
try:
if '.' in value:
results.append(float(value))
else:
results.append(int(value))
except ValueError:
continue
# ๋งˆ์ง€๋ง‰ ๋ฌธ๋‹จ์—์„œ ์ˆซ์ž๋งŒ ์ถ”์ถœ (์ตœ์ข… ๋‹ต๋ณ€์— ๊ฐ€๊นŒ์šด ์ˆซ์ž)
last_paragraph = reasoning_text.split('\n\n')[-1]
numbers_in_last = re.findall(r'(\d+[\.,]?\d*)', last_paragraph)
for num in numbers_in_last:
num = num.replace(',', '')
try:
if '.' in num:
results.append(float(num))
else:
results.append(int(num))
except ValueError:
continue
return results
def determine_best_result(results, full_reasoning):
"""๊ฐ€์žฅ ์‹ ๋ขฐํ•  ์ˆ˜ ์žˆ๋Š” ๊ฒฐ๊ณผ๋ฅผ ๊ฒฐ์ •ํ•ฉ๋‹ˆ๋‹ค."""
if not results:
return None
# ๊ฒฐ๊ณผ๊ฐ€ ํ•˜๋‚˜๋ฐ–์— ์—†์œผ๋ฉด ๊ทธ๊ฒƒ์„ ๋ฐ˜ํ™˜
if len(set(results)) == 1:
return results[0]
# ๋นˆ๋„ ๊ธฐ๋ฐ˜ ๋ถ„์„ (๊ฐ€์žฅ ์ž์ฃผ ๋“ฑ์žฅํ•œ ๊ฒฐ๊ณผ๊ฐ€ ์‹ ๋ขฐ์„ฑ์ด ๋†’์„ ๊ฐ€๋Šฅ์„ฑ)
counter = Counter(results)
most_common = counter.most_common()
# ๋นˆ๋„๊ฐ€ ๋†’์€ ์ƒ์œ„ ๊ฒฐ๊ณผ๋“ค
top_results = [result for result, count in most_common if count >= most_common[0][1] * 0.8]
if len(top_results) == 1:
return top_results[0]
# ์ตœ์ข… ๊ฒฐ๋ก  ๊ทผ์ฒ˜์— ์žˆ๋Š” ๊ฒฐ๊ณผ์— ๋” ๋†’์€ ๊ฐ€์ค‘์น˜ ๋ถ€์—ฌ
paragraphs = full_reasoning.split('\n\n')
last_paragraphs = '\n\n'.join(paragraphs[-2:]) # ๋งˆ์ง€๋ง‰ ๋‘ ๋‹จ๋ฝ
# ๋งˆ์ง€๋ง‰ ๋‹จ๋ฝ์—์„œ ๋“ฑ์žฅํ•˜๋Š” ๊ฒฐ๊ณผ ํ™•์ธ
final_results = [result for result in top_results if str(result) in last_paragraphs]
if final_results:
# ๋งˆ์ง€๋ง‰ ๋‹จ๋ฝ์—์„œ ๊ฐ€์žฅ ์ž์ฃผ ๋“ฑ์žฅํ•œ ๊ฒฐ๊ณผ
final_counter = Counter([r for r in results if r in final_results])
if final_counter:
return final_counter.most_common(1)[0][0]
# ์ˆ˜์‹๊ณผ ํ•จ๊ป˜ ๋“ฑ์žฅํ•˜๋Š” ๊ฒฐ๊ณผ (์˜ˆ: "= 78", "์ดํ•ฉ: 78")
for result in top_results:
result_str = str(result)
if re.search(r'=\s*' + result_str + r'(?!\d)', full_reasoning) or \
re.search(r'๊ฒฐ๊ณผ[:๋Š”์€์ด๊ฐ€]\s*' + result_str, full_reasoning) or \
re.search(r'๋‹ต[:๋Š”์€์ด๊ฐ€]\s*' + result_str, full_reasoning) or \
re.search(r'์ •๋‹ต[:๋Š”์€์ด๊ฐ€]\s*' + result_str, full_reasoning):
return result
# ์œ„์˜ ๋ฐฉ๋ฒ•์œผ๋กœ ๊ฒฐ์ •ํ•  ์ˆ˜ ์—†์„ ๊ฒฝ์šฐ ๊ฐ€์žฅ ๋นˆ๋„๊ฐ€ ๋†’์€ ๊ฒฐ๊ณผ ๋ฐ˜ํ™˜
return most_common[0][0]
# ์ค‘๊ฐ„ ๊ฒฐ๊ณผ๋ฅผ ์š”์•ฝํ•˜๊ธฐ ์œ„ํ•œ ํ”„๋กฌํ”„ํŠธ
structured_reasoning_prompt = """
์ง€๊ธˆ๊นŒ์ง€์˜ ์ถ”๋ก ์„ ๋‹จ๊ณ„๋ณ„๋กœ ์ •๋ฆฌํ•ด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค:
1. ๋ฌธ์ œ ๋ถ„์„:
- ์ฃผ์–ด์ง„ ์ •๋ณด: {given_info}
- ๊ตฌํ•ด์•ผ ํ•  ๊ฒƒ: {goal}
2. ๊ณ„์‚ฐ ๊ณผ์ •:
{calculation_steps}
3. ํ˜„์žฌ๊นŒ์ง€์˜ ๊ฒฐ๋ก :
{current_conclusion}
์ด์ œ ๋‹ค์Œ ๋‹จ๊ณ„๋กœ ์ง„ํ–‰ํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค.
"""
# ์ตœ์ข… ๊ฒฐ๊ณผ ๊ฒ€์ฆ์„ ์œ„ํ•œ ํ”„๋กฌํ”„ํŠธ
verification_prompt = """
์ง€๊ธˆ๊นŒ์ง€์˜ ์ถ”๋ก  ๊ณผ์ •์—์„œ ์—ฌ๋Ÿฌ ๊ฒฐ๊ณผ๊ฐ€ ๋„์ถœ๋˜์—ˆ์Šต๋‹ˆ๋‹ค:
{different_results}
์ด ์ค‘์—์„œ ๊ฐ€์žฅ ์ •ํ™•ํ•œ ๋‹ต๋ณ€์„ ์ฐพ๊ธฐ ์œ„ํ•ด ๊ณ„์‚ฐ ๊ณผ์ •์„ ์ฒ˜์Œ๋ถ€ํ„ฐ ๋‹ค์‹œ ๊ฒ€ํ† ํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค:
1. ๋ฌธ์ œ ๋ถ„์„:
- ์ฃผ์–ด์ง„ ์ •๋ณด: {given_info}
- ๊ตฌํ•ด์•ผ ํ•  ๊ฒƒ: {goal}
2. ๋‹จ๊ณ„๋ณ„ ๊ณ„์‚ฐ ๊ณผ์ •:
{calculation_steps}
3. ๊ฒฐ๋ก :
์œ„ ๊ณ„์‚ฐ ๊ณผ์ •์„ ํ†ตํ•ด ์ •ํ™•ํ•œ ๋‹ต์€ {result}์ž…๋‹ˆ๋‹ค.
"""
# ์ตœ์ข… ๋‹ต๋ณ€ ์ƒ์„ฑ์„ ์œ„ํ•œ ํ”„๋กฌํ”„ํŠธ ์ถ”๊ฐ€
final_answer_prompt = """
์ง€๊ธˆ๊นŒ์ง€์˜ ์ฒด๊ณ„์ ์ธ ์ถ”๋ก  ๊ณผ์ •์„ ์ข…ํ•ฉํ•˜์—ฌ, ์›๋ž˜ ์งˆ๋ฌธ์— ๋‹ต๋ณ€ํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค:
{question}
์ถ”๋ก  ๊ณผ์ •์„ ๊ฒ€ํ† ํ•œ ๊ฒฐ๊ณผ, ๋‹ค์Œ๊ณผ ๊ฐ™์€ ๊ฒฐ๋ก ์— ๋„๋‹ฌํ–ˆ์Šต๋‹ˆ๋‹ค:
{reasoning_conclusion}
๋”ฐ๋ผ์„œ ์ตœ์ข… ๋‹ต๋ณ€์€:
{ANSWER_MARKER}
"""
# ์ˆ˜์‹ ํ‘œ์‹œ ๋ฌธ์ œ ํ•ด๊ฒฐ์„ ์œ„ํ•œ ์„ค์ •
latex_delimiters = [
{"left": "$$", "right": "$$", "display": True},
{"left": "$", "right": "$", "display": False},
]
def reformat_math(text):
"""Gradio ๊ตฌ๋ฌธ(Katex)์„ ์‚ฌ์šฉํ•˜๋„๋ก MathJax ๊ตฌ๋ถ„ ๊ธฐํ˜ธ ์ˆ˜์ •.
์ด๊ฒƒ์€ Gradio์—์„œ ์ˆ˜ํ•™ ๊ณต์‹์„ ํ‘œ์‹œํ•˜๊ธฐ ์œ„ํ•œ ์ž„์‹œ ํ•ด๊ฒฐ์ฑ…์ž…๋‹ˆ๋‹ค. ํ˜„์žฌ๋กœ์„œ๋Š”
๋‹ค๋ฅธ latex_delimiters๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์˜ˆ์ƒ๋Œ€๋กœ ์ž‘๋™ํ•˜๊ฒŒ ํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ์ฐพ์ง€ ๋ชปํ–ˆ์Šต๋‹ˆ๋‹ค...
"""
text = re.sub(r"\\\[\s*(.*?)\s*\\\]", r"$$\1$$", text, flags=re.DOTALL)
text = re.sub(r"\\\(\s*(.*?)\s*\\\)", r"$\1$", text, flags=re.DOTALL)
return text
def user_input(message, history_original, history_thinking):
"""์‚ฌ์šฉ์ž ์ž…๋ ฅ์„ ํžˆ์Šคํ† ๋ฆฌ์— ์ถ”๊ฐ€ํ•˜๊ณ  ์ž…๋ ฅ ํ…์ŠคํŠธ ์ƒ์ž ๋น„์šฐ๊ธฐ"""
return "", history_original + [
gr.ChatMessage(role="user", content=message.replace(ANSWER_MARKER, ""))
], history_thinking + [
gr.ChatMessage(role="user", content=message.replace(ANSWER_MARKER, ""))
]
def rebuild_messages(history: list):
"""์ค‘๊ฐ„ ์ƒ๊ฐ ๊ณผ์ • ์—†์ด ๋ชจ๋ธ์ด ์‚ฌ์šฉํ•  ํžˆ์Šคํ† ๋ฆฌ์—์„œ ๋ฉ”์‹œ์ง€ ์žฌ๊ตฌ์„ฑ"""
messages = []
for h in history:
if isinstance(h, dict) and not h.get("metadata", {}).get("title", False):
messages.append(h)
elif (
isinstance(h, gr.ChatMessage)
and h.metadata.get("title", None) is None
and isinstance(h.content, str)
):
messages.append({"role": h.role, "content": h.content})
return messages
def extract_info_from_question(question):
"""๋ฌธ์ œ์—์„œ ์ฃผ์–ด์ง„ ์ •๋ณด์™€ ๋ชฉํ‘œ๋ฅผ ์ถ”์ถœํ•ฉ๋‹ˆ๋‹ค."""
# ๊ธฐ๋ณธ ๊ฐ’
given_info = "๋ฌธ์ œ์—์„œ ์ œ๊ณต๋œ ๋ชจ๋“  ์กฐ๊ฑด๊ณผ ์ˆ˜์น˜"
goal = "๋ฌธ์ œ์—์„œ ์š”๊ตฌํ•˜๋Š” ๊ฐ’์ด๋‚˜ ๊ฒฐ๊ณผ"
# ์ผ๋ฐ˜์ ์ธ ์ •๋ณด ์ถ”์ถœ ํŒจํ„ด
if "๋ช‡ ๊ฐœ" in question or "๊ฐœ์ˆ˜" in question:
goal = "ํŠน์ • ์กฐ๊ฑด์„ ๋งŒ์กฑํ•˜๋Š” ํ•ญ๋ชฉ์˜ ๊ฐœ์ˆ˜"
elif "์–ผ๋งˆ" in question:
goal = "ํŠน์ • ๊ฐ’ ๋˜๋Š” ๊ธˆ์•ก"
elif "๋‚˜์ด" in question:
goal = "์‚ฌ๋žŒ์˜ ๋‚˜์ด"
elif "ํ™•๋ฅ " in question:
goal = "ํŠน์ • ์‚ฌ๊ฑด์˜ ํ™•๋ฅ "
elif "ํ‰๊ท " in question:
goal = "๊ฐ’๋“ค์˜ ํ‰๊ท "
return given_info, goal
@spaces.GPU
def bot_original(
history: list,
max_num_tokens: int,
do_sample: bool,
temperature: float,
):
"""์›๋ณธ ๋ชจ๋ธ์ด ์งˆ๋ฌธ์— ๋‹ต๋ณ€ํ•˜๋„๋ก ํ•˜๊ธฐ (์ถ”๋ก  ๊ณผ์ • ์—†์ด)"""
# ๋‚˜์ค‘์— ์Šค๋ ˆ๋“œ์—์„œ ํ† ํฐ์„ ์ŠคํŠธ๋ฆผ์œผ๋กœ ๊ฐ€์ ธ์˜ค๊ธฐ ์œ„ํ•จ
streamer = transformers.TextIteratorStreamer(
pipe.tokenizer, # pyright: ignore
skip_special_tokens=True,
skip_prompt=True,
)
# ๋ณด์กฐ์ž ๋ฉ”์‹œ์ง€ ์ค€๋น„
history.append(
gr.ChatMessage(
role="assistant",
content=str(""),
)
)
# ํ˜„์žฌ ์ฑ„ํŒ…์— ํ‘œ์‹œ๋  ๋ฉ”์‹œ์ง€
messages = rebuild_messages(history[:-1]) # ๋งˆ์ง€๋ง‰ ๋นˆ ๋ฉ”์‹œ์ง€ ์ œ์™ธ
# ์›๋ณธ ๋ชจ๋ธ์€ ์ถ”๋ก  ์—†์ด ๋ฐ”๋กœ ๋‹ต๋ณ€
t = threading.Thread(
target=pipe,
args=(messages,),
kwargs=dict(
max_new_tokens=max_num_tokens,
streamer=streamer,
do_sample=do_sample,
temperature=temperature,
),
)
t.start()
for token in streamer:
history[-1].content += token
history[-1].content = reformat_math(history[-1].content)
yield history
t.join()
yield history
@spaces.GPU
def bot_thinking(
history: list,
max_num_tokens: int,
final_num_tokens: int,
do_sample: bool,
temperature: float,
):
"""์ถ”๋ก  ๊ณผ์ •์„ ํฌํ•จํ•˜์—ฌ ๋ชจ๋ธ์ด ์งˆ๋ฌธ์— ๋‹ต๋ณ€ํ•˜๋„๋ก ํ•˜๊ธฐ"""
# ๋‚˜์ค‘์— ์Šค๋ ˆ๋“œ์—์„œ ํ† ํฐ์„ ์ŠคํŠธ๋ฆผ์œผ๋กœ ๊ฐ€์ ธ์˜ค๊ธฐ ์œ„ํ•จ
streamer = transformers.TextIteratorStreamer(
pipe.tokenizer, # pyright: ignore
skip_special_tokens=True,
skip_prompt=True,
)
# ํ•„์š”ํ•œ ๊ฒฝ์šฐ ์ถ”๋ก ์— ์งˆ๋ฌธ์„ ๋‹ค์‹œ ์‚ฝ์ž…ํ•˜๊ธฐ ์œ„ํ•จ
question = history[-1]["content"]
# ๋ฌธ์ œ์—์„œ ์ฃผ์–ด์ง„ ์ •๋ณด์™€ ๋ชฉํ‘œ ์ถ”์ถœ
given_info, goal = extract_info_from_question(question)
# ๋ณด์กฐ์ž ๋ฉ”์‹œ์ง€ ์ค€๋น„
history.append(
gr.ChatMessage(
role="assistant",
content=str(""),
metadata={"title": "๐Ÿง  ์ƒ๊ฐ ์ค‘...", "status": "pending"},
)
)
# ํ˜„์žฌ ์ฑ„ํŒ…์— ํ‘œ์‹œ๋  ์ถ”๋ก  ๊ณผ์ •
messages = rebuild_messages(history)
# ์ „์ฒด ์ถ”๋ก  ๊ณผ์ •์„ ์ €์žฅํ•  ๋ณ€์ˆ˜
full_reasoning = ""
# ์ถ”๋ก  ๊ณผ์ •์—์„œ ์ˆ˜์ง‘๋œ ๊ณ„์‚ฐ ๋‹จ๊ณ„ ์ €์žฅ
calculation_steps = ""
current_conclusion = "์•„์ง ์ตœ์ข… ๊ฒฐ๋ก ์— ๋„๋‹ฌํ•˜์ง€ ์•Š์•˜์Šต๋‹ˆ๋‹ค."
# ์ถ”๋ก  ๋‹จ๊ณ„ ์‹คํ–‰
for i, prepend in enumerate(rethink_prepends):
if i > 0:
messages[-1]["content"] += "\n\n"
# ์ฒซ ๋‹จ๊ณ„์—์„œ ์ผ๋ฐ˜์ ์ธ ์ถ”๋ก  ๊ฐ€์ด๋“œ ์ถ”๊ฐ€
if i == 0:
messages[-1]["content"] += general_reasoning_guide + "\n\n"
# ์ค‘๊ฐ„ ๋‹จ๊ณ„์—์„œ ๊ตฌ์กฐํ™”๋œ ์ถ”๋ก  ์š”์•ฝ ์ถ”๊ฐ€
if i > 1 and calculation_steps:
structured_summary = structured_reasoning_prompt.format(
given_info=given_info,
goal=goal,
calculation_steps=calculation_steps,
current_conclusion=current_conclusion
)
messages[-1]["content"] += structured_summary + "\n\n"
messages[-1]["content"] += prepend.format(question=question)
t = threading.Thread(
target=pipe,
args=(messages,),
kwargs=dict(
max_new_tokens=max_num_tokens,
streamer=streamer,
do_sample=do_sample,
temperature=temperature,
),
)
t.start()
# ์ƒˆ ๋‚ด์šฉ์œผ๋กœ ํžˆ์Šคํ† ๋ฆฌ ์žฌ๊ตฌ์„ฑ
if i == 0:
history[-1].content += general_reasoning_guide + "\n\n"
if i > 1 and calculation_steps:
history[-1].content += structured_summary + "\n\n"
history[-1].content += prepend.format(question=question)
for token in streamer:
history[-1].content += token
history[-1].content = reformat_math(history[-1].content)
yield history
t.join()
# ๊ฐ ์ถ”๋ก  ๋‹จ๊ณ„์˜ ๊ฒฐ๊ณผ๋ฅผ full_reasoning์— ์ €์žฅ
full_reasoning = history[-1].content
# ๊ณ„์‚ฐ ๋‹จ๊ณ„ ์ถ”์ถœ ๋ฐ ์—…๋ฐ์ดํŠธ
new_content = history[-1].content.split(prepend.format(question=question))[-1]
if "=" in new_content or ":" in new_content:
# ๊ณ„์‚ฐ ๋‹จ๊ณ„๊ฐ€ ์žˆ๋Š” ๊ฒƒ์œผ๋กœ ๊ฐ„์ฃผ
calculation_steps += f"\n - {new_content.strip()}"
# ๋‹จ๊ณ„์—์„œ ๊ฐ€๋Šฅํ•œ ๊ฒฐ๋ก  ์ถ”์ถœ
results = extract_calculation_results(new_content)
if results:
current_conclusion = f"ํ˜„์žฌ ๊ณ„์‚ฐ๋œ ๊ฐ’: {results[-1]}"
# ์ถ”๋ก  ์™„๋ฃŒ, ์ด์ œ ์ตœ์ข… ๋‹ต๋ณ€์„ ์ƒ์„ฑ
history[-1].metadata = {"title": "๐Ÿ’ญ ์‚ฌ๊ณ  ๊ณผ์ •", "status": "done"}
# ์ถ”๋ก  ๊ณผ์ •์—์„œ ๋„์ถœ๋œ ๋ชจ๋“  ๊ฒฐ๊ณผ ์ถ”์ถœ
all_results = extract_calculation_results(full_reasoning)
# ๊ฒฐ๊ณผ๊ฐ€ ์žˆ๋Š” ๊ฒฝ์šฐ ๊ฒ€์ฆ ๋‹จ๊ณ„ ์ถ”๊ฐ€
if all_results and len(set(all_results)) > 1:
# ๊ฒฐ๊ณผ๋ณ„ ๋นˆ๋„ ๊ณ„์‚ฐ
result_counter = Counter(all_results)
different_results = "\n".join([f"{result} (๋นˆ๋„: {freq}ํšŒ)" for result, freq in result_counter.most_common()])
# ์ตœ์ ์˜ ๊ฒฐ๊ณผ ๊ฒฐ์ •
best_result = determine_best_result(all_results, full_reasoning)
# ๋ชจ๋ธ์—๊ฒŒ ๊ฐ€์žฅ ์ •ํ™•ํ•œ ๊ฒฐ๊ณผ ์„ ํƒ ์š”์ฒญ
verify_prompt = verification_prompt.format(
different_results=different_results,
given_info=given_info,
goal=goal,
calculation_steps=calculation_steps,
result=best_result
)
messages[-1]["content"] += "\n\n" + verify_prompt
# ๊ฒ€์ฆ ๋‹จ๊ณ„ ์‹คํ–‰
t = threading.Thread(
target=pipe,
args=(messages,),
kwargs=dict(
max_new_tokens=max_num_tokens // 2,
streamer=streamer,
do_sample=False, # ํ™•์ •์ ์ธ ๊ฒฐ๊ณผ๋ฅผ ์œ„ํ•ด ์ƒ˜ํ”Œ๋ง ๋น„ํ™œ์„ฑํ™”
temperature=0.3, # ๋‚ฎ์€ ์˜จ๋„ ์‚ฌ์šฉ
),
)
t.start()
history[-1].content += "\n\n" + verify_prompt
for token in streamer:
history[-1].content += token
history[-1].content = reformat_math(history[-1].content)
yield history
t.join()
# ๊ฒ€์ฆ ๋‹จ๊ณ„ ํ›„ full_reasoning ์—…๋ฐ์ดํŠธ
full_reasoning = history[-1].content
# ์ตœ์ข… ๊ฒฐ๊ณผ ๊ฒฐ์ •
final_results = extract_calculation_results(full_reasoning)
best_result = determine_best_result(final_results, full_reasoning) if final_results else None
# ์ตœ์ข… ๊ฒฐ๋ก  ์ƒ์„ฑ
if best_result is not None:
reasoning_conclusion = f"์ถ”๋ก  ๊ณผ์ •์„ ์ข…ํ•ฉํ•œ ๊ฒฐ๊ณผ, ์ •ํ™•ํ•œ ๋‹ต๋ณ€์€ {best_result}์ž…๋‹ˆ๋‹ค."
else:
# ๊ฒฐ๊ณผ๋ฅผ ์ถ”์ถœํ•  ์ˆ˜ ์—†๋Š” ๊ฒฝ์šฐ์˜ ๋Œ€๋น„์ฑ…
reasoning_parts = full_reasoning.split("\n\n")
reasoning_conclusion = "\n\n".join(reasoning_parts[-2:]) if len(reasoning_parts) > 2 else full_reasoning
# ์ตœ์ข… ๋‹ต๋ณ€ ๋ฉ”์‹œ์ง€ ์ถ”๊ฐ€
history.append(gr.ChatMessage(role="assistant", content=""))
# ์ตœ์ข… ๋‹ต๋ณ€์„ ์œ„ํ•œ ๋ฉ”์‹œ์ง€ ๊ตฌ์„ฑ
final_messages = rebuild_messages(history[:-1]) # ๋งˆ์ง€๋ง‰ ๋นˆ ๋ฉ”์‹œ์ง€ ์ œ์™ธ
final_prompt = final_answer_prompt.format(
question=question,
reasoning_conclusion=reasoning_conclusion,
ANSWER_MARKER=ANSWER_MARKER
)
final_messages[-1]["content"] += "\n\n" + final_prompt
# ์ตœ์ข… ๋‹ต๋ณ€ ์ƒ์„ฑ
t = threading.Thread(
target=pipe,
args=(final_messages,),
kwargs=dict(
max_new_tokens=final_num_tokens,
streamer=streamer,
do_sample=do_sample,
temperature=temperature * 0.8, # ์ตœ์ข… ๋‹ต๋ณ€์— ๋” ํ™•์‹ ์„ ์ฃผ๊ธฐ ์œ„ํ•ด ์˜จ๋„ ์•ฝ๊ฐ„ ๋‚ฎ์ถค
),
)
t.start()
# ์ตœ์ข… ๋‹ต๋ณ€ ์ŠคํŠธ๋ฆฌ๋ฐ
for token in streamer:
history[-1].content += token
history[-1].content = reformat_math(history[-1].content)
yield history
t.join()
yield history
with gr.Blocks(fill_height=True, title="Vidraft ThinkFlow") as demo:
# ์ œ๋ชฉ๊ณผ ์„ค๋ช…
gr.Markdown("# Vidraft ThinkFlow")
gr.Markdown("### ์ถ”๋ก  ๊ธฐ๋Šฅ์ด ์—†๋Š” LLM ๋ชจ๋ธ์˜ ์ˆ˜์ • ์—†์ด๋„ ์ถ”๋ก  ๊ธฐ๋Šฅ์„ ์ž๋™์œผ๋กœ ์ ์šฉํ•˜๋Š” LLM ์ถ”๋ก  ์ƒ์„ฑ ํ”Œ๋žซํผ")
with gr.Row(scale=1):
with gr.Column(scale=2):
gr.Markdown("## Before (Original)")
chatbot_original = gr.Chatbot(
scale=1,
type="messages",
latex_delimiters=latex_delimiters,
label="Original Model (No Reasoning)"
)
with gr.Column(scale=2):
gr.Markdown("## After (Thinking)")
chatbot_thinking = gr.Chatbot(
scale=1,
type="messages",
latex_delimiters=latex_delimiters,
label="Model with Reasoning"
)
with gr.Row():
# msg ํ…์ŠคํŠธ๋ฐ•์Šค๋ฅผ ๋จผ์ € ์ •์˜
msg = gr.Textbox(
submit_btn=True,
label="",
show_label=False,
placeholder="์—ฌ๊ธฐ์— ์งˆ๋ฌธ์„ ์ž…๋ ฅํ•˜์„ธ์š”.",
autofocus=True,
)
# ์˜ˆ์ œ ์„น์…˜ - msg ๋ณ€์ˆ˜ ์ •์˜ ์ดํ›„์— ๋ฐฐ์น˜
with gr.Accordion("EXAMPLES", open=False):
examples = gr.Examples(
examples=[
"[์ถœ์ฒ˜: MATH-500)] ์ฒ˜์Œ 100๊ฐœ์˜ ์–‘์˜ ์ •์ˆ˜ ์ค‘์—์„œ 3, 4, 5๋กœ ๋‚˜๋ˆ„์–ด ๋–จ์–ด์ง€๋Š” ์ˆ˜๋Š” ๋ช‡ ๊ฐœ์ž…๋‹ˆ๊นŒ?",
"[์ถœ์ฒ˜: MATH-500)] ์ž‰ํฌ์˜ ๋•…์—์„œ ๋ˆ ์‹œ์Šคํ…œ์€ ๋…ํŠนํ•ฉ๋‹ˆ๋‹ค. ํŠธ๋งํ‚› 1๊ฐœ๋Š” ๋ธ”๋งํ‚ท 4๊ฐœ์™€ ๊ฐ™๊ณ , ๋ธ”๋งํ‚ท 3๊ฐœ๋Š” ๋“œ๋งํฌ 7๊ฐœ์™€ ๊ฐ™์Šต๋‹ˆ๋‹ค. ํŠธ๋งํ‚ท์—์„œ ๋“œ๋งํฌ 56๊ฐœ์˜ ๊ฐ€์น˜๋Š” ์–ผ๋งˆ์ž…๋‹ˆ๊นŒ?",
"[์ถœ์ฒ˜: MATH-500)] ์—์ด๋ฏธ, ๋ฒค, ํฌ๋ฆฌ์Šค์˜ ํ‰๊ท  ๋‚˜์ด๋Š” 6์‚ด์ž…๋‹ˆ๋‹ค. 4๋…„ ์ „ ํฌ๋ฆฌ์Šค๋Š” ์ง€๊ธˆ ์—์ด๋ฏธ์™€ ๊ฐ™์€ ๋‚˜์ด์˜€์Šต๋‹ˆ๋‹ค. 4๋…„ ํ›„ ๋ฒค์˜ ๋‚˜์ด๋Š” ๊ทธ๋•Œ ์—์ด๋ฏธ์˜ ๋‚˜์ด์˜ $\\frac{3}{5}$๊ฐ€ ๋  ๊ฒƒ์ž…๋‹ˆ๋‹ค. ํฌ๋ฆฌ์Šค๋Š” ์ง€๊ธˆ ๋ช‡ ์‚ด์ž…๋‹ˆ๊นŒ?",
"[์ถœ์ฒ˜: MATH-500)] ๋…ธ๋ž€์ƒ‰๊ณผ ํŒŒ๋ž€์ƒ‰ ๊ตฌ์Šฌ์ด ๋“ค์–ด ์žˆ๋Š” ๊ฐ€๋ฐฉ์ด ์žˆ์Šต๋‹ˆ๋‹ค. ํ˜„์žฌ ํŒŒ๋ž€์ƒ‰ ๊ตฌ์Šฌ๊ณผ ๋…ธ๋ž€์ƒ‰ ๊ตฌ์Šฌ์˜ ๋น„์œจ์€ 4:3์ž…๋‹ˆ๋‹ค. ํŒŒ๋ž€์ƒ‰ ๊ตฌ์Šฌ 5๊ฐœ๋ฅผ ๋”ํ•˜๊ณ  ๋…ธ๋ž€์ƒ‰ ๊ตฌ์Šฌ 3๊ฐœ๋ฅผ ์ œ๊ฑฐํ•˜๋ฉด ๋น„์œจ์€ 7:3์ด ๋ฉ๋‹ˆ๋‹ค. ๋” ๋„ฃ๊ธฐ ์ „์— ๊ฐ€๋ฐฉ์— ํŒŒ๋ž€์ƒ‰ ๊ตฌ์Šฌ์ด ๋ช‡ ๊ฐœ ์žˆ์—ˆ์Šต๋‹ˆ๊นŒ?",
"์ˆ˜ํ•™ ๋™์•„๋ฆฌ์—์„œ ๋‹ค๊ฐ€์˜ฌ ์—ฌํ–‰์„ ์œ„ํ•œ ๊ธฐ๊ธˆ ๋ชจ๊ธˆ์„ ์œ„ํ•ด ๋ฒ ์ดํ‚น ์„ธ์ผ์„ ์—ด๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. 3๊ฐœ์— 54๋‹ฌ๋Ÿฌ์งœ๋ฆฌ ์ฟ ํ‚ค๋ฅผ 1๋‹ฌ๋Ÿฌ์— ํŒ๋งคํ•˜๊ณ , 20๊ฐœ์— ์ปต์ผ€์ดํฌ๋ฅผ ๊ฐ๊ฐ 2๋‹ฌ๋Ÿฌ์— ํŒ๋งคํ•˜๊ณ , 35๊ฐœ์— ๋ธŒ๋ผ์šฐ๋‹ˆ๋ฅผ ๊ฐ๊ฐ 1๋‹ฌ๋Ÿฌ์— ํŒ๋งคํ•ฉ๋‹ˆ๋‹ค. ์ˆ˜ํ•™ ๋™์•„๋ฆฌ์—์„œ ์ด ์ œํ’ˆ์„ ๊ตฝ๋Š” ๋ฐ 15๋‹ฌ๋Ÿฌ๊ฐ€ ๋“ค์—ˆ๋‹ค๋ฉด, ์ˆ˜์ต์€ ์–ผ๋งˆ์˜€์„๊นŒ์š”?"
],
inputs=msg
)
with gr.Row():
with gr.Column():
gr.Markdown("""## ๋งค๊ฐœ๋ณ€์ˆ˜ ์กฐ์ •""")
num_tokens = gr.Slider(
50,
4000,
2000,
step=1,
label="์ถ”๋ก  ๋‹จ๊ณ„๋‹น ์ตœ๋Œ€ ํ† ํฐ ์ˆ˜",
interactive=True,
)
final_num_tokens = gr.Slider(
50,
4000,
2000,
step=1,
label="์ตœ์ข… ๋‹ต๋ณ€์˜ ์ตœ๋Œ€ ํ† ํฐ ์ˆ˜",
interactive=True,
)
do_sample = gr.Checkbox(True, label="์ƒ˜ํ”Œ๋ง ์‚ฌ์šฉ")
temperature = gr.Slider(0.1, 1.0, 0.7, step=0.1, label="์˜จ๋„")
# ์‚ฌ์šฉ์ž๊ฐ€ ๋ฉ”์‹œ์ง€๋ฅผ ์ œ์ถœํ•˜๋ฉด ๋‘ ๋ด‡์ด ๋™์‹œ์— ์‘๋‹ตํ•ฉ๋‹ˆ๋‹ค
msg.submit(
user_input,
[msg, chatbot_original, chatbot_thinking], # ์ž…๋ ฅ
[msg, chatbot_original, chatbot_thinking], # ์ถœ๋ ฅ
).then(
bot_original,
[
chatbot_original,
num_tokens,
do_sample,
temperature,
],
chatbot_original, # ์ถœ๋ ฅ์—์„œ ์ƒˆ ํžˆ์Šคํ† ๋ฆฌ ์ €์žฅ
).then(
bot_thinking,
[
chatbot_thinking,
num_tokens,
final_num_tokens,
do_sample,
temperature,
],
chatbot_thinking, # ์ถœ๋ ฅ์—์„œ ์ƒˆ ํžˆ์Šคํ† ๋ฆฌ ์ €์žฅ
)
if __name__ == "__main__":
demo.queue().launch() # title ๋งค๊ฐœ๋ณ€์ˆ˜ ์ œ๊ฑฐ