import streamlit as st import javalang import torch import torch.nn as nn import torch.nn.functional as F import re import numpy as np import networkx as nx from transformers import AutoTokenizer, AutoModel from torch_geometric.data import Data from torch_geometric.nn import GCNConv import warnings import pandas as pd import zipfile import os from collections import defaultdict # Set up page config st.set_page_config( page_title="Advanced Java Code Clone Detector (IJaDataset 2.1)", page_icon="🔍", layout="wide" ) # Suppress warnings warnings.filterwarnings("ignore") # Constants MODEL_NAME = "microsoft/codebert-base" MAX_LENGTH = 512 DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') DATASET_PATH = "archive (1).zip" # Update this path if needed # Initialize models with caching @st.cache_resource def load_models(): try: # Load CodeBERT for semantic analysis tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) code_model = AutoModel.from_pretrained(MODEL_NAME).to(DEVICE) # Initialize RNN model class RNNModel(nn.Module): def __init__(self, input_size, hidden_size, num_layers): super(RNNModel, self).__init__() self.hidden_size = hidden_size self.num_layers = num_layers self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True) self.fc = nn.Linear(hidden_size, 1) def forward(self, x): h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(DEVICE) out, _ = self.rnn(x, h0) out = self.fc(out[:, -1, :]) return out rnn_model = RNNModel(input_size=768, hidden_size=256, num_layers=2).to(DEVICE) # Initialize GNN model class GNNModel(nn.Module): def __init__(self, node_features): super(GNNModel, self).__init__() self.conv1 = GCNConv(node_features, 128) self.conv2 = GCNConv(128, 64) self.fc = nn.Linear(64, 1) def forward(self, data): x, edge_index = data.x, data.edge_index x = F.relu(self.conv1(x, edge_index)) x = F.dropout(x, training=self.training) x = self.conv2(x, edge_index) x = self.fc(x) return torch.sigmoid(x.mean()) gnn_model = GNNModel(node_features=128).to(DEVICE) return tokenizer, code_model, rnn_model, gnn_model except Exception as e: st.error(f"Failed to load models: {str(e)}") return None, None, None, None @st.cache_resource def load_dataset(): try: # Extract dataset if needed if not os.path.exists("Diverse_100K_Dataset"): with zipfile.ZipFile(DATASET_PATH, 'r') as zip_ref: zip_ref.extractall(".") # Load sample pairs (modify this based on your dataset structure) clone_pairs = [] base_path = "Subject_CloneTypes_Directories" # Load pairs from all clone types for clone_type in ["Clone_Type1", "Clone_Type2", "Clone_Type3 - ST", "Clone_Type4"]: type_path = os.path.join(base_path, clone_type) if os.path.exists(type_path): for root, _, files in os.walk(type_path): if files: # Take first two files as a pair if len(files) >= 2: with open(os.path.join(root, files[0]), 'r', encoding='utf-8') as f1: code1 = f1.read() with open(os.path.join(root, files[1]), 'r', encoding='utf-8') as f2: code2 = f2.read() clone_pairs.append({ "type": clone_type, "code1": code1, "code2": code2 }) break # Just take one pair per type for demo return clone_pairs[:10] # Return first 10 pairs for demo except Exception as e: st.error(f"Error loading dataset: {str(e)}") return [] tokenizer, code_model, rnn_model, gnn_model = load_models() dataset_pairs = load_dataset() # AST Processing Functions def parse_ast(code): try: tokens = javalang.tokenizer.tokenize(code) parser = javalang.parser.Parser(tokens) tree = parser.parse() return tree except Exception as e: st.warning(f"AST parsing error: {str(e)}") return None def build_ast_graph(ast_tree): if not ast_tree: return None G = nx.DiGraph() node_id = 0 node_map = {} def traverse(node, parent_id=None): nonlocal node_id current_id = node_id node_label = str(type(node).__name__) node_map[current_id] = {'type': node_label, 'node': node} G.add_node(current_id, type=node_label) if parent_id is not None: G.add_edge(parent_id, current_id) node_id += 1 for child in node.children: if isinstance(child, javalang.ast.Node): traverse(child, current_id) elif isinstance(child, (list, tuple)): for item in child: if isinstance(item, javalang.ast.Node): traverse(item, current_id) traverse(ast_tree) return G, node_map def ast_to_pyg_data(ast_graph): if not ast_graph: return None # Convert AST to PyTorch Geometric Data format node_features = [] node_types = [] for node in ast_graph.nodes(): node_type = ast_graph.nodes[node]['type'] node_types.append(node_type) # Simple one-hot encoding of node types (in practice, use better encoding) feature = [0] * 50 # Assuming max 50 node types feature[hash(node_type) % 50] = 1 node_features.append(feature) # Convert networkx graph to edge_index format edge_index = list(ast_graph.edges()) if not edge_index: # Add self-loop if no edges edge_index = [(0, 0)] edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous() x = torch.tensor(node_features, dtype=torch.float) return Data(x=x, edge_index=edge_index) # Normalization function def normalize_code(code): try: code = re.sub(r'//.*', '', code) # Remove single-line comments code = re.sub(r'/\*.*?\*/', '', code, flags=re.DOTALL) # Multi-line comments code = re.sub(r'\s+', ' ', code).strip() # Normalize whitespace return code except Exception: return code # Feature extraction functions def get_lexical_features(code): """Extract lexical features (for Type-1 and Type-2 clones)""" normalized = normalize_code(code) tokens = re.findall(r'\b\w+\b', normalized) return { 'token_count': len(tokens), 'unique_tokens': len(set(tokens)), 'avg_token_length': np.mean([len(t) for t in tokens]) if tokens else 0 } def get_syntactic_features(ast_tree): """Extract syntactic features (for Type-3 clones)""" if not ast_tree: return {} # Count different node types in AST node_counts = defaultdict(int) def count_nodes(node): node_counts[type(node).__name__] += 1 for child in node.children: if isinstance(child, javalang.ast.Node): count_nodes(child) elif isinstance(child, (list, tuple)): for item in child: if isinstance(item, javalang.ast.Node): count_nodes(item) count_nodes(ast_tree) return dict(node_counts) def get_semantic_features(code): """Extract semantic features (for Type-4 clones)""" embedding = get_embedding(code) return embedding.cpu().numpy().flatten() if embedding is not None else None # Embedding generation def get_embedding(code): try: code = normalize_code(code) inputs = tokenizer( code, return_tensors="pt", truncation=True, max_length=MAX_LENGTH, padding='max_length' ).to(DEVICE) with torch.no_grad(): outputs = code_model(**inputs) return outputs.last_hidden_state.mean(dim=1) # Pooled embedding except Exception as e: st.error(f"Error processing code: {str(e)}") return None # Clone detection models def rnn_similarity(emb1, emb2): """Calculate similarity using RNN model""" if emb1 is None or emb2 is None: return None # Prepare input for RNN (sequence of embeddings) combined = torch.cat([emb1.unsqueeze(0), emb2.unsqueeze(0)], dim=0) with torch.no_grad(): similarity = rnn_model(combined.permute(1, 0, 2)) return torch.sigmoid(similarity).item() def gnn_similarity(ast1, ast2): """Calculate similarity using GNN model""" if ast1 is None or ast2 is None: return None data1 = ast_to_pyg_data(ast1) data2 = ast_to_pyg_data(ast2) if data1 is None or data2 is None: return None # Move data to device data1 = data1.to(DEVICE) data2 = data2.to(DEVICE) with torch.no_grad(): sim1 = gnn_model(data1) sim2 = gnn_model(data2) return F.cosine_similarity(sim1, sim2).item() def hybrid_similarity(code1, code2): """Combined similarity score using all models""" # Get embeddings emb1 = get_embedding(code1) emb2 = get_embedding(code2) # Parse ASTs ast_tree1 = parse_ast(code1) ast_tree2 = parse_ast(code2) ast_graph1 = build_ast_graph(ast_tree1) if ast_tree1 else None ast_graph2 = build_ast_graph(ast_tree2) if ast_tree2 else None # Calculate individual similarities codebert_sim = F.cosine_similarity(emb1, emb2).item() if emb1 is not None and emb2 is not None else 0 rnn_sim = rnn_similarity(emb1, emb2) if emb1 is not None and emb2 is not None else 0 gnn_sim = gnn_similarity(ast_graph1[0] if ast_graph1 else None, ast_graph2[0] if ast_graph2 else None) or 0 # Combine with weights (can be tuned) weights = { 'codebert': 0.4, 'rnn': 0.3, 'gnn': 0.3 } combined = (weights['codebert'] * codebert_sim + weights['rnn'] * rnn_sim + weights['gnn'] * gnn_sim) return { 'combined': combined, 'codebert': codebert_sim, 'rnn': rnn_sim, 'gnn': gnn_sim } # Comparison function def compare_code(code1, code2): if not code1 or not code2: return None with st.spinner('Analyzing code with multiple techniques...'): # Get lexical features lex1 = get_lexical_features(code1) lex2 = get_lexical_features(code2) # Get AST trees ast_tree1 = parse_ast(code1) ast_tree2 = parse_ast(code2) # Get syntactic features syn1 = get_syntactic_features(ast_tree1) syn2 = get_syntactic_features(ast_tree2) # Get semantic features sem1 = get_semantic_features(code1) sem2 = get_semantic_features(code2) # Calculate hybrid similarity similarities = hybrid_similarity(code1, code2) return { 'similarities': similarities, 'lexical_features': (lex1, lex2), 'syntactic_features': (syn1, syn2), 'ast_trees': (ast_tree1, ast_tree2) } # UI Elements st.title("🔍 Advanced Java Code Clone Detector (IJaDataset 2.1)") st.markdown(""" Detect all types of code clones (Type 1-4) using hybrid approach with: - **CodeBERT** for semantic analysis - **RNN** for sequence modeling - **GNN** for AST structural analysis """) # Dataset selector selected_pair = None if dataset_pairs: pair_options = {f"{i+1}: {pair['type']}": pair for i, pair in enumerate(dataset_pairs)} selected_option = st.selectbox("Select a preloaded example pair:", list(pair_options.keys())) selected_pair = pair_options[selected_option] # Layout col1, col2 = st.columns(2) with col1: code1 = st.text_area( "First Java Code", height=300, value=selected_pair["code1"] if selected_pair else "", help="Enter the first Java code snippet" ) with col2: code2 = st.text_area( "Second Java Code", height=300, value=selected_pair["code2"] if selected_pair else "", help="Enter the second Java code snippet" ) # Threshold sliders st.subheader("Detection Thresholds") col1, col2, col3 = st.columns(3) with col1: threshold_type12 = st.slider( "Type 1/2 Threshold", min_value=0.5, max_value=1.0, value=0.9, step=0.01, help="Threshold for exact/syntactic clones" ) with col2: threshold_type3 = st.slider( "Type 3 Threshold", min_value=0.5, max_value=1.0, value=0.8, step=0.01, help="Threshold for near-miss clones" ) with col3: threshold_type4 = st.slider( "Type 4 Threshold", min_value=0.5, max_value=1.0, value=0.7, step=0.01, help="Threshold for semantic clones" ) # Compare button if st.button("Compare Code", type="primary"): if tokenizer is None or code_model is None or rnn_model is None or gnn_model is None: st.error("Models failed to load. Please check the logs.") else: result = compare_code(code1, code2) if result is not None: similarities = result['similarities'] lex1, lex2 = result['lexical_features'] syn1, syn2 = result['syntactic_features'] ast_tree1, ast_tree2 = result['ast_trees'] # Display results st.subheader("Detection Results") # Determine clone type combined_sim = similarities['combined'] clone_type = "No Clone" if combined_sim >= threshold_type12: clone_type = "Type 1/2 Clone (Exact/Near-Exact)" elif combined_sim >= threshold_type3: clone_type = "Type 3 Clone (Near-Miss)" elif combined_sim >= threshold_type4: clone_type = "Type 4 Clone (Semantic)" # Main metrics col1, col2, col3 = st.columns(3) with col1: st.metric("Combined Similarity", f"{combined_sim:.3f}") with col2: st.metric("Detected Clone Type", clone_type) with col3: st.metric("CodeBERT Similarity", f"{similarities['codebert']:.3f}") # Detailed metrics with st.expander("Detailed Similarity Scores"): cols = st.columns(3) with cols[0]: st.metric("RNN Similarity", f"{similarities['rnn']:.3f}") with cols[1]: st.metric("GNN Similarity", f"{similarities['gnn']:.3f}") with cols[2]: st.metric("Lexical Similarity", f"{sum(lex1[k] == lex2[k] for k in lex1)/max(len(lex1),1):.2f}") # Feature comparison with st.expander("Feature Analysis"): st.subheader("Lexical Features") lex_df = pd.DataFrame([lex1, lex2], index=["Code 1", "Code 2"]) st.dataframe(lex_df) st.subheader("Syntactic Features (AST Node Counts)") syn_df = pd.DataFrame([syn1, syn2], index=["Code 1", "Code 2"]).fillna(0) st.dataframe(syn_df) # AST Visualization if ast_tree1 and ast_tree2: with st.expander("AST Visualization (First 20 nodes)"): st.write("AST visualization would be implemented here with graphviz") # In a real implementation, you would use graphviz to render the ASTs # st.graphviz_chart(ast_to_graphviz(ast_tree1)) # st.graphviz_chart(ast_to_graphviz(ast_tree2)) # Normalized code view with st.expander("Show normalized code"): tab1, tab2 = st.tabs(["First Code", "Second Code"]) with tab1: st.code(normalize_code(code1)) with tab2: st.code(normalize_code(code2)) # Footer st.markdown("---") st.markdown(""" *Dataset Information*: - Using IJaDataset 2.1 from Kaggle - Contains 100K Java files with clone annotations - Clone types: Type-1, Type-2, Type-3, and Type-4 clones *Model Architecture*: - **CodeBERT**: Pre-trained model for semantic analysis - **RNN**: Processes token sequences for sequential patterns - **GNN**: Analyzes AST structure for syntactic patterns - **Hybrid Approach**: Combines all techniques for comprehensive detection """)