Aranwer commited on
Commit
daf072c
Β·
verified Β·
1 Parent(s): 84d4d13

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -139
app.py CHANGED
@@ -7,12 +7,9 @@ import torch.nn.functional as F
7
  import re
8
  import numpy as np
9
  import networkx as nx
10
- from transformers import AutoTokenizer, AutoModel, AutoConfig
11
- from torch_geometric.data import Data
12
- from torch_geometric.nn import GCNConv
13
  import warnings
14
  import pandas as pd
15
- import zipfile
16
  from collections import defaultdict
17
 
18
  # Configuration
@@ -23,137 +20,66 @@ warnings.filterwarnings("ignore")
23
  MODEL_NAME = "microsoft/codebert-base"
24
  MAX_LENGTH = 512
25
  DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
26
- DATASET_PATH = "ijadataset2-1.zip"
27
- CACHE_DIR = "./model_cache"
28
 
29
  # Set up page config
30
  st.set_page_config(
31
- page_title="Advanced Java Code Clone Detector",
32
  page_icon="πŸ”",
33
  layout="wide"
34
  )
35
 
36
- # Model Definitions
37
- class RNNModel(nn.Module):
38
- def __init__(self, input_size, hidden_size, num_layers):
39
  super().__init__()
40
- self.hidden_size = hidden_size
41
- self.num_layers = num_layers
42
- self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)
43
  self.fc = nn.Linear(hidden_size, 1)
44
 
45
  def forward(self, x):
46
- h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(DEVICE)
47
- out, _ = self.rnn(x, h0)
48
- return self.fc(out[:, -1, :])
49
 
50
- class GNNModel(nn.Module):
51
- def __init__(self, node_features):
52
- super().__init__()
53
- self.conv1 = GCNConv(node_features, 128)
54
- self.conv2 = GCNConv(128, 64)
55
- self.fc = nn.Linear(64, 1)
56
-
57
- def forward(self, data):
58
- x, edge_index = data.x, data.edge_index
59
- x = F.relu(self.conv1(x, edge_index))
60
- x = F.dropout(x, training=self.training)
61
- x = self.conv2(x, edge_index)
62
- return torch.sigmoid(self.fc(x).mean())
63
-
64
- # Model Loading with Cache
65
  @st.cache_resource(show_spinner=False)
66
  def load_models():
67
  try:
68
  with st.spinner('Loading models (first run may take a few minutes)...'):
69
- config = AutoConfig.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
70
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
71
- model = AutoModel.from_pretrained(MODEL_NAME, config=config, cache_dir=CACHE_DIR).to(DEVICE)
72
 
73
- rnn_model = RNNModel(input_size=768, hidden_size=256, num_layers=2).to(DEVICE)
74
- gnn_model = GNNModel(node_features=128).to(DEVICE)
75
 
76
- return tokenizer, model, rnn_model, gnn_model
77
  except Exception as e:
78
  st.error(f"Model loading failed: {str(e)}")
79
- return None, None, None, None
80
-
81
- # Dataset Loading
82
- @st.cache_resource
83
- def load_dataset():
84
- try:
85
- if not os.path.exists("Diverse_100K_Dataset"):
86
- with zipfile.ZipFile(DATASET_PATH, 'r') as zip_ref:
87
- zip_ref.extractall(".")
88
-
89
- clone_pairs = []
90
- base_path = "Diverse_100K_Dataset/Subject_CloneTypes_Directories"
91
-
92
- for clone_type in ["Clone_Type1", "Clone_Type2", "Clone_Type3 - ST", "Clone_Type4"]:
93
- type_path = os.path.join(base_path, clone_type)
94
- if os.path.exists(type_path):
95
- for root, _, files in os.walk(type_path):
96
- if files and len(files) >= 2:
97
- with open(os.path.join(root, files[0]), 'r', encoding='utf-8') as f1, \
98
- open(os.path.join(root, files[1]), 'r', encoding='utf-8') as f2:
99
- clone_pairs.append({
100
- "type": clone_type,
101
- "code1": f1.read(),
102
- "code2": f2.read()
103
- })
104
- break
105
-
106
- return clone_pairs[:10]
107
- except Exception as e:
108
- st.error(f"Dataset error: {str(e)}")
109
- return []
110
 
111
- # AST Processing
112
  def parse_ast(code):
113
  try:
114
  return javalang.parse.parse(code)
115
  except:
116
  return None
117
 
118
- def build_ast_graph(ast_tree):
119
- if not ast_tree: return None
120
 
121
- G = nx.DiGraph()
122
- node_id = 0
123
 
