Spaces:
Running
Running
from flask import Flask, request, jsonify, render_template_string | |
from sentence_transformers import SentenceTransformer, util | |
import logging | |
import sys | |
import signal | |
# 初始化 Flask 应用 | |
app = Flask(__name__) | |
# 配置日志,级别设为 INFO | |
logging.basicConfig(level=logging.INFO) | |
app.logger = logging.getLogger("CodeSearchAPI") | |
# 预定义代码片段 | |
CODE_SNIPPETS = [ | |
"print('Hello, World!')", | |
"def add(a, b): return a + b", | |
"import random; def generate_random(): return random.randint(1, 100)", | |
"def is_even(n): return n % 2 == 0", | |
"def string_length(s): return len(s)", | |
"from datetime import date; def get_current_date(): return date.today()", | |
"import os; def file_exists(path): return os.path.exists(path)", | |
"def read_file(path): return open(path, 'r').read()", | |
"def write_file(path, content): open(path, 'w').write(content)", | |
"from datetime import datetime; def get_current_time(): return datetime.now()", | |
"def to_upper(s): return s.upper()", | |
"def to_lower(s): return s.lower()", | |
"def reverse_string(s): return s[::-1]", | |
"def list_length(lst): return len(lst)", | |
"def list_max(lst): return max(lst)", | |
"def list_min(lst): return min(lst)", | |
"def sort_list(lst): return sorted(lst)", | |
"def merge_lists(lst1, lst2): return lst1 + lst2", | |
"def remove_element(lst, element): lst.remove(element)", | |
"def is_list_empty(lst): return len(lst) == 0", | |
"def count_char(s, char): return s.count(char)", | |
"def contains_substring(s, sub): return sub in s", | |
"def int_to_str(n): return str(n)", | |
"def str_to_int(s): return int(s)", | |
"def is_numeric(s): return s.isdigit()", | |
"def get_index(lst, element): return lst.index(element)", | |
"def clear_list(lst): lst.clear()", | |
"def reverse_list(lst): lst.reverse()", | |
"def remove_duplicates(lst): return list(set(lst))", | |
"def is_in_list(lst, value): return value in lst", | |
"def create_dict(): return {}", | |
"def add_to_dict(d, key, value): d[key] = value", | |
"def delete_key(d, key): del d[key]", | |
"def get_keys(d): return list(d.keys())", | |
"def get_values(d): return list(d.values())", | |
"def merge_dicts(d1, d2): return {**d1, **d2}", | |
"def is_dict_empty(d): return len(d) == 0", | |
"def get_value(d, key): return d[key]", | |
"def key_exists(d, key): return key in d", | |
"def clear_dict(d): d.clear()", | |
"def count_lines(path): return len(open(path).readlines())", | |
"def write_list_to_file(path, lst): open(path, 'w').write('\\n'.join(map(str, lst)))", | |
"def read_list_from_file(path): return open(path, 'r').read().splitlines()", | |
"def count_words(path): return len(open(path, 'r').read().split())", | |
"def is_leap_year(year): return year % 4 == 0 and (year % 100 != 0 or year % 400 == 0)", | |
"from datetime import datetime; def format_time(dt): return dt.strftime('%Y-%m-%d %H:%M:%S')", | |
"from datetime import date; def days_between(d1, d2): return (d2 - d1).days", | |
"import os; def get_current_dir(): return os.getcwd()", | |
"import os; def list_files(path): return os.listdir(path)", | |
"import os; def create_dir(path): os.mkdir(path)" | |
"import os; def remove_dir(path): os.rmdir(path)", | |
"import os; def is_file(path): return os.path.isfile(path)", | |
"import os; def is_dir(path): return os.path.isdir(path)", | |
"import os; def get_file_size(path): return os.path.getsize(path)", | |
"import os; def rename_file(src, dst): os.rename(src, dst)", | |
"import shutil; def copy_file(src, dst): shutil.copy(src, dst)", | |
"import shutil; def move_file(src, dst): shutil.move(src, dst)", | |
"import os; def delete_file(path): os.remove(path)", | |
"import os; def get_env_var(key): return os.getenv(key)", | |
"import os; def set_env_var(key, value): os.environ[key] = value", | |
"import webbrowser; def open_url(url): webbrowser.open(url)", | |
"import requests; def send_get_request(url): return requests.get(url).text", | |
"import json; def parse_json(data): return json.loads(data)", | |
"import json; def write_json(data, path): open(path, 'w').write(json.dumps(data))", | |
"import json; def read_json(path): return json.loads(open(path, 'r').read())", | |
"def list_to_string(lst): return ''.join(lst)", | |
"def string_to_list(s): return list(s)", | |
"def join_with_comma(lst): return ','.join(lst)", | |
"def join_with_newline(lst): return '\\n'.join(lst)", | |
"def split_by_space(s): return s.split()", | |
"def split_by_char(s, char): return s.split(char)", | |
"def split_to_chars(s): return list(s)", | |
"def replace_string(s, old, new): return s.replace(old, new)", | |
"def remove_spaces(s): return s.replace(' ', '')", | |
"import string; def remove_punctuation(s): return s.translate(str.maketrans('', '', string.punctuation))", | |
"def is_string_empty(s): return len(s) == 0", | |
"def is_palindrome(s): return s == s[::-1]", | |
"import csv; def write_csv(data, path): open(path, 'w', newline='').write('\\n'.join([','.join(map(str, row)) for row in data]))", | |
"import csv; def read_csv(path): return [row for row in csv.reader(open(path, 'r'))]", | |
"def count_csv_lines(path): return len(open(path).readlines())", | |
"import random; def shuffle_list(lst): random.shuffle(lst)", | |
"import random; def random_choice(lst): return random.choice(lst)", | |
"import random; def random_sample(lst, k): return random.sample(lst, k)", | |
"import random; def roll_dice(): return random.randint(1, 6)", | |
"import random; def flip_coin(): return random.choice(['Heads', 'Tails'])", | |
"import random; import string; def generate_password(length=8): return ''.join(random.choices(string.ascii_letters + string.digits, k=length))", | |
"import random; def generate_color(): return '#%06x' % random.randint(0, 0xFFFFFF)", | |
"import uuid; def generate_uuid(): return str(uuid.uuid4())", | |
"class MyClass: pass", | |
"def create_instance(): return MyClass()", | |
"class MyClass: def my_method(self): pass", | |
"class MyClass: def __init__(self): self.my_attr = None", | |
"class ChildClass(MyClass): pass", | |
"class ChildClass(MyClass): def my_method(self): pass", | |
"class MyClass: @classmethod def my_class_method(cls): pass", | |
"class MyClass: @staticmethod def my_static_method(): pass", | |
"def check_type(obj): return type(obj)", | |
"def get_attr(obj, attr): return getattr(obj, attr)", | |
"def set_attr(obj, attr, value): setattr(obj, attr, value)", | |
"def del_attr(obj, attr): delattr(obj, attr)", | |
"""try: | |
x = 1 / 0 | |
except ZeroDivisionError: | |
pass""", | |
"""class CustomError(Exception): pass | |
def raise_custom_error(): raise CustomError('Error occurred')""", | |
"""try: | |
x = 1 / 0 | |
except Exception as e: return str(e)""", | |
"""import logging; logging.basicConfig(filename='error.log', level=logging.ERROR); logging.error('Error occurred')""", | |
"""import time; def timer(): start = time.time(); return lambda: time.time() - start""", | |
"""import time; def run_time(): start = time.time(); return lambda: time.time() - start""", | |
"""import sys; def print_progress(progress): sys.stdout.write(f'\\rProgress: {progress}%'); sys.stdout.flush()""", | |
"""import time; def delay(seconds): time.sleep(seconds)""", | |
"lambda x: x * 2", | |
"map(lambda x: x * 2, [1, 2, 3])", | |
"filter(lambda x: x > 2, [1, 2, 3])", | |
"from functools import reduce; reduce(lambda x, y: x + y, [1, 2, 3])", | |
"[x * 2 for x in [1, 2, 3]]", | |
"{x: x * 2 for x in [1, 2, 3]}", | |
"{x for x in [1, 2, 3]}", | |
"set1 & set2", | |
"set1 | set2", | |
"set1 - set2", | |
"[x for x in lst if x is not None]", | |
"""try: | |
with open('file.txt', 'r') as f: pass | |
except IOError: pass""", | |
"type(var)", | |
"bool(s)", | |
"if condition: pass", | |
"while condition: pass", | |
"for item in lst: pass", | |
"for key, value in d.items(): pass", | |
"for char in s: pass", | |
"for item in lst: if condition: break", | |
"for item in lst: if condition: continue", | |
"def my_func(): pass", | |
"def my_func(param=1): pass", | |
"def my_func(): return 1, 2", | |
"def my_func(*args): pass", | |
"def my_func(**kwargs): pass", | |
"""import time; def timer(func): | |
def wrapper(*args, **kwargs): | |
start = time.time() | |
result = func(*args, **kwargs) | |
print(f'Time: {time.time() - start}'); return result | |
return wrapper""", | |
"""def decorator(func): | |
def wrapper(*args, **kwargs): return func(*args, **kwargs) | |
return wrapper""", | |
"""from functools import lru_cache; @lru_cache(maxsize=None) | |
def my_func(): pass""", | |
"def my_generator(): yield 1", | |
"gen = my_generator(); next(gen)", | |
"class MyIterator: def __iter__(self): return self; def __next__(self): pass", | |
"it = iter([1, 2, 3]); next(it)", | |
"for i, val in enumerate(lst): pass", | |
"list(zip(lst1, lst2))", | |
"dict(zip(keys, values))", | |
"lst1 == lst2", | |
"dict1 == dict2", | |
"set1 == set2", | |
"set(lst)", | |
"set.clear()", | |
"len(set) == 0", | |
"set.add(item)", | |
"set.remove(item)", | |
"item in set", | |
"len(set)", | |
"set1 & set2", | |
"set(lst1).issubset(lst2)", | |
"sub in s", | |
"s[0]", | |
"s[-1]", | |
"import mimetypes; mimetypes.guess_type(path)[0] == 'text/plain'", | |
"import mimetypes; mimetypes.guess_type(path)[0].startswith('image/')", | |
"round(num)", | |
"import math; math.ceil(num)", | |
"import math; math.floor(num)", | |
"f'{num:.2f}'", | |
"import random; import string; ''.join(random.choices(string.ascii_letters + string.digits, k=8))", | |
"import os; os.path.exists(path)", | |
"import os; for root, dirs, files in os.walk(path): pass", | |
"import os; os.path.splitext(path)[1]", | |
"import os; os.path.basename(path)", | |
"import os; os.path.abspath(path)", | |
"import platform; platform.python_version()", | |
"import platform; platform.system()", | |
"import multiprocessing; multiprocessing.cpu_count()", | |
"import psutil; psutil.virtual_memory().total", | |
"import psutil; psutil.disk_usage('/')", | |
"import socket; socket.gethostbyname(socket.gethostname())", | |
"import requests; try: requests.get('http://www.google.com'); return True; except: return False", | |
"import requests; def download_file(url, path): with open(path, 'wb') as f: f.write(requests.get(url).content)", | |
"def upload_file(path): with open(path, 'rb') as f: requests.post('http://example.com/upload', files={'file': f})", | |
"import requests; requests.post(url, data={'key': 'value'})", | |
"import requests; requests.get(url, params={'key': 'value'})", | |
"import requests; requests.get(url, headers={'key': 'value'})", | |
"from bs4 import BeautifulSoup; BeautifulSoup(html, 'html.parser')", | |
"from bs4 import BeautifulSoup; soup.title.text", | |
"from bs4 import BeautifulSoup; [a['href'] for a in soup.find_all('a')]", | |
"from bs4 import BeautifulSoup; import requests; for img in soup.find_all('img'): requests.get(img['src']).content", | |
"from collections import Counter; Counter(text.split())", | |
"import requests; session = requests.Session(); session.post(login_url, data={'username': 'user', 'password': 'pass'})", | |
"from bs4 import BeautifulSoup; soup.get_text()", | |
"import re; re.findall(r'[\\w.-]+@[\\w.-]+', text)", | |
"import re; re.findall(r'\\+?\\d[\\d -]{8,12}\\d', text)", | |
"import re; re.findall(r'\\d+', text)", | |
"import re; re.sub(pattern, repl, text)", | |
"import re; re.match(pattern, text)", | |
"from bs4 import BeautifulSoup; soup.get_text()", | |
"import html; html.escape(text)", | |
"import html; html.unescape(text)", | |
"import tkinter as tk; root = tk.Tk(); root.mainloop()", | |
"import tkinter as tk; def add_button(window, text): return tk.Button(window, text=text)", | |
"""def bind_click(button, func): button.config(command=func)""", | |
"""import tkinter.messagebox; def show_alert(message): tkinter.messagebox.showinfo('Info', message)""", | |
"""def get_entry_text(entry): return entry.get()""", | |
"def set_title(window, title): window.title(title)", | |
"def set_size(window, width, height): window.geometry(f'{width}x{height}')", | |
"""def center_window(window): | |
window.update_idletasks() | |
width = window.winfo_width() | |
height = window.winfo_height() | |
x = (window.winfo_screenwidth() // 2) - (width // 2) | |
y = (window.winfo_screenheight() // 2) - (height // 2) | |
window.geometry(f'{width}x{height}+{x}+{y}')""", | |
"""def create_menu(window): return tk.Menu(window)""", | |
"def create_combobox(window): return ttk.Combobox(window)", | |
"def create_radiobutton(window, text): return tk.Radiobutton(window, text=text)", | |
"def create_checkbutton(window, text): return tk.Checkbutton(window, text=text)", | |
"""from PIL import ImageTk, Image; def show_image(window, path): | |
img = Image.open(path) | |
photo = ImageTk.PhotoImage(img) | |
label = tk.Label(window, image=photo) | |
label.image = photo | |
return label""", | |
"import pygame; def play_audio(path): pygame.mixer.init(); pygame.mixer.music.load(path); pygame.mixer.music.play()", | |
"import cv2; def play_video(path): cap = cv2.VideoCapture(path); while cap.isOpened(): ret, frame = cap.read()", | |
"def get_playback_time(): return pygame.mixer.music.get_pos()", | |
"import pyautogui; def screenshot(): return pyautogui.screenshot()", | |
"import pyautogui; import time; def record_screen(duration): return [pyautogui.screenshot() for _ in range(duration)]", | |
"def get_mouse_pos(): return pyautogui.position()", | |
"import pyautogui; def type_text(text): pyautogui.write(text)", | |
"import pyautogui; def click_mouse(x, y): pyautogui.click(x, y)", | |
"import time; def get_timestamp(): return int(time.time())", | |
"import datetime; def timestamp_to_date(ts): return datetime.datetime.fromtimestamp(ts)", | |
"import time; def date_to_timestamp(dt): return int(time.mktime(dt.timetuple()))", | |
"def get_weekday(): return datetime.datetime.now().strftime('%A')", | |
"import calendar; def get_month_days(): return calendar.monthrange(datetime.datetime.now().year, datetime.datetime.now().month)[1]", | |
"def first_day_of_year(): return datetime.date(datetime.datetime.now().year, 1, 1)", | |
"def last_day_of_year(): return datetime.date(datetime.datetime.now().year, 12, 31)", | |
"def first_day_of_month(month): return datetime.date(datetime.datetime.now().year, month, 1)", | |
"import calendar; def last_day_of_month(month): return datetime.date(datetime.datetime.now().year, month, calendar.monthrange(datetime.datetime.now().year, month)[1])", | |
"def is_weekday(): return datetime.datetime.now().weekday() < 5", | |
"def is_weekend(): return datetime.datetime.now().weekday() >= 5", | |
"def current_hour(): return datetime.datetime.now().hour", | |
"def current_minute(): return datetime.datetime.now().minute", | |
"def current_second(): return datetime.datetime.now().second", | |
"import time; def delay_1s(): time.sleep(1)", | |
"import time; def millis_timestamp(): return int(time.time() * 1000)", | |
"def format_time(dt, fmt='%Y-%m-%d %H:%M:%S'): return dt.strftime(fmt)", | |
"from dateutil.parser import parse; def parse_time(s): return parse(s)", | |
"import threading; def create_thread(target): return threading.Thread(target=target)", | |
"import time; def thread_pause(seconds): time.sleep(seconds)", | |
"def run_threads(*threads): [t.start() for t in threads]", | |
"import threading; def current_thread_name(): return threading.current_thread().name", | |
"def set_daemon(thread): thread.daemon = True", | |
"import threading; lock = threading.Lock()", | |
"import multiprocessing; def create_process(target): return multiprocessing.Process(target=target)", | |
"import os; def get_pid(): return os.getpid()", | |
"import psutil; def is_process_alive(pid): return psutil.pid_exists(pid)", | |
"def run_processes(*procs): [p.start() for p in procs]", | |
"from queue import Queue; q = Queue()", | |
"from multiprocessing import Pipe; parent_conn, child_conn = Pipe()", | |
"import os; def limit_cpu_usage(percent): os.system(f'cpulimit -p {os.getpid()} -l {percent}')", | |
"import subprocess; def run_command(cmd): subprocess.run(cmd, shell=True)", | |
"import subprocess; def get_command_output(cmd): return subprocess.check_output(cmd, shell=True).decode()", | |
"def get_exit_code(cmd): return subprocess.call(cmd, shell=True)", | |
"def is_success(code): return code == 0", | |
"import os; def script_path(): return os.path.realpath(__file__)", | |
"import sys; def get_cli_args(): return sys.argv[1:]", | |
"import argparse; parser = argparse.ArgumentParser()", | |
"parser.print_help()", | |
"help('modules')", | |
"import pip; def install_pkg(pkg): pip.main(['install', pkg])", | |
"import pip; def uninstall_pkg(pkg): pip.main(['uninstall', pkg])", | |
"import pkg_resources; def get_pkg_version(pkg): return pkg_resources.get_distribution(pkg).version", | |
"import venv; def create_venv(path): venv.create(path)", | |
"import pip; def list_pkgs(): return pip.get_installed_distributions()", | |
"import pip; def upgrade_pkg(pkg): pip.main(['install', '--upgrade', pkg])", | |
"import sqlite3; conn = sqlite3.connect(':memory:')", | |
"def execute_query(conn, query): return conn.execute(query)", | |
"""def insert_record(conn, table, data): conn.execute(f'INSERT INTO {table} VALUES ({",".join("?"*len(data))})', data)""", | |
"def delete_record(conn, table, condition): conn.execute(f'DELETE FROM {table} WHERE {condition}')", | |
"def update_record(conn, table, set_clause, condition): conn.execute(f'UPDATE {table} SET {set_clause} WHERE {condition}')", | |
"def fetch_all(conn, query): return conn.execute(query).fetchall()", | |
"def safe_query(conn, query, params): return conn.execute(query, params)", | |
"def close_db(conn): conn.close()", | |
"def create_table(conn, name, columns): conn.execute(f'CREATE TABLE {name} ({columns})')", | |
"def drop_table(conn, name): conn.execute(f'DROP TABLE {name}')", | |
"def table_exists(conn, name): return conn.execute(f\"SELECT name FROM sqlite_master WHERE type='table' AND name='{name}'\").fetchone()", | |
"def list_tables(conn): return conn.execute(\"SELECT name FROM sqlite_master WHERE type='table'\").fetchall()", | |
"""from sqlalchemy import Column, Integer, String | |
class User(Base): | |
__tablename__ = 'users' | |
id = Column(Integer, primary_key=True) | |
name = Column(String)""", | |
"session.add(User(name='John'))", | |
"session.query(User).filter_by(name='John')", | |
"session.query(User).filter_by(name='John').delete()", | |
"session.query(User).filter_by(name='John').update({'name': 'Bob'})", | |
"Base = declarative_base()", | |
"class Admin(User): pass", | |
"id = Column(Integer, primary_key=True)", | |
"name = Column(String, unique=True)", | |
"name = Column(String, default='Unknown')", | |
"import csv; def export_csv(data, path): open(path, 'w').write('\\n'.join([','.join(map(str, row)) for row in data]))", | |
"import pandas as pd; pd.DataFrame(data).to_excel(path)", | |
"import json; json.dump(data, open(path, 'w'))", | |
"pd.read_excel(path).values.tolist()", | |
"pd.concat([pd.read_excel(f) for f in files])", | |
"with pd.ExcelWriter(path, mode='a') as writer: df.to_excel(writer, sheet_name='New')", | |
"from openpyxl.styles import copy; copy.copy(style)", | |
"from openpyxl.styles import PatternFill; cell.fill = PatternFill(start_color='FFFF00', fill_type='solid')", | |
"from openpyxl.styles import Font; cell.font = Font(bold=True)", | |
"sheet['A1'].value", | |
"sheet['A1'] = value", | |
"from PIL import Image; Image.open(path).size", | |
"from PIL import Image; Image.open(path).resize((w, h))" | |
] | |
# 全局服务状态 | |
service_ready = False | |
# 优雅关闭处理 | |
def handle_shutdown(signum, frame): | |
app.logger.info("收到终止信号,开始关闭...") | |
sys.exit(0) | |
signal.signal(signal.SIGTERM, handle_shutdown) | |
signal.signal(signal.SIGINT, handle_shutdown) | |
# 初始化模型和预计算编码 | |
try: | |
app.logger.info("开始加载模型...") | |
model = SentenceTransformer( | |
"flax-sentence-embeddings/st-codesearch-distilroberta-base", | |
cache_folder="/model-cache" | |
) | |
# 预计算代码片段的编码(强制使用 CPU) | |
code_emb = model.encode(CODE_SNIPPETS, convert_to_tensor=True, device="cpu") | |
service_ready = True | |
app.logger.info("服务初始化完成") | |
except Exception as e: | |
app.logger.error("初始化失败: %s", str(e)) | |
raise | |
# Hugging Face 健康检查端点,必须响应根路径 | |
def hf_health_check(): | |
# 如果请求接受 HTML,则返回一个简单的 HTML 页面(包含测试链接) | |
if request.accept_mimetypes.accept_html: | |
html = """ | |
<h2>CodeSearch API</h2> | |
<p>服务状态:{{ status }}</p> | |
<p>你可以在地址栏输入 /search?query=你的查询 来测试接口</p> | |
""" | |
status = "ready" if service_ready else "initializing" | |
return render_template_string(html, status=status) | |
# 否则返回 JSON 格式的健康检查 | |
if service_ready: | |
return jsonify({"status": "ready"}), 200 | |
else: | |
return jsonify({"status": "initializing"}), 503 | |
# 搜索 API 端点,同时支持 GET 和 POST 请求 | |
def handle_search(): | |
if not service_ready: | |
app.logger.info("服务未就绪") | |
return jsonify({"error": "服务正在初始化"}), 503 | |
try: | |
# 根据请求方法提取查询内容 | |
if request.method == 'GET': | |
query = request.args.get('query', '').strip() | |
else: | |
data = request.get_json() or {} | |
query = data.get('query', '').strip() | |
if not query: | |
app.logger.info("收到空的查询请求") | |
return jsonify({"error": "查询不能为空"}), 400 | |
# 记录接收到的查询 | |
app.logger.info("收到查询请求: %s", query) | |
# 对查询进行编码,并进行语义搜索 | |
query_emb = model.encode(query, convert_to_tensor=True, device="cpu") | |
hits = util.semantic_search(query_emb, code_emb, top_k=1)[0] | |
best = hits[0] | |
result = { | |
"code": CODE_SNIPPETS[best['corpus_id']], | |
"score": round(float(best['score']), 4) | |
} | |
# 记录返回结果 | |
app.logger.info("返回结果: %s", result) | |
return jsonify(result) | |
except Exception as e: | |
app.logger.error("请求处理失败: %s", str(e)) | |
return jsonify({"error": "服务器内部错误"}), 500 | |
if __name__ == "__main__": | |
# 本地测试用,Hugging Face Spaces 通常通过 gunicorn 启动 | |
app.run(host='0.0.0.0', port=7860) | |