Aranwer commited on
Commit
84d4d13
·
verified ·
1 Parent(s): f2bf3d8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +175 -413
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import streamlit as st
2
  import javalang
3
  import torch
@@ -6,241 +7,164 @@ import torch.nn.functional as F
6
  import re
7
  import numpy as np
8
  import networkx as nx
9
- from transformers import AutoTokenizer, AutoModel
10
  from torch_geometric.data import Data
11
  from torch_geometric.nn import GCNConv
12
  import warnings
13
  import pandas as pd
14
  import zipfile
15
- import os
16
  from collections import defaultdict
17
 
 
 
 
 
 
 
 
 
 
 
 
18
  # Set up page config
19
  st.set_page_config(
20
- page_title="Advanced Java Code Clone Detector (IJaDataset 2.1)",
21
  page_icon="🔍",
22
  layout="wide"
23
  )
24
 
25
- # Suppress warnings
26
- warnings.filterwarnings("ignore")
 
 
 
 
 
 
 
 
 
 
 
27
 
28
- # Constants
29
- MODEL_NAME = "microsoft/codebert-base"
30
- MAX_LENGTH = 512
31
- DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
32
- DATASET_PATH = "archive (1).zip" # Update this path if needed
 
 
 
 
 
 
 
 
33
 
34
- # Initialize models with caching
35
- @st.cache_resource
36
  def load_models():
37
  try:
38
- # Load CodeBERT for semantic analysis
39
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
40
- code_model = AutoModel.from_pretrained(MODEL_NAME).to(DEVICE)
41
-
42
- # Initialize RNN model
43
- class RNNModel(nn.Module):
44
- def __init__(self, input_size, hidden_size, num_layers):
45
- super(RNNModel, self).__init__()
46
- self.hidden_size = hidden_size
47
- self.num_layers = num_layers
48
- self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)
49
- self.fc = nn.Linear(hidden_size, 1)
50
-
51
- def forward(self, x):
52
- h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(DEVICE)
53
- out, _ = self.rnn(x, h0)
54
- out = self.fc(out[:, -1, :])
55
- return out
56
-
57
- rnn_model = RNNModel(input_size=768, hidden_size=256, num_layers=2).to(DEVICE)
58
-
59
- # Initialize GNN model
60
- class GNNModel(nn.Module):
61
- def __init__(self, node_features):
62
- super(GNNModel, self).__init__()
63
- self.conv1 = GCNConv(node_features, 128)
64
- self.conv2 = GCNConv(128, 64)
65
- self.fc = nn.Linear(64, 1)
66
-
67
- def forward(self, data):
68
- x, edge_index = data.x, data.edge_index
69
- x = F.relu(self.conv1(x, edge_index))
70
- x = F.dropout(x, training=self.training)
71
- x = self.conv2(x, edge_index)
72
- x = self.fc(x)
73
- return torch.sigmoid(x.mean())
74
-
75
- gnn_model = GNNModel(node_features=128).to(DEVICE)
76
-
77
- return tokenizer, code_model, rnn_model, gnn_model
78
  except Exception as e:
79
- st.error(f"Failed to load models: {str(e)}")
80
  return None, None, None, None
81
 
 
82
  @st.cache_resource
83
  def load_dataset():
84
  try:
85
- # Extract dataset if needed
86
  if not os.path.exists("Diverse_100K_Dataset"):
87
  with zipfile.ZipFile(DATASET_PATH, 'r') as zip_ref:
88
  zip_ref.extractall(".")
89
 
90
- # Load sample pairs (modify this based on your dataset structure)
91
  clone_pairs = []
92
- base_path = "Subject_CloneTypes_Directories"
93
 
94
- # Load pairs from all clone types
95
  for clone_type in ["Clone_Type1", "Clone_Type2", "Clone_Type3 - ST", "Clone_Type4"]:
96
  type_path = os.path.join(base_path, clone_type)
97
  if os.path.exists(type_path):
98
  for root, _, files in os.walk(type_path):