124
- def traverse(node, parent=None):
125
- nonlocal node_id
126
- current = node_id
127
- G.add_node(current, type=type(node).__name__)
128
- if parent is not None:
129
- G.add_edge(parent, current)
130
- node_id += 1
131
-
132
  for child in getattr(node, 'children', []):
133
  if isinstance(child, javalang.ast.Node):
134
- traverse(child, current)
135
  elif isinstance(child, (list, tuple)):
136
  for item in child:
137
  if isinstance(item, javalang.ast.Node):
138
- traverse(item, current)
139
 
140
  traverse(ast_tree)
141
- return G
142
-
143
- def ast_to_pyg_data(ast_graph):
144
- if not ast_graph: return None
145
-
146
- node_types = list(nx.get_node_attributes(ast_graph, 'type').values())
147
- unique_types = list(set(node_types))
148
- type_to_idx = {t: i for i, t in enumerate(unique_types)}
149
-
150
- x = torch.zeros(len(node_types), len(unique_types))
151
- for i, t in enumerate(node_types):
152
- x[i, type_to_idx[t]] = 1
153
-
154
- edge_index = torch.tensor(list(ast_graph.edges())).t().contiguous()
155
-
156
- return Data(x=x.to(DEVICE), edge_index=edge_index.to(DEVICE))
157
 
158
  # Feature Extraction
159
  def normalize_code(code):
@@ -176,66 +102,88 @@ def get_embedding(code, tokenizer, model):
176
  except:
177
  return None
178
 
179
- # Similarity Calculations
180
  def calculate_similarities(code1, code2, models):
181
- tokenizer, code_model, rnn_model, gnn_model = models
182
 
183
  # Get embeddings
184
  emb1 = get_embedding(code1, tokenizer, code_model)
185
  emb2 = get_embedding(code2, tokenizer, code_model)
186
 
187
- # Parse ASTs
188
- ast1 = build_ast_graph(parse_ast(code1))
189
- ast2 = build_ast_graph(parse_ast(code2))
 
 
190
 
191
  # Calculate similarities
192
- codebert_sim = F.cosine_similarity(emb1, emb2).item() if emb1 is not None and emb2 is not None else 0
 
 
193
 
194
  rnn_sim = 0
195
  if emb1 is not None and emb2 is not None:
196
  with torch.no_grad():
197
- rnn_input = torch.stack([emb1.squeeze(), emb2.squeeze()])
198
- rnn_sim = torch.sigmoid(rnn_model(rnn_input.unsqueeze(0))).item()
199
 
200
- gnn_sim = 0
201
- if ast1 and ast2:
202
- data1 = ast_to_pyg_data(ast1)
203
- data2 = ast_to_pyg_data(ast2)
204
- if data1 and data2:
205
- with torch.no_grad():
206
- gnn_sim = F.cosine_similarity(
207
- gnn_model(data1).unsqueeze(0),
208
- gnn_model(data2).unsqueeze(0)
209
- ).item()
210
 
211
  return {
212
  'codebert': codebert_sim,
213
  'rnn': rnn_sim,
214
- 'gnn': gnn_sim,
215
- 'combined': 0.4*codebert_sim + 0.3*rnn_sim + 0.3*gnn_sim
216
  }
217
 
218
- # UI Components
219
  def main():
220
- st.title("πŸ” Advanced Java Code Clone Detector")
221
- st.markdown("Detect all clone types (1-4) using hybrid analysis")
222
 
223
- # Load resources
224
  models = load_models()
225
- dataset_pairs = load_dataset()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
 
227
  # Code input
228
- selected_pair = None
229
- if dataset_pairs:
230
- pair_options = {f"{i+1}: {pair['type']}": pair for i, pair in enumerate(dataset_pairs)}
231
- selected_option = st.selectbox("Select example pair:", list(pair_options.keys()))
232
- selected_pair = pair_options[selected_option]
233
 
234
  col1, col2 = st.columns(2)
235
  with col1:
236
- code1 = st.text_area("Code 1", height=300, value=selected_pair["code1"] if selected_pair else "")
 
 
 
 
237
  with col2:
238
- code2 = st.text_area("Code 2", height=300, value=selected_pair["code2"] if selected_pair else "")
 
 
 
 
239
 
240
  # Thresholds
241
  st.subheader("Detection Thresholds")
@@ -247,33 +195,38 @@ def main():
247
  with cols[2]:
248
  t4 = st.slider("Type 4", 0.5, 0.8, 0.65)
249
 
250
- # Analysis
251
- if st.button("Analyze", type="primary") and models[0]:
252
- with st.spinner("Analyzing..."):
253
  sims = calculate_similarities(code1, code2, models)
