lihongze8 commited on
Commit
8d257f5
·
verified ·
1 Parent(s): 13b9b88

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -76
app.py CHANGED
@@ -1,92 +1,84 @@
1
  import os
2
  import subprocess
3
- import sys
4
- import json
5
- import numpy as np
6
- # 设置环境
7
- def setup_environment():
8
- if not os.path.exists("skywork-o1-prm-inference"):
9
- print("Cloning repository...")
10
- subprocess.run(["git", "clone", "https://github.com/SkyworkAI/skywork-o1-prm-inference.git"], check=True)
11
- repo_path = os.path.abspath("skywork-o1-prm-inference")
12
- else:
13
- repo_path = os.path.abspath("skywork-o1-prm-inference")
14
 
15
- if repo_path not in sys.path:
16
- sys.path.append(repo_path)
17
- print(f"Added {repo_path} to Python path")
18
 
19
- setup_environment()
 
 
20
 
21
- import gradio as gr
22
- from transformers import AutoTokenizer
23
- from model_utils.prm_model import PRM_MODEL
24
- from model_utils.io_utils import prepare_input, prepare_batch_input_for_model, derive_step_rewards
25
- import torch
 
 
26
 
27
- model_id = "Skywork/Skywork-o1-Open-PRM-Qwen-2.5-1.5B"
28
- tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
29
- model = PRM_MODEL.from_pretrained(model_id).to("cpu").eval()
30
 
31
- def evaluate(problem, response):
32
- try:
33
- # 处理输入数据
34
- processed_data = prepare_input(problem, response, tokenizer=tokenizer, step_token="\n")
35
- input_ids = [processed_data[0]] # 第一个元素
36
- steps = [processed_data[1]] # 第二个元素
37
- reward_flags = [processed_data[2]] # 第三个元素
38
 