99
- if files:
100
- # Take first two files as a pair
101
- if len(files) >= 2:
102
- with open(os.path.join(root, files[0]), 'r', encoding='utf-8') as f1:
103
- code1 = f1.read()
104
- with open(os.path.join(root, files[1]), 'r', encoding='utf-8') as f2:
105
- code2 = f2.read()
106
  clone_pairs.append({
107
  "type": clone_type,
108
- "code1": code1,
109
- "code2": code2
110
  })
111
- break # Just take one pair per type for demo
112
-
113
- return clone_pairs[:10] # Return first 10 pairs for demo
114
 
 
115
  except Exception as e:
116
- st.error(f"Error loading dataset: {str(e)}")
117
  return []
118
 
119
- tokenizer, code_model, rnn_model, gnn_model = load_models()
120
- dataset_pairs = load_dataset()
121
-
122
- # AST Processing Functions
123
  def parse_ast(code):
124
  try:
125
- tokens = javalang.tokenizer.tokenize(code)
126
- parser = javalang.parser.Parser(tokens)
127
- tree = parser.parse()
128
- return tree
129
- except Exception as e:
130
- st.warning(f"AST parsing error: {str(e)}")
131
  return None
132
 
133
  def build_ast_graph(ast_tree):
134
- if not ast_tree:
135
- return None
136
 
137
  G = nx.DiGraph()
138
  node_id = 0
139
- node_map = {}
140
 
141
- def traverse(node, parent_id=None):
142
  nonlocal node_id
143
- current_id = node_id
144
- node_label = str(type(node).__name__)
145
- node_map[current_id] = {'type': node_label, 'node': node}
146
- G.add_node(current_id, type=node_label)
147
-
148
- if parent_id is not None:
149
- G.add_edge(parent_id, current_id)
150
-
151
  node_id += 1
152
 
153
- for child in node.children:
154
  if isinstance(child, javalang.ast.Node):
155
- traverse(child, current_id)
156
  elif isinstance(child, (list, tuple)):
157
  for item in child:
158
  if isinstance(item, javalang.ast.Node):
159
- traverse(item, current_id)
160
 
161
  traverse(ast_tree)
162
- return G, node_map
163
 
164
  def ast_to_pyg_data(ast_graph):
165
- if not ast_graph:
166
- return None
167
-
168
- # Convert AST to PyTorch Geometric Data format
169
- node_features = []
170
- node_types = []
171
-
172
- for node in ast_graph.nodes():
173
- node_type = ast_graph.nodes[node]['type']
174
- node_types.append(node_type)
175
- # Simple one-hot encoding of node types (in practice, use better encoding)
176
- feature = [0] * 50 # Assuming max 50 node types
177
- feature[hash(node_type) % 50] = 1
178
- node_features.append(feature)
179
 
180
- # Convert networkx graph to edge_index format
181
- edge_index = list(ast_graph.edges())
182
- if not edge_index:
183
- # Add self-loop if no edges
184
- edge_index = [(0, 0)]
185
 
186
- edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
187
- x = torch.tensor(node_features, dtype=torch.float)
 
 
 
188
 
189
- return Data(x=x, edge_index=edge_index)
190
 
191
- # Normalization function
192
  def normalize_code(code):