254
 
255
  # Determine clone type
256
  clone_type = "No Clone"
257
  if sims['combined'] >= t1:
258
- clone_type = "Type 1/2 Clone"
259
  elif sims['combined'] >= t3:
260
- clone_type = "Type 3 Clone"
261
  elif sims['combined'] >= t4:
262
- clone_type = "Type 4 Clone"
263
 
264
  # Display results
265
  st.subheader("Results")
 
 
266
  cols = st.columns(4)
267
  cols[0].metric("Combined", f"{sims['combined']:.2f}")
268
  cols[1].metric("CodeBERT", f"{sims['codebert']:.2f}")
269
  cols[2].metric("RNN", f"{sims['rnn']:.2f}")
270
- cols[3].metric("GNN", f"{sims['gnn']:.2f}")
271
 
 
272
  st.progress(sims['combined'])
 
 
273
  st.metric("Detection Result", clone_type)
274
 
275
  # Show details
276
- with st.expander("Details"):
277
  st.json(sims)
278
  st.code(f"Normalized Code 1:\n{normalize_code(code1)}")
279
  st.code(f"Normalized Code 2:\n{normalize_code(code2)}")
 
7
  import re
8
  import numpy as np
9
  import networkx as nx
10
+ from transformers import AutoTokenizer, AutoModel
 
 
11
  import warnings
12
  import pandas as pd
 
13
  from collections import defaultdict
14
 
15
  # Configuration
 
20
  MODEL_NAME = "microsoft/codebert-base"
21
  MAX_LENGTH = 512
22
  DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
 
23
 
24
  # Set up page config
25
  st.set_page_config(
26
+ page_title="Java Code Clone Detector",
27
  page_icon="πŸ”",
28
  layout="wide"
29
  )
30
 
31
+ # Simplified RNN Model (for Hugging Face compatibility)
32
+ class SimpleRNN(nn.Module):
33
+ def __init__(self, input_size=768, hidden_size=128):
34
  super().__init__()
35
+ self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)
 
 
36
  self.fc = nn.Linear(hidden_size, 1)
37
 
38
  def forward(self, x):
39
+ out, _ = self.rnn(x)
40
+ return torch.sigmoid(self.fc(out[:, -1]))
 
41
 
42
+ # Model Loading with caching
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  @st.cache_resource(show_spinner=False)
44
  def load_models():
45
  try:
46
  with st.spinner('Loading models (first run may take a few minutes)...'):
47
+ # Load CodeBERT
48
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
49
+ code_model = AutoModel.from_pretrained(MODEL_NAME).to(DEVICE)
50
 
51
+ # Initialize simple RNN
52
+ rnn_model = SimpleRNN().to(DEVICE)
53
 
54
+ return tokenizer, code_model, rnn_model
55
  except Exception as e:
56
  st.error(f"Model loading failed: {str(e)}")
57
+ return None, None, None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
+ # AST Processing (simplified for Hugging Face)
60
  def parse_ast(code):
61
  try:
62
  return javalang.parse.parse(code)
63
  except:
64
  return None
65
 
66
+ def build_simple_ast_features(ast_tree):
67
+ if not ast_tree: return {}
68
 
69
+ features = defaultdict(int)
 
70
 
71
+ def traverse(node):
72
+ features[type(node).__name__] += 1
 
 
 
 
 
 
73
  for child in getattr(node, 'children', []):
74
  if isinstance(child, javalang.ast.Node):
75
+ traverse(child)
76
  elif isinstance(child, (list, tuple)):
77
  for item in child:
78
  if isinstance(item, javalang.ast.Node):
79
+ traverse(item)
80
 
81
  traverse(ast_tree)
82
+ return dict(features)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
  # Feature Extraction
85
  def normalize_code(code):
 
102
  except:
103
  return None
104
 
105
+ # Similarity Calculations (optimized for Hugging Face)
106
  def calculate_similarities(code1, code2, models):
107
+ tokenizer, code_model, rnn_model = models
108
 
109
  # Get embeddings
110
  emb1 = get_embedding(code1, tokenizer, code_model)
111
  emb2 = get_embedding(code2, tokenizer, code_model)
112
 
113
+ # Get AST features
114
+ ast1 = parse_ast(code1)
115
+ ast2 = parse_ast(code2)
116
+ ast_features1 = build_simple_ast_features(ast1)
117
+ ast_features2 = build_simple_ast_features(ast2)
118
 
119
  # Calculate similarities
120
+ codebert_sim = 0
121
+ if emb1 is not None and emb2 is not None:
122
+ codebert_sim = F.cosine_similarity(emb1, emb2).item()
123
 
