File size: 17,420 Bytes
c4ca5b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
import streamlit as st
import javalang
import torch
import torch.nn as nn
import torch.nn.functional as F
import re
import numpy as np
import networkx as nx
from transformers import AutoTokenizer, AutoModel
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
import warnings
import pandas as pd
import zipfile
import os
from collections import defaultdict

# Set up page config
st.set_page_config(
    page_title="Advanced Java Code Clone Detector (IJaDataset 2.1)",
    page_icon="πŸ”",
    layout="wide"
)

# Suppress warnings
warnings.filterwarnings("ignore")

# Constants
MODEL_NAME = "microsoft/codebert-base"
MAX_LENGTH = 512
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DATASET_PATH = "archive (1).zip"  # Update this path if needed

# Initialize models with caching
@st.cache_resource
def load_models():
    try:
        # Load CodeBERT for semantic analysis
        tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
        code_model = AutoModel.from_pretrained(MODEL_NAME).to(DEVICE)
        
        # Initialize RNN model
        class RNNModel(nn.Module):
            def __init__(self, input_size, hidden_size, num_layers):
                super(RNNModel, self).__init__()
                self.hidden_size = hidden_size
                self.num_layers = num_layers
                self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)
                self.fc = nn.Linear(hidden_size, 1)
                
            def forward(self, x):
                h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(DEVICE)
                out, _ = self.rnn(x, h0)
                out = self.fc(out[:, -1, :])
                return out
        
        rnn_model = RNNModel(input_size=768, hidden_size=256, num_layers=2).to(DEVICE)
        
        # Initialize GNN model
        class GNNModel(nn.Module):
            def __init__(self, node_features):
                super(GNNModel, self).__init__()
                self.conv1 = GCNConv(node_features, 128)
                self.conv2 = GCNConv(128, 64)
                self.fc = nn.Linear(64, 1)
                
            def forward(self, data):
                x, edge_index = data.x, data.edge_index
                x = F.relu(self.conv1(x, edge_index))
                x = F.dropout(x, training=self.training)
                x = self.conv2(x, edge_index)
                x = self.fc(x)
                return torch.sigmoid(x.mean())
        
        gnn_model = GNNModel(node_features=128).to(DEVICE)
        
        return tokenizer, code_model, rnn_model, gnn_model
    except Exception as e:
        st.error(f"Failed to load models: {str(e)}")
        return None, None, None, None

@st.cache_resource
def load_dataset():
    try:
        # Extract dataset if needed
        if not os.path.exists("Diverse_100K_Dataset"):
            with zipfile.ZipFile(DATASET_PATH, 'r') as zip_ref:
                zip_ref.extractall(".")
        
        # Load sample pairs (modify this based on your dataset structure)
        clone_pairs = []
        base_path = "Subject_CloneTypes_Directories"
        
        # Load pairs from all clone types
        for clone_type in ["Clone_Type1", "Clone_Type2", "Clone_Type3 - ST", "Clone_Type4"]:
            type_path = os.path.join(base_path, clone_type)
            if os.path.exists(type_path):
                for root, _, files in os.walk(type_path):
                    if files:
                        # Take first two files as a pair
                        if len(files) >= 2:
                            with open(os.path.join(root, files[0]), 'r', encoding='utf-8') as f1:
                                code1 = f1.read()
                            with open(os.path.join(root, files[1]), 'r', encoding='utf-8') as f2:
                                code2 = f2.read()
                            clone_pairs.append({
                                "type": clone_type,
                                "code1": code1,
                                "code2": code2
                            })
                        break  # Just take one pair per type for demo
        
        return clone_pairs[:10]  # Return first 10 pairs for demo
        
    except Exception as e:
        st.error(f"Error loading dataset: {str(e)}")
        return []

tokenizer, code_model, rnn_model, gnn_model = load_models()
dataset_pairs = load_dataset()

# AST Processing Functions
def parse_ast(code):
    try:
        tokens = javalang.tokenizer.tokenize(code)
        parser = javalang.parser.Parser(tokens)
        tree = parser.parse()
        return tree
    except Exception as e:
        st.warning(f"AST parsing error: {str(e)}")
        return None