193
- try:
194
- code = re.sub(r'//.*', '', code) # Remove single-line comments
195
- code = re.sub(r'/\*.*?\*/', '', code, flags=re.DOTALL) # Multi-line comments
196
- code = re.sub(r'\s+', ' ', code).strip() # Normalize whitespace
197
- return code
198
- except Exception:
199
- return code
200
-
201
- # Feature extraction functions
202
- def get_lexical_features(code):
203
- """Extract lexical features (for Type-1 and Type-2 clones)"""
204
- normalized = normalize_code(code)
205
- tokens = re.findall(r'\b\w+\b', normalized)
206
- return {
207
- 'token_count': len(tokens),
208
- 'unique_tokens': len(set(tokens)),
209
- 'avg_token_length': np.mean([len(t) for t in tokens]) if tokens else 0
210
- }
211
-
212
- def get_syntactic_features(ast_tree):
213
- """Extract syntactic features (for Type-3 clones)"""
214
- if not ast_tree:
215
- return {}
216
-
217
- # Count different node types in AST
218
- node_counts = defaultdict(int)
219
-
220
- def count_nodes(node):
221
- node_counts[type(node).__name__] += 1
222
- for child in node.children:
223
- if isinstance(child, javalang.ast.Node):
224
- count_nodes(child)
225
- elif isinstance(child, (list, tuple)):
226
- for item in child:
227
- if isinstance(item, javalang.ast.Node):
228
- count_nodes(item)
229
-
230
- count_nodes(ast_tree)
231
- return dict(node_counts)
232
-
233
- def get_semantic_features(code):
234
- """Extract semantic features (for Type-4 clones)"""
235
- embedding = get_embedding(code)
236
- return embedding.cpu().numpy().flatten() if embedding is not None else None
237
 
238
- # Embedding generation
239
- def get_embedding(code):
240
  try:
241
- code = normalize_code(code)
242
  inputs = tokenizer(
243
- code,
244
  return_tensors="pt",
245
  truncation=True,
246
  max_length=MAX_LENGTH,
@@ -248,273 +172,111 @@ def get_embedding(code):
248
  ).to(DEVICE)
249
 
250
  with torch.no_grad():
251
- outputs = code_model(**inputs)
252
-
253
- return outputs.last_hidden_state.mean(dim=1) # Pooled embedding
254
- except Exception as e:
255
- st.error(f"Error processing code: {str(e)}")
256
- return None
257
-
258
- # Clone detection models
259
- def rnn_similarity(emb1, emb2):
260
- """Calculate similarity using RNN model"""
261
- if emb1 is None or emb2 is None:
262
  return None
263
-
264
- # Prepare input for RNN (sequence of embeddings)
265
- combined = torch.cat([emb1.unsqueeze(0), emb2.unsqueeze(0)], dim=0)
266
- with torch.no_grad():
267
- similarity = rnn_model(combined.permute(1, 0, 2))
268
- return torch.sigmoid(similarity).item()
269
 
270
- def gnn_similarity(ast1, ast2):
271
- """Calculate similarity using GNN model"""
272
- if ast1 is None or ast2 is None:
273
- return None
274
 
275
- data1 = ast_to_pyg_data(ast1)
276
- data2 = ast_to_pyg_data(ast2)
277
-
278
- if data1 is None or data2 is None:
279
- return None
280
-
281
- # Move data to device
282
- data1 = data1.to(DEVICE)
283
- data2 = data2.to(DEVICE)
284
-
285
- with torch.no_grad():
286
- sim1 = gnn_model(data1)
287
- sim2 = gnn_model(data2)
288
-
289
- return F.cosine_similarity(sim1, sim2).item()
290
-
291
- def hybrid_similarity(code1, code2):
292
- """Combined similarity score using all models"""
293
  # Get embeddings
294
- emb1 = get_embedding(code1)
295
- emb2 = get_embedding(code2)
296
 
297
  # Parse ASTs
298
- ast_tree1 = parse_ast(code1)
299
- ast_tree2 = parse_ast(code2)
300
 
301
- ast_graph1 = build_ast_graph(ast_tree1) if ast_tree1 else None
302
- ast_graph2 = build_ast_graph(ast_tree2) if ast_tree2 else None
303
-
304
- # Calculate individual similarities
305
  codebert_sim = F.cosine_similarity(emb1, emb2).item() if emb1 is not None and emb2 is not None else 0
306
- rnn_sim = rnn_similarity(emb1, emb2) if emb1 is not None and emb2 is not None else 0
307
- gnn_sim = gnn_similarity(ast_graph1[0] if ast_graph1 else None,
308
- ast_graph2[0] if ast_graph2 else None) or 0
309
 
