Shuu12121's picture
Update app.py
e3dfc76 verified
# CodeSearch-ModernBERT-Owl Demo Space using CodeSearchNet Dataset
import gradio as gr
import torch
import random
from sentence_transformers import SentenceTransformer, util
from datasets import load_dataset
from spaces import GPU
import re
# --- Load model ---
model = SentenceTransformer("Shuu12121/CodeSearch-ModernBERT-Owl")
model.eval()
# --- Load CodeSearchNet dataset (test split only) ---
dataset = load_dataset("code_x_glue_tc_nl_code_search_adv", trust_remote_code=True, split="test")
def remove_comments_from_code(code: str) -> str:
# 複数行コメント(docstring含む)を除去
code = re.sub(r'"""[\s\S]*?"""', '', code)
code = re.sub(r"'''[\s\S]*?'''", '', code)
# 単一行コメント(# 以降を除去)
code = re.sub(r'#.*', '', code)
return code
# --- Query & Candidate Generator ---
def get_query_and_candidates(seed: int = 8520):
random.seed(seed)
idx = random.randint(0, len(dataset) - 1)
query = dataset[idx]
correct_code = remove_comments_from_code(query["code"]) # 修正
doc_str = query["docstring"]
candidate_pool = [example for i, example in enumerate(dataset) if i != idx]
negatives = random.sample(candidate_pool, k=99)
candidates = [correct_code] + [remove_comments_from_code(neg["code"]) for neg in negatives] # 修正
random.shuffle(candidates)
return doc_str, correct_code, candidates
@GPU
def code_search_demo(seed: int):
doc_str, correct_code, candidates = get_query_and_candidates(seed)
query_emb = model.encode(doc_str, convert_to_tensor=True)
candidate_embeddings = model.encode(candidates, convert_to_tensor=True)
cos_scores = util.cos_sim(query_emb, candidate_embeddings)[0]
results = sorted(zip(candidates, cos_scores), key=lambda x: x[1], reverse=True)
top_k = 10
correct_in_top_k = any(code.strip() == correct_code.strip() for code, _ in results[:top_k])
mrr = 0.0
for rank, (code, _) in enumerate(results, start=1):
if code.strip() == correct_code.strip():
mrr = 1.0 / rank
break
output = f"### 🔍 Query Docstring\n\n{doc_str}\n\n"
output += f"**✅ 正解は Top-{top_k} に含まれているか?**: {'🟢 Yes' if correct_in_top_k else '🔴 No'}\n\n"
output += f"**📈 MRR@{top_k}**: {mrr:.4f}\n\n"
output += "## 🏆 Top Matches:\n"
medals = ["🥇", "🥈", "🥉"] + [f"#{i+1}" for i in range(3, len(results))]
for i, (code, score) in enumerate(results):
label = medals[i] if i < len(medals) else f"#{i+1}"
is_correct = "✅" if code.strip() == correct_code.strip() else ""
output += f"\n**{label}** - Similarity: {score.item():.4f} {is_correct}\n\n```python\n{code.strip()[:1000]}\n```\n"
return output
# --- Gradio UI ---
demo = gr.Interface(
fn=code_search_demo,
inputs=gr.Slider(0, 100000, value=8520, step=1, label="Random Seed"),
outputs=gr.Markdown(label="Search Result"),
title="🔎 CodeSearch-ModernBERT-Owl🦉 Demo",
description="docstring から類似 Python 関数を検索(CodeXGlue + ModernBERT-Owl)"
)
if __name__ == "__main__":
demo.launch()