def build_ast_graph(ast_tree):
    if not ast_tree:
        return None
    
    G = nx.DiGraph()
    node_id = 0
    node_map = {}
    
    def traverse(node, parent_id=None):
        nonlocal node_id
        current_id = node_id
        node_label = str(type(node).__name__)
        node_map[current_id] = {'type': node_label, 'node': node}
        G.add_node(current_id, type=node_label)
        
        if parent_id is not None:
            G.add_edge(parent_id, current_id)
        
        node_id += 1
        
        for child in node.children:
            if isinstance(child, javalang.ast.Node):
                traverse(child, current_id)
            elif isinstance(child, (list, tuple)):
                for item in child:
                    if isinstance(item, javalang.ast.Node):
                        traverse(item, current_id)
    
    traverse(ast_tree)
    return G, node_map

def ast_to_pyg_data(ast_graph):
    if not ast_graph:
        return None
    
    # Convert AST to PyTorch Geometric Data format
    node_features = []
    node_types = []
    
    for node in ast_graph.nodes():
        node_type = ast_graph.nodes[node]['type']
        node_types.append(node_type)
        # Simple one-hot encoding of node types (in practice, use better encoding)
        feature = [0] * 50  # Assuming max 50 node types
        feature[hash(node_type) % 50] = 1
        node_features.append(feature)
    
    # Convert networkx graph to edge_index format
    edge_index = list(ast_graph.edges())
    if not edge_index:
        # Add self-loop if no edges
        edge_index = [(0, 0)]
    
    edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
    x = torch.tensor(node_features, dtype=torch.float)
    
    return Data(x=x, edge_index=edge_index)

# Normalization function
def normalize_code(code):
    try:
        code = re.sub(r'//.*', '', code)  # Remove single-line comments
        code = re.sub(r'/\*.*?\*/', '', code, flags=re.DOTALL)  # Multi-line comments
        code = re.sub(r'\s+', ' ', code).strip()  # Normalize whitespace
        return code
    except Exception:
        return code

# Feature extraction functions
def get_lexical_features(code):
    """Extract lexical features (for Type-1 and Type-2 clones)"""
    normalized = normalize_code(code)
    tokens = re.findall(r'\b\w+\b', normalized)
    return {
        'token_count': len(tokens),
        'unique_tokens': len(set(tokens)),
        'avg_token_length': np.mean([len(t) for t in tokens]) if tokens else 0
    }

def get_syntactic_features(ast_tree):
    """Extract syntactic features (for Type-3 clones)"""
    if not ast_tree:
        return {}
    
    # Count different node types in AST
    node_counts = defaultdict(int)
    
    def count_nodes(node):
        node_counts[type(node).__name__] += 1
        for child in node.children:
            if isinstance(child, javalang.ast.Node):
                count_nodes(child)
            elif isinstance(child, (list, tuple)):
                for item in child:
                    if isinstance(item, javalang.ast.Node):
                        count_nodes(item)
    
    count_nodes(ast_tree)
    return dict(node_counts)

def get_semantic_features(code):
    """Extract semantic features (for Type-4 clones)"""
    embedding = get_embedding(code)
    return embedding.cpu().numpy().flatten() if embedding is not None else None

# Embedding generation
def get_embedding(code):
    try:
        code = normalize_code(code)
        inputs = tokenizer(
            code,
            return_tensors="pt",
            truncation=True,
            max_length=MAX_LENGTH,
            padding='max_length'
        ).to(DEVICE)
        
        with torch.no_grad():
            outputs = code_model(**inputs)
        
        return outputs.last_hidden_state.mean(dim=1)  # Pooled embedding
    except Exception as e:
        st.error(f"Error processing code: {str(e)}")
        return None

# Clone detection models
def rnn_similarity(emb1, emb2):
    """Calculate similarity using RNN model"""
    if emb1 is None or emb2 is None:
        return None
    
    # Prepare input for RNN (sequence of embeddings)
    combined = torch.cat([emb1.unsqueeze(0), emb2.unsqueeze(0)], dim=0)
    with torch.no_grad():
        similarity = rnn_model(combined.permute(1, 0, 2))
    return torch.sigmoid(similarity).item()