39
- # 准备批处理输入
40
- input_ids, attention_mask, reward_flags = prepare_batch_input_for_model(
41
- input_ids,
42
- reward_flags,
43
- tokenizer.pad_token_id
44
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
- # 确保在CPU上
47
- input_ids = input_ids.to("cpu")
48
- attention_mask = attention_mask.to("cpu")
49
- if isinstance(reward_flags, torch.Tensor):
50
- reward_flags = reward_flags.to("cpu")
 
 
51
 
52
- # 模型推理
53
- with torch.no_grad():
54
- _, _, rewards = model(
55
- input_ids=input_ids,
56
- attention_mask=attention_mask,
57
- return_probs=True
58
- )
59
 
60
- # 计算步骤奖励
61
- step_rewards = derive_step_rewards(rewards, reward_flags)
 
62
 
63
- # 确保返回的是有效的JSON字符串
64
- if isinstance(step_rewards[0], torch.Tensor):
65
- return json.dumps(step_rewards[0].cpu().numpy().tolist())
66
- elif isinstance(step_rewards[0], np.ndarray):
67
- return json.dumps(step_rewards[0].tolist())
68
- else:
69
- return json.dumps(list(step_rewards[0])) # 转换为列表
70
- except Exception as e:
71
- return json.dumps({"error": str(e)})
72
- # 创建Gradio界面
73
- iface = gr.Interface(
74
- fn=evaluate,
75
  inputs=[
76
- gr.Textbox(label="Problem", lines=4),
77
- gr.Textbox(label="Response", lines=8)
78
- ],
79
- outputs=gr.JSON(),
80
- title="Problem Response Evaluation",
81
- description="Enter a problem and its response to get step-wise rewards",
82
- examples=[
83
- [
84
- "Janet's ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?",
85
- "To determine how much money Janet makes every day at the farmers' market, we need to follow these steps:\n1. Calculate the total number of eggs laid by the ducks per day.\n Janet's ducks lay 16 eggs per day.\n2. Determine the number of eggs Janet uses each day.\n - She eats 3 eggs for breakfast every morning.\n - She bakes muffins for her friends every day with 4 eggs.\n So, the total number of eggs used per day is:\n 3 + 4 = 7 eggs\n3. Calculate the number of eggs Janet sells at the farmers' market each day.\n Subtract the number of eggs used from the total number of eggs laid:\n 16 - 7 = 9 eggs\n4. Determine how much money Janet makes from selling the eggs.\n She sells each egg for $2, so the total amount of money she makes is:\n 9 ×2 = 18 dollars\nTherefore, the amount of money Janet makes every day at the farmers' market is $18."
86
- ]
87
  ],
88
- cache_examples=False # 禁用示例缓存
 
 
 
 
 
89
  )
90
 
91
- # 启动接口
92
- iface.launch(server_name="0.0.0.0")
 
 
1
  import os
2
  import subprocess
3
+ import gradio as gr
4
+ import torch
5
+ from transformers import AutoTokenizer
 
 
 
 
 
 
 
 
6
 
7
+ # 克隆项目仓库
8
+ REPO_URL = "https://github.com/SkyworkAI/skywork-o1-prm-inference.git"
9
+ REPO_DIR = "skywork-o1-prm-inference"
10
 
11
+ # 在 Hugging Face Spaces 上首次启动时克隆
12
+ if not os.path.exists(REPO_DIR):
13
+ subprocess.run(["git", "clone", REPO_URL], check=True)
14
 
15
+ # 导入必要的模块
16
+ from skywork-o1-prm-inference.model_utils.prm_model import PRM_MODEL
17
+ from skywork-o1-prm-inference.model_utils.io_utils import (
18
+ prepare_input,
19
+ prepare_batch_input_for_model,
20
+ derive_step_rewards
21
+ )
22
 
23
+ # 模型配置
24
+ MODEL_ID = "Skywork/Skywork-o1-Open-PRM-Qwen-2.5-1.5B"
 
25
 
26
+ # 初始化模型和tokenizer
27
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
28
+ model = PRM_MODEL.from_pretrained(MODEL_ID).to("cpu").eval()
 
 
 
 
29
 
30
+ def compute_step_rewards(problem_text, response_text):
31
+ # 准备输入数据
32
+ processed_input = prepare_input(
33
+ problem_text,
34
+ response_text,
35
+ tokenizer=tokenizer,
36
+ step_token="\n"
37
+ )
38
+ input_ids, steps, reward_flags = processed_input
39
+
40
+ # 转换为批处理形式
41
+ input_ids, attention_mask, reward_flags = prepare_batch_input_for_model(
42
+ [input_ids],
43
+ [reward_flags],
44
+ tokenizer.pad_token_id
45
+ )
46
+ input_ids = input_ids.to("cpu")
47
+ attention_mask = attention_mask.to("cpu")
48
+ if isinstance(reward_flags, torch.Tensor):
49
+ reward_flags = reward_flags.to("cpu")
50
 
51
+ # 模型推理
52
+ with torch.no_grad():
53
+ _, _, rewards = model(
54
+ input_ids=input_ids,
55
+ attention_mask=attention_mask,
56
+ return_probs=True
57
+ )
58
 
59
+ # 获取step rewards
60
+ step_rewards = derive_step_rewards(rewards, reward_flags)
61
+ return step_rewards[0]
 
 
 
 
62
 
63
+ def inference_interface(problem, response):
64
+ rewards = compute_step_rewards(problem, response)
65
+ return rewards
66
 
67
+ # Gradio界面配置
68
+ interface = gr.Interface(
69
+ fn=inference_interface,
 
 
 
 
 
 
 
 
 
70
  inputs=[
71
+ gr.Textbox(lines=4, label="Problem (题目)"),
72
+ gr.Textbox(lines=6, label="Response (回答)"),
 
 
 
 
 
 
 
 
 
73
  ],
74
+ outputs="json",
75
+ title="Skywork-o1-prm-inference Demo",
76
+ description=(
77
+ "输入题目和回答,点击提交查看其每个 step 对应的 reward,"
78
+ "这些 reward 值用于度量回答每一步的质量。"
79
+ ),
80
  )
81
 
82
+ if __name__ == "__main__":
83
+ interface.launch(server_name="0.0.0.0", server_port=7860)
84
+