File size: 4,169 Bytes
30b1610
 
08681f4
30b1610
4f32597
08681f4
3499425
fc6c268
30b1610
74d43a2
4f32597
74d43a2
4f32597
 
74d43a2
 
4f32597
08681f4
 
30b1610
141e12d
3499425
141e12d
 
3499425
141e12d
 
3499425
 
 
 
e74db4f
 
3499425
 
 
 
 
 
 
 
 
 
 
141e12d
 
 
 
 
 
 
08681f4
3499425
30b1610
08681f4
 
 
 
3499425
 
 
 
 
52d43e7
3499425
 
52d43e7
3499425
 
 
 
 
64ec244
3499425
64ec244
3499425
08681f4
3499425
 
08681f4
 
22cec65
08681f4
 
 
 
 
 
30b1610
08681f4
 
 
 
fc6c268
08681f4
3f1394b
74d43a2
30b1610
a2bac48
 
4770e23
 
 
a2bac48
 
 
 
30b1610
a2bac48
 
30b1610
08681f4
30b1610
4f32597
 
08681f4
 
30b1610
08681f4
 
30b1610
08681f4
 
 
2f2f63e
30b1610
 
08681f4
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import gradio as gr
import json
import importlib
import os
import sys
from pathlib import Path
import concurrent.futures
import multiprocessing

# 添加当前目录和src目录到模块搜索路径
current_dir = os.path.dirname(os.path.abspath(__file__))
src_dir = os.path.join(current_dir, "src")
if current_dir not in sys.path:
    sys.path.append(current_dir)
if src_dir not in sys.path:
    sys.path.append(src_dir)

def evaluate(input_data):
    """评估代码的主函数
    
    Args:
        input_data: 列表(批量处理多个测试用例)
        
    Returns:
        list: 包含评估结果的列表
    """
    try:
        if not isinstance(input_data, list):
            return {"status": "Exception", "error": "Input must be a list"}
            
        results = []
        max_workers = multiprocessing.cpu_count()
        with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
            future_to_item = {executor.submit(evaluate_single_case, item): item for item in input_data}
            for future in concurrent.futures.as_completed(future_to_item):
                item = future_to_item[future]
                try:
                    result = future.result()
                    item.update(result)
                    results.append(item)
                except Exception as e:
                    item.update({"status": "Exception", "error": str(e)})
                    results.append(item)
        return results
            
    except Exception as e:
        return {"status": "Exception", "error": str(e)}

def evaluate_single_case(input_data):
    """评估单个代码用例
    
    Args:
        input_data: 字典(包含代码信息)
        
    Returns:
        dict: 包含评估结果的字典
    """
    try:
        if not isinstance(input_data, dict):
            return {"status": "Exception", "error": "Input item must be a dictionary"}
            
        language = input_data.get('language')
        completions = input_data.get('processed_completions', [])

        if not completions:
            return {"status": "Exception", "error": "No code provided"}

        results = []
        for comp in completions:
            code = input_data.get('prompt') + comp + '\n' + input_data.get('tests')
            result = evaluate_code(code, language)
            if result["status"] == "OK":
                return {"status": "pass", "compiled_code": code}
            print(f'Code failed to compile: \n{code}')
            result["compiled_code"] = code
            results.append(result)
            
        return results[0]
                
    except Exception as e:
        return {"status": "Exception", "error": str(e)}

def evaluate_code(code, language):
    """评估特定语言的代码
    
    Args:
        code (str): 要评估的代码
        language (str): 编程语言
        
    Returns:
        dict: 包含评估结果的字典
    """
    try:
        language = language.split('.')[-1] if '.' in language else language  # just for go
        module_name = f"src.eval_{language.lower()}"
        print(f'module_name: {module_name}')
        module = importlib.import_module(module_name)

        import tempfile
        
        # 对于Go语言,确保文件以_test.go结尾
        suffix = "_test.go" if language.lower() == "go" else f".{language}"
        with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as temp_file:
            temp_file_path = temp_file.name
            temp_file.write(code.encode('utf-8'))
        
        result = module.eval_script(temp_file_path)

        if os.path.exists(temp_file_path):
            os.unlink(temp_file_path)

        return result

    except ImportError as e:
        return {"status": "Exception", "error": f"Language {language} not supported: {str(e)}"}
    except Exception as e:
        return {"status": "Exception", "error": str(e)}

# 创建Gradio接口
demo = gr.Interface(
    fn=evaluate,
    inputs=gr.JSON(),
    outputs=gr.JSON(),
    title="代码评估服务",
    description="支持多种编程语言的代码评估服务"
)

if __name__ == "__main__":
    demo.launch()