310
- # Combine with weights (can be tuned)
311
- weights = {
312
- 'codebert': 0.4,
313
- 'rnn': 0.3,
314
- 'gnn': 0.3
315
- }
316
 
317
- combined = (weights['codebert'] * codebert_sim +
318
- weights['rnn'] * rnn_sim +
319
- weights['gnn'] * gnn_sim)
 
 
 
 
 
 
 
320
 
321
  return {
322
- 'combined': combined,
323
  'codebert': codebert_sim,
324
  'rnn': rnn_sim,
325
- 'gnn': gnn_sim
 
326
  }
327
 
328
- # Comparison function
329
- def compare_code(code1, code2):
330
- if not code1 or not code2:
331
- return None
332
 
333
- with st.spinner('Analyzing code with multiple techniques...'):
334
- # Get lexical features
335
- lex1 = get_lexical_features(code1)
336
- lex2 = get_lexical_features(code2)
337
-
338
- # Get AST trees
339
- ast_tree1 = parse_ast(code1)
340
- ast_tree2 = parse_ast(code2)
341
-
342
- # Get syntactic features
343
- syn1 = get_syntactic_features(ast_tree1)
344
- syn2 = get_syntactic_features(ast_tree2)
345
-
346
- # Get semantic features
347
- sem1 = get_semantic_features(code1)
348
- sem2 = get_semantic_features(code2)
349
-
350
- # Calculate hybrid similarity
351
- similarities = hybrid_similarity(code1, code2)
352
-
353
- return {
354
- 'similarities': similarities,
355
- 'lexical_features': (lex1, lex2),
356
- 'syntactic_features': (syn1, syn2),
357
- 'ast_trees': (ast_tree1, ast_tree2)
358
- }
359
-
360
- # UI Elements
361
- st.title("🔍 Advanced Java Code Clone Detector (IJaDataset 2.1)")
362
- st.markdown("""
363
- Detect all types of code clones (Type 1-4) using hybrid approach with:
364
- - **CodeBERT** for semantic analysis
365
- - **RNN** for sequence modeling
366
- - **GNN** for AST structural analysis
367
- """)
368
-
369
- # Dataset selector
370
- selected_pair = None
371
- if dataset_pairs:
372
- pair_options = {f"{i+1}: {pair['type']}": pair for i, pair in enumerate(dataset_pairs)}
373
- selected_option = st.selectbox("Select a preloaded example pair:", list(pair_options.keys()))
374
- selected_pair = pair_options[selected_option]
375
-
376
- # Layout
377
- col1, col2 = st.columns(2)
378
-
379
- with col1:
380
- code1 = st.text_area(
381
- "First Java Code",
382
- height=300,
383
- value=selected_pair["code1"] if selected_pair else "",
384
- help="Enter the first Java code snippet"
385
- )
386
-
387
- with col2:
388
- code2 = st.text_area(
389
- "Second Java Code",
390
- height=300,
391
- value=selected_pair["code2"] if selected_pair else "",
392
- help="Enter the second Java code snippet"
393
- )
394
-
395
- # Threshold sliders
396
- st.subheader("Detection Thresholds")
397
- col1, col2, col3 = st.columns(3)
398
-
399
- with col1:
400
- threshold_type12 = st.slider(
401
- "Type 1/2 Threshold",
402
- min_value=0.5,
403
- max_value=1.0,
404
- value=0.9,
405
- step=0.01,
406
- help="Threshold for exact/syntactic clones"
407
- )
408
-
409
- with col2:
410
- threshold_type3 = st.slider(
411
- "Type 3 Threshold",
412
- min_value=0.5,
413
- max_value=1.0,
414
- value=0.8,
415
- step=0.01,
416
- help="Threshold for near-miss clones"
417
- )
418
-
419
- with col3:
420
- threshold_type4 = st.slider(
421
- "Type 4 Threshold",
422
- min_value=0.5,
423
- max_value=1.0,
424
- value=0.7,
425
- step=0.01,
426
- help="Threshold for semantic clones"
427
- )
428
-
429
- # Compare button
430
- if st.button("Compare Code", type="primary"):
431
- if tokenizer is None or code_model is None or rnn_model is None or gnn_model is None:
432
- st.error("Models failed to load. Please check the logs.")
433
- else:
434
- result = compare_code(code1, code2)
435
-
436
- if result is not None:
437
- similarities = result['similarities']
438
- lex1, lex2 = result['lexical_features']
439
- syn1, syn2 = result['syntactic_features']
440
- ast_tree1, ast_tree2 = result['ast_trees']
441
-
442
- # Display results
443
- st.subheader("Detection Results")
444
 
