codesearchBase / app.py
Forrest99's picture
Update app.py
81c9a50 verified
raw
history blame
21.4 kB
from flask import Flask, request, jsonify, render_template_string
from sentence_transformers import SentenceTransformer, util
import logging
import sys
import signal
# 初始化 Flask 应用
app = Flask(__name__)
# 配置日志,级别设为 INFO
logging.basicConfig(level=logging.INFO)
app.logger = logging.getLogger("CodeSearchAPI")
# 预定义代码片段
CODE_SNIPPETS = [
"print('Hello, World!')",
"def add(a, b): return a + b",
"import random; def generate_random(): return random.randint(1, 100)",
"def is_even(n): return n % 2 == 0",
"def string_length(s): return len(s)",
"from datetime import date; def get_current_date(): return date.today()",
"import os; def file_exists(path): return os.path.exists(path)",
"def read_file(path): return open(path, 'r').read()",
"def write_file(path, content): open(path, 'w').write(content)",
"from datetime import datetime; def get_current_time(): return datetime.now()",
"def to_upper(s): return s.upper()",
"def to_lower(s): return s.lower()",
"def reverse_string(s): return s[::-1]",
"def list_length(lst): return len(lst)",
"def list_max(lst): return max(lst)",
"def list_min(lst): return min(lst)",
"def sort_list(lst): return sorted(lst)",
"def merge_lists(lst1, lst2): return lst1 + lst2",
"def remove_element(lst, element): lst.remove(element)",
"def is_list_empty(lst): return len(lst) == 0",
"def count_char(s, char): return s.count(char)",
"def contains_substring(s, sub): return sub in s",
"def int_to_str(n): return str(n)",
"def str_to_int(s): return int(s)",
"def is_numeric(s): return s.isdigit()",
"def get_index(lst, element): return lst.index(element)",
"def clear_list(lst): lst.clear()",
"def reverse_list(lst): lst.reverse()",
"def remove_duplicates(lst): return list(set(lst))",
"def is_in_list(lst, value): return value in lst",
"def create_dict(): return {}",
"def add_to_dict(d, key, value): d[key] = value",
"def delete_key(d, key): del d[key]",
"def get_keys(d): return list(d.keys())",
"def get_values(d): return list(d.values())",
"def merge_dicts(d1, d2): return {**d1, **d2}",
"def is_dict_empty(d): return len(d) == 0",
"def get_value(d, key): return d[key]",
"def key_exists(d, key): return key in d",
"def clear_dict(d): d.clear()",
"def count_lines(path): return len(open(path).readlines())",
"def write_list_to_file(path, lst): open(path, 'w').write('\\n'.join(map(str, lst)))",
"def read_list_from_file(path): return open(path, 'r').read().splitlines()",
"def count_words(path): return len(open(path, 'r').read().split())",
"def is_leap_year(year): return year % 4 == 0 and (year % 100 != 0 or year % 400 == 0)",
"from datetime import datetime; def format_time(dt): return dt.strftime('%Y-%m-%d %H:%M:%S')",
"from datetime import date; def days_between(d1, d2): return (d2 - d1).days",
"import os; def get_current_dir(): return os.getcwd()",
"import os; def list_files(path): return os.listdir(path)",
"import os; def create_dir(path): os.mkdir(path)"
"import os; def remove_dir(path): os.rmdir(path)",
"import os; def is_file(path): return os.path.isfile(path)",
"import os; def is_dir(path): return os.path.isdir(path)",
"import os; def get_file_size(path): return os.path.getsize(path)",
"import os; def rename_file(src, dst): os.rename(src, dst)",
"import shutil; def copy_file(src, dst): shutil.copy(src, dst)",
"import shutil; def move_file(src, dst): shutil.move(src, dst)",
"import os; def delete_file(path): os.remove(path)",
"import os; def get_env_var(key): return os.getenv(key)",
"import os; def set_env_var(key, value): os.environ[key] = value",
"import webbrowser; def open_url(url): webbrowser.open(url)",
"import requests; def send_get_request(url): return requests.get(url).text",
"import json; def parse_json(data): return json.loads(data)",
"import json; def write_json(data, path): open(path, 'w').write(json.dumps(data))",
"import json; def read_json(path): return json.loads(open(path, 'r').read())",
"def list_to_string(lst): return ''.join(lst)",
"def string_to_list(s): return list(s)",
"def join_with_comma(lst): return ','.join(lst)",
"def join_with_newline(lst): return '\\n'.join(lst)",
"def split_by_space(s): return s.split()",
"def split_by_char(s, char): return s.split(char)",
"def split_to_chars(s): return list(s)",
"def replace_string(s, old, new): return s.replace(old, new)",
"def remove_spaces(s): return s.replace(' ', '')",
"import string; def remove_punctuation(s): return s.translate(str.maketrans('', '', string.punctuation))",
"def is_string_empty(s): return len(s) == 0",
"def is_palindrome(s): return s == s[::-1]",
"import csv; def write_csv(data, path): open(path, 'w', newline='').write('\\n'.join([','.join(map(str, row)) for row in data]))",
"import csv; def read_csv(path): return [row for row in csv.reader(open(path, 'r'))]",
"def count_csv_lines(path): return len(open(path).readlines())",
"import random; def shuffle_list(lst): random.shuffle(lst)",
"import random; def random_choice(lst): return random.choice(lst)",
"import random; def random_sample(lst, k): return random.sample(lst, k)",
"import random; def roll_dice(): return random.randint(1, 6)",
"import random; def flip_coin(): return random.choice(['Heads', 'Tails'])",
"import random; import string; def generate_password(length=8): return ''.join(random.choices(string.ascii_letters + string.digits, k=length))",
"import random; def generate_color(): return '#%06x' % random.randint(0, 0xFFFFFF)",
"import uuid; def generate_uuid(): return str(uuid.uuid4())",
"class MyClass: pass",
"def create_instance(): return MyClass()",
"class MyClass: def my_method(self): pass",
"class MyClass: def __init__(self): self.my_attr = None",
"class ChildClass(MyClass): pass",
"class ChildClass(MyClass): def my_method(self): pass",
"class MyClass: @classmethod def my_class_method(cls): pass",
"class MyClass: @staticmethod def my_static_method(): pass",
"def check_type(obj): return type(obj)",
"def get_attr(obj, attr): return getattr(obj, attr)",
"def set_attr(obj, attr, value): setattr(obj, attr, value)",
"def del_attr(obj, attr): delattr(obj, attr)",
"""try:
x = 1 / 0
except ZeroDivisionError:
pass""",
"""class CustomError(Exception): pass
def raise_custom_error(): raise CustomError('Error occurred')""",
"""try:
x = 1 / 0
except Exception as e: return str(e)""",
"""import logging; logging.basicConfig(filename='error.log', level=logging.ERROR); logging.error('Error occurred')""",
"""import time; def timer(): start = time.time(); return lambda: time.time() - start""",
"""import time; def run_time(): start = time.time(); return lambda: time.time() - start""",
"""import sys; def print_progress(progress): sys.stdout.write(f'\\rProgress: {progress}%'); sys.stdout.flush()""",
"""import time; def delay(seconds): time.sleep(seconds)""",
"lambda x: x * 2",
"map(lambda x: x * 2, [1, 2, 3])",
"filter(lambda x: x > 2, [1, 2, 3])",
"from functools import reduce; reduce(lambda x, y: x + y, [1, 2, 3])",
"[x * 2 for x in [1, 2, 3]]",
"{x: x * 2 for x in [1, 2, 3]}",
"{x for x in [1, 2, 3]}",
"set1 & set2",
"set1 | set2",
"set1 - set2",
"[x for x in lst if x is not None]",
"""try:
with open('file.txt', 'r') as f: pass
except IOError: pass""",
"type(var)",
"bool(s)",
"if condition: pass",
"while condition: pass",
"for item in lst: pass",
"for key, value in d.items(): pass",
"for char in s: pass",
"for item in lst: if condition: break",
"for item in lst: if condition: continue",
"def my_func(): pass",
"def my_func(param=1): pass",
"def my_func(): return 1, 2",
"def my_func(*args): pass",
"def my_func(**kwargs): pass",
"""import time; def timer(func):
def wrapper(*args, **kwargs):
start = time.time()
result = func(*args, **kwargs)
print(f'Time: {time.time() - start}'); return result
return wrapper""",
"""def decorator(func):
def wrapper(*args, **kwargs): return func(*args, **kwargs)
return wrapper""",
"""from functools import lru_cache; @lru_cache(maxsize=None)
def my_func(): pass""",
"def my_generator(): yield 1",
"gen = my_generator(); next(gen)",
"class MyIterator: def __iter__(self): return self; def __next__(self): pass",
"it = iter([1, 2, 3]); next(it)",
"for i, val in enumerate(lst): pass",
"list(zip(lst1, lst2))",
"dict(zip(keys, values))",
"lst1 == lst2",
"dict1 == dict2",
"set1 == set2",
"set(lst)",
"set.clear()",
"len(set) == 0",
"set.add(item)",
"set.remove(item)",
"item in set",
"len(set)",
"set1 & set2",
"set(lst1).issubset(lst2)",
"sub in s",
"s[0]",
"s[-1]",
"import mimetypes; mimetypes.guess_type(path)[0] == 'text/plain'",
"import mimetypes; mimetypes.guess_type(path)[0].startswith('image/')",
"round(num)",
"import math; math.ceil(num)",
"import math; math.floor(num)",
"f'{num:.2f}'",
"import random; import string; ''.join(random.choices(string.ascii_letters + string.digits, k=8))",
"import os; os.path.exists(path)",
"import os; for root, dirs, files in os.walk(path): pass",
"import os; os.path.splitext(path)[1]",
"import os; os.path.basename(path)",
"import os; os.path.abspath(path)",
"import platform; platform.python_version()",
"import platform; platform.system()",
"import multiprocessing; multiprocessing.cpu_count()",
"import psutil; psutil.virtual_memory().total",
"import psutil; psutil.disk_usage('/')",
"import socket; socket.gethostbyname(socket.gethostname())",
"import requests; try: requests.get('http://www.google.com'); return True; except: return False",
"import requests; def download_file(url, path): with open(path, 'wb') as f: f.write(requests.get(url).content)",
"def upload_file(path): with open(path, 'rb') as f: requests.post('http://example.com/upload', files={'file': f})",
"import requests; requests.post(url, data={'key': 'value'})",
"import requests; requests.get(url, params={'key': 'value'})",
"import requests; requests.get(url, headers={'key': 'value'})",
"from bs4 import BeautifulSoup; BeautifulSoup(html, 'html.parser')",
"from bs4 import BeautifulSoup; soup.title.text",
"from bs4 import BeautifulSoup; [a['href'] for a in soup.find_all('a')]",
"from bs4 import BeautifulSoup; import requests; for img in soup.find_all('img'): requests.get(img['src']).content",
"from collections import Counter; Counter(text.split())",
"import requests; session = requests.Session(); session.post(login_url, data={'username': 'user', 'password': 'pass'})",
"from bs4 import BeautifulSoup; soup.get_text()",
"import re; re.findall(r'[\\w.-]+@[\\w.-]+', text)",
"import re; re.findall(r'\\+?\\d[\\d -]{8,12}\\d', text)",
"import re; re.findall(r'\\d+', text)",
"import re; re.sub(pattern, repl, text)",
"import re; re.match(pattern, text)",
"from bs4 import BeautifulSoup; soup.get_text()",
"import html; html.escape(text)",
"import html; html.unescape(text)",
"import tkinter as tk; root = tk.Tk(); root.mainloop()",
"import tkinter as tk; def add_button(window, text): return tk.Button(window, text=text)",
"""def bind_click(button, func): button.config(command=func)""",
"""import tkinter.messagebox; def show_alert(message): tkinter.messagebox.showinfo('Info', message)""",
"""def get_entry_text(entry): return entry.get()""",
"def set_title(window, title): window.title(title)",
"def set_size(window, width, height): window.geometry(f'{width}x{height}')",
"""def center_window(window):
window.update_idletasks()
width = window.winfo_width()
height = window.winfo_height()
x = (window.winfo_screenwidth() // 2) - (width // 2)
y = (window.winfo_screenheight() // 2) - (height // 2)
window.geometry(f'{width}x{height}+{x}+{y}')""",
"""def create_menu(window): return tk.Menu(window)""",
"def create_combobox(window): return ttk.Combobox(window)",
"def create_radiobutton(window, text): return tk.Radiobutton(window, text=text)",
"def create_checkbutton(window, text): return tk.Checkbutton(window, text=text)",
"""from PIL import ImageTk, Image; def show_image(window, path):
img = Image.open(path)
photo = ImageTk.PhotoImage(img)
label = tk.Label(window, image=photo)
label.image = photo
return label""",
"import pygame; def play_audio(path): pygame.mixer.init(); pygame.mixer.music.load(path); pygame.mixer.music.play()",
"import cv2; def play_video(path): cap = cv2.VideoCapture(path); while cap.isOpened(): ret, frame = cap.read()",
"def get_playback_time(): return pygame.mixer.music.get_pos()",
"import pyautogui; def screenshot(): return pyautogui.screenshot()",
"import pyautogui; import time; def record_screen(duration): return [pyautogui.screenshot() for _ in range(duration)]",
"def get_mouse_pos(): return pyautogui.position()",
"import pyautogui; def type_text(text): pyautogui.write(text)",
"import pyautogui; def click_mouse(x, y): pyautogui.click(x, y)",
"import time; def get_timestamp(): return int(time.time())",
"import datetime; def timestamp_to_date(ts): return datetime.datetime.fromtimestamp(ts)",
"import time; def date_to_timestamp(dt): return int(time.mktime(dt.timetuple()))",
"def get_weekday(): return datetime.datetime.now().strftime('%A')",
"import calendar; def get_month_days(): return calendar.monthrange(datetime.datetime.now().year, datetime.datetime.now().month)[1]",
"def first_day_of_year(): return datetime.date(datetime.datetime.now().year, 1, 1)",
"def last_day_of_year(): return datetime.date(datetime.datetime.now().year, 12, 31)",
"def first_day_of_month(month): return datetime.date(datetime.datetime.now().year, month, 1)",
"import calendar; def last_day_of_month(month): return datetime.date(datetime.datetime.now().year, month, calendar.monthrange(datetime.datetime.now().year, month)[1])",
"def is_weekday(): return datetime.datetime.now().weekday() < 5",
"def is_weekend(): return datetime.datetime.now().weekday() >= 5",
"def current_hour(): return datetime.datetime.now().hour",
"def current_minute(): return datetime.datetime.now().minute",
"def current_second(): return datetime.datetime.now().second",
"import time; def delay_1s(): time.sleep(1)",
"import time; def millis_timestamp(): return int(time.time() * 1000)",
"def format_time(dt, fmt='%Y-%m-%d %H:%M:%S'): return dt.strftime(fmt)",
"from dateutil.parser import parse; def parse_time(s): return parse(s)",
"import threading; def create_thread(target): return threading.Thread(target=target)",
"import time; def thread_pause(seconds): time.sleep(seconds)",
"def run_threads(*threads): [t.start() for t in threads]",
"import threading; def current_thread_name(): return threading.current_thread().name",
"def set_daemon(thread): thread.daemon = True",
"import threading; lock = threading.Lock()",
"import multiprocessing; def create_process(target): return multiprocessing.Process(target=target)",
"import os; def get_pid(): return os.getpid()",
"import psutil; def is_process_alive(pid): return psutil.pid_exists(pid)",
"def run_processes(*procs): [p.start() for p in procs]",
"from queue import Queue; q = Queue()",
"from multiprocessing import Pipe; parent_conn, child_conn = Pipe()",
"import os; def limit_cpu_usage(percent): os.system(f'cpulimit -p {os.getpid()} -l {percent}')",
"import subprocess; def run_command(cmd): subprocess.run(cmd, shell=True)",
"import subprocess; def get_command_output(cmd): return subprocess.check_output(cmd, shell=True).decode()",
"def get_exit_code(cmd): return subprocess.call(cmd, shell=True)",
"def is_success(code): return code == 0",
"import os; def script_path(): return os.path.realpath(__file__)",
"import sys; def get_cli_args(): return sys.argv[1:]",
"import argparse; parser = argparse.ArgumentParser()",
"parser.print_help()",
"help('modules')",
"import pip; def install_pkg(pkg): pip.main(['install', pkg])",
"import pip; def uninstall_pkg(pkg): pip.main(['uninstall', pkg])",
"import pkg_resources; def get_pkg_version(pkg): return pkg_resources.get_distribution(pkg).version",
"import venv; def create_venv(path): venv.create(path)",
"import pip; def list_pkgs(): return pip.get_installed_distributions()",
"import pip; def upgrade_pkg(pkg): pip.main(['install', '--upgrade', pkg])",
"import sqlite3; conn = sqlite3.connect(':memory:')",
"def execute_query(conn, query): return conn.execute(query)",
"""def insert_record(conn, table, data): conn.execute(f'INSERT INTO {table} VALUES ({",".join("?"*len(data))})', data)""",
"def delete_record(conn, table, condition): conn.execute(f'DELETE FROM {table} WHERE {condition}')",
"def update_record(conn, table, set_clause, condition): conn.execute(f'UPDATE {table} SET {set_clause} WHERE {condition}')",
"def fetch_all(conn, query): return conn.execute(query).fetchall()",
"def safe_query(conn, query, params): return conn.execute(query, params)",
"def close_db(conn): conn.close()",
"def create_table(conn, name, columns): conn.execute(f'CREATE TABLE {name} ({columns})')",
"def drop_table(conn, name): conn.execute(f'DROP TABLE {name}')",
"def table_exists(conn, name): return conn.execute(f\"SELECT name FROM sqlite_master WHERE type='table' AND name='{name}'\").fetchone()",
"def list_tables(conn): return conn.execute(\"SELECT name FROM sqlite_master WHERE type='table'\").fetchall()",
"""from sqlalchemy import Column, Integer, String
class User(Base):
__tablename__ = 'users'
id = Column(Integer, primary_key=True)
name = Column(String)""",
"session.add(User(name='John'))",
"session.query(User).filter_by(name='John')",
"session.query(User).filter_by(name='John').delete()",
"session.query(User).filter_by(name='John').update({'name': 'Bob'})",
"Base = declarative_base()",
"class Admin(User): pass",
"id = Column(Integer, primary_key=True)",
"name = Column(String, unique=True)",
"name = Column(String, default='Unknown')",
"import csv; def export_csv(data, path): open(path, 'w').write('\\n'.join([','.join(map(str, row)) for row in data]))",
"import pandas as pd; pd.DataFrame(data).to_excel(path)",
"import json; json.dump(data, open(path, 'w'))",
"pd.read_excel(path).values.tolist()",
"pd.concat([pd.read_excel(f) for f in files])",
"with pd.ExcelWriter(path, mode='a') as writer: df.to_excel(writer, sheet_name='New')",
"from openpyxl.styles import copy; copy.copy(style)",
"from openpyxl.styles import PatternFill; cell.fill = PatternFill(start_color='FFFF00', fill_type='solid')",
"from openpyxl.styles import Font; cell.font = Font(bold=True)",
"sheet['A1'].value",
"sheet['A1'] = value",
"from PIL import Image; Image.open(path).size",
"from PIL import Image; Image.open(path).resize((w, h))"
]
# 全局服务状态
service_ready = False
# 优雅关闭处理
def handle_shutdown(signum, frame):
app.logger.info("收到终止信号,开始关闭...")
sys.exit(0)
signal.signal(signal.SIGTERM, handle_shutdown)
signal.signal(signal.SIGINT, handle_shutdown)
# 初始化模型和预计算编码
try:
app.logger.info("开始加载模型...")
model = SentenceTransformer(
"flax-sentence-embeddings/st-codesearch-distilroberta-base",
cache_folder="/model-cache"
)
# 预计算代码片段的编码(强制使用 CPU)
code_emb = model.encode(CODE_SNIPPETS, convert_to_tensor=True, device="cpu")
service_ready = True
app.logger.info("服务初始化完成")
except Exception as e:
app.logger.error("初始化失败: %s", str(e))
raise
# Hugging Face 健康检查端点,必须响应根路径
@app.route('/')
def hf_health_check():
# 如果请求接受 HTML,则返回一个简单的 HTML 页面(包含测试链接)
if request.accept_mimetypes.accept_html:
html = """
<h2>CodeSearch API</h2>
<p>服务状态:{{ status }}</p>
<p>你可以在地址栏输入 /search?query=你的查询 来测试接口</p>
"""
status = "ready" if service_ready else "initializing"
return render_template_string(html, status=status)
# 否则返回 JSON 格式的健康检查
if service_ready:
return jsonify({"status": "ready"}), 200
else:
return jsonify({"status": "initializing"}), 503
# 搜索 API 端点,同时支持 GET 和 POST 请求
@app.route('/search', methods=['GET', 'POST'])
def handle_search():
if not service_ready:
app.logger.info("服务未就绪")
return jsonify({"error": "服务正在初始化"}), 503
try:
# 根据请求方法提取查询内容
if request.method == 'GET':
query = request.args.get('query', '').strip()
else:
data = request.get_json() or {}
query = data.get('query', '').strip()
if not query:
app.logger.info("收到空的查询请求")
return jsonify({"error": "查询不能为空"}), 400
# 记录接收到的查询
app.logger.info("收到查询请求: %s", query)
# 对查询进行编码,并进行语义搜索
query_emb = model.encode(query, convert_to_tensor=True, device="cpu")
hits = util.semantic_search(query_emb, code_emb, top_k=1)[0]
best = hits[0]
result = {
"code": CODE_SNIPPETS[best['corpus_id']],
"score": round(float(best['score']), 4)
}
# 记录返回结果
app.logger.info("返回结果: %s", result)
return jsonify(result)
except Exception as e:
app.logger.error("请求处理失败: %s", str(e))
return jsonify({"error": "服务器内部错误"}), 500
if __name__ == "__main__":
# 本地测试用,Hugging Face Spaces 通常通过 gunicorn 启动
app.run(host='0.0.0.0', port=7860)