Spaces:
Running
on
Zero
Running
on
Zero
import re | |
import threading | |
from collections import Counter | |
import gradio as gr | |
import spaces | |
import transformers | |
from transformers import pipeline | |
# ๋ชจ๋ธ๊ณผ ํ ํฌ๋์ด์ ๋ก๋ฉ | |
model_name = "CohereForAI/c4ai-command-r7b-arabic-02-2025" | |
if gr.NO_RELOAD: | |
pipe = pipeline( | |
"text-generation", | |
model=model_name, | |
device_map="auto", | |
torch_dtype="auto", | |
) | |
# ์ต์ข ๋ต๋ณ์ ๊ฐ์งํ๊ธฐ ์ํ ๋ง์ปค | |
ANSWER_MARKER = "**๋ต๋ณ**" | |
# ๋จ๊ณ๋ณ ์ถ๋ก ์ ์์ํ๋ ๋ฌธ์ฅ๋ค | |
rethink_prepends = [ | |
"์, ์ด์ ๋ค์์ ํ์ ํด์ผ ํฉ๋๋ค ", | |
"์ ์๊ฐ์๋ ", | |
"์ ์๋ง์, ์ ์๊ฐ์๋ ", | |
"๋ค์ ์ฌํญ์ด ๋ง๋์ง ํ์ธํด ๋ณด๊ฒ ์ต๋๋ค ", | |
"๋ํ ๊ธฐ์ตํด์ผ ํ ๊ฒ์ ", | |
"๋ ๋ค๋ฅธ ์ฃผ๋ชฉํ ์ ์ ", | |
"๊ทธ๋ฆฌ๊ณ ์ ๋ ๋ค์๊ณผ ๊ฐ์ ์ฌ์ค๋ ๊ธฐ์ตํฉ๋๋ค ", | |
"์ด์ ์ถฉ๋ถํ ์ดํดํ๋ค๊ณ ์๊ฐํฉ๋๋ค ", | |
] | |
# ์ผ๋ฐ์ ์ธ ์ถ๋ก ๊ฐ์ด๋ ํ๋กฌํํธ | |
general_reasoning_guide = """ | |
์ด ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๊ธฐ ์ํ ์ฒด๊ณ์ ์ธ ์ ๊ทผ ๋ฐฉ๋ฒ์ ์ฌ์ฉํ๊ฒ ์ต๋๋ค: | |
1. ๋ฌธ์ ์์ ์ ๊ณต๋ ๋ชจ๋ ์ ๋ณด์ ์กฐ๊ฑด์ ๋ช ํํ ์ดํดํฉ๋๋ค. | |
2. ๊ฐ ๋ณ์์ ๊ด๊ณ๋ฅผ ์๋ณํ๊ณ ํ์ํ ๋ฐฉ์ ์์ ์ธ์๋๋ค. | |
3. ๋จ๊ณ๋ณ๋ก ๊ณ์ฐ์ ์ํํ๋ฉฐ, ๊ฐ ๋จ๊ณ์ ๊ฒฐ๊ณผ๋ฅผ ํ์ธํฉ๋๋ค. | |
4. ์ค๊ฐ ๊ฒฐ๊ณผ๊ฐ ํฉ๋ฆฌ์ ์ธ์ง ๊ฒํ ํ๋ฉฐ ์งํํฉ๋๋ค. | |
5. ์ต์ข ๋ต์์ ๋์ถํ๊ณ ๋ฌธ์ ์ ์๊ตฌ์ฌํญ์ ์ถฉ์กฑํ๋์ง ํ์ธํฉ๋๋ค. | |
์ด์ ๋ฌธ์ ๋ฅผ ํ์ด๋ณด๊ฒ ์ต๋๋ค: | |
""" | |
# ๊ฒฐ๊ณผ ์ถ์ถ ๋ฐ ๊ฒ์ฆ์ ์ํ ํจ์๋ค | |
def extract_calculation_results(reasoning_text): | |
"""์ถ๋ก ๊ณผ์ ์์ ๋์ถ๋ ๊ฐ๋ฅํ ๋ต์ ๊ฒฐ๊ณผ๋ฅผ ์ถ์ถํฉ๋๋ค.""" | |
# ์์น ๊ฒฐ๊ณผ ํจํด (๋ค์ํ ํํ ๋ฐฉ์ ๊ณ ๋ ค) | |
numeric_patterns = [ | |
r'๊ฒฐ๊ณผ๋ (\d+[\.,]?\d*)', | |
r'๋ต(์|๋|์ด) (\d+[\.,]?\d*)', | |
r'์ ๋ต(์|๋|์ด) (\d+[\.,]?\d*)', | |
r'๋ต์(์|๋|์ด) (\d+[\.,]?\d*)', | |
r'์์ต(์|๋|์ด) (\d+[\.,]?\d*)', | |
r'๊ฐ(์|๋|์ด) (\d+[\.,]?\d*)', | |
r'๊ฒฐ๋ก (์|๋|์ด) (\d+[\.,]?\d*)', | |
r'๊ฐ์(๋|์|๊ฐ) (\d+[\.,]?\d*)', | |
r'์ด (\d+[\.,]?\d*)๊ฐ', | |
r'์ด์ก(์|๋|์ด) (\d+[\.,]?\d*)', | |
r'์ดํฉ(์|๋|์ด) (\d+[\.,]?\d*)', | |
r'ํฉ๊ณ(๋|์|๊ฐ) (\d+[\.,]?\d*)', | |
r'=\s*(\d+[\.,]?\d*)\s*$', | |
r':\s*(\d+[\.,]?\d*)\s*$', | |
r'์ด๊ณ:\s*(\d+[\.,]?\d*)', | |
r'์ต์ข ๊ฒฐ๊ณผ:\s*(\d+[\.,]?\d*)', | |
r'์ต์ข ๊ฐ:\s*(\d+[\.,]?\d*)', | |
r'์ต์ข ๋ต๋ณ:\s*(\d+[\.,]?\d*)', | |
] | |
# ๋จ์๋ฅผ ํฌํจํ ํจํด (๋ฌ๋ฌ, ๊ฐ, ์ธํธ ๋ฑ) | |
unit_patterns = [ | |
r'(\d+[\.,]?\d*)\s*(๋ฌ๋ฌ|์|์ ๋ก|ํ์ด๋|์)', | |
r'(\d+[\.,]?\d*)\s*(๊ฐ|๋ช |์ธํธ|์|ํ|๊ทธ๋ฃน)', | |
r'(\d+[\.,]?\d*)\s*(๋ถ|์๊ฐ|์ด|์ผ|์ฃผ|๊ฐ์|๋ )', | |
r'(\d+[\.,]?\d*)\s*(๋ฏธํฐ|ํฌ๋ก๋ฏธํฐ|์ผํฐ๋ฏธํฐ|์ธ์น|ํผํธ)', | |
r'(\d+[\.,]?\d*)\s*(๊ทธ๋จ|ํฌ๋ก๊ทธ๋จ|ํ์ด๋|์จ์ค)', | |
] | |
results = [] | |
# ์ซ์ ๊ฒฐ๊ณผ ์ถ์ถ | |
for pattern in numeric_patterns: | |
matches = re.findall(pattern, reasoning_text, re.IGNORECASE) | |
for match in matches: | |
if isinstance(match, tuple): | |
# ๊ทธ๋ฃน์ด ์ฌ๋ฌ ๊ฐ์ธ ๊ฒฝ์ฐ (์ฒซ ๋ฒ์งธ๋ ์กฐ์ฌ ๋ฑ) | |
value = match[-1] # ๋ง์ง๋ง ๊ทธ๋ฃน์ด ์ซ์๊ฐ | |
else: | |
value = match | |
# ์ฝค๋ง ์ ๊ฑฐ ๋ฐ ์์์ ์ฒ๋ฆฌ | |
value = value.replace(',', '') | |
try: | |
if '.' in value: | |
results.append(float(value)) | |
else: | |
results.append(int(value)) | |
except ValueError: | |
continue | |
# ๋จ์๊ฐ ํฌํจ๋ ๊ฒฐ๊ณผ ์ถ์ถ | |
for pattern in unit_patterns: | |
matches = re.findall(pattern, reasoning_text, re.IGNORECASE) | |
for match in matches: | |
value = match[0].replace(',', '') | |
try: | |
if '.' in value: | |
results.append(float(value)) | |
else: | |
results.append(int(value)) | |
except ValueError: | |
continue | |
# ๋ง์ง๋ง ๋ฌธ๋จ์์ ์ซ์๋ง ์ถ์ถ (์ต์ข ๋ต๋ณ์ ๊ฐ๊น์ด ์ซ์) | |
last_paragraph = reasoning_text.split('\n\n')[-1] | |
numbers_in_last = re.findall(r'(\d+[\.,]?\d*)', last_paragraph) | |
for num in numbers_in_last: | |
num = num.replace(',', '') | |
try: | |
if '.' in num: | |
results.append(float(num)) | |
else: | |
results.append(int(num)) | |
except ValueError: | |
continue | |
return results | |
def determine_best_result(results, full_reasoning): | |
"""๊ฐ์ฅ ์ ๋ขฐํ ์ ์๋ ๊ฒฐ๊ณผ๋ฅผ ๊ฒฐ์ ํฉ๋๋ค.""" | |
if not results: | |
return None | |
# ๊ฒฐ๊ณผ๊ฐ ํ๋๋ฐ์ ์์ผ๋ฉด ๊ทธ๊ฒ์ ๋ฐํ | |
if len(set(results)) == 1: | |
return results[0] | |
# ๋น๋ ๊ธฐ๋ฐ ๋ถ์ (๊ฐ์ฅ ์์ฃผ ๋ฑ์ฅํ ๊ฒฐ๊ณผ๊ฐ ์ ๋ขฐ์ฑ์ด ๋์ ๊ฐ๋ฅ์ฑ) | |
counter = Counter(results) | |
most_common = counter.most_common() | |
# ๋น๋๊ฐ ๋์ ์์ ๊ฒฐ๊ณผ๋ค | |
top_results = [result for result, count in most_common if count >= most_common[0][1] * 0.8] | |
if len(top_results) == 1: | |
return top_results[0] | |
# ์ต์ข ๊ฒฐ๋ก ๊ทผ์ฒ์ ์๋ ๊ฒฐ๊ณผ์ ๋ ๋์ ๊ฐ์ค์น ๋ถ์ฌ | |
paragraphs = full_reasoning.split('\n\n') | |
last_paragraphs = '\n\n'.join(paragraphs[-2:]) # ๋ง์ง๋ง ๋ ๋จ๋ฝ | |
# ๋ง์ง๋ง ๋จ๋ฝ์์ ๋ฑ์ฅํ๋ ๊ฒฐ๊ณผ ํ์ธ | |
final_results = [result for result in top_results if str(result) in last_paragraphs] | |
if final_results: | |
# ๋ง์ง๋ง ๋จ๋ฝ์์ ๊ฐ์ฅ ์์ฃผ ๋ฑ์ฅํ ๊ฒฐ๊ณผ | |
final_counter = Counter([r for r in results if r in final_results]) | |
if final_counter: | |
return final_counter.most_common(1)[0][0] | |
# ์์๊ณผ ํจ๊ป ๋ฑ์ฅํ๋ ๊ฒฐ๊ณผ (์: "= 78", "์ดํฉ: 78") | |
for result in top_results: | |
result_str = str(result) | |
if re.search(r'=\s*' + result_str + r'(?!\d)', full_reasoning) or \ | |
re.search(r'๊ฒฐ๊ณผ[:๋์์ด๊ฐ]\s*' + result_str, full_reasoning) or \ | |
re.search(r'๋ต[:๋์์ด๊ฐ]\s*' + result_str, full_reasoning) or \ | |
re.search(r'์ ๋ต[:๋์์ด๊ฐ]\s*' + result_str, full_reasoning): | |
return result | |
# ์์ ๋ฐฉ๋ฒ์ผ๋ก ๊ฒฐ์ ํ ์ ์์ ๊ฒฝ์ฐ ๊ฐ์ฅ ๋น๋๊ฐ ๋์ ๊ฒฐ๊ณผ ๋ฐํ | |
return most_common[0][0] | |
# ์ค๊ฐ ๊ฒฐ๊ณผ๋ฅผ ์์ฝํ๊ธฐ ์ํ ํ๋กฌํํธ | |
structured_reasoning_prompt = """ | |
์ง๊ธ๊น์ง์ ์ถ๋ก ์ ๋จ๊ณ๋ณ๋ก ์ ๋ฆฌํด๋ณด๊ฒ ์ต๋๋ค: | |
1. ๋ฌธ์ ๋ถ์: | |
- ์ฃผ์ด์ง ์ ๋ณด: {given_info} | |
- ๊ตฌํด์ผ ํ ๊ฒ: {goal} | |
2. ๊ณ์ฐ ๊ณผ์ : | |
{calculation_steps} | |
3. ํ์ฌ๊น์ง์ ๊ฒฐ๋ก : | |
{current_conclusion} | |
์ด์ ๋ค์ ๋จ๊ณ๋ก ์งํํ๊ฒ ์ต๋๋ค. | |
""" | |
# ์ต์ข ๊ฒฐ๊ณผ ๊ฒ์ฆ์ ์ํ ํ๋กฌํํธ | |
verification_prompt = """ | |
์ง๊ธ๊น์ง์ ์ถ๋ก ๊ณผ์ ์์ ์ฌ๋ฌ ๊ฒฐ๊ณผ๊ฐ ๋์ถ๋์์ต๋๋ค: | |
{different_results} | |
์ด ์ค์์ ๊ฐ์ฅ ์ ํํ ๋ต๋ณ์ ์ฐพ๊ธฐ ์ํด ๊ณ์ฐ ๊ณผ์ ์ ์ฒ์๋ถํฐ ๋ค์ ๊ฒํ ํ๊ฒ ์ต๋๋ค: | |
1. ๋ฌธ์ ๋ถ์: | |
- ์ฃผ์ด์ง ์ ๋ณด: {given_info} | |
- ๊ตฌํด์ผ ํ ๊ฒ: {goal} | |
2. ๋จ๊ณ๋ณ ๊ณ์ฐ ๊ณผ์ : | |
{calculation_steps} | |
3. ๊ฒฐ๋ก : | |
์ ๊ณ์ฐ ๊ณผ์ ์ ํตํด ์ ํํ ๋ต์ {result}์ ๋๋ค. | |
""" | |
# ์ต์ข ๋ต๋ณ ์์ฑ์ ์ํ ํ๋กฌํํธ ์ถ๊ฐ | |
final_answer_prompt = """ | |
์ง๊ธ๊น์ง์ ์ฒด๊ณ์ ์ธ ์ถ๋ก ๊ณผ์ ์ ์ข ํฉํ์ฌ, ์๋ ์ง๋ฌธ์ ๋ต๋ณํ๊ฒ ์ต๋๋ค: | |
{question} | |
์ถ๋ก ๊ณผ์ ์ ๊ฒํ ํ ๊ฒฐ๊ณผ, ๋ค์๊ณผ ๊ฐ์ ๊ฒฐ๋ก ์ ๋๋ฌํ์ต๋๋ค: | |
{reasoning_conclusion} | |
๋ฐ๋ผ์ ์ต์ข ๋ต๋ณ์: | |
{ANSWER_MARKER} | |
""" | |
# ์์ ํ์ ๋ฌธ์ ํด๊ฒฐ์ ์ํ ์ค์ | |
latex_delimiters = [ | |
{"left": "$$", "right": "$$", "display": True}, | |
{"left": "$", "right": "$", "display": False}, | |
] | |
def reformat_math(text): | |
"""Gradio ๊ตฌ๋ฌธ(Katex)์ ์ฌ์ฉํ๋๋ก MathJax ๊ตฌ๋ถ ๊ธฐํธ ์์ . | |
์ด๊ฒ์ Gradio์์ ์ํ ๊ณต์์ ํ์ํ๊ธฐ ์ํ ์์ ํด๊ฒฐ์ฑ ์ ๋๋ค. ํ์ฌ๋ก์๋ | |
๋ค๋ฅธ latex_delimiters๋ฅผ ์ฌ์ฉํ์ฌ ์์๋๋ก ์๋ํ๊ฒ ํ๋ ๋ฐฉ๋ฒ์ ์ฐพ์ง ๋ชปํ์ต๋๋ค... | |
""" | |
text = re.sub(r"\\\[\s*(.*?)\s*\\\]", r"$$\1$$", text, flags=re.DOTALL) | |
text = re.sub(r"\\\(\s*(.*?)\s*\\\)", r"$\1$", text, flags=re.DOTALL) | |
return text | |
def user_input(message, history_original, history_thinking): | |
"""์ฌ์ฉ์ ์ ๋ ฅ์ ํ์คํ ๋ฆฌ์ ์ถ๊ฐํ๊ณ ์ ๋ ฅ ํ ์คํธ ์์ ๋น์ฐ๊ธฐ""" | |
return "", history_original + [ | |
gr.ChatMessage(role="user", content=message.replace(ANSWER_MARKER, "")) | |
], history_thinking + [ | |
gr.ChatMessage(role="user", content=message.replace(ANSWER_MARKER, "")) | |
] | |
def rebuild_messages(history: list): | |
"""์ค๊ฐ ์๊ฐ ๊ณผ์ ์์ด ๋ชจ๋ธ์ด ์ฌ์ฉํ ํ์คํ ๋ฆฌ์์ ๋ฉ์์ง ์ฌ๊ตฌ์ฑ""" | |
messages = [] | |
for h in history: | |
if isinstance(h, dict) and not h.get("metadata", {}).get("title", False): | |
messages.append(h) | |
elif ( | |
isinstance(h, gr.ChatMessage) | |
and h.metadata.get("title", None) is None | |
and isinstance(h.content, str) | |
): | |
messages.append({"role": h.role, "content": h.content}) | |
return messages | |
def extract_info_from_question(question): | |
"""๋ฌธ์ ์์ ์ฃผ์ด์ง ์ ๋ณด์ ๋ชฉํ๋ฅผ ์ถ์ถํฉ๋๋ค.""" | |
# ๊ธฐ๋ณธ ๊ฐ | |
given_info = "๋ฌธ์ ์์ ์ ๊ณต๋ ๋ชจ๋ ์กฐ๊ฑด๊ณผ ์์น" | |
goal = "๋ฌธ์ ์์ ์๊ตฌํ๋ ๊ฐ์ด๋ ๊ฒฐ๊ณผ" | |
# ์ผ๋ฐ์ ์ธ ์ ๋ณด ์ถ์ถ ํจํด | |
if "๋ช ๊ฐ" in question or "๊ฐ์" in question: | |
goal = "ํน์ ์กฐ๊ฑด์ ๋ง์กฑํ๋ ํญ๋ชฉ์ ๊ฐ์" | |
elif "์ผ๋ง" in question: | |
goal = "ํน์ ๊ฐ ๋๋ ๊ธ์ก" | |
elif "๋์ด" in question: | |
goal = "์ฌ๋์ ๋์ด" | |
elif "ํ๋ฅ " in question: | |
goal = "ํน์ ์ฌ๊ฑด์ ํ๋ฅ " | |
elif "ํ๊ท " in question: | |
goal = "๊ฐ๋ค์ ํ๊ท " | |
return given_info, goal | |
def bot_original( | |
history: list, | |
max_num_tokens: int, | |
do_sample: bool, | |
temperature: float, | |
): | |
"""์๋ณธ ๋ชจ๋ธ์ด ์ง๋ฌธ์ ๋ต๋ณํ๋๋ก ํ๊ธฐ (์ถ๋ก ๊ณผ์ ์์ด)""" | |
# ๋์ค์ ์ค๋ ๋์์ ํ ํฐ์ ์คํธ๋ฆผ์ผ๋ก ๊ฐ์ ธ์ค๊ธฐ ์ํจ | |
streamer = transformers.TextIteratorStreamer( | |
pipe.tokenizer, # pyright: ignore | |
skip_special_tokens=True, | |
skip_prompt=True, | |
) | |
# ๋ณด์กฐ์ ๋ฉ์์ง ์ค๋น | |
history.append( | |
gr.ChatMessage( | |
role="assistant", | |
content=str(""), | |
) | |
) | |
# ํ์ฌ ์ฑํ ์ ํ์๋ ๋ฉ์์ง | |
messages = rebuild_messages(history[:-1]) # ๋ง์ง๋ง ๋น ๋ฉ์์ง ์ ์ธ | |
# ์๋ณธ ๋ชจ๋ธ์ ์ถ๋ก ์์ด ๋ฐ๋ก ๋ต๋ณ | |
t = threading.Thread( | |
target=pipe, | |
args=(messages,), | |
kwargs=dict( | |
max_new_tokens=max_num_tokens, | |
streamer=streamer, | |
do_sample=do_sample, | |
temperature=temperature, | |
), | |
) | |
t.start() | |
for token in streamer: | |
history[-1].content += token | |
history[-1].content = reformat_math(history[-1].content) | |
yield history | |
t.join() | |
yield history | |
def bot_thinking( | |
history: list, | |
max_num_tokens: int, | |
final_num_tokens: int, | |
do_sample: bool, | |
temperature: float, | |
): | |
"""์ถ๋ก ๊ณผ์ ์ ํฌํจํ์ฌ ๋ชจ๋ธ์ด ์ง๋ฌธ์ ๋ต๋ณํ๋๋ก ํ๊ธฐ""" | |
# ๋์ค์ ์ค๋ ๋์์ ํ ํฐ์ ์คํธ๋ฆผ์ผ๋ก ๊ฐ์ ธ์ค๊ธฐ ์ํจ | |
streamer = transformers.TextIteratorStreamer( | |
pipe.tokenizer, # pyright: ignore | |
skip_special_tokens=True, | |
skip_prompt=True, | |
) | |
# ํ์ํ ๊ฒฝ์ฐ ์ถ๋ก ์ ์ง๋ฌธ์ ๋ค์ ์ฝ์ ํ๊ธฐ ์ํจ | |
question = history[-1]["content"] | |
# ๋ฌธ์ ์์ ์ฃผ์ด์ง ์ ๋ณด์ ๋ชฉํ ์ถ์ถ | |
given_info, goal = extract_info_from_question(question) | |
# ๋ณด์กฐ์ ๋ฉ์์ง ์ค๋น | |
history.append( | |
gr.ChatMessage( | |
role="assistant", | |
content=str(""), | |
metadata={"title": "๐ง ์๊ฐ ์ค...", "status": "pending"}, | |
) | |
) | |
# ํ์ฌ ์ฑํ ์ ํ์๋ ์ถ๋ก ๊ณผ์ | |
messages = rebuild_messages(history) | |
# ์ ์ฒด ์ถ๋ก ๊ณผ์ ์ ์ ์ฅํ ๋ณ์ | |
full_reasoning = "" | |
# ์ถ๋ก ๊ณผ์ ์์ ์์ง๋ ๊ณ์ฐ ๋จ๊ณ ์ ์ฅ | |
calculation_steps = "" | |
current_conclusion = "์์ง ์ต์ข ๊ฒฐ๋ก ์ ๋๋ฌํ์ง ์์์ต๋๋ค." | |
# ์ถ๋ก ๋จ๊ณ ์คํ | |
for i, prepend in enumerate(rethink_prepends): | |
if i > 0: | |
messages[-1]["content"] += "\n\n" | |
# ์ฒซ ๋จ๊ณ์์ ์ผ๋ฐ์ ์ธ ์ถ๋ก ๊ฐ์ด๋ ์ถ๊ฐ | |
if i == 0: | |
messages[-1]["content"] += general_reasoning_guide + "\n\n" | |
# ์ค๊ฐ ๋จ๊ณ์์ ๊ตฌ์กฐํ๋ ์ถ๋ก ์์ฝ ์ถ๊ฐ | |
if i > 1 and calculation_steps: | |
structured_summary = structured_reasoning_prompt.format( | |
given_info=given_info, | |
goal=goal, | |
calculation_steps=calculation_steps, | |
current_conclusion=current_conclusion | |
) | |
messages[-1]["content"] += structured_summary + "\n\n" | |
messages[-1]["content"] += prepend.format(question=question) | |
t = threading.Thread( | |
target=pipe, | |
args=(messages,), | |
kwargs=dict( | |
max_new_tokens=max_num_tokens, | |
streamer=streamer, | |
do_sample=do_sample, | |
temperature=temperature, | |
), | |
) | |
t.start() | |
# ์ ๋ด์ฉ์ผ๋ก ํ์คํ ๋ฆฌ ์ฌ๊ตฌ์ฑ | |
if i == 0: | |
history[-1].content += general_reasoning_guide + "\n\n" | |
if i > 1 and calculation_steps: | |
history[-1].content += structured_summary + "\n\n" | |
history[-1].content += prepend.format(question=question) | |
for token in streamer: | |
history[-1].content += token | |
history[-1].content = reformat_math(history[-1].content) | |
yield history | |
t.join() | |
# ๊ฐ ์ถ๋ก ๋จ๊ณ์ ๊ฒฐ๊ณผ๋ฅผ full_reasoning์ ์ ์ฅ | |
full_reasoning = history[-1].content | |
# ๊ณ์ฐ ๋จ๊ณ ์ถ์ถ ๋ฐ ์ ๋ฐ์ดํธ | |
new_content = history[-1].content.split(prepend.format(question=question))[-1] | |
if "=" in new_content or ":" in new_content: | |
# ๊ณ์ฐ ๋จ๊ณ๊ฐ ์๋ ๊ฒ์ผ๋ก ๊ฐ์ฃผ | |
calculation_steps += f"\n - {new_content.strip()}" | |
# ๋จ๊ณ์์ ๊ฐ๋ฅํ ๊ฒฐ๋ก ์ถ์ถ | |
results = extract_calculation_results(new_content) | |
if results: | |
current_conclusion = f"ํ์ฌ ๊ณ์ฐ๋ ๊ฐ: {results[-1]}" | |
# ์ถ๋ก ์๋ฃ, ์ด์ ์ต์ข ๋ต๋ณ์ ์์ฑ | |
history[-1].metadata = {"title": "๐ญ ์ฌ๊ณ ๊ณผ์ ", "status": "done"} | |
# ์ถ๋ก ๊ณผ์ ์์ ๋์ถ๋ ๋ชจ๋ ๊ฒฐ๊ณผ ์ถ์ถ | |
all_results = extract_calculation_results(full_reasoning) | |
# ๊ฒฐ๊ณผ๊ฐ ์๋ ๊ฒฝ์ฐ ๊ฒ์ฆ ๋จ๊ณ ์ถ๊ฐ | |
if all_results and len(set(all_results)) > 1: | |
# ๊ฒฐ๊ณผ๋ณ ๋น๋ ๊ณ์ฐ | |
result_counter = Counter(all_results) | |
different_results = "\n".join([f"{result} (๋น๋: {freq}ํ)" for result, freq in result_counter.most_common()]) | |
# ์ต์ ์ ๊ฒฐ๊ณผ ๊ฒฐ์ | |
best_result = determine_best_result(all_results, full_reasoning) | |
# ๋ชจ๋ธ์๊ฒ ๊ฐ์ฅ ์ ํํ ๊ฒฐ๊ณผ ์ ํ ์์ฒญ | |
verify_prompt = verification_prompt.format( | |
different_results=different_results, | |
given_info=given_info, | |
goal=goal, | |
calculation_steps=calculation_steps, | |
result=best_result | |
) | |
messages[-1]["content"] += "\n\n" + verify_prompt | |
# ๊ฒ์ฆ ๋จ๊ณ ์คํ | |
t = threading.Thread( | |
target=pipe, | |
args=(messages,), | |
kwargs=dict( | |
max_new_tokens=max_num_tokens // 2, | |
streamer=streamer, | |
do_sample=False, # ํ์ ์ ์ธ ๊ฒฐ๊ณผ๋ฅผ ์ํด ์ํ๋ง ๋นํ์ฑํ | |
temperature=0.3, # ๋ฎ์ ์จ๋ ์ฌ์ฉ | |
), | |
) | |
t.start() | |
history[-1].content += "\n\n" + verify_prompt | |
for token in streamer: | |
history[-1].content += token | |
history[-1].content = reformat_math(history[-1].content) | |
yield history | |
t.join() | |
# ๊ฒ์ฆ ๋จ๊ณ ํ full_reasoning ์ ๋ฐ์ดํธ | |
full_reasoning = history[-1].content | |
# ์ต์ข ๊ฒฐ๊ณผ ๊ฒฐ์ | |
final_results = extract_calculation_results(full_reasoning) | |
best_result = determine_best_result(final_results, full_reasoning) if final_results else None | |
# ์ต์ข ๊ฒฐ๋ก ์์ฑ | |
if best_result is not None: | |
reasoning_conclusion = f"์ถ๋ก ๊ณผ์ ์ ์ข ํฉํ ๊ฒฐ๊ณผ, ์ ํํ ๋ต๋ณ์ {best_result}์ ๋๋ค." | |
else: | |
# ๊ฒฐ๊ณผ๋ฅผ ์ถ์ถํ ์ ์๋ ๊ฒฝ์ฐ์ ๋๋น์ฑ | |
reasoning_parts = full_reasoning.split("\n\n") | |
reasoning_conclusion = "\n\n".join(reasoning_parts[-2:]) if len(reasoning_parts) > 2 else full_reasoning | |
# ์ต์ข ๋ต๋ณ ๋ฉ์์ง ์ถ๊ฐ | |
history.append(gr.ChatMessage(role="assistant", content="")) | |
# ์ต์ข ๋ต๋ณ์ ์ํ ๋ฉ์์ง ๊ตฌ์ฑ | |
final_messages = rebuild_messages(history[:-1]) # ๋ง์ง๋ง ๋น ๋ฉ์์ง ์ ์ธ | |
final_prompt = final_answer_prompt.format( | |
question=question, | |
reasoning_conclusion=reasoning_conclusion, | |
ANSWER_MARKER=ANSWER_MARKER | |
) | |
final_messages[-1]["content"] += "\n\n" + final_prompt | |
# ์ต์ข ๋ต๋ณ ์์ฑ | |
t = threading.Thread( | |
target=pipe, | |
args=(final_messages,), | |
kwargs=dict( | |
max_new_tokens=final_num_tokens, | |
streamer=streamer, | |
do_sample=do_sample, | |
temperature=temperature * 0.8, # ์ต์ข ๋ต๋ณ์ ๋ ํ์ ์ ์ฃผ๊ธฐ ์ํด ์จ๋ ์ฝ๊ฐ ๋ฎ์ถค | |
), | |
) | |
t.start() | |
# ์ต์ข ๋ต๋ณ ์คํธ๋ฆฌ๋ฐ | |
for token in streamer: | |
history[-1].content += token | |
history[-1].content = reformat_math(history[-1].content) | |
yield history | |
t.join() | |
yield history | |
with gr.Blocks(fill_height=True, title="Vidraft ThinkFlow") as demo: | |
# ์ ๋ชฉ๊ณผ ์ค๋ช | |
gr.Markdown("# Vidraft ThinkFlow") | |
gr.Markdown("### ์ถ๋ก ๊ธฐ๋ฅ์ด ์๋ LLM ๋ชจ๋ธ์ ์์ ์์ด๋ ์ถ๋ก ๊ธฐ๋ฅ์ ์๋์ผ๋ก ์ ์ฉํ๋ LLM ์ถ๋ก ์์ฑ ํ๋ซํผ") | |
with gr.Row(scale=1): | |
with gr.Column(scale=2): | |
gr.Markdown("## Before (Original)") | |
chatbot_original = gr.Chatbot( | |
scale=1, | |
type="messages", | |
latex_delimiters=latex_delimiters, | |
label="Original Model (No Reasoning)" | |
) | |
with gr.Column(scale=2): | |
gr.Markdown("## After (Thinking)") | |
chatbot_thinking = gr.Chatbot( | |
scale=1, | |
type="messages", | |
latex_delimiters=latex_delimiters, | |
label="Model with Reasoning" | |
) | |
with gr.Row(): | |
# msg ํ ์คํธ๋ฐ์ค๋ฅผ ๋จผ์ ์ ์ | |
msg = gr.Textbox( | |
submit_btn=True, | |
label="", | |
show_label=False, | |
placeholder="์ฌ๊ธฐ์ ์ง๋ฌธ์ ์ ๋ ฅํ์ธ์.", | |
autofocus=True, | |
) | |
# ์์ ์น์ - msg ๋ณ์ ์ ์ ์ดํ์ ๋ฐฐ์น | |
with gr.Accordion("EXAMPLES", open=False): | |
examples = gr.Examples( | |
examples=[ | |
"[์ถ์ฒ: MATH-500)] ์ฒ์ 100๊ฐ์ ์์ ์ ์ ์ค์์ 3, 4, 5๋ก ๋๋์ด ๋จ์ด์ง๋ ์๋ ๋ช ๊ฐ์ ๋๊น?", | |
"[์ถ์ฒ: MATH-500)] ์ํฌ์ ๋ ์์ ๋ ์์คํ ์ ๋ ํนํฉ๋๋ค. ํธ๋งํ 1๊ฐ๋ ๋ธ๋งํท 4๊ฐ์ ๊ฐ๊ณ , ๋ธ๋งํท 3๊ฐ๋ ๋๋งํฌ 7๊ฐ์ ๊ฐ์ต๋๋ค. ํธ๋งํท์์ ๋๋งํฌ 56๊ฐ์ ๊ฐ์น๋ ์ผ๋ง์ ๋๊น?", | |
"[์ถ์ฒ: MATH-500)] ์์ด๋ฏธ, ๋ฒค, ํฌ๋ฆฌ์ค์ ํ๊ท ๋์ด๋ 6์ด์ ๋๋ค. 4๋ ์ ํฌ๋ฆฌ์ค๋ ์ง๊ธ ์์ด๋ฏธ์ ๊ฐ์ ๋์ด์์ต๋๋ค. 4๋ ํ ๋ฒค์ ๋์ด๋ ๊ทธ๋ ์์ด๋ฏธ์ ๋์ด์ $\\frac{3}{5}$๊ฐ ๋ ๊ฒ์ ๋๋ค. ํฌ๋ฆฌ์ค๋ ์ง๊ธ ๋ช ์ด์ ๋๊น?", | |
"[์ถ์ฒ: MATH-500)] ๋ ธ๋์๊ณผ ํ๋์ ๊ตฌ์ฌ์ด ๋ค์ด ์๋ ๊ฐ๋ฐฉ์ด ์์ต๋๋ค. ํ์ฌ ํ๋์ ๊ตฌ์ฌ๊ณผ ๋ ธ๋์ ๊ตฌ์ฌ์ ๋น์จ์ 4:3์ ๋๋ค. ํ๋์ ๊ตฌ์ฌ 5๊ฐ๋ฅผ ๋ํ๊ณ ๋ ธ๋์ ๊ตฌ์ฌ 3๊ฐ๋ฅผ ์ ๊ฑฐํ๋ฉด ๋น์จ์ 7:3์ด ๋ฉ๋๋ค. ๋ ๋ฃ๊ธฐ ์ ์ ๊ฐ๋ฐฉ์ ํ๋์ ๊ตฌ์ฌ์ด ๋ช ๊ฐ ์์์ต๋๊น?", | |
"์ํ ๋์๋ฆฌ์์ ๋ค๊ฐ์ฌ ์ฌํ์ ์ํ ๊ธฐ๊ธ ๋ชจ๊ธ์ ์ํด ๋ฒ ์ดํน ์ธ์ผ์ ์ด๊ณ ์์ต๋๋ค. 3๊ฐ์ 54๋ฌ๋ฌ์ง๋ฆฌ ์ฟ ํค๋ฅผ 1๋ฌ๋ฌ์ ํ๋งคํ๊ณ , 20๊ฐ์ ์ปต์ผ์ดํฌ๋ฅผ ๊ฐ๊ฐ 2๋ฌ๋ฌ์ ํ๋งคํ๊ณ , 35๊ฐ์ ๋ธ๋ผ์ฐ๋๋ฅผ ๊ฐ๊ฐ 1๋ฌ๋ฌ์ ํ๋งคํฉ๋๋ค. ์ํ ๋์๋ฆฌ์์ ์ด ์ ํ์ ๊ตฝ๋ ๋ฐ 15๋ฌ๋ฌ๊ฐ ๋ค์๋ค๋ฉด, ์์ต์ ์ผ๋ง์์๊น์?" | |
], | |
inputs=msg | |
) | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("""## ๋งค๊ฐ๋ณ์ ์กฐ์ """) | |
num_tokens = gr.Slider( | |
50, | |
4000, | |
2000, | |
step=1, | |
label="์ถ๋ก ๋จ๊ณ๋น ์ต๋ ํ ํฐ ์", | |
interactive=True, | |
) | |
final_num_tokens = gr.Slider( | |
50, | |
4000, | |
2000, | |
step=1, | |
label="์ต์ข ๋ต๋ณ์ ์ต๋ ํ ํฐ ์", | |
interactive=True, | |
) | |
do_sample = gr.Checkbox(True, label="์ํ๋ง ์ฌ์ฉ") | |
temperature = gr.Slider(0.1, 1.0, 0.7, step=0.1, label="์จ๋") | |
# ์ฌ์ฉ์๊ฐ ๋ฉ์์ง๋ฅผ ์ ์ถํ๋ฉด ๋ ๋ด์ด ๋์์ ์๋ตํฉ๋๋ค | |
msg.submit( | |
user_input, | |
[msg, chatbot_original, chatbot_thinking], # ์ ๋ ฅ | |
[msg, chatbot_original, chatbot_thinking], # ์ถ๋ ฅ | |
).then( | |
bot_original, | |
[ | |
chatbot_original, | |
num_tokens, | |
do_sample, | |
temperature, | |
], | |
chatbot_original, # ์ถ๋ ฅ์์ ์ ํ์คํ ๋ฆฌ ์ ์ฅ | |
).then( | |
bot_thinking, | |
[ | |
chatbot_thinking, | |
num_tokens, | |
final_num_tokens, | |
do_sample, | |
temperature, | |
], | |
chatbot_thinking, # ์ถ๋ ฅ์์ ์ ํ์คํ ๋ฆฌ ์ ์ฅ | |
) | |
if __name__ == "__main__": | |
demo.queue().launch() # title ๋งค๊ฐ๋ณ์ ์ ๊ฑฐ |