445
  # Determine clone type
446
- combined_sim = similarities['combined']
447
  clone_type = "No Clone"
 
 
 
 
 
 
448
 
449
- if combined_sim >= threshold_type12:
450
- clone_type = "Type 1/2 Clone (Exact/Near-Exact)"
451
- elif combined_sim >= threshold_type3:
452
- clone_type = "Type 3 Clone (Near-Miss)"
453
- elif combined_sim >= threshold_type4:
454
- clone_type = "Type 4 Clone (Semantic)"
455
-
456
- # Main metrics
457
- col1, col2, col3 = st.columns(3)
458
-
459
- with col1:
460
- st.metric("Combined Similarity", f"{combined_sim:.3f}")
461
-
462
- with col2:
463
- st.metric("Detected Clone Type", clone_type)
464
-
465
- with col3:
466
- st.metric("CodeBERT Similarity", f"{similarities['codebert']:.3f}")
467
-
468
- # Detailed metrics
469
- with st.expander("Detailed Similarity Scores"):
470
- cols = st.columns(3)
471
- with cols[0]:
472
- st.metric("RNN Similarity", f"{similarities['rnn']:.3f}")
473
- with cols[1]:
474
- st.metric("GNN Similarity", f"{similarities['gnn']:.3f}")
475
- with cols[2]:
476
- st.metric("Lexical Similarity",
477
- f"{sum(lex1[k] == lex2[k] for k in lex1)/max(len(lex1),1):.2f}")
478
-
479
- # Feature comparison
480
- with st.expander("Feature Analysis"):
481
- st.subheader("Lexical Features")
482
- lex_df = pd.DataFrame([lex1, lex2], index=["Code 1", "Code 2"])
483
- st.dataframe(lex_df)
484
-
485
- st.subheader("Syntactic Features (AST Node Counts)")
486
- syn_df = pd.DataFrame([syn1, syn2], index=["Code 1", "Code 2"]).fillna(0)
487
- st.dataframe(syn_df)
488
 
489
- # AST Visualization
490
- if ast_tree1 and ast_tree2:
491
- with st.expander("AST Visualization (First 20 nodes)"):
492
- st.write("AST visualization would be implemented here with graphviz")
493
- # In a real implementation, you would use graphviz to render the ASTs
494
- # st.graphviz_chart(ast_to_graphviz(ast_tree1))
495
- # st.graphviz_chart(ast_to_graphviz(ast_tree2))
496
 
497
- # Normalized code view
498
- with st.expander("Show normalized code"):
499
- tab1, tab2 = st.tabs(["First Code", "Second Code"])
500
-
501
- with tab1:
502
- st.code(normalize_code(code1))
503
-
504
- with tab2:
505
- st.code(normalize_code(code2))
506
-
507
- # Footer
508
- st.markdown("---")
509
- st.markdown("""
510
- *Dataset Information*:
511
- - Using IJaDataset 2.1 from Kaggle
512
- - Contains 100K Java files with clone annotations
513
- - Clone types: Type-1, Type-2, Type-3, and Type-4 clones
514
 
515
- *Model Architecture*:
516
- - **CodeBERT**: Pre-trained model for semantic analysis
517
- - **RNN**: Processes token sequences for sequential patterns
518
- - **GNN**: Analyzes AST structure for syntactic patterns
519
- - **Hybrid Approach**: Combines all techniques for comprehensive detection
520
- """)
 