def gnn_similarity(ast1, ast2):
    """Calculate similarity using GNN model"""
    if ast1 is None or ast2 is None:
        return None
    
    data1 = ast_to_pyg_data(ast1)
    data2 = ast_to_pyg_data(ast2)
    
    if data1 is None or data2 is None:
        return None
    
    # Move data to device
    data1 = data1.to(DEVICE)
    data2 = data2.to(DEVICE)
    
    with torch.no_grad():
        sim1 = gnn_model(data1)
        sim2 = gnn_model(data2)
    
    return F.cosine_similarity(sim1, sim2).item()

def hybrid_similarity(code1, code2):
    """Combined similarity score using all models"""
    # Get embeddings
    emb1 = get_embedding(code1)
    emb2 = get_embedding(code2)
    
    # Parse ASTs
    ast_tree1 = parse_ast(code1)
    ast_tree2 = parse_ast(code2)
    
    ast_graph1 = build_ast_graph(ast_tree1) if ast_tree1 else None
    ast_graph2 = build_ast_graph(ast_tree2) if ast_tree2 else None
    
    # Calculate individual similarities
    codebert_sim = F.cosine_similarity(emb1, emb2).item() if emb1 is not None and emb2 is not None else 0
    rnn_sim = rnn_similarity(emb1, emb2) if emb1 is not None and emb2 is not None else 0
    gnn_sim = gnn_similarity(ast_graph1[0] if ast_graph1 else None, 
                            ast_graph2[0] if ast_graph2 else None) or 0
    
    # Combine with weights (can be tuned)
    weights = {
        'codebert': 0.4,
        'rnn': 0.3,
        'gnn': 0.3
    }
    
    combined = (weights['codebert'] * codebert_sim + 
               weights['rnn'] * rnn_sim + 
               weights['gnn'] * gnn_sim)
    
    return {
        'combined': combined,
        'codebert': codebert_sim,
        'rnn': rnn_sim,
        'gnn': gnn_sim
    }

# Comparison function
def compare_code(code1, code2):
    if not code1 or not code2:
        return None
    
    with st.spinner('Analyzing code with multiple techniques...'):
        # Get lexical features
        lex1 = get_lexical_features(code1)
        lex2 = get_lexical_features(code2)
        
        # Get AST trees
        ast_tree1 = parse_ast(code1)
        ast_tree2 = parse_ast(code2)
        
        # Get syntactic features
        syn1 = get_syntactic_features(ast_tree1)
        syn2 = get_syntactic_features(ast_tree2)
        
        # Get semantic features
        sem1 = get_semantic_features(code1)
        sem2 = get_semantic_features(code2)
        
        # Calculate hybrid similarity
        similarities = hybrid_similarity(code1, code2)
        
        return {
            'similarities': similarities,
            'lexical_features': (lex1, lex2),
            'syntactic_features': (syn1, syn2),
            'ast_trees': (ast_tree1, ast_tree2)
        }

# UI Elements
st.title("πŸ” Advanced Java Code Clone Detector (IJaDataset 2.1)")
st.markdown("""
Detect all types of code clones (Type 1-4) using hybrid approach with:
- **CodeBERT** for semantic analysis
- **RNN** for sequence modeling
- **GNN** for AST structural analysis
""")

# Dataset selector
selected_pair = None
if dataset_pairs:
    pair_options = {f"{i+1}: {pair['type']}": pair for i, pair in enumerate(dataset_pairs)}
    selected_option = st.selectbox("Select a preloaded example pair:", list(pair_options.keys()))
    selected_pair = pair_options[selected_option]

# Layout
col1, col2 = st.columns(2)

with col1:
    code1 = st.text_area(
        "First Java Code", 
        height=300,
        value=selected_pair["code1"] if selected_pair else "",
        help="Enter the first Java code snippet"
    )

with col2:
    code2 = st.text_area(
        "Second Java Code", 
        height=300,
        value=selected_pair["code2"] if selected_pair else "",
        help="Enter the second Java code snippet"
    )

# Threshold sliders
st.subheader("Detection Thresholds")
col1, col2, col3 = st.columns(3)

