Aranwer commited on
Commit
c031615
·
verified ·
1 Parent(s): daf072c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +234 -214
app.py CHANGED
@@ -1,235 +1,255 @@
1
  import os
2
- import streamlit as st
 
 
 
3
  import javalang
4
  import torch
5
  import torch.nn as nn
6
  import torch.nn.functional as F
7
- import re
8
- import numpy as np
 
 
 
9
  import networkx as nx
10
- from transformers import AutoTokenizer, AutoModel
11
- import warnings
12
- import pandas as pd
13
- from collections import defaultdict
14
-
15
- # Configuration
16
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
17
- warnings.filterwarnings("ignore")
18
-
19
- # Constants
20
- MODEL_NAME = "microsoft/codebert-base"
21
- MAX_LENGTH = 512
22
- DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
23
-
24
- # Set up page config
25
- st.set_page_config(
26
- page_title="Java Code Clone Detector",
27
- page_icon="🔍",
28
- layout="wide"
29
- )
30
-
31
- # Simplified RNN Model (for Hugging Face compatibility)
32
- class SimpleRNN(nn.Module):
33
- def __init__(self, input_size=768, hidden_size=128):
34
- super().__init__()
35
- self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)
36
- self.fc = nn.Linear(hidden_size, 1)
37
-
38
- def forward(self, x):
39
- out, _ = self.rnn(x)
40
- return torch.sigmoid(self.fc(out[:, -1]))
41
-
42
- # Model Loading with caching
43
- @st.cache_resource(show_spinner=False)
44
- def load_models():
45
- try:
46
- with st.spinner('Loading models (first run may take a few minutes)...'):
47
- # Load CodeBERT
48
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
49
- code_model = AutoModel.from_pretrained(MODEL_NAME).to(DEVICE)
50
-
51
- # Initialize simple RNN
52
- rnn_model = SimpleRNN().to(DEVICE)
53
-
54
- return tokenizer, code_model, rnn_model
55
- except Exception as e:
56
- st.error(f"Model loading failed: {str(e)}")
57
- return None, None, None
58
-
59
- # AST Processing (simplified for Hugging Face)
60
- def parse_ast(code):
61
  try:
62
- return javalang.parse.parse(code)
63
- except:
 
 
 
64
  return None
65
 
66
- def build_simple_ast_features(ast_tree):
67
- if not ast_tree: return {}
68
-
69
- features = defaultdict(int)
70
 
71
- def traverse(node):
72
- features[type(node).__name__] += 1
 
 
 
73
  for child in getattr(node, 'children', []):
74
- if isinstance(child, javalang.ast.Node):
75
- traverse(child)
76
- elif isinstance(child, (list, tuple)):
77
  for item in child:
78
  if isinstance(item, javalang.ast.Node):
79
- traverse(item)
 
 
80
 
81
- traverse(ast_tree)
82
- return dict(features)
83
-
84
- # Feature Extraction
85
- def normalize_code(code):
86
- code = re.sub(r'//.*?$', '', code, flags=re.MULTILINE)
87
- code = re.sub(r'/\*.*?\*/', '', code, flags=re.DOTALL)
88
- return re.sub(r'\s+', ' ', code).strip()
89
 
90
- def get_embedding(code, tokenizer, model):
91
  try:
92
- inputs = tokenizer(
93
- normalize_code(code),
94
- return_tensors="pt",
95
- truncation=True,
96
- max_length=MAX_LENGTH,
97
- padding='max_length'
98
- ).to(DEVICE)
99
-
100
- with torch.no_grad():
101
- return model(**inputs).last_hidden_state.mean(dim=1)
102
  except:
103
- return None
104
 