1
+ import os
2
  import streamlit as st
3
  import javalang
4
  import torch
 
7
  import re
8
  import numpy as np
9
  import networkx as nx
10
+ from transformers import AutoTokenizer, AutoModel, AutoConfig
11
  from torch_geometric.data import Data
12
  from torch_geometric.nn import GCNConv
13
  import warnings
14
  import pandas as pd
15
  import zipfile
 
16
  from collections import defaultdict
17
 
18
+ # Configuration
19
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
20
+ warnings.filterwarnings("ignore")
21
+
22
+ # Constants
23
+ MODEL_NAME = "microsoft/codebert-base"
24
+ MAX_LENGTH = 512
25
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
26
+ DATASET_PATH = "ijadataset2-1.zip"
27
+ CACHE_DIR = "./model_cache"
28
+
29
  # Set up page config
30
  st.set_page_config(
31
+ page_title="Advanced Java Code Clone Detector",
32
  page_icon="🔍",
33
  layout="wide"
34
  )
35
 
36
+ # Model Definitions
37
+ class RNNModel(nn.Module):
38
+ def __init__(self, input_size, hidden_size, num_layers):
39
+ super().__init__()
40
+ self.hidden_size = hidden_size
41
+ self.num_layers = num_layers
42
+ self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)
43
+ self.fc = nn.Linear(hidden_size, 1)
44
+
45
+ def forward(self, x):
46
+ h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(DEVICE)
47
+ out, _ = self.rnn(x, h0)
48
+ return self.fc(out[:, -1, :])
49
 
50
+ class GNNModel(nn.Module):
51
+ def __init__(self, node_features):
52
+ super().__init__()
53
+ self.conv1 = GCNConv(node_features, 128)
54
+ self.conv2 = GCNConv(128, 64)
55
+ self.fc = nn.Linear(64, 1)
56
+
57
+ def forward(self, data):
58
+ x, edge_index = data.x, data.edge_index
59
+ x = F.relu(self.conv1(x, edge_index))
60
+ x = F.dropout(x, training=self.training)
61
+ x = self.conv2(x, edge_index)
62
+ return torch.sigmoid(self.fc(x).mean())
63
 
64
+ # Model Loading with Cache
65
+ @st.cache_resource(show_spinner=False)
66
  def load_models():
67
  try:
68
+ with st.spinner('Loading models (first run may take a few minutes)...'):
69
+ config = AutoConfig.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
70
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
71
+ model = AutoModel.from_pretrained(MODEL_NAME, config=config, cache_dir=CACHE_DIR).to(DEVICE)
72
+
73
+ rnn_model = RNNModel(input_size=768, hidden_size=256, num_layers=2).to(DEVICE)
74
+ gnn_model = GNNModel(node_features=128).to(DEVICE)
75
+
76
+ return tokenizer, model, rnn_model, gnn_model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  except Exception as e:
78
+ st.error(f"Model loading failed: {str(e)}")
79
  return None, None, None, None
80
 
81
+ # Dataset Loading
82
  @st.cache_resource
83
  def load_dataset():
84
  try:
 
85
  if not os.path.exists("Diverse_100K_Dataset"):
86
  with zipfile.ZipFile(DATASET_PATH, 'r') as zip_ref:
87
  zip_ref.extractall(".")
88
 
 
89
  clone_pairs = []
90
+ base_path = "Diverse_100K_Dataset/Subject_CloneTypes_Directories"
91
 
 
92
  for clone_type in ["Clone_Type1", "Clone_Type2", "Clone_Type3 - ST", "Clone_Type4"]:
93
  type_path = os.path.join(base_path, clone_type)
94
  if os.path.exists(type_path):
95
  for root, _, files in os.walk(type_path):
