Aranwer commited on
Commit
c4ca5b1
·
verified ·
1 Parent(s): 66b43be

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +520 -0
app.py ADDED
@@ -0,0 +1,520 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import javalang
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import re
7
+ import numpy as np
8
+ import networkx as nx
9
+ from transformers import AutoTokenizer, AutoModel
10
+ from torch_geometric.data import Data
11
+ from torch_geometric.nn import GCNConv
12
+ import warnings
13
+ import pandas as pd
14
+ import zipfile
15
+ import os
16
+ from collections import defaultdict
17
+
18
+ # Set up page config
19
+ st.set_page_config(
20
+ page_title="Advanced Java Code Clone Detector (IJaDataset 2.1)",
21
+ page_icon="🔍",
22
+ layout="wide"
23
+ )
24
+
25
+ # Suppress warnings
26
+ warnings.filterwarnings("ignore")
27
+
28
+ # Constants
29
+ MODEL_NAME = "microsoft/codebert-base"
30
+ MAX_LENGTH = 512
31
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
32
+ DATASET_PATH = "archive (1).zip" # Update this path if needed
33
+
34
+ # Initialize models with caching
35
+ @st.cache_resource
36
+ def load_models():
37
+ try:
38
+ # Load CodeBERT for semantic analysis
39
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
40
+ code_model = AutoModel.from_pretrained(MODEL_NAME).to(DEVICE)
41
+
42
+ # Initialize RNN model
43
+ class RNNModel(nn.Module):
44
+ def __init__(self, input_size, hidden_size, num_layers):
45
+ super(RNNModel, self).__init__()
46
+ self.hidden_size = hidden_size
47
+ self.num_layers = num_layers
48
+ self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)
49
+ self.fc = nn.Linear(hidden_size, 1)
50
+
51
+ def forward(self, x):
52
+ h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(DEVICE)
53
+ out, _ = self.rnn(x, h0)
54
+ out = self.fc(out[:, -1, :])
55
+ return out
56
+
57
+ rnn_model = RNNModel(input_size=768, hidden_size=256, num_layers=2).to(DEVICE)
58
+
59
+ # Initialize GNN model
60
+ class GNNModel(nn.Module):
61
+ def __init__(self, node_features):
62
+ super(GNNModel, self).__init__()
63
+ self.conv1 = GCNConv(node_features, 128)
64
+ self.conv2 = GCNConv(128, 64)
65
+ self.fc = nn.Linear(64, 1)
66
+
67
+ def forward(self, data):
68
+ x, edge_index = data.x, data.edge_index
69
+ x = F.relu(self.conv1(x, edge_index))
70
+ x = F.dropout(x, training=self.training)
71
+ x = self.conv2(x, edge_index)
72
+ x = self.fc(x)
73
+ return torch.sigmoid(x.mean())
74
+
75
+ gnn_model = GNNModel(node_features=128).to(DEVICE)
76
+
77
+ return tokenizer, code_model, rnn_model, gnn_model
78
+ except Exception as e:
79
+ st.error(f"Failed to load models: {str(e)}")
80
+ return None, None, None, None
81
+
82
+ @st.cache_resource
83
+ def load_dataset():
84
+ try:
85
+ # Extract dataset if needed
86
+ if not os.path.exists("Diverse_100K_Dataset"):
87
+ with zipfile.ZipFile(DATASET_PATH, 'r') as zip_ref:
88
+ zip_ref.extractall(".")
89
+
90
+ # Load sample pairs (modify this based on your dataset structure)
91
+ clone_pairs = []
92
+ base_path = "Subject_CloneTypes_Directories"
93
+
94
+ # Load pairs from all clone types
95
+ for clone_type in ["Clone_Type1", "Clone_Type2", "Clone_Type3 - ST", "Clone_Type4"]:
96
+ type_path = os.path.join(base_path, clone_type)
97
+ if os.path.exists(type_path):
98
+ for root, _, files in os.walk(type_path):
99
+ if files:
100
+ # Take first two files as a pair
101
+ if len(files) >= 2:
102
+ with open(os.path.join(root, files[0]), 'r', encoding='utf-8') as f1:
103
+ code1 = f1.read()
104
+ with open(os.path.join(root, files[1]), 'r', encoding='utf-8') as f2:
105
+ code2 = f2.read()
106
+ clone_pairs.append({
107
+ "type": clone_type,
108
+ "code1": code1,
109
+ "code2": code2
110
+ })
111
+ break # Just take one pair per type for demo
112
+
113
+ return clone_pairs[:10] # Return first 10 pairs for demo
114
+
115
+ except Exception as e:
116
+ st.error(f"Error loading dataset: {str(e)}")
117
+ return []
118
+
119
+ tokenizer, code_model, rnn_model, gnn_model = load_models()
120
+ dataset_pairs = load_dataset()
121
+
122
+ # AST Processing Functions
123
+ def parse_ast(code):
124
+ try:
125
+ tokens = javalang.tokenizer.tokenize(code)
126
+ parser = javalang.parser.Parser(tokens)
127
+ tree = parser.parse()
128
+ return tree
129
+ except Exception as e:
130
+ st.warning(f"AST parsing error: {str(e)}")
131
+ return None
132
+
133
+ def build_ast_graph(ast_tree):
134
+ if not ast_tree:
135
+ return None
136
+
137
+ G = nx.DiGraph()
138
+ node_id = 0
139
+ node_map = {}
140
+
141
+ def traverse(node, parent_id=None):
142
+ nonlocal node_id
143
+ current_id = node_id
144
+ node_label = str(type(node).__name__)
145
+ node_map[current_id] = {'type': node_label, 'node': node}
146
+ G.add_node(current_id, type=node_label)
147
+
148
+ if parent_id is not None:
149
+ G.add_edge(parent_id, current_id)
150
+
151
+ node_id += 1
152
+
153
+ for child in node.children:
154
+ if isinstance(child, javalang.ast.Node):
155
+ traverse(child, current_id)
156
+ elif isinstance(child, (list, tuple)):
157
+ for item in child:
158
+ if isinstance(item, javalang.ast.Node):
159
+ traverse(item, current_id)
160
+
161
+ traverse(ast_tree)
162
+ return G, node_map
163
+
164
+ def ast_to_pyg_data(ast_graph):
165
+ if not ast_graph:
166
+ return None
167
+
168
+ # Convert AST to PyTorch Geometric Data format
169
+ node_features = []
170
+ node_types = []
171
+
172
+ for node in ast_graph.nodes():
173
+ node_type = ast_graph.nodes[node]['type']
174
+ node_types.append(node_type)
175
+ # Simple one-hot encoding of node types (in practice, use better encoding)
176
+ feature = [0] * 50 # Assuming max 50 node types
177
+ feature[hash(node_type) % 50] = 1
178
+ node_features.append(feature)
179
+
180
+ # Convert networkx graph to edge_index format
181
+ edge_index = list(ast_graph.edges())
182
+ if not edge_index:
183
+ # Add self-loop if no edges
184
+ edge_index = [(0, 0)]
185
+
186
+ edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
187
+ x = torch.tensor(node_features, dtype=torch.float)
188
+
189
+ return Data(x=x, edge_index=edge_index)
190
+
191
+ # Normalization function
192
+ def normalize_code(code):
193
+ try:
194
+ code = re.sub(r'//.*', '', code) # Remove single-line comments
195
+ code = re.sub(r'/\*.*?\*/', '', code, flags=re.DOTALL) # Multi-line comments
196
+ code = re.sub(r'\s+', ' ', code).strip() # Normalize whitespace
197
+ return code
198
+ except Exception:
199
+ return code
200
+
201
+ # Feature extraction functions
202
+ def get_lexical_features(code):
203
+ """Extract lexical features (for Type-1 and Type-2 clones)"""
204
+ normalized = normalize_code(code)
205
+ tokens = re.findall(r'\b\w+\b', normalized)
206
+ return {
207
+ 'token_count': len(tokens),
208
+ 'unique_tokens': len(set(tokens)),
209
+ 'avg_token_length': np.mean([len(t) for t in tokens]) if tokens else 0
210
+ }
211
+
212
+ def get_syntactic_features(ast_tree):
213
+ """Extract syntactic features (for Type-3 clones)"""
214
+ if not ast_tree:
215
+ return {}
216
+
217
+ # Count different node types in AST
218
+ node_counts = defaultdict(int)
219
+
220
+ def count_nodes(node):
221
+ node_counts[type(node).__name__] += 1
222
+ for child in node.children:
223
+ if isinstance(child, javalang.ast.Node):
224
+ count_nodes(child)
225
+ elif isinstance(child, (list, tuple)):
226
+ for item in child:
227
+ if isinstance(item, javalang.ast.Node):
228
+ count_nodes(item)
229
+
230
+ count_nodes(ast_tree)
231
+ return dict(node_counts)
232
+
233
+ def get_semantic_features(code):
234
+ """Extract semantic features (for Type-4 clones)"""
235
+ embedding = get_embedding(code)
236
+ return embedding.cpu().numpy().flatten() if embedding is not None else None
237
+
238
+ # Embedding generation
239
+ def get_embedding(code):
240
+ try:
241
+ code = normalize_code(code)
242
+ inputs = tokenizer(
243
+ code,
244
+ return_tensors="pt",
245
+ truncation=True,
246
+ max_length=MAX_LENGTH,
247
+ padding='max_length'
248
+ ).to(DEVICE)
249
+
250
+ with torch.no_grad():
251
+ outputs = code_model(**inputs)
252
+
253
+ return outputs.last_hidden_state.mean(dim=1) # Pooled embedding
254
+ except Exception as e:
255
+ st.error(f"Error processing code: {str(e)}")
256
+ return None
257
+
258
+ # Clone detection models
259
+ def rnn_similarity(emb1, emb2):
260
+ """Calculate similarity using RNN model"""
261
+ if emb1 is None or emb2 is None:
262
+ return None
263
+
264
+ # Prepare input for RNN (sequence of embeddings)
265
+ combined = torch.cat([emb1.unsqueeze(0), emb2.unsqueeze(0)], dim=0)
266
+ with torch.no_grad():
267
+ similarity = rnn_model(combined.permute(1, 0, 2))
268
+ return torch.sigmoid(similarity).item()
269
+
270
+ def gnn_similarity(ast1, ast2):
271
+ """Calculate similarity using GNN model"""
272
+ if ast1 is None or ast2 is None:
273
+ return None
274
+
275
+ data1 = ast_to_pyg_data(ast1)
276
+ data2 = ast_to_pyg_data(ast2)
277
+
278
+ if data1 is None or data2 is None:
279
+ return None
280
+
281
+ # Move data to device
282
+ data1 = data1.to(DEVICE)
283
+ data2 = data2.to(DEVICE)
284
+
285
+ with torch.no_grad():
286
+ sim1 = gnn_model(data1)
287
+ sim2 = gnn_model(data2)
288
+
289
+ return F.cosine_similarity(sim1, sim2).item()
290
+
291
+ def hybrid_similarity(code1, code2):
292
+ """Combined similarity score using all models"""
293
+ # Get embeddings
294
+ emb1 = get_embedding(code1)
295
+ emb2 = get_embedding(code2)
296
+
297
+ # Parse ASTs
298
+ ast_tree1 = parse_ast(code1)
299
+ ast_tree2 = parse_ast(code2)
300
+
301
+ ast_graph1 = build_ast_graph(ast_tree1) if ast_tree1 else None
302
+ ast_graph2 = build_ast_graph(ast_tree2) if ast_tree2 else None
303
+
304
+ # Calculate individual similarities
305
+ codebert_sim = F.cosine_similarity(emb1, emb2).item() if emb1 is not None and emb2 is not None else 0
306
+ rnn_sim = rnn_similarity(emb1, emb2) if emb1 is not None and emb2 is not None else 0
307
+ gnn_sim = gnn_similarity(ast_graph1[0] if ast_graph1 else None,
308
+ ast_graph2[0] if ast_graph2 else None) or 0
309
+
310
+ # Combine with weights (can be tuned)
311
+ weights = {
312
+ 'codebert': 0.4,
313
+ 'rnn': 0.3,
314
+ 'gnn': 0.3
315
+ }
316
+
317
+ combined = (weights['codebert'] * codebert_sim +
318
+ weights['rnn'] * rnn_sim +
319
+ weights['gnn'] * gnn_sim)
320
+
321
+ return {
322
+ 'combined': combined,
323
+ 'codebert': codebert_sim,
324
+ 'rnn': rnn_sim,
325
+ 'gnn': gnn_sim
326
+ }
327
+
328
+ # Comparison function
329
+ def compare_code(code1, code2):
330
+ if not code1 or not code2:
331
+ return None
332
+
333
+ with st.spinner('Analyzing code with multiple techniques...'):
334
+ # Get lexical features
335
+ lex1 = get_lexical_features(code1)
336
+ lex2 = get_lexical_features(code2)
337
+
338
+ # Get AST trees
339
+ ast_tree1 = parse_ast(code1)
340
+ ast_tree2 = parse_ast(code2)
341
+
342
+ # Get syntactic features
343
+ syn1 = get_syntactic_features(ast_tree1)
344
+ syn2 = get_syntactic_features(ast_tree2)
345
+
346
+ # Get semantic features
347
+ sem1 = get_semantic_features(code1)
348
+ sem2 = get_semantic_features(code2)
349
+
350
+ # Calculate hybrid similarity
351
+ similarities = hybrid_similarity(code1, code2)
352
+
353
+ return {
354
+ 'similarities': similarities,
355
+ 'lexical_features': (lex1, lex2),
356
+ 'syntactic_features': (syn1, syn2),
357
+ 'ast_trees': (ast_tree1, ast_tree2)
358
+ }
359
+
360
+ # UI Elements
361
+ st.title("🔍 Advanced Java Code Clone Detector (IJaDataset 2.1)")
362
+ st.markdown("""
363
+ Detect all types of code clones (Type 1-4) using hybrid approach with:
364
+ - **CodeBERT** for semantic analysis
365
+ - **RNN** for sequence modeling
366
+ - **GNN** for AST structural analysis
367
+ """)
368
+
369
+ # Dataset selector
370
+ selected_pair = None
371
+ if dataset_pairs:
372
+ pair_options = {f"{i+1}: {pair['type']}": pair for i, pair in enumerate(dataset_pairs)}
373
+ selected_option = st.selectbox("Select a preloaded example pair:", list(pair_options.keys()))
374
+ selected_pair = pair_options[selected_option]
375
+
376
+ # Layout
377
+ col1, col2 = st.columns(2)
378
+
379
+ with col1:
380
+ code1 = st.text_area(
381
+ "First Java Code",
382
+ height=300,
383
+ value=selected_pair["code1"] if selected_pair else "",
384
+ help="Enter the first Java code snippet"
385
+ )
386
+
387
+ with col2:
388
+ code2 = st.text_area(
389
+ "Second Java Code",
390
+ height=300,
391
+ value=selected_pair["code2"] if selected_pair else "",
392
+ help="Enter the second Java code snippet"
393
+ )
394
+
395
+ # Threshold sliders
396
+ st.subheader("Detection Thresholds")
397
+ col1, col2, col3 = st.columns(3)
398
+
399
+ with col1:
400
+ threshold_type12 = st.slider(
401
+ "Type 1/2 Threshold",
402
+ min_value=0.5,
403
+ max_value=1.0,
404
+ value=0.9,
405
+ step=0.01,
406
+ help="Threshold for exact/syntactic clones"
407
+ )
408
+
409
+ with col2:
410
+ threshold_type3 = st.slider(
411
+ "Type 3 Threshold",
412
+ min_value=0.5,
413
+ max_value=1.0,
414
+ value=0.8,
415
+ step=0.01,
416
+ help="Threshold for near-miss clones"
417
+ )
418
+
419
+ with col3:
420
+ threshold_type4 = st.slider(
421
+ "Type 4 Threshold",
422
+ min_value=0.5,
423
+ max_value=1.0,
424
+ value=0.7,
425
+ step=0.01,
426
+ help="Threshold for semantic clones"
427
+ )
428
+
429
+ # Compare button
430
+ if st.button("Compare Code", type="primary"):
431
+ if tokenizer is None or code_model is None or rnn_model is None or gnn_model is None:
432
+ st.error("Models failed to load. Please check the logs.")
433
+ else:
434
+ result = compare_code(code1, code2)
435
+
436
+ if result is not None:
437
+ similarities = result['similarities']
438
+ lex1, lex2 = result['lexical_features']
439
+ syn1, syn2 = result['syntactic_features']
440
+ ast_tree1, ast_tree2 = result['ast_trees']
441
+
442
+ # Display results
443
+ st.subheader("Detection Results")
444
+
445
+ # Determine clone type
446
+ combined_sim = similarities['combined']
447
+ clone_type = "No Clone"
448
+
449
+ if combined_sim >= threshold_type12:
450
+ clone_type = "Type 1/2 Clone (Exact/Near-Exact)"
451
+ elif combined_sim >= threshold_type3:
452
+ clone_type = "Type 3 Clone (Near-Miss)"
453
+ elif combined_sim >= threshold_type4:
454
+ clone_type = "Type 4 Clone (Semantic)"
455
+
456
+ # Main metrics
457
+ col1, col2, col3 = st.columns(3)
458
+
459
+ with col1:
460
+ st.metric("Combined Similarity", f"{combined_sim:.3f}")
461
+
462
+ with col2:
463
+ st.metric("Detected Clone Type", clone_type)
464
+
465
+ with col3:
466
+ st.metric("CodeBERT Similarity", f"{similarities['codebert']:.3f}")
467
+
468
+ # Detailed metrics
469
+ with st.expander("Detailed Similarity Scores"):
470
+ cols = st.columns(3)
471
+ with cols[0]:
472
+ st.metric("RNN Similarity", f"{similarities['rnn']:.3f}")
473
+ with cols[1]:
474
+ st.metric("GNN Similarity", f"{similarities['gnn']:.3f}")
475
+ with cols[2]:
476
+ st.metric("Lexical Similarity",
477
+ f"{sum(lex1[k] == lex2[k] for k in lex1)/max(len(lex1),1):.2f}")
478
+
479
+ # Feature comparison
480
+ with st.expander("Feature Analysis"):
481
+ st.subheader("Lexical Features")
482
+ lex_df = pd.DataFrame([lex1, lex2], index=["Code 1", "Code 2"])
483
+ st.dataframe(lex_df)
484
+
485
+ st.subheader("Syntactic Features (AST Node Counts)")
486
+ syn_df = pd.DataFrame([syn1, syn2], index=["Code 1", "Code 2"]).fillna(0)
487
+ st.dataframe(syn_df)
488
+
489
+ # AST Visualization
490
+ if ast_tree1 and ast_tree2:
491
+ with st.expander("AST Visualization (First 20 nodes)"):
492
+ st.write("AST visualization would be implemented here with graphviz")
493
+ # In a real implementation, you would use graphviz to render the ASTs
494
+ # st.graphviz_chart(ast_to_graphviz(ast_tree1))
495
+ # st.graphviz_chart(ast_to_graphviz(ast_tree2))
496
+
497
+ # Normalized code view
498
+ with st.expander("Show normalized code"):
499
+ tab1, tab2 = st.tabs(["First Code", "Second Code"])
500
+
501
+ with tab1:
502
+ st.code(normalize_code(code1))
503
+
504
+ with tab2:
505
+ st.code(normalize_code(code2))
506
+
507
+ # Footer
508
+ st.markdown("---")
509
+ st.markdown("""
510
+ *Dataset Information*:
511
+ - Using IJaDataset 2.1 from Kaggle
512
+ - Contains 100K Java files with clone annotations
513
+ - Clone types: Type-1, Type-2, Type-3, and Type-4 clones
514
+
515
+ *Model Architecture*:
516
+ - **CodeBERT**: Pre-trained model for semantic analysis
517
+ - **RNN**: Processes token sequences for sequential patterns
518
+ - **GNN**: Analyzes AST structure for syntactic patterns
519
+ - **Hybrid Approach**: Combines all techniques for comprehensive detection
520
+ """)