with col1:
    threshold_type12 = st.slider(
        "Type 1/2 Threshold",
        min_value=0.5,
        max_value=1.0,
        value=0.9,
        step=0.01,
        help="Threshold for exact/syntactic clones"
    )

with col2:
    threshold_type3 = st.slider(
        "Type 3 Threshold",
        min_value=0.5,
        max_value=1.0,
        value=0.8,
        step=0.01,
        help="Threshold for near-miss clones"
    )

with col3:
    threshold_type4 = st.slider(
        "Type 4 Threshold",
        min_value=0.5,
        max_value=1.0,
        value=0.7,
        step=0.01,
        help="Threshold for semantic clones"
    )

# Compare button
if st.button("Compare Code", type="primary"):
    if tokenizer is None or code_model is None or rnn_model is None or gnn_model is None:
        st.error("Models failed to load. Please check the logs.")
    else:
        result = compare_code(code1, code2)
        
        if result is not None:
            similarities = result['similarities']
            lex1, lex2 = result['lexical_features']
            syn1, syn2 = result['syntactic_features']
            ast_tree1, ast_tree2 = result['ast_trees']
            
            # Display results
            st.subheader("Detection Results")
            
            # Determine clone type
            combined_sim = similarities['combined']
            clone_type = "No Clone"
            
            if combined_sim >= threshold_type12:
                clone_type = "Type 1/2 Clone (Exact/Near-Exact)"
            elif combined_sim >= threshold_type3:
                clone_type = "Type 3 Clone (Near-Miss)"
            elif combined_sim >= threshold_type4:
                clone_type = "Type 4 Clone (Semantic)"
            
            # Main metrics
            col1, col2, col3 = st.columns(3)
            
            with col1:
                st.metric("Combined Similarity", f"{combined_sim:.3f}")
            
            with col2:
                st.metric("Detected Clone Type", clone_type)
            
            with col3:
                st.metric("CodeBERT Similarity", f"{similarities['codebert']:.3f}")
            
            # Detailed metrics
            with st.expander("Detailed Similarity Scores"):
                cols = st.columns(3)
                with cols[0]:
                    st.metric("RNN Similarity", f"{similarities['rnn']:.3f}")
                with cols[1]:
                    st.metric("GNN Similarity", f"{similarities['gnn']:.3f}")
                with cols[2]:
                    st.metric("Lexical Similarity", 
                            f"{sum(lex1[k] == lex2[k] for k in lex1)/max(len(lex1),1):.2f}")
            
            # Feature comparison
            with st.expander("Feature Analysis"):
                st.subheader("Lexical Features")
                lex_df = pd.DataFrame([lex1, lex2], index=["Code 1", "Code 2"])
                st.dataframe(lex_df)
                
                st.subheader("Syntactic Features (AST Node Counts)")
                syn_df = pd.DataFrame([syn1, syn2], index=["Code 1", "Code 2"]).fillna(0)
                st.dataframe(syn_df)
            
            # AST Visualization
            if ast_tree1 and ast_tree2:
                with st.expander("AST Visualization (First 20 nodes)"):
                    st.write("AST visualization would be implemented here with graphviz")
                    # In a real implementation, you would use graphviz to render the ASTs
                    # st.graphviz_chart(ast_to_graphviz(ast_tree1))
                    # st.graphviz_chart(ast_to_graphviz(ast_tree2))
            
            # Normalized code view
            with st.expander("Show normalized code"):
                tab1, tab2 = st.tabs(["First Code", "Second Code"])
                
                with tab1:
                    st.code(normalize_code(code1))
                
                with tab2:
                    st.code(normalize_code(code2))

# Footer
st.markdown("---")
st.markdown("""
*Dataset Information*:
- Using IJaDataset 2.1 from Kaggle
- Contains 100K Java files with clone annotations
- Clone types: Type-1, Type-2, Type-3, and Type-4 clones

*Model Architecture*:
- **CodeBERT**: Pre-trained model for semantic analysis
- **RNN**: Processes token sequences for sequential patterns
- **GNN**: Analyzes AST structure for syntactic patterns
- **Hybrid Approach**: Combines all techniques for comprehensive detection
""")