96
+ if files and len(files) >= 2:
97
+ with open(os.path.join(root, files[0]), 'r', encoding='utf-8') as f1, \
98
+ open(os.path.join(root, files[1]), 'r', encoding='utf-8') as f2:
 
 
 
 
99
  clone_pairs.append({
100
  "type": clone_type,
101
+ "code1": f1.read(),
102
+ "code2": f2.read()
103
  })
104
+ break
 
 
105
 
106
+ return clone_pairs[:10]
107
  except Exception as e:
108
+ st.error(f"Dataset error: {str(e)}")
109
  return []
110
 
111
+ # AST Processing
 
 
 
112
  def parse_ast(code):
113
  try:
114
+ return javalang.parse.parse(code)
115
+ except:
 
 
 
 
116
  return None
117
 
118
  def build_ast_graph(ast_tree):
119
+ if not ast_tree: return None
 
120
 
121
  G = nx.DiGraph()
122
  node_id = 0
 
123
 
124
+ def traverse(node, parent=None):
125
  nonlocal node_id
126
+ current = node_id
127
+ G.add_node(current, type=type(node).__name__)
128
+ if parent is not None:
129
+ G.add_edge(parent, current)
 
 
 
 
130
  node_id += 1
131
 
132
+ for child in getattr(node, 'children', []):
133
  if isinstance(child, javalang.ast.Node):
134
+ traverse(child, current)
135
  elif isinstance(child, (list, tuple)):
136
  for item in child:
137
  if isinstance(item, javalang.ast.Node):
138
+ traverse(item, current)
139
 
140
  traverse(ast_tree)
141
+ return G
142
 
143
  def ast_to_pyg_data(ast_graph):
144
+ if not ast_graph: return None
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
+ node_types = list(nx.get_node_attributes(ast_graph, 'type').values())
147
+ unique_types = list(set(node_types))
148
+ type_to_idx = {t: i for i, t in enumerate(unique_types)}
 
 
149
 
150
+ x = torch.zeros(len(node_types), len(unique_types))
151
+ for i, t in enumerate(node_types):
152
+ x[i, type_to_idx[t]] = 1
153
+
154
+ edge_index = torch.tensor(list(ast_graph.edges())).t().contiguous()
155
 
156
+ return Data(x=x.to(DEVICE), edge_index=edge_index.to(DEVICE))
157
 
158
+ # Feature Extraction
159
  def normalize_code(code):
160
+ code = re.sub(r'//.*?$', '', code, flags=re.MULTILINE)
161
+ code = re.sub(r'/\*.*?\*/', '', code, flags=re.DOTALL)
162
+ return re.sub(r'\s+', ' ', code).strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
 
164
+ def get_embedding(code, tokenizer, model):
 
165
  try:
 
166
  inputs = tokenizer(
167
+ normalize_code(code),
168
  return_tensors="pt",
169
  truncation=True,
170
  max_length=MAX_LENGTH,
 
172
  ).to(DEVICE)
173
 
174
  with torch.no_grad():
175
+ return model(**inputs).last_hidden_state.mean(dim=1)
176
+ except:
 
 
 
 
 
 
 
 
 
177
  return None
 
 
 
 
 
 
178
 
179
+ # Similarity Calculations
180
+ def calculate_similarities(code1, code2, models):
181
+ tokenizer, code_model, rnn_model, gnn_model = models
 
182
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  # Get embeddings
184
+ emb1 = get_embedding(code1, tokenizer, code_model)
185
+ emb2 = get_embedding(code2, tokenizer, code_model)
186
 
187
  # Parse ASTs
188
+ ast1 = build_ast_graph(parse_ast(code1))
189
+ ast2 = build_ast_graph(parse_ast(code2))
190
 
191
+ # Calculate similarities
 
 
 
192
  codebert_sim = F.cosine_similarity(emb1, emb2).item() if emb1 is not None and emb2 is not None else 0
 
 
 
193
 