105
- # Similarity Calculations (optimized for Hugging Face)
106
- def calculate_similarities(code1, code2, models):
107
- tokenizer, code_model, rnn_model = models
108
-
109
- # Get embeddings
110
- emb1 = get_embedding(code1, tokenizer, code_model)
111
- emb2 = get_embedding(code2, tokenizer, code_model)
112
-
113
- # Get AST features
114
- ast1 = parse_ast(code1)
115
- ast2 = parse_ast(code2)
116
- ast_features1 = build_simple_ast_features(ast1)
117
- ast_features2 = build_simple_ast_features(ast2)
118
-
119
- # Calculate similarities
120
- codebert_sim = 0
121
- if emb1 is not None and emb2 is not None:
122
- codebert_sim = F.cosine_similarity(emb1, emb2).item()
123
-
124
- rnn_sim = 0
125
- if emb1 is not None and emb2 is not None:
126
- with torch.no_grad():
127
- rnn_input = torch.cat([emb1, emb2]).unsqueeze(0)
128
- rnn_sim = rnn_model(rnn_input).item()
129
-
130
- # Simple AST similarity (count matching node types)
131
- ast_sim = 0
132
- if ast_features1 and ast_features2:
133
- common_keys = set(ast_features1.keys()) & set(ast_features2.keys())
134
- total_keys = set(ast_features1.keys()) | set(ast_features2.keys())
135
- ast_sim = len(common_keys) / len(total_keys) if total_keys else 0
136
-
137
- return {
138
- 'codebert': codebert_sim,
139
- 'rnn': rnn_sim,
140
- 'ast': ast_sim,
141
- 'combined': 0.5*codebert_sim + 0.3*rnn_sim + 0.2*ast_sim
142
- }
143
-
144
- # Main UI
145
- def main():
146
- st.title("🔍 Java Code Clone Detector (IJaDataset 2.1)")
147
- st.markdown("Detect Type 1-4 clones using hybrid analysis")
148
-
149
- # Load models
150
- models = load_models()
151
- if None in models:
152
- st.error("Failed to load required models. Please check the logs.")
153
- return
154
-
155
- # Example code pairs
156
- example_pairs = {
157
- "Type 1 Example": {
158
- "code1": "public class Test { public static void main(String[] args) { System.out.println(\"Hello\"); } }",
159
- "code2": "public class Test { public static void main(String[] args) { System.out.println(\"Hello\"); } }"
160
- },
161
- "Type 2 Example": {
162
- "code1": "public class Test { public static void main(String[] args) { System.out.println(\"Hello\"); } }",
163
- "code2": "public class Example { public static void main(String[] args) { System.out.println(\"Hello\"); } }"
164
- },
165
- "Type 3 Example": {
166
- "code1": "public class Test { public static void main(String[] args) { for(int i=0;i<10;i++) System.out.println(i); } }",
167
- "code2": "public class Example { public static void run(String[] params) { for(int j=0;j<10;j++) System.out.println(j); } }"
168
  }
169
- }
170
-
171
- # Code input
172
- selected_example = st.selectbox("Select example pair:", list(example_pairs.keys()))
173
-
174
- col1, col2 = st.columns(2)
175
- with col1:
176
- code1 = st.text_area(
177
- "Code 1",
178
- height=300,
179
- value=example_pairs[selected_example]["code1"]
180
- )
181
- with col2:
182
- code2 = st.text_area(
183
- "Code 2",
184
- height=300,
185
- value=example_pairs[selected_example]["code2"]
186
- )
187
-
188
- # Thresholds
189
- st.subheader("Detection Thresholds")
190
- cols = st.columns(3)
191
- with cols[0]:
192
- t1 = st.slider("Type 1/2", 0.85, 1.0, 0.95)
193
- with cols[1]:
194
- t3 = st.slider("Type 3", 0.7, 0.9, 0.8)
195
- with cols[2]:
196
- t4 = st.slider("Type 4", 0.5, 0.8, 0.65)
197
-
198
- # Analysis button
199
- if st.button("Analyze Code", type="primary"):
200
- with st.spinner("Analyzing code..."):
201
- sims = calculate_similarities(code1, code2, models)
202
-
203
- # Determine clone type
204
- clone_type = "No Clone"
205
- if sims['combined'] >= t1:
206
- clone_type = "Type 1/2 Clone (Exact/Near-Exact)"
207
- elif sims['combined'] >= t3:
208
- clone_type = "Type 3 Clone (Near-Miss)"
209
- elif sims['combined'] >= t4:
210
- clone_type = "Type 4 Clone (Semantic)"
211
-
212
- # Display results
213
- st.subheader("Results")
214
-
215
- # Metrics
216
- cols = st.columns(4)
217
- cols[0].metric("Combined", f"{sims['combined']:.2f}")
218
- cols[1].metric("CodeBERT", f"{sims['codebert']:.2f}")
219
- cols[2].metric("RNN", f"{sims['rnn']:.2f}")
220
- cols[3].metric("AST", f"{sims['ast']:.2f}")
221
-
222
- # Progress bar
223
- st.progress(sims['combined'])
224
-
225
- # Final result
226
- st.metric("Detection Result", clone_type)
227
-
228
- # Show details
229
- with st.expander("Advanced Details"):
230
- st.json(sims)
231
- st.code(f"Normalized Code 1:\n{normalize_code(code1)}")
232
- st.code(f"Normalized Code 2:\n{normalize_code(code2)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
 
234
  if __name__ == "__main__":
235
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import re
3
+ import time
4
+ import random
5
+ import zipfile
6
  import javalang
7
  import torch
8
  import torch.nn as nn
9
  import torch.nn.functional as F
10
+ import torch_geometric
11
+ from torch_geometric.data import Data, Dataset, DataLoader
12
+ from sklearn.model_selection import train_test_split
13
+ from sklearn.metrics import precision_recall_fscore_support
14
+ from tqdm import tqdm
15
  import networkx as nx
16
+
17
+ # ---- Utility functions ----
18
+
19
+ def unzip_dataset(zip_path, extract_to):
20
+ with zipfile.ZipFile(zip_path, 'r') as zip_ref:
21
+ zip_ref.extractall(extract_to)
22
+
23
+ def normalize_java_code(code):
24
+ # Remove single-line comments
25
+ code = re.sub(r'//.*?\n', '', code)
26
+ # Remove multi-line comments
27
+ code = re.sub(r'/\*.*?\*/', '', code, flags=re.DOTALL)
28
+ # Remove extra spaces and blank lines
29
+ code = re.sub(r'\s+', ' ', code)
30
+ return code.strip()
31
+
32
+ def safe_parse_java(code):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  try:
34
+ tokens = list(javalang.tokenizer.tokenize(code))
35
+ parser = javalang.parser.Parser(tokens)
36
+ tree = parser.parse()
37
+ return tree
38
+ except Exception:
39
  return None
40
 
41
+ def ast_to_graph(ast):
42
+ graph = nx.DiGraph()
 
 
43
 
44
+ def dfs(node, parent_id=None):
45
+ node_id = len(graph)
46
+ graph.add_node(node_id, label=type(node).__name__)
47
+ if parent_id is not None:
48
+ graph.add_edge(parent_id, node_id)
49
  for child in getattr(node, 'children', []):
50
+ if isinstance(child, (list, tuple)):
 
 
51
  for item in child:
52
  if isinstance(item, javalang.ast.Node):
53
+ dfs(item, node_id)
54
+ elif isinstance(child, javalang.ast.Node):
55
+ dfs(child, node_id)
56
 
57
+ dfs(ast)
58
+ return graph
 
 
 
 
 
 
59
 
60
+ def tokenize_java_code(code):
61
  try:
62
+ tokens = list(javalang.tokenizer.tokenize(code))
63
+ token_list = [token.value for token in tokens]
64
+ return token_list
 
 
 
 
 
 
 
65
  except:
66
+ return []
67
 
68
+ # ---- Data Preprocessing ----
69
+
70
+ class CloneDataset(Dataset):
71
+ def __init__(self, root_dir, transform=None):
72
+ super().__init__()
73
+ self.data_list = []
74
+ self.labels = []
75
+ self.skipped_files = 0
76
+ self.max_tokens = 5000
77
+
78
+ clone_dirs = {
79
+ "Clone_Type1": 1,
80
+ "Clone_Type2": 1,
81
+ "Clone_Type3 - ST": 1,
82
+ "Clone_Type3 - VST": 1,
83
+ "Clone_Type3 - MT": 0 # Assuming MT = Not Clone
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  }
85
+
86
+ for clone_type, label in clone_dirs.items():
87
+ clone_path = os.path.join(root_dir, 'Subject_CloneTypes_Directories', clone_type)
88
+ for root, _, files in os.walk(clone_path):
89
+ for file in files:
90
+ if file.endswith(".java"):
91
+ file_path = os.path.join(root, file)
92
+ with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
93
+ code = f.read()
94
+ code = normalize_java_code(code)
95
+ if len(code.split()) > self.max_tokens:
96
+ self.skipped_files += 1
97
+ continue
98
+ ast = safe_parse_java(code)
99
+ if ast is None:
100
+ self.skipped_files += 1
101
+ continue
102
+ graph = ast_to_graph(ast)
103
+ tokens = tokenize_java_code(code)
104
+ if not tokens:
105
+ self.skipped_files += 1
106
+ continue
107
+ data = {
108
+ 'graph': graph,
109
+ 'tokens': tokens,
110
+ 'label': label
111
+ }
112
+ self.data_list.append(data)
113
+
114
+ def len(self):
115
+ return len(self.data_list)
116
+
117
+ def get(self, idx):
118
+ data_item = self.data_list[idx]
119
+ graph = data_item['graph']
120
+ tokens = data_item['tokens']
121
+ label = data_item['label']
122
+
123
+ # Graph processing
124
+ edge_index = torch.tensor(list(graph.edges)).t().contiguous()
125
+
126
+ node_features = torch.arange(graph.number_of_nodes()).unsqueeze(1).float()
127
+
128
+ # Token processing
129
+ token_indices = torch.tensor([hash(t) % 5000 for t in tokens], dtype=torch.long)
130
+
131
+ return edge_index, node_features, token_indices, torch.tensor(label, dtype=torch.long)
132
+
133
+ # ---- Models ----
134
+
135
+ class GNNEncoder(nn.Module):
136
+ def __init__(self, in_channels=1, hidden_dim=64):
137
+ super().__init__()
138
+ self.conv1 = torch_geometric.nn.GCNConv(in_channels, hidden_dim)
139
+ self.conv2 = torch_geometric.nn.GCNConv(hidden_dim, hidden_dim)
140
+
141
+ def forward(self, x, edge_index):
142
+ x = self.conv1(x, edge_index)
143
+ x = F.relu(x)
144
+ x = self.conv2(x, edge_index)
145
+ x = F.relu(x)
146
+ return torch.mean(x, dim=0) # Graph-level embedding
147
+
148
+ class RNNEncoder(nn.Module):
149
+ def __init__(self, vocab_size=5000, embedding_dim=64, hidden_dim=64):
150
+ super().__init__()
151
+ self.embedding = nn.Embedding(vocab_size, embedding_dim)
152
+ self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
153
+
154
+ def forward(self, tokens):
155
+ embeds = self.embedding(tokens)
156
+ _, (hidden, _) = self.lstm(embeds)
157
+ return hidden.squeeze(0)
158
+
159
+ class HybridClassifier(nn.Module):
160
+ def __init__(self):
161
+ super().__init__()
162
+ self.gnn = GNNEncoder()
163
+ self.rnn = RNNEncoder()
164
+ self.fc = nn.Linear(128, 2)
165
+
166
+ def forward(self, edge_index, node_features, tokens):
167
+ gnn_out = self.gnn(node_features, edge_index)
168
+ rnn_out = self.rnn(tokens)
169
+ combined = torch.cat([gnn_out, rnn_out], dim=-1)
170
+ out = self.fc(combined)
171
+ return out
172
+
173
+ # ---- Training and Evaluation ----
174
+
175
+ def train(model, optimizer, loader, device):
176
+ model.train()
177
+ total_loss = 0
178
+ for edge_index, node_features, tokens, labels in loader:
179
+ edge_index = edge_index.to(device)
180
+ node_features = node_features.to(device)
181
+ tokens = tokens.to(device)
182
+ labels = labels.to(device)
183
+
184
+ optimizer.zero_grad()
185
+ outputs = model(edge_index, node_features, tokens)
186
+ loss = F.cross_entropy(outputs.unsqueeze(0), labels.unsqueeze(0))
187
+ loss.backward()
188
+ optimizer.step()
189
+ total_loss += loss.item()
190
+ return total_loss / len(loader)
191
+
192
+ def evaluate(model, loader, device):
193
+ model.eval()
194
+ preds, labels_all = [], []
195
+ with torch.no_grad():
196
+ for edge_index, node_features, tokens, labels in loader:
197
+ edge_index = edge_index.to(device)
198
+ node_features = node_features.to(device)
199
+ tokens = tokens.to(device)
200
+ labels = labels.to(device)
201
+
202
+ outputs = model(edge_index, node_features, tokens)
203
+ pred = outputs.argmax(dim=-1)
204
+ preds.append(pred.cpu().numpy())
205
+ labels_all.append(labels.cpu().numpy())
206
+
207
+ preds = np.concatenate(preds)
208
+ labels_all = np.concatenate(labels_all)
209
+
210
+ precision, recall, f1, _ = precision_recall_fscore_support(labels_all, preds, average='binary')
211
+ return precision, recall, f1
212
+
213
+ # ---- Main Execution ----
214
 
215
  if __name__ == "__main__":
216
+ import numpy as np
217
+
218
+ dataset_root = '/content/dataset/archive (1)'
219
+ unzip_dataset('/content/dataset/archive (1).zip', dataset_root)
220
+
221
+ dataset = CloneDataset(dataset_root)
222
+ print(f"Total valid samples: {dataset.len()}")
223
+ print(f"Total skipped files: {dataset.skipped_files}")
224
+
225
+ indices = list(range(dataset.len()))
226
+ train_idx, temp_idx = train_test_split(indices, test_size=0.2, random_state=42)
227
+ val_idx, test_idx = train_test_split(temp_idx, test_size=0.5, random_state=42)
228
+
229
+ train_set = torch.utils.data.Subset(dataset, train_idx)
230
+ val_set = torch.utils.data.Subset(dataset, val_idx)
231
+ test_set = torch.utils.data.Subset(dataset, test_idx)
232
+
233
+ batch_size = 1 # small because of variable graph sizes
234
+ train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
235
+ val_loader = DataLoader(val_set, batch_size=batch_size)
236
+ test_loader = DataLoader(test_set, batch_size=batch_size)
237
+
238
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
239
+
240
+ model = HybridClassifier().to(device)
241
+ optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
242
+
243
+ epochs = 5
244
+
245
+ start_time = time.time()
246
+ for epoch in range(epochs):
247
+ train_loss = train(model, optimizer, train_loader, device)
248
+ precision, recall, f1 = evaluate(model, val_loader, device)
249
+ print(f"Epoch {epoch+1}: Loss={train_loss:.4f}, Precision={precision:.4f}, Recall={recall:.4f}, F1={f1:.4f}")
250
+
251
+ precision, recall, f1 = evaluate(model, test_loader, device)
252
+ total_time = time.time() - start_time
253
+
254
+ print(f"Test Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}")
255
+ print(f"Total execution time: {total_time:.2f} seconds")