朱东升 commited on
Commit
fc6c268
·
1 Parent(s): e74db4f

requirements update16

Browse files
Files changed (1) hide show
  1. app.py +2 -15
app.py CHANGED
@@ -5,7 +5,7 @@ import os
5
  import sys
6
  from pathlib import Path
7
  import concurrent.futures
8
- import multiprocessing # 添加multiprocessing模块以获取CPU核心数
9
 
10
  # 添加当前目录和src目录到模块搜索路径
11
  current_dir = os.path.dirname(os.path.abspath(__file__))
@@ -25,18 +25,13 @@ def evaluate(input_data):
25
  list: 包含评估结果的列表
26
  """
27
  try:
28
- # 确保输入是列表
29
  if not isinstance(input_data, list):
30
  return {"status": "Exception", "error": "Input must be a list"}
31
 
32
  results = []
33
- # 获取CPU核心数作为最大线程数
34
  max_workers = multiprocessing.cpu_count()
35
- # 使用线程池并行处理多个测试用例,限制最大线程数为CPU核心数
36
  with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
37
- # 提交所有任务到线程池
38
  future_to_item = {executor.submit(evaluate_single_case, item): item for item in input_data}
39
- # 获取结果
40
  for future in concurrent.futures.as_completed(future_to_item):
41
  item = future_to_item[future]
42
  try:
@@ -61,7 +56,6 @@ def evaluate_single_case(input_data):
61
  dict: 包含评估结果的字典
62
  """
63
  try:
64
- # 只处理字典类型输入
65
  if not isinstance(input_data, dict):
66
  return {"status": "Exception", "error": "Input item must be a dictionary"}
67
 
@@ -71,19 +65,16 @@ def evaluate_single_case(input_data):
71
  if not completions:
72
  return {"status": "Exception", "error": "No code provided"}
73
 
74
- # 评估所有完成的代码
75
  results = []
76
  for comp in completions:
77
  code = input_data.get('prompt') + comp + '\n' + input_data.get('tests')
78
  result = evaluate_code(code, language)
79
- # 如果当前代码执行成功,立即返回pass,不再评估后续代码
80
  if result["status"] == "OK":
81
  return {"status": "pass", "compiled_code": code}
82
  print(f'Code failed to compile: \n{code}')
83
  result["compiled_code"] = code
84
  results.append(result)
85
 
86
- # 所有代码都执行失败,返回第一个失败结果
87
  return results[0]
88
 
89
  except Exception as e:
@@ -100,22 +91,18 @@ def evaluate_code(code, language):
100
  dict: 包含评估结果的字典
101
  """
102
  try:
103
- # 动态导入对应语言的评估模块
104
  module_name = f"src.eval_{language.lower()}"
105
  module = importlib.import_module(module_name)
106
 
107
- # 使用系统临时目录而不是固定的temp目录
108
  import tempfile
109
 
110
- # 创建临时文件
111
  with tempfile.NamedTemporaryFile(suffix=f".{language}", delete=False) as temp_file:
112
  temp_file_path = temp_file.name
113
  temp_file.write(code.encode('utf-8'))
114
 
115
- # 调用对应语言的评估函数
116
  result = module.eval_script(temp_file_path)
117
 
118
- # 清理临时文件
119
  if os.path.exists(temp_file_path):
120
  os.unlink(temp_file_path)
121
 
 
5
  import sys
6
  from pathlib import Path
7
  import concurrent.futures
8
+ import multiprocessing
9
 
10
  # 添加当前目录和src目录到模块搜索路径
11
  current_dir = os.path.dirname(os.path.abspath(__file__))
 
25
  list: 包含评估结果的列表
26
  """
27
  try:
 
28
  if not isinstance(input_data, list):
29
  return {"status": "Exception", "error": "Input must be a list"}
30
 
31
  results = []
 
32
  max_workers = multiprocessing.cpu_count()
 
33
  with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
 
34
  future_to_item = {executor.submit(evaluate_single_case, item): item for item in input_data}
 
35
  for future in concurrent.futures.as_completed(future_to_item):
36
  item = future_to_item[future]
37
  try:
 
56
  dict: 包含评估结果的字典
57
  """
58
  try:
 
59
  if not isinstance(input_data, dict):
60
  return {"status": "Exception", "error": "Input item must be a dictionary"}
61
 
 
65
  if not completions:
66
  return {"status": "Exception", "error": "No code provided"}
67
 
 
68
  results = []
69
  for comp in completions:
70
  code = input_data.get('prompt') + comp + '\n' + input_data.get('tests')
71
  result = evaluate_code(code, language)
 
72
  if result["status"] == "OK":
73
  return {"status": "pass", "compiled_code": code}
74
  print(f'Code failed to compile: \n{code}')
75
  result["compiled_code"] = code
76
  results.append(result)
77
 
 
78
  return results[0]
79
 
80
  except Exception as e:
 
91
  dict: 包含评估结果的字典
92
  """
93
  try:
94
+ language = language.split('.')[-1] if '.' in language else language # just for go
95
  module_name = f"src.eval_{language.lower()}"
96
  module = importlib.import_module(module_name)
97
 
 
98
  import tempfile
99
 
 
100
  with tempfile.NamedTemporaryFile(suffix=f".{language}", delete=False) as temp_file:
101
  temp_file_path = temp_file.name
102
  temp_file.write(code.encode('utf-8'))
103
 
 
104
  result = module.eval_script(temp_file_path)
105
 
 
106
  if os.path.exists(temp_file_path):
107
  os.unlink(temp_file_path)
108