Spaces:
Runtime error
Runtime error
import streamlit as st | |
import javalang | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import re | |
import numpy as np | |
import networkx as nx | |
from transformers import AutoTokenizer, AutoModel | |
from torch_geometric.data import Data | |
from torch_geometric.nn import GCNConv | |
import warnings | |
import pandas as pd | |
import zipfile | |
import os | |
from collections import defaultdict | |
# Set up page config | |
st.set_page_config( | |
page_title="Advanced Java Code Clone Detector (IJaDataset 2.1)", | |
page_icon="π", | |
layout="wide" | |
) | |
# Suppress warnings | |
warnings.filterwarnings("ignore") | |
# Constants | |
MODEL_NAME = "microsoft/codebert-base" | |
MAX_LENGTH = 512 | |
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
DATASET_PATH = "archive (1).zip" # Update this path if needed | |
# Initialize models with caching | |
def load_models(): | |
try: | |
# Load CodeBERT for semantic analysis | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
code_model = AutoModel.from_pretrained(MODEL_NAME).to(DEVICE) | |
# Initialize RNN model | |
class RNNModel(nn.Module): | |
def __init__(self, input_size, hidden_size, num_layers): | |
super(RNNModel, self).__init__() | |
self.hidden_size = hidden_size | |
self.num_layers = num_layers | |
self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True) | |
self.fc = nn.Linear(hidden_size, 1) | |
def forward(self, x): | |
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(DEVICE) | |
out, _ = self.rnn(x, h0) | |
out = self.fc(out[:, -1, :]) | |
return out | |
rnn_model = RNNModel(input_size=768, hidden_size=256, num_layers=2).to(DEVICE) | |
# Initialize GNN model | |
class GNNModel(nn.Module): | |
def __init__(self, node_features): | |
super(GNNModel, self).__init__() | |
self.conv1 = GCNConv(node_features, 128) | |
self.conv2 = GCNConv(128, 64) | |
self.fc = nn.Linear(64, 1) | |
def forward(self, data): | |
x, edge_index = data.x, data.edge_index | |
x = F.relu(self.conv1(x, edge_index)) | |
x = F.dropout(x, training=self.training) | |
x = self.conv2(x, edge_index) | |
x = self.fc(x) | |
return torch.sigmoid(x.mean()) | |
gnn_model = GNNModel(node_features=128).to(DEVICE) | |
return tokenizer, code_model, rnn_model, gnn_model | |
except Exception as e: | |
st.error(f"Failed to load models: {str(e)}") | |
return None, None, None, None | |
def load_dataset(): | |
try: | |
# Extract dataset if needed | |
if not os.path.exists("Diverse_100K_Dataset"): | |
with zipfile.ZipFile(DATASET_PATH, 'r') as zip_ref: | |
zip_ref.extractall(".") | |
# Load sample pairs (modify this based on your dataset structure) | |
clone_pairs = [] | |
base_path = "Subject_CloneTypes_Directories" | |
# Load pairs from all clone types | |
for clone_type in ["Clone_Type1", "Clone_Type2", "Clone_Type3 - ST", "Clone_Type4"]: | |
type_path = os.path.join(base_path, clone_type) | |
if os.path.exists(type_path): | |
for root, _, files in os.walk(type_path): | |
if files: | |
# Take first two files as a pair | |
if len(files) >= 2: | |
with open(os.path.join(root, files[0]), 'r', encoding='utf-8') as f1: | |
code1 = f1.read() | |
with open(os.path.join(root, files[1]), 'r', encoding='utf-8') as f2: | |
code2 = f2.read() | |
clone_pairs.append({ | |
"type": clone_type, | |
"code1": code1, | |
"code2": code2 | |
}) | |
break # Just take one pair per type for demo | |
return clone_pairs[:10] # Return first 10 pairs for demo | |
except Exception as e: | |
st.error(f"Error loading dataset: {str(e)}") | |
return [] | |
tokenizer, code_model, rnn_model, gnn_model = load_models() | |
dataset_pairs = load_dataset() | |
# AST Processing Functions | |
def parse_ast(code): | |
try: | |
tokens = javalang.tokenizer.tokenize(code) | |
parser = javalang.parser.Parser(tokens) | |
tree = parser.parse() | |
return tree | |
except Exception as e: | |
st.warning(f"AST parsing error: {str(e)}") | |
return None | |
def build_ast_graph(ast_tree): | |
if not ast_tree: | |
return None | |
G = nx.DiGraph() | |
node_id = 0 | |
node_map = {} | |
def traverse(node, parent_id=None): | |
nonlocal node_id | |
current_id = node_id | |
node_label = str(type(node).__name__) | |
node_map[current_id] = {'type': node_label, 'node': node} | |
G.add_node(current_id, type=node_label) | |
if parent_id is not None: | |
G.add_edge(parent_id, current_id) | |
node_id += 1 | |
for child in node.children: | |
if isinstance(child, javalang.ast.Node): | |
traverse(child, current_id) | |
elif isinstance(child, (list, tuple)): | |
for item in child: | |
if isinstance(item, javalang.ast.Node): | |
traverse(item, current_id) | |
traverse(ast_tree) | |
return G, node_map | |
def ast_to_pyg_data(ast_graph): | |
if not ast_graph: | |
return None | |
# Convert AST to PyTorch Geometric Data format | |
node_features = [] | |
node_types = [] | |
for node in ast_graph.nodes(): | |
node_type = ast_graph.nodes[node]['type'] | |
node_types.append(node_type) | |
# Simple one-hot encoding of node types (in practice, use better encoding) | |
feature = [0] * 50 # Assuming max 50 node types | |
feature[hash(node_type) % 50] = 1 | |
node_features.append(feature) | |
# Convert networkx graph to edge_index format | |
edge_index = list(ast_graph.edges()) | |
if not edge_index: | |
# Add self-loop if no edges | |
edge_index = [(0, 0)] | |
edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous() | |
x = torch.tensor(node_features, dtype=torch.float) | |
return Data(x=x, edge_index=edge_index) | |
# Normalization function | |
def normalize_code(code): | |
try: | |
code = re.sub(r'//.*', '', code) # Remove single-line comments | |
code = re.sub(r'/\*.*?\*/', '', code, flags=re.DOTALL) # Multi-line comments | |
code = re.sub(r'\s+', ' ', code).strip() # Normalize whitespace | |
return code | |
except Exception: | |
return code | |
# Feature extraction functions | |
def get_lexical_features(code): | |
"""Extract lexical features (for Type-1 and Type-2 clones)""" | |
normalized = normalize_code(code) | |
tokens = re.findall(r'\b\w+\b', normalized) | |
return { | |
'token_count': len(tokens), | |
'unique_tokens': len(set(tokens)), | |
'avg_token_length': np.mean([len(t) for t in tokens]) if tokens else 0 | |
} | |
def get_syntactic_features(ast_tree): | |
"""Extract syntactic features (for Type-3 clones)""" | |
if not ast_tree: | |
return {} | |
# Count different node types in AST | |
node_counts = defaultdict(int) | |
def count_nodes(node): | |
node_counts[type(node).__name__] += 1 | |
for child in node.children: | |
if isinstance(child, javalang.ast.Node): | |
count_nodes(child) | |
elif isinstance(child, (list, tuple)): | |
for item in child: | |
if isinstance(item, javalang.ast.Node): | |
count_nodes(item) | |
count_nodes(ast_tree) | |
return dict(node_counts) | |
def get_semantic_features(code): | |
"""Extract semantic features (for Type-4 clones)""" | |
embedding = get_embedding(code) | |
return embedding.cpu().numpy().flatten() if embedding is not None else None | |
# Embedding generation | |
def get_embedding(code): | |
try: | |
code = normalize_code(code) | |
inputs = tokenizer( | |
code, | |
return_tensors="pt", | |
truncation=True, | |
max_length=MAX_LENGTH, | |
padding='max_length' | |
).to(DEVICE) | |
with torch.no_grad(): | |
outputs = code_model(**inputs) | |
return outputs.last_hidden_state.mean(dim=1) # Pooled embedding | |
except Exception as e: | |
st.error(f"Error processing code: {str(e)}") | |
return None | |
# Clone detection models | |
def rnn_similarity(emb1, emb2): | |
"""Calculate similarity using RNN model""" | |
if emb1 is None or emb2 is None: | |
return None | |
# Prepare input for RNN (sequence of embeddings) | |
combined = torch.cat([emb1.unsqueeze(0), emb2.unsqueeze(0)], dim=0) | |
with torch.no_grad(): | |
similarity = rnn_model(combined.permute(1, 0, 2)) | |
return torch.sigmoid(similarity).item() | |
def gnn_similarity(ast1, ast2): | |
"""Calculate similarity using GNN model""" | |
if ast1 is None or ast2 is None: | |
return None | |
data1 = ast_to_pyg_data(ast1) | |
data2 = ast_to_pyg_data(ast2) | |
if data1 is None or data2 is None: | |
return None | |
# Move data to device | |
data1 = data1.to(DEVICE) | |
data2 = data2.to(DEVICE) | |
with torch.no_grad(): | |
sim1 = gnn_model(data1) | |
sim2 = gnn_model(data2) | |
return F.cosine_similarity(sim1, sim2).item() | |
def hybrid_similarity(code1, code2): | |
"""Combined similarity score using all models""" | |
# Get embeddings | |
emb1 = get_embedding(code1) | |
emb2 = get_embedding(code2) | |
# Parse ASTs | |
ast_tree1 = parse_ast(code1) | |
ast_tree2 = parse_ast(code2) | |
ast_graph1 = build_ast_graph(ast_tree1) if ast_tree1 else None | |
ast_graph2 = build_ast_graph(ast_tree2) if ast_tree2 else None | |
# Calculate individual similarities | |
codebert_sim = F.cosine_similarity(emb1, emb2).item() if emb1 is not None and emb2 is not None else 0 | |
rnn_sim = rnn_similarity(emb1, emb2) if emb1 is not None and emb2 is not None else 0 | |
gnn_sim = gnn_similarity(ast_graph1[0] if ast_graph1 else None, | |
ast_graph2[0] if ast_graph2 else None) or 0 | |
# Combine with weights (can be tuned) | |
weights = { | |
'codebert': 0.4, | |
'rnn': 0.3, | |
'gnn': 0.3 | |
} | |
combined = (weights['codebert'] * codebert_sim + | |
weights['rnn'] * rnn_sim + | |
weights['gnn'] * gnn_sim) | |
return { | |
'combined': combined, | |
'codebert': codebert_sim, | |
'rnn': rnn_sim, | |
'gnn': gnn_sim | |
} | |
# Comparison function | |
def compare_code(code1, code2): | |
if not code1 or not code2: | |
return None | |
with st.spinner('Analyzing code with multiple techniques...'): | |
# Get lexical features | |
lex1 = get_lexical_features(code1) | |
lex2 = get_lexical_features(code2) | |
# Get AST trees | |
ast_tree1 = parse_ast(code1) | |
ast_tree2 = parse_ast(code2) | |
# Get syntactic features | |
syn1 = get_syntactic_features(ast_tree1) | |
syn2 = get_syntactic_features(ast_tree2) | |
# Get semantic features | |
sem1 = get_semantic_features(code1) | |
sem2 = get_semantic_features(code2) | |
# Calculate hybrid similarity | |
similarities = hybrid_similarity(code1, code2) | |
return { | |
'similarities': similarities, | |
'lexical_features': (lex1, lex2), | |
'syntactic_features': (syn1, syn2), | |
'ast_trees': (ast_tree1, ast_tree2) | |
} | |
# UI Elements | |
st.title("π Advanced Java Code Clone Detector (IJaDataset 2.1)") | |
st.markdown(""" | |
Detect all types of code clones (Type 1-4) using hybrid approach with: | |
- **CodeBERT** for semantic analysis | |
- **RNN** for sequence modeling | |
- **GNN** for AST structural analysis | |
""") | |
# Dataset selector | |
selected_pair = None | |
if dataset_pairs: | |
pair_options = {f"{i+1}: {pair['type']}": pair for i, pair in enumerate(dataset_pairs)} | |
selected_option = st.selectbox("Select a preloaded example pair:", list(pair_options.keys())) | |
selected_pair = pair_options[selected_option] | |
# Layout | |
col1, col2 = st.columns(2) | |
with col1: | |
code1 = st.text_area( | |
"First Java Code", | |
height=300, | |
value=selected_pair["code1"] if selected_pair else "", | |
help="Enter the first Java code snippet" | |
) | |
with col2: | |
code2 = st.text_area( | |
"Second Java Code", | |
height=300, | |
value=selected_pair["code2"] if selected_pair else "", | |
help="Enter the second Java code snippet" | |
) | |
# Threshold sliders | |
st.subheader("Detection Thresholds") | |
col1, col2, col3 = st.columns(3) | |
with col1: | |
threshold_type12 = st.slider( | |
"Type 1/2 Threshold", | |
min_value=0.5, | |
max_value=1.0, | |
value=0.9, | |
step=0.01, | |
help="Threshold for exact/syntactic clones" | |
) | |
with col2: | |
threshold_type3 = st.slider( | |
"Type 3 Threshold", | |
min_value=0.5, | |
max_value=1.0, | |
value=0.8, | |
step=0.01, | |
help="Threshold for near-miss clones" | |
) | |
with col3: | |
threshold_type4 = st.slider( | |
"Type 4 Threshold", | |
min_value=0.5, | |
max_value=1.0, | |
value=0.7, | |
step=0.01, | |
help="Threshold for semantic clones" | |
) | |
# Compare button | |
if st.button("Compare Code", type="primary"): | |
if tokenizer is None or code_model is None or rnn_model is None or gnn_model is None: | |
st.error("Models failed to load. Please check the logs.") | |
else: | |
result = compare_code(code1, code2) | |
if result is not None: | |
similarities = result['similarities'] | |
lex1, lex2 = result['lexical_features'] | |
syn1, syn2 = result['syntactic_features'] | |
ast_tree1, ast_tree2 = result['ast_trees'] | |
# Display results | |
st.subheader("Detection Results") | |
# Determine clone type | |
combined_sim = similarities['combined'] | |
clone_type = "No Clone" | |
if combined_sim >= threshold_type12: | |
clone_type = "Type 1/2 Clone (Exact/Near-Exact)" | |
elif combined_sim >= threshold_type3: | |
clone_type = "Type 3 Clone (Near-Miss)" | |
elif combined_sim >= threshold_type4: | |
clone_type = "Type 4 Clone (Semantic)" | |
# Main metrics | |
col1, col2, col3 = st.columns(3) | |
with col1: | |
st.metric("Combined Similarity", f"{combined_sim:.3f}") | |
with col2: | |
st.metric("Detected Clone Type", clone_type) | |
with col3: | |
st.metric("CodeBERT Similarity", f"{similarities['codebert']:.3f}") | |
# Detailed metrics | |
with st.expander("Detailed Similarity Scores"): | |
cols = st.columns(3) | |
with cols[0]: | |
st.metric("RNN Similarity", f"{similarities['rnn']:.3f}") | |
with cols[1]: | |
st.metric("GNN Similarity", f"{similarities['gnn']:.3f}") | |
with cols[2]: | |
st.metric("Lexical Similarity", | |
f"{sum(lex1[k] == lex2[k] for k in lex1)/max(len(lex1),1):.2f}") | |
# Feature comparison | |
with st.expander("Feature Analysis"): | |
st.subheader("Lexical Features") | |
lex_df = pd.DataFrame([lex1, lex2], index=["Code 1", "Code 2"]) | |
st.dataframe(lex_df) | |
st.subheader("Syntactic Features (AST Node Counts)") | |
syn_df = pd.DataFrame([syn1, syn2], index=["Code 1", "Code 2"]).fillna(0) | |
st.dataframe(syn_df) | |
# AST Visualization | |
if ast_tree1 and ast_tree2: | |
with st.expander("AST Visualization (First 20 nodes)"): | |
st.write("AST visualization would be implemented here with graphviz") | |
# In a real implementation, you would use graphviz to render the ASTs | |
# st.graphviz_chart(ast_to_graphviz(ast_tree1)) | |
# st.graphviz_chart(ast_to_graphviz(ast_tree2)) | |
# Normalized code view | |
with st.expander("Show normalized code"): | |
tab1, tab2 = st.tabs(["First Code", "Second Code"]) | |
with tab1: | |
st.code(normalize_code(code1)) | |
with tab2: | |
st.code(normalize_code(code2)) | |
# Footer | |
st.markdown("---") | |
st.markdown(""" | |
*Dataset Information*: | |
- Using IJaDataset 2.1 from Kaggle | |
- Contains 100K Java files with clone annotations | |
- Clone types: Type-1, Type-2, Type-3, and Type-4 clones | |
*Model Architecture*: | |
- **CodeBERT**: Pre-trained model for semantic analysis | |
- **RNN**: Processes token sequences for sequential patterns | |
- **GNN**: Analyzes AST structure for syntactic patterns | |
- **Hybrid Approach**: Combines all techniques for comprehensive detection | |
""") |