124
  rnn_sim = 0
125
  if emb1 is not None and emb2 is not None:
126
  with torch.no_grad():
127
+ rnn_input = torch.cat([emb1, emb2]).unsqueeze(0)
128
+ rnn_sim = rnn_model(rnn_input).item()
129
 
130
+ # Simple AST similarity (count matching node types)
131
+ ast_sim = 0
132
+ if ast_features1 and ast_features2:
133
+ common_keys = set(ast_features1.keys()) & set(ast_features2.keys())
134
+ total_keys = set(ast_features1.keys()) | set(ast_features2.keys())
135
+ ast_sim = len(common_keys) / len(total_keys) if total_keys else 0
 
 
 
 
136
 
137
  return {
138
  'codebert': codebert_sim,
139
  'rnn': rnn_sim,
140
+ 'ast': ast_sim,
141
+ 'combined': 0.5*codebert_sim + 0.3*rnn_sim + 0.2*ast_sim
142
  }
143
 
144
+ # Main UI
145
  def main():
146
+ st.title("πŸ” Java Code Clone Detector (IJaDataset 2.1)")
147
+ st.markdown("Detect Type 1-4 clones using hybrid analysis")
148
 
149
+ # Load models
150
  models = load_models()
151
+ if None in models:
152
+ st.error("Failed to load required models. Please check the logs.")
153
+ return
154
+
155
+ # Example code pairs
156
+ example_pairs = {
157
+ "Type 1 Example": {
158
+ "code1": "public class Test { public static void main(String[] args) { System.out.println(\"Hello\"); } }",
159
+ "code2": "public class Test { public static void main(String[] args) { System.out.println(\"Hello\"); } }"
160
+ },
161
+ "Type 2 Example": {
162
+ "code1": "public class Test { public static void main(String[] args) { System.out.println(\"Hello\"); } }",
163
+ "code2": "public class Example { public static void main(String[] args) { System.out.println(\"Hello\"); } }"
164
+ },
165
+ "Type 3 Example": {
166
+ "code1": "public class Test { public static void main(String[] args) { for(int i=0;i<10;i++) System.out.println(i); } }",
167
+ "code2": "public class Example { public static void run(String[] params) { for(int j=0;j<10;j++) System.out.println(j); } }"
168
+ }
169
+ }
170
 
171
  # Code input
172
+ selected_example = st.selectbox("Select example pair:", list(example_pairs.keys()))
 
 
 
 
173
 
174
  col1, col2 = st.columns(2)
175
  with col1:
176
+ code1 = st.text_area(
177
+ "Code 1",
178
+ height=300,
179
+ value=example_pairs[selected_example]["code1"]
180
+ )
181
  with col2:
182
+ code2 = st.text_area(
183
+ "Code 2",
184
+ height=300,
185
+ value=example_pairs[selected_example]["code2"]
186
+ )
187
 
188
  # Thresholds
189
  st.subheader("Detection Thresholds")
 
195
  with cols[2]:
196
  t4 = st.slider("Type 4", 0.5, 0.8, 0.65)
197
 
198
+ # Analysis button
199
+ if st.button("Analyze Code", type="primary"):
200
+ with st.spinner("Analyzing code..."):
201
  sims = calculate_similarities(code1, code2, models)
202
 
203
  # Determine clone type
204
  clone_type = "No Clone"
205
  if sims['combined'] >= t1:
206
+ clone_type = "Type 1/2 Clone (Exact/Near-Exact)"
207
  elif sims['combined'] >= t3:
208
+ clone_type = "Type 3 Clone (Near-Miss)"
209
  elif sims['combined'] >= t4:
210
+ clone_type = "Type 4 Clone (Semantic)"
211
 
212
  # Display results
213
  st.subheader("Results")
214
+
215
+ # Metrics
216
  cols = st.columns(4)
217
  cols[0].metric("Combined", f"{sims['combined']:.2f}")
218
  cols[1].metric("CodeBERT", f"{sims['codebert']:.2f}")
219
  cols[2].metric("RNN", f"{sims['rnn']:.2f}")
220
+ cols[3].metric("AST", f"{sims['ast']:.2f}")
221
 
222
+ # Progress bar
223
  st.progress(sims['combined'])
224
+
225
+ # Final result
226
  st.metric("Detection Result", clone_type)
227
 
228
  # Show details
229
+ with st.expander("Advanced Details"):
230
  st.json(sims)
231
  st.code(f"Normalized Code 1:\n{normalize_code(code1)}")
232
  st.code(f"Normalized Code 2:\n{normalize_code(code2)}")