194
+ rnn_sim = 0
195
+ if emb1 is not None and emb2 is not None:
196
+ with torch.no_grad():
197
+ rnn_input = torch.stack([emb1.squeeze(), emb2.squeeze()])
198
+ rnn_sim = torch.sigmoid(rnn_model(rnn_input.unsqueeze(0))).item()
 
199
 
200
+ gnn_sim = 0
201
+ if ast1 and ast2:
202
+ data1 = ast_to_pyg_data(ast1)
203
+ data2 = ast_to_pyg_data(ast2)
204
+ if data1 and data2:
205
+ with torch.no_grad():
206
+ gnn_sim = F.cosine_similarity(
207
+ gnn_model(data1).unsqueeze(0),
208
+ gnn_model(data2).unsqueeze(0)
209
+ ).item()
210
 
211
  return {
 
212
  'codebert': codebert_sim,
213
  'rnn': rnn_sim,
214
+ 'gnn': gnn_sim,
215
+ 'combined': 0.4*codebert_sim + 0.3*rnn_sim + 0.3*gnn_sim
216
  }
217
 
218
+ # UI Components
219
+ def main():
220
+ st.title("🔍 Advanced Java Code Clone Detector")
221
+ st.markdown("Detect all clone types (1-4) using hybrid analysis")
222
 
223
+ # Load resources
224
+ models = load_models()
225
+ dataset_pairs = load_dataset()
226
+
227
+ # Code input
228
+ selected_pair = None
229
+ if dataset_pairs:
230
+ pair_options = {f"{i+1}: {pair['type']}": pair for i, pair in enumerate(dataset_pairs)}
231
+ selected_option = st.selectbox("Select example pair:", list(pair_options.keys()))
232
+ selected_pair = pair_options[selected_option]
233
+
234
+ col1, col2 = st.columns(2)
235
+ with col1:
236
+ code1 = st.text_area("Code 1", height=300, value=selected_pair["code1"] if selected_pair else "")
237
+ with col2:
238
+ code2 = st.text_area("Code 2", height=300, value=selected_pair["code2"] if selected_pair else "")
239
+
240
+ # Thresholds
241
+ st.subheader("Detection Thresholds")
242
+ cols = st.columns(3)
243
+ with cols[0]:
244
+ t1 = st.slider("Type 1/2", 0.85, 1.0, 0.95)
245
+ with cols[1]:
246
+ t3 = st.slider("Type 3", 0.7, 0.9, 0.8)
247
+ with cols[2]:
248
+ t4 = st.slider("Type 4", 0.5, 0.8, 0.65)
249
+
250
+ # Analysis
251
+ if st.button("Analyze", type="primary") and models[0]:
252
+ with st.spinner("Analyzing..."):
253
+ sims = calculate_similarities(code1, code2, models)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
 
255
  # Determine clone type
 
256
  clone_type = "No Clone"
257
+ if sims['combined'] >= t1:
258
+ clone_type = "Type 1/2 Clone"
259
+ elif sims['combined'] >= t3:
260
+ clone_type = "Type 3 Clone"
261
+ elif sims['combined'] >= t4:
262
+ clone_type = "Type 4 Clone"
263
 
264
+ # Display results
265
+ st.subheader("Results")
266
+ cols = st.columns(4)
267
+ cols[0].metric("Combined", f"{sims['combined']:.2f}")
268
+ cols[1].metric("CodeBERT", f"{sims['codebert']:.2f}")
269
+ cols[2].metric("RNN", f"{sims['rnn']:.2f}")
270
+ cols[3].metric("GNN", f"{sims['gnn']:.2f}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
 
272
+ st.progress(sims['combined'])
273
+ st.metric("Detection Result", clone_type)
 
 
 
 
 
274
 
275
+ # Show details
276
+ with st.expander("Details"):
277
+ st.json(sims)
278
+ st.code(f"Normalized Code 1:\n{normalize_code(code1)}")
279
+ st.code(f"Normalized Code 2:\n{normalize_code(code2)}")
 
 
 
 
 
 
 
 
 
 
 
 
280
 
281
+ if __name__ == "__main__":
282
+ main()