codesearchBase / app.py
Forrest99's picture
Update app.py
adb528f verified
raw
history blame
15.9 kB
from flask import Flask, request, jsonify, render_template_string
from sentence_transformers import SentenceTransformer, util
import logging
import sys
import signal
# 初始化 Flask 应用
app = Flask(__name__)
# 配置日志,级别设为 INFO
logging.basicConfig(level=logging.INFO)
app.logger = logging.getLogger("CodeSearchAPI")
# 预定义代码片段
CODE_SNIPPETS = [
以下是所有功能转换为纯 Java 语言实现的代码块,严格按照要求输出:
"System.out.println(\"Hello, World!\");",
"public static int sum(int a, int b) { return a + b; }",
"import java.util.Random; public static int generateRandomNumber() { return new Random().nextInt(); }",
"public static boolean isEven(int number) { return number % 2 == 0; }",
"public static int stringLength(String str) { return str.length(); }",
"import java.time.LocalDate; public static LocalDate getCurrentDate() { return LocalDate.now(); }",
"import java.io.File; public static boolean fileExists(String path) { return new File(path).exists(); }",
"import java.nio.file.Files; import java.nio.file.Paths; public static String readFileContent(String path) throws Exception { return new String(Files.readAllBytes(Paths.get(path))); }",
"import java.nio.file.Files; import java.nio.file.Paths; public static void writeToFile(String path, String content) throws Exception { Files.write(Paths.get(path), content.getBytes()); }",
"import java.time.LocalTime; public static LocalTime getCurrentTime() { return LocalTime.now(); }",
"public static String toUpperCase(String str) { return str.toUpperCase(); }",
"public static String toLowerCase(String str) { return str.toLowerCase(); }",
"public static String reverseString(String str) { return new StringBuilder(str).reverse().toString(); }",
"public static int countListElements(List<?> list) { return list.size(); }",
"public static int findMax(List<Integer> list) { return Collections.max(list); }",
"public static int findMin(List<Integer> list) { return Collections.min(list); }",
"public static List<Integer> sortList(List<Integer> list) { Collections.sort(list); return list; }",
"public static List<?> mergeLists(List<?> list1, List<?> list2) { List<Object> mergedList = new ArrayList<>(list1); mergedList.addAll(list2); return mergedList; }",
"public static void removeElement(List<?> list, Object element) { list.remove(element); }",
"public static boolean isListEmpty(List<?> list) { return list.isEmpty(); }",
"public static int countCharOccurrences(String str, char ch) { return (int) str.chars().filter(c -> c == ch).count(); }",
"public static boolean containsSubstring(String str, String sub) { return str.contains(sub); }",
"public static String numberToString(int number) { return Integer.toString(number); }",
"public static int stringToNumber(String str) { return Integer.parseInt(str); }",
"public static boolean isNumeric(String str) { return str.matches(\"\\\\d+\"); }",
"public static int getElementIndex(List<?> list, Object element) { return list.indexOf(element); }",
"public static void clearList(List<?> list) { list.clear(); }",
"public static void reverseList(List<?> list) { Collections.reverse(list); }",
"public static List<?> removeDuplicates(List<?> list) { return new ArrayList<>(new HashSet<>(list)); }",
"public static boolean isInList(List<?> list, Object element) { return list.contains(element); }",
"public static Map<Object, Object> createDictionary() { return new HashMap<>(); }",
"public static void addToDictionary(Map<Object, Object> dict, Object key, Object value) { dict.put(key, value); }",
"public static void removeFromDictionary(Map<Object, Object> dict, Object key) { dict.remove(key); }",
"public static Set<Object> getDictionaryKeys(Map<Object, Object> dict) { return dict.keySet(); }",
"public static Collection<Object> getDictionaryValues(Map<Object, Object> dict) { return dict.values(); }",
"public static Map<Object, Object> mergeDictionaries(Map<Object, Object> dict1, Map<Object, Object> dict2) { Map<Object, Object> mergedDict = new HashMap<>(dict1); mergedDict.putAll(dict2); return mergedDict; }",
"public static boolean isDictionaryEmpty(Map<Object, Object> dict) { return dict.isEmpty(); }",
"public static Object getDictionaryValue(Map<Object, Object> dict, Object key) { return dict.get(key); }",
"public static boolean keyExistsInDictionary(Map<Object, Object> dict, Object key) { return dict.containsKey(key); }",
"public static void clearDictionary(Map<Object, Object> dict) { dict.clear(); }",
"import java.io.BufferedReader; import java.io.FileReader; public static int countFileLines(String path) throws Exception { return (int) new BufferedReader(new FileReader(path)).lines().count(); }",
"import java.nio.file.Files; import java.nio.file.Paths; public static void writeListToFile(String path, List<?> list) throws Exception { Files.write(Paths.get(path), list.toString().getBytes()); }",
"import java.nio.file.Files; import java.nio.file.Paths; public static List<String> readListFromFile(String path) throws Exception { return Files.readAllLines(Paths.get(path)); }",
"import java.io.BufferedReader; import java.io.FileReader; public static int countFileWords(String path) throws Exception { return new BufferedReader(new FileReader(path)).lines().mapToInt(line -> line.split(\"\\\\s+\").length).sum(); }",
"public static boolean isLeapYear(int year) { return (year % 4 == 0 && year % 100 != 0) || (year % 400 == 0); }",
"import java.time.format.DateTimeFormatter; public static String formatTime(LocalTime time, String pattern) { return time.format(DateTimeFormatter.ofPattern(pattern)); }",
"import java.time.LocalDate; import java.time.temporal.ChronoUnit; public static long daysBetween(LocalDate date1, LocalDate date2) { return ChronoUnit.DAYS.between(date1, date2); }",
"import java.nio.file.Paths; public static String getCurrentWorkingDirectory() { return Paths.get(\"\").toAbsolutePath().toString(); }",
"import java.io.File; public static List<String> listFilesInDirectory(String path) { return Arrays.asList(new File(path).list()); }",
"import java.nio.file.Files; import java.nio.file.Paths; public static void createDirectory(String path) throws Exception { Files.createDirectory(Paths.get(path)); }",
"import java.nio.file.Files; import java.nio.file.Paths; public static void deleteDirectory(String path) throws Exception { Files.delete(Paths.get(path)); }",
"import java.nio.file.Files; import java.nio.file.Paths; public static boolean isFile(String path) { return Files.isRegularFile(Paths.get(path)); }",
"import java.nio.file.Files; import java.nio.file.Paths; public static boolean isDirectory(String path) { return Files.isDirectory(Paths.get(path)); }",
"import java.nio.file.Files; import java.nio.file.Paths; public static long getFileSize(String path) throws Exception { return Files.size(Paths.get(path)); }",
"import java.nio.file.Files; import java.nio.file.Paths; public static void renameFile(String oldPath, String newPath) throws Exception { Files.move(Paths.get(oldPath), Paths.get(newPath)); }",
"import java.nio.file.Files; import java.nio.file.Paths; public static void copyFile(String sourcePath, String destinationPath) throws Exception { Files.copy(Paths.get(sourcePath), Paths.get(destinationPath)); }",
"import java.nio.file.Files; import java.nio.file.Paths; public static void moveFile(String sourcePath, String destinationPath) throws Exception { Files.move(Paths.get(sourcePath), Paths.get(destinationPath)); }",
"import java.nio.file.Files; import java.nio.file.Paths; public static void deleteFile(String path) throws Exception { Files.delete(Paths.get(path)); }",
"public static String getEnvVariable(String name) { return System.getenv(name); }",
"public static void setEnvVariable(String name, String value) { System.setProperty(name, value); }",
"import java.awt.Desktop; import java.net.URI; public static void openWebLink(String url) throws Exception { Desktop.getDesktop().browse(new URI(url)); }",
"import java.net.HttpURLConnection; import java.net.URL; public static String sendGetRequest(String url) throws Exception { HttpURLConnection connection = (HttpURLConnection) new URL(url).openConnection(); connection.setRequestMethod(\"GET\"); return new String(connection.getInputStream().readAllBytes()); }",
"import com.google.gson.JsonParser; public static Object parseJson(String json) { return JsonParser.parseString(json); }",
"import com.google.gson.Gson; import java.nio.file.Files; import java.nio.file.Paths; public static void writeJsonToFile(String path, Object obj) throws Exception { Files.write(Paths.get(path), new Gson().toJson(obj).getBytes()); }",
"import com.google.gson.Gson; import java.nio.file.Files; import java.nio.file.Paths; public static Object readJsonFromFile(String path) throws Exception { return new Gson().fromJson(new String(Files.readAllBytes(Paths.get(path))), Object.class); }",
"public static String listToString(List<?> list) { return list.toString(); }",
"public static List<String> stringToList(String str) { return Arrays.asList(str.split(\",\")); }",
"public static String joinWithComma(List<?> list) { return String.join(\",\", list.stream().map(Object::toString).toArray(String[]::new)); }",
"public static String joinWithNewline(List<?> list) { return String.join(\"\\n\", list.stream().map(Object::toString).toArray(String[]::new)); }",
"public static String[] splitBySpace(String str) { return str.split(\"\\\\s+\"); }",
"public static String[] splitByDelimiter(String str, String delimiter) { return str.split(delimiter); }",
"public static String[] splitIntoChars(String str) { return str.split(\"\"); }",
"public static String replaceInString(String str, String target, String replacement) { return str.replace(target, replacement); }",
"public static String removeSpaces(String str) { return str.replaceAll(\"\\\\s\", \"\"); }",
"public static String removePunctuation(String str) { return str.replaceAll(\"[^a-zA-Z0-9]\", \"\"); }",
"public static boolean isStringEmpty(String str) { return str.isEmpty(); }",
"public static boolean isPalindrome(String str) { return str.equals(new StringBuilder(str).reverse().toString()); }",
"import com.opencsv.CSVWriter; import java.io.FileWriter; public static void writeToCsv(String path, List<String[]> data) throws Exception { CSVWriter writer = new CSVWriter(new FileWriter(path)); writer.writeAll(data); writer.close(); }",
"import com.opencsv.CSVReader; import java.io.FileReader; public static List<String[]> readFromCsv(String path) throws Exception { CSVReader reader = new CSVReader(new FileReader(path)); return reader.readAll(); }",
"import com.opencsv.CSVReader; import java.io.FileReader; public static int countCsvLines(String path) throws Exception { return readFromCsv(path).size(); }",
"public static void shuffleList(List<?> list) { Collections.shuffle(list); }",
"public static Object getRandomElement(List<?> list) { return list.get(new Random().nextInt(list.size())); }",
"public static List<?> getRandomElements(List<?> list, int count) { Collections.shuffle(list); return list.subList(0, count); }",
"public static int rollDice() { return new Random().nextInt(6) + 1; }",
"public static String flipCoin() { return new Random().nextBoolean() ? \"Heads\" : \"Tails\"; }",
"import java.util.Random; public static String generateRandomPassword(int length) { String chars = \"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789\"; StringBuilder password = new StringBuilder(); for (int i = 0; i < length; i++) { password.append(chars.charAt(new Random().nextInt(chars.length()))); } return password.toString(); }",
"import java.util.Random; public static String generateRandomColor() { Random random = new Random(); return String.format(\"#%06x\", random.nextInt(0xFFFFFF + 1)); }",
"import java.util.UUID; public static String generateUniqueId() { return UUID.randomUUID().toString(); }",
"public class MyClass {}",
"MyClass myObject = new MyClass();",
"public class MyClass { public void myMethod() {} }",
"public class MyClass { public String myAttribute; }",
"public class ChildClass extends MyClass {}",
"public class ChildClass extends MyClass { @Override public void myMethod() {} }",
"public class MyClass { public static void myClassMethod() {} }",
"public class MyClass { public static void myStaticMethod() {} }",
"public static boolean isInstanceOf(Object obj, Class<?> clazz) { return clazz.isInstance(obj); }",
"public static Object getAttribute(Object obj, String attribute) throws Exception { return obj.getClass().getDeclaredField(attribute).get(obj); }",
"public static void setAttribute(Object obj, String attribute, Object value) throws Exception { obj.getClass().getDeclaredField(attribute).set(obj, value); }",
"public static void deleteAttribute(Object obj, String attribute) throws Exception { obj.getClass().getDeclaredField(attribute).set(obj, null); }"
]
# 全局服务状态
service_ready = False
# 优雅关闭处理
def handle_shutdown(signum, frame):
app.logger.info("收到终止信号,开始关闭...")
sys.exit(0)
signal.signal(signal.SIGTERM, handle_shutdown)
signal.signal(signal.SIGINT, handle_shutdown)
# 初始化模型和预计算编码
try:
app.logger.info("开始加载模型...")
model = SentenceTransformer(
"flax-sentence-embeddings/st-codesearch-distilroberta-base",
cache_folder="/model-cache"
)
# 预计算代码片段的编码(强制使用 CPU)
code_emb = model.encode(CODE_SNIPPETS, convert_to_tensor=True, device="cpu")
service_ready = True
app.logger.info("服务初始化完成")
except Exception as e:
app.logger.error("初始化失败: %s", str(e))
raise
# Hugging Face 健康检查端点,必须响应根路径
@app.route('/')
def hf_health_check():
# 如果请求接受 HTML,则返回一个简单的 HTML 页面(包含测试链接)
if request.accept_mimetypes.accept_html:
html = """
<h2>CodeSearch API</h2>
<p>服务状态:{{ status }}</p>
<p>你可以在地址栏输入 /search?query=你的查询 来测试接口</p>
"""
status = "ready" if service_ready else "initializing"
return render_template_string(html, status=status)
# 否则返回 JSON 格式的健康检查
if service_ready:
return jsonify({"status": "ready"}), 200
else:
return jsonify({"status": "initializing"}), 503
# 搜索 API 端点,同时支持 GET 和 POST 请求
@app.route('/search', methods=['GET', 'POST'])
def handle_search():
if not service_ready:
app.logger.info("服务未就绪")
return jsonify({"error": "服务正在初始化"}), 503
try:
# 根据请求方法提取查询内容
if request.method == 'GET':
query = request.args.get('query', '').strip()
else:
data = request.get_json() or {}
query = data.get('query', '').strip()
if not query:
app.logger.info("收到空的查询请求")
return jsonify({"error": "查询不能为空"}), 400
# 记录接收到的查询
app.logger.info("收到查询请求: %s", query)
# 对查询进行编码,并进行语义搜索
query_emb = model.encode(query, convert_to_tensor=True, device="cpu")
hits = util.semantic_search(query_emb, code_emb, top_k=1)[0]
best = hits[0]
result = {
"code": CODE_SNIPPETS[best['corpus_id']],
"score": round(float(best['score']), 4)
}
# 记录返回结果
app.logger.info("返回结果: %s", result)
return jsonify(result)
except Exception as e:
app.logger.error("请求处理失败: %s", str(e))
return jsonify({"error": "服务器内部错误"}), 500
if __name__ == "__main__":
# 本地测试用,Hugging Face Spaces 通常通过 gunicorn 启动
app.run(host='0.0.0.0', port=7860)