Spaces:
Runtime error
Runtime error
import os | |
import streamlit as st | |
import javalang | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import re | |
import numpy as np | |
import networkx as nx | |
from transformers import AutoTokenizer, AutoModel | |
import warnings | |
import pandas as pd | |
from collections import defaultdict | |
# Configuration | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
warnings.filterwarnings("ignore") | |
# Constants | |
MODEL_NAME = "microsoft/codebert-base" | |
MAX_LENGTH = 512 | |
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
# Set up page config | |
st.set_page_config( | |
page_title="Java Code Clone Detector", | |
page_icon="π", | |
layout="wide" | |
) | |
# Simplified RNN Model (for Hugging Face compatibility) | |
class SimpleRNN(nn.Module): | |
def __init__(self, input_size=768, hidden_size=128): | |
super().__init__() | |
self.rnn = nn.RNN(input_size, hidden_size, batch_first=True) | |
self.fc = nn.Linear(hidden_size, 1) | |
def forward(self, x): | |
out, _ = self.rnn(x) | |
return torch.sigmoid(self.fc(out[:, -1])) | |
# Model Loading with caching | |
def load_models(): | |
try: | |
with st.spinner('Loading models (first run may take a few minutes)...'): | |
# Load CodeBERT | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
code_model = AutoModel.from_pretrained(MODEL_NAME).to(DEVICE) | |
# Initialize simple RNN | |
rnn_model = SimpleRNN().to(DEVICE) | |
return tokenizer, code_model, rnn_model | |
except Exception as e: | |
st.error(f"Model loading failed: {str(e)}") | |
return None, None, None | |
# AST Processing (simplified for Hugging Face) | |
def parse_ast(code): | |
try: | |
return javalang.parse.parse(code) | |
except: | |
return None | |
def build_simple_ast_features(ast_tree): | |
if not ast_tree: return {} | |
features = defaultdict(int) | |
def traverse(node): | |
features[type(node).__name__] += 1 | |
for child in getattr(node, 'children', []): | |
if isinstance(child, javalang.ast.Node): | |
traverse(child) | |
elif isinstance(child, (list, tuple)): | |
for item in child: | |
if isinstance(item, javalang.ast.Node): | |
traverse(item) | |
traverse(ast_tree) | |
return dict(features) | |
# Feature Extraction | |
def normalize_code(code): | |
code = re.sub(r'//.*?$', '', code, flags=re.MULTILINE) | |
code = re.sub(r'/\*.*?\*/', '', code, flags=re.DOTALL) | |
return re.sub(r'\s+', ' ', code).strip() | |
def get_embedding(code, tokenizer, model): | |
try: | |
inputs = tokenizer( | |
normalize_code(code), | |
return_tensors="pt", | |
truncation=True, | |
max_length=MAX_LENGTH, | |
padding='max_length' | |
).to(DEVICE) | |
with torch.no_grad(): | |
return model(**inputs).last_hidden_state.mean(dim=1) | |
except: | |
return None | |
# Similarity Calculations (optimized for Hugging Face) | |
def calculate_similarities(code1, code2, models): | |
tokenizer, code_model, rnn_model = models | |
# Get embeddings | |
emb1 = get_embedding(code1, tokenizer, code_model) | |
emb2 = get_embedding(code2, tokenizer, code_model) | |
# Get AST features | |
ast1 = parse_ast(code1) | |
ast2 = parse_ast(code2) | |
ast_features1 = build_simple_ast_features(ast1) | |
ast_features2 = build_simple_ast_features(ast2) | |
# Calculate similarities | |
codebert_sim = 0 | |
if emb1 is not None and emb2 is not None: | |
codebert_sim = F.cosine_similarity(emb1, emb2).item() | |
rnn_sim = 0 | |
if emb1 is not None and emb2 is not None: | |
with torch.no_grad(): | |
rnn_input = torch.cat([emb1, emb2]).unsqueeze(0) | |
rnn_sim = rnn_model(rnn_input).item() | |
# Simple AST similarity (count matching node types) | |
ast_sim = 0 | |
if ast_features1 and ast_features2: | |
common_keys = set(ast_features1.keys()) & set(ast_features2.keys()) | |
total_keys = set(ast_features1.keys()) | set(ast_features2.keys()) | |
ast_sim = len(common_keys) / len(total_keys) if total_keys else 0 | |
return { | |
'codebert': codebert_sim, | |
'rnn': rnn_sim, | |
'ast': ast_sim, | |
'combined': 0.5*codebert_sim + 0.3*rnn_sim + 0.2*ast_sim | |
} | |
# Main UI | |
def main(): | |
st.title("π Java Code Clone Detector (IJaDataset 2.1)") | |
st.markdown("Detect Type 1-4 clones using hybrid analysis") | |
# Load models | |
models = load_models() | |
if None in models: | |
st.error("Failed to load required models. Please check the logs.") | |
return | |
# Example code pairs | |
example_pairs = { | |
"Type 1 Example": { | |
"code1": "public class Test { public static void main(String[] args) { System.out.println(\"Hello\"); } }", | |
"code2": "public class Test { public static void main(String[] args) { System.out.println(\"Hello\"); } }" | |
}, | |
"Type 2 Example": { | |
"code1": "public class Test { public static void main(String[] args) { System.out.println(\"Hello\"); } }", | |
"code2": "public class Example { public static void main(String[] args) { System.out.println(\"Hello\"); } }" | |
}, | |
"Type 3 Example": { | |
"code1": "public class Test { public static void main(String[] args) { for(int i=0;i<10;i++) System.out.println(i); } }", | |
"code2": "public class Example { public static void run(String[] params) { for(int j=0;j<10;j++) System.out.println(j); } }" | |
} | |
} | |
# Code input | |
selected_example = st.selectbox("Select example pair:", list(example_pairs.keys())) | |
col1, col2 = st.columns(2) | |
with col1: | |
code1 = st.text_area( | |
"Code 1", | |
height=300, | |
value=example_pairs[selected_example]["code1"] | |
) | |
with col2: | |
code2 = st.text_area( | |
"Code 2", | |
height=300, | |
value=example_pairs[selected_example]["code2"] | |
) | |
# Thresholds | |
st.subheader("Detection Thresholds") | |
cols = st.columns(3) | |
with cols[0]: | |
t1 = st.slider("Type 1/2", 0.85, 1.0, 0.95) | |
with cols[1]: | |
t3 = st.slider("Type 3", 0.7, 0.9, 0.8) | |
with cols[2]: | |
t4 = st.slider("Type 4", 0.5, 0.8, 0.65) | |
# Analysis button | |
if st.button("Analyze Code", type="primary"): | |
with st.spinner("Analyzing code..."): | |
sims = calculate_similarities(code1, code2, models) | |
# Determine clone type | |
clone_type = "No Clone" | |
if sims['combined'] >= t1: | |
clone_type = "Type 1/2 Clone (Exact/Near-Exact)" | |
elif sims['combined'] >= t3: | |
clone_type = "Type 3 Clone (Near-Miss)" | |
elif sims['combined'] >= t4: | |
clone_type = "Type 4 Clone (Semantic)" | |
# Display results | |
st.subheader("Results") | |
# Metrics | |
cols = st.columns(4) | |
cols[0].metric("Combined", f"{sims['combined']:.2f}") | |
cols[1].metric("CodeBERT", f"{sims['codebert']:.2f}") | |
cols[2].metric("RNN", f"{sims['rnn']:.2f}") | |
cols[3].metric("AST", f"{sims['ast']:.2f}") | |
# Progress bar | |
st.progress(sims['combined']) | |
# Final result | |
st.metric("Detection Result", clone_type) | |
# Show details | |
with st.expander("Advanced Details"): | |
st.json(sims) | |
st.code(f"Normalized Code 1:\n{normalize_code(code1)}") | |
st.code(f"Normalized Code 2:\n{normalize_code(code2)}") | |
if __name__ == "__main__": | |
main() |