Spaces:
Runtime error
Runtime error
import os | |
import streamlit as st | |
import javalang | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import re | |
import numpy as np | |
import networkx as nx | |
from transformers import AutoTokenizer, AutoModel, AutoConfig | |
from torch_geometric.data import Data | |
from torch_geometric.nn import GCNConv | |
import warnings | |
import pandas as pd | |
import zipfile | |
from collections import defaultdict | |
# Configuration | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
warnings.filterwarnings("ignore") | |
# Constants | |
MODEL_NAME = "microsoft/codebert-base" | |
MAX_LENGTH = 512 | |
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
DATASET_PATH = "ijadataset2-1.zip" | |
CACHE_DIR = "./model_cache" | |
# Set up page config | |
st.set_page_config( | |
page_title="Advanced Java Code Clone Detector", | |
page_icon="π", | |
layout="wide" | |
) | |
# Model Definitions | |
class RNNModel(nn.Module): | |
def __init__(self, input_size, hidden_size, num_layers): | |
super().__init__() | |
self.hidden_size = hidden_size | |
self.num_layers = num_layers | |
self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True) | |
self.fc = nn.Linear(hidden_size, 1) | |
def forward(self, x): | |
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(DEVICE) | |
out, _ = self.rnn(x, h0) | |
return self.fc(out[:, -1, :]) | |
class GNNModel(nn.Module): | |
def __init__(self, node_features): | |
super().__init__() | |
self.conv1 = GCNConv(node_features, 128) | |
self.conv2 = GCNConv(128, 64) | |
self.fc = nn.Linear(64, 1) | |
def forward(self, data): | |
x, edge_index = data.x, data.edge_index | |
x = F.relu(self.conv1(x, edge_index)) | |
x = F.dropout(x, training=self.training) | |
x = self.conv2(x, edge_index) | |
return torch.sigmoid(self.fc(x).mean()) | |
# Model Loading with Cache | |
def load_models(): | |
try: | |
with st.spinner('Loading models (first run may take a few minutes)...'): | |
config = AutoConfig.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR) | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR) | |
model = AutoModel.from_pretrained(MODEL_NAME, config=config, cache_dir=CACHE_DIR).to(DEVICE) | |
rnn_model = RNNModel(input_size=768, hidden_size=256, num_layers=2).to(DEVICE) | |
gnn_model = GNNModel(node_features=128).to(DEVICE) | |
return tokenizer, model, rnn_model, gnn_model | |
except Exception as e: | |
st.error(f"Model loading failed: {str(e)}") | |
return None, None, None, None | |
# Dataset Loading | |
def load_dataset(): | |
try: | |
if not os.path.exists("Diverse_100K_Dataset"): | |
with zipfile.ZipFile(DATASET_PATH, 'r') as zip_ref: | |
zip_ref.extractall(".") | |
clone_pairs = [] | |
base_path = "Diverse_100K_Dataset/Subject_CloneTypes_Directories" | |
for clone_type in ["Clone_Type1", "Clone_Type2", "Clone_Type3 - ST", "Clone_Type4"]: | |
type_path = os.path.join(base_path, clone_type) | |
if os.path.exists(type_path): | |
for root, _, files in os.walk(type_path): | |
if files and len(files) >= 2: | |
with open(os.path.join(root, files[0]), 'r', encoding='utf-8') as f1, \ | |
open(os.path.join(root, files[1]), 'r', encoding='utf-8') as f2: | |
clone_pairs.append({ | |
"type": clone_type, | |
"code1": f1.read(), | |
"code2": f2.read() | |
}) | |
break | |
return clone_pairs[:10] | |
except Exception as e: | |
st.error(f"Dataset error: {str(e)}") | |
return [] | |
# AST Processing | |
def parse_ast(code): | |
try: | |
return javalang.parse.parse(code) | |
except: | |
return None | |
def build_ast_graph(ast_tree): | |
if not ast_tree: return None | |
G = nx.DiGraph() | |
node_id = 0 | |
def traverse(node, parent=None): | |
nonlocal node_id | |
current = node_id | |
G.add_node(current, type=type(node).__name__) | |
if parent is not None: | |
G.add_edge(parent, current) | |
node_id += 1 | |
for child in getattr(node, 'children', []): | |
if isinstance(child, javalang.ast.Node): | |
traverse(child, current) | |
elif isinstance(child, (list, tuple)): | |
for item in child: | |
if isinstance(item, javalang.ast.Node): | |
traverse(item, current) | |
traverse(ast_tree) | |
return G | |
def ast_to_pyg_data(ast_graph): | |
if not ast_graph: return None | |
node_types = list(nx.get_node_attributes(ast_graph, 'type').values()) | |
unique_types = list(set(node_types)) | |
type_to_idx = {t: i for i, t in enumerate(unique_types)} | |
x = torch.zeros(len(node_types), len(unique_types)) | |
for i, t in enumerate(node_types): | |
x[i, type_to_idx[t]] = 1 | |
edge_index = torch.tensor(list(ast_graph.edges())).t().contiguous() | |
return Data(x=x.to(DEVICE), edge_index=edge_index.to(DEVICE)) | |
# Feature Extraction | |
def normalize_code(code): | |
code = re.sub(r'//.*?$', '', code, flags=re.MULTILINE) | |
code = re.sub(r'/\*.*?\*/', '', code, flags=re.DOTALL) | |
return re.sub(r'\s+', ' ', code).strip() | |
def get_embedding(code, tokenizer, model): | |
try: | |
inputs = tokenizer( | |
normalize_code(code), | |
return_tensors="pt", | |
truncation=True, | |
max_length=MAX_LENGTH, | |
padding='max_length' | |
).to(DEVICE) | |
with torch.no_grad(): | |
return model(**inputs).last_hidden_state.mean(dim=1) | |
except: | |
return None | |
# Similarity Calculations | |
def calculate_similarities(code1, code2, models): | |
tokenizer, code_model, rnn_model, gnn_model = models | |
# Get embeddings | |
emb1 = get_embedding(code1, tokenizer, code_model) | |
emb2 = get_embedding(code2, tokenizer, code_model) | |
# Parse ASTs | |
ast1 = build_ast_graph(parse_ast(code1)) | |
ast2 = build_ast_graph(parse_ast(code2)) | |
# Calculate similarities | |
codebert_sim = F.cosine_similarity(emb1, emb2).item() if emb1 is not None and emb2 is not None else 0 | |
rnn_sim = 0 | |
if emb1 is not None and emb2 is not None: | |
with torch.no_grad(): | |
rnn_input = torch.stack([emb1.squeeze(), emb2.squeeze()]) | |
rnn_sim = torch.sigmoid(rnn_model(rnn_input.unsqueeze(0))).item() | |
gnn_sim = 0 | |
if ast1 and ast2: | |
data1 = ast_to_pyg_data(ast1) | |
data2 = ast_to_pyg_data(ast2) | |
if data1 and data2: | |
with torch.no_grad(): | |
gnn_sim = F.cosine_similarity( | |
gnn_model(data1).unsqueeze(0), | |
gnn_model(data2).unsqueeze(0) | |
).item() | |
return { | |
'codebert': codebert_sim, | |
'rnn': rnn_sim, | |
'gnn': gnn_sim, | |
'combined': 0.4*codebert_sim + 0.3*rnn_sim + 0.3*gnn_sim | |
} | |
# UI Components | |
def main(): | |
st.title("π Advanced Java Code Clone Detector") | |
st.markdown("Detect all clone types (1-4) using hybrid analysis") | |
# Load resources | |
models = load_models() | |
dataset_pairs = load_dataset() | |
# Code input | |
selected_pair = None | |
if dataset_pairs: | |
pair_options = {f"{i+1}: {pair['type']}": pair for i, pair in enumerate(dataset_pairs)} | |
selected_option = st.selectbox("Select example pair:", list(pair_options.keys())) | |
selected_pair = pair_options[selected_option] | |
col1, col2 = st.columns(2) | |
with col1: | |
code1 = st.text_area("Code 1", height=300, value=selected_pair["code1"] if selected_pair else "") | |
with col2: | |
code2 = st.text_area("Code 2", height=300, value=selected_pair["code2"] if selected_pair else "") | |
# Thresholds | |
st.subheader("Detection Thresholds") | |
cols = st.columns(3) | |
with cols[0]: | |
t1 = st.slider("Type 1/2", 0.85, 1.0, 0.95) | |
with cols[1]: | |
t3 = st.slider("Type 3", 0.7, 0.9, 0.8) | |
with cols[2]: | |
t4 = st.slider("Type 4", 0.5, 0.8, 0.65) | |
# Analysis | |
if st.button("Analyze", type="primary") and models[0]: | |
with st.spinner("Analyzing..."): | |
sims = calculate_similarities(code1, code2, models) | |
# Determine clone type | |
clone_type = "No Clone" | |
if sims['combined'] >= t1: | |
clone_type = "Type 1/2 Clone" | |
elif sims['combined'] >= t3: | |
clone_type = "Type 3 Clone" | |
elif sims['combined'] >= t4: | |
clone_type = "Type 4 Clone" | |
# Display results | |
st.subheader("Results") | |
cols = st.columns(4) | |
cols[0].metric("Combined", f"{sims['combined']:.2f}") | |
cols[1].metric("CodeBERT", f"{sims['codebert']:.2f}") | |
cols[2].metric("RNN", f"{sims['rnn']:.2f}") | |
cols[3].metric("GNN", f"{sims['gnn']:.2f}") | |
st.progress(sims['combined']) | |
st.metric("Detection Result", clone_type) | |
# Show details | |
with st.expander("Details"): | |
st.json(sims) | |
st.code(f"Normalized Code 1:\n{normalize_code(code1)}") | |
st.code(f"Normalized Code 2:\n{normalize_code(code2)}") | |
if __name__ == "__main__": | |
main() |