Update app.py
Browse files
app.py
CHANGED
@@ -1,92 +1,84 @@
|
|
1 |
import os
|
2 |
import subprocess
|
3 |
-
import
|
4 |
-
import
|
5 |
-
|
6 |
-
# 设置环境
|
7 |
-
def setup_environment():
|
8 |
-
if not os.path.exists("skywork-o1-prm-inference"):
|
9 |
-
print("Cloning repository...")
|
10 |
-
subprocess.run(["git", "clone", "https://github.com/SkyworkAI/skywork-o1-prm-inference.git"], check=True)
|
11 |
-
repo_path = os.path.abspath("skywork-o1-prm-inference")
|
12 |
-
else:
|
13 |
-
repo_path = os.path.abspath("skywork-o1-prm-inference")
|
14 |
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
|
19 |
-
|
|
|
|
|
20 |
|
21 |
-
|
22 |
-
from
|
23 |
-
from model_utils.
|
24 |
-
|
25 |
-
|
|
|
|
|
26 |
|
27 |
-
|
28 |
-
|
29 |
-
model = PRM_MODEL.from_pretrained(model_id).to("cpu").eval()
|
30 |
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
processed_data = prepare_input(problem, response, tokenizer=tokenizer, step_token="\n")
|
35 |
-
input_ids = [processed_data[0]] # 第一个元素
|
36 |
-
steps = [processed_data[1]] # 第二个元素
|
37 |
-
reward_flags = [processed_data[2]] # 第三个元素
|
38 |
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
|
|
|
|
51 |
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
input_ids=input_ids,
|
56 |
-
attention_mask=attention_mask,
|
57 |
-
return_probs=True
|
58 |
-
)
|
59 |
|
60 |
-
|
61 |
-
|
|
|
62 |
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
elif isinstance(step_rewards[0], np.ndarray):
|
67 |
-
return json.dumps(step_rewards[0].tolist())
|
68 |
-
else:
|
69 |
-
return json.dumps(list(step_rewards[0])) # 转换为列表
|
70 |
-
except Exception as e:
|
71 |
-
return json.dumps({"error": str(e)})
|
72 |
-
# 创建Gradio界面
|
73 |
-
iface = gr.Interface(
|
74 |
-
fn=evaluate,
|
75 |
inputs=[
|
76 |
-
gr.Textbox(label="Problem
|
77 |
-
gr.Textbox(label="Response",
|
78 |
-
],
|
79 |
-
outputs=gr.JSON(),
|
80 |
-
title="Problem Response Evaluation",
|
81 |
-
description="Enter a problem and its response to get step-wise rewards",
|
82 |
-
examples=[
|
83 |
-
[
|
84 |
-
"Janet's ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?",
|
85 |
-
"To determine how much money Janet makes every day at the farmers' market, we need to follow these steps:\n1. Calculate the total number of eggs laid by the ducks per day.\n Janet's ducks lay 16 eggs per day.\n2. Determine the number of eggs Janet uses each day.\n - She eats 3 eggs for breakfast every morning.\n - She bakes muffins for her friends every day with 4 eggs.\n So, the total number of eggs used per day is:\n 3 + 4 = 7 eggs\n3. Calculate the number of eggs Janet sells at the farmers' market each day.\n Subtract the number of eggs used from the total number of eggs laid:\n 16 - 7 = 9 eggs\n4. Determine how much money Janet makes from selling the eggs.\n She sells each egg for $2, so the total amount of money she makes is:\n 9 ×2 = 18 dollars\nTherefore, the amount of money Janet makes every day at the farmers' market is $18."
|
86 |
-
]
|
87 |
],
|
88 |
-
|
|
|
|
|
|
|
|
|
|
|
89 |
)
|
90 |
|
91 |
-
|
92 |
-
|
|
|
|
1 |
import os
|
2 |
import subprocess
|
3 |
+
import gradio as gr
|
4 |
+
import torch
|
5 |
+
from transformers import AutoTokenizer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
+
# 克隆项目仓库
|
8 |
+
REPO_URL = "https://github.com/SkyworkAI/skywork-o1-prm-inference.git"
|
9 |
+
REPO_DIR = "skywork-o1-prm-inference"
|
10 |
|
11 |
+
# 在 Hugging Face Spaces 上首次启动时克隆
|
12 |
+
if not os.path.exists(REPO_DIR):
|
13 |
+
subprocess.run(["git", "clone", REPO_URL], check=True)
|
14 |
|
15 |
+
# 导入必要的模块
|
16 |
+
from skywork-o1-prm-inference.model_utils.prm_model import PRM_MODEL
|
17 |
+
from skywork-o1-prm-inference.model_utils.io_utils import (
|
18 |
+
prepare_input,
|
19 |
+
prepare_batch_input_for_model,
|
20 |
+
derive_step_rewards
|
21 |
+
)
|
22 |
|
23 |
+
# 模型配置
|
24 |
+
MODEL_ID = "Skywork/Skywork-o1-Open-PRM-Qwen-2.5-1.5B"
|
|
|
25 |
|
26 |
+
# 初始化模型和tokenizer
|
27 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
|
28 |
+
model = PRM_MODEL.from_pretrained(MODEL_ID).to("cpu").eval()
|
|
|
|
|
|
|
|
|
29 |
|
30 |
+
def compute_step_rewards(problem_text, response_text):
|
31 |
+
# 准备输入数据
|
32 |
+
processed_input = prepare_input(
|
33 |
+
problem_text,
|
34 |
+
response_text,
|
35 |
+
tokenizer=tokenizer,
|
36 |
+
step_token="\n"
|
37 |
+
)
|
38 |
+
input_ids, steps, reward_flags = processed_input
|
39 |
+
|
40 |
+
# 转换为批处理形式
|
41 |
+
input_ids, attention_mask, reward_flags = prepare_batch_input_for_model(
|
42 |
+
[input_ids],
|
43 |
+
[reward_flags],
|
44 |
+
tokenizer.pad_token_id
|
45 |
+
)
|
46 |
+
input_ids = input_ids.to("cpu")
|
47 |
+
attention_mask = attention_mask.to("cpu")
|
48 |
+
if isinstance(reward_flags, torch.Tensor):
|
49 |
+
reward_flags = reward_flags.to("cpu")
|
50 |
|
51 |
+
# 模型推理
|
52 |
+
with torch.no_grad():
|
53 |
+
_, _, rewards = model(
|
54 |
+
input_ids=input_ids,
|
55 |
+
attention_mask=attention_mask,
|
56 |
+
return_probs=True
|
57 |
+
)
|
58 |
|
59 |
+
# 获取step rewards
|
60 |
+
step_rewards = derive_step_rewards(rewards, reward_flags)
|
61 |
+
return step_rewards[0]
|
|
|
|
|
|
|
|
|
62 |
|
63 |
+
def inference_interface(problem, response):
|
64 |
+
rewards = compute_step_rewards(problem, response)
|
65 |
+
return rewards
|
66 |
|
67 |
+
# Gradio界面配置
|
68 |
+
interface = gr.Interface(
|
69 |
+
fn=inference_interface,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
inputs=[
|
71 |
+
gr.Textbox(lines=4, label="Problem (题目)"),
|
72 |
+
gr.Textbox(lines=6, label="Response (回答)"),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
],
|
74 |
+
outputs="json",
|
75 |
+
title="Skywork-o1-prm-inference Demo",
|
76 |
+
description=(
|
77 |
+
"输入题目和回答,点击提交查看其每个 step 对应的 reward,"
|
78 |
+
"这些 reward 值用于度量回答每一步的质量。"
|
79 |
+
),
|
80 |
)
|
81 |
|
82 |
+
if __name__ == "__main__":
|
83 |
+
interface.launch(server_name="0.0.0.0", server_port=7860)
|
84 |
+
|