rahideer commited on
Commit
75f22f0
·
verified ·
1 Parent(s): b264710

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -164
app.py CHANGED
@@ -2,7 +2,6 @@ import javalang
2
  import torch
3
  import torch.nn as nn
4
  import torch.nn.functional as F
5
- from torch_geometric.data import Data
6
  import re
7
  import gradio as gr
8
  from transformers import AutoTokenizer, AutoModel
@@ -10,186 +9,61 @@ from pathlib import Path
10
 
11
  # Configuration
12
  MAX_FILE_SIZE = 5000
13
- MAX_AST_DEPTH = 50
14
  EMBEDDING_DIM = 128
15
  DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
16
 
17
- # AST Encoder
18
- class ASTNodeEncoder:
19
- def __init__(self):
20
- self.node_types = set()
21
- self.type_to_idx = {}
22
-
23
- def fit(self, ast_nodes):
24
- self.node_types.update(ast_nodes)
25
- self.type_to_idx = {t: i for i, t in enumerate(sorted(self.node_types))}
26
-
27
- def encode(self, node_type):
28
- if node_type not in self.type_to_idx:
29
- return torch.zeros(EMBEDDING_DIM)
30
- idx = self.type_to_idx[node_type]
31
- embedding = torch.zeros(EMBEDDING_DIM)
32
- embedding[idx % EMBEDDING_DIM] = 1
33
- embedding += torch.randn(EMBEDDING_DIM) * 0.1
34
- return embedding
35
-
36
- # Code Normalization
37
- def normalize_java_code(code):
38
- code = re.sub(r'//.*', '', code)
39
- code = re.sub(r'/\*.*?\*/', '', code, flags=re.DOTALL)
40
- code = re.sub(r'"[^"]*"', '"<STRING>"', code)
41
- code = re.sub(r"'[^']*'", "'<CHAR>'", code)
42
- return ' '.join(code.split())
43
-
44
- # AST Processing
45
- def extract_ast_paths(node, encoder, current_path=None, paths=None, depth=0):
46
- if current_path is None:
47
- current_path = []
48
- if paths is None:
49
- paths = []
50
-
51
- if depth > MAX_AST_DEPTH:
52
- return paths
53
-
54
- node_type = str(type(node).__name__)
55
- node_embedding = encoder.encode(node_type)
56
- current_path.append(node_embedding)
57
-
58
- if not hasattr(node, 'children') or depth == MAX_AST_DEPTH:
59
- paths.append(torch.stack(current_path))
60
- current_path.pop()
61
- return paths
62
-
63
- for child in node.children:
64
- if isinstance(child, (javalang.ast.Node, list, tuple)):
65
- if isinstance(child, (list, tuple)):
66
- for c in child:
67
- if isinstance(c, javalang.ast.Node):
68
- extract_ast_paths(c, encoder, current_path, paths, depth+1)
69
- elif isinstance(child, javalang.ast.Node):
70
- extract_ast_paths(child, encoder, current_path, paths, depth+1)
71
-
72
- current_path.pop()
73
- return paths
74
-
75
- def ast_to_graph_data(ast, encoder):
76
- paths = extract_ast_paths(ast, encoder)
77
- if not paths:
78
- return None
79
-
80
- edge_index = []
81
- node_features = []
82
- node_counter = 0
83
- node_mapping = {}
84
-
85
- for path in paths:
86
- for i in range(len(path) - 1):
87
- for j in [i, i+1]:
88
- node_key = tuple(path[j].tolist())
89
- if node_key not in node_mapping:
90
- node_mapping[node_key] = node_counter
91
- node_features.append(path[j])
92
- node_counter += 1
93
-
94
- src = node_mapping[tuple(path[i].tolist())]
95
- dst = node_mapping[tuple(path[i+1].tolist())]
96
- edge_index.append([src, dst])
97
-
98
- if not edge_index:
99
- return None
100
-
101
- node_features = torch.stack(node_features)
102
- edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
103
- return Data(x=node_features, edge_index=edge_index)
104
-
105
- # Model Architecture
106
- class ASTGNN(nn.Module):
107
- def __init__(self, input_dim, hidden_dim):
108
- super().__init__()
109
- self.conv1 = nn.Sequential(
110
- nn.Linear(input_dim, hidden_dim),
111
- nn.ReLU(),
112
- nn.Linear(hidden_dim, hidden_dim)
113
- )
114
- self.conv2 = nn.Sequential(
115
- nn.Linear(hidden_dim, hidden_dim),
116
- nn.ReLU(),
117
- nn.Linear(hidden_dim, hidden_dim)
118
- )
119
- self.pool = nn.AdaptiveMaxPool1d(1)
120
-
121
- def forward(self, data):
122
- x, edge_index = data.x.to(DEVICE), data.edge_index.to(DEVICE)
123
- x = self.conv1(x)
124
- x = self.conv2(x)
125
- x = x.t().unsqueeze(0)
126
- x = self.pool(x)
127
- return x.squeeze(0).squeeze(-1)
128
 
129
- class HybridCloneDetector(nn.Module):
130
- def __init__(self, ast_input_dim, hidden_dim):
 
131
  super().__init__()
132
- self.ast_gnn = ASTGNN(ast_input_dim, hidden_dim)
133
  self.classifier = nn.Sequential(
134
  nn.Linear(hidden_dim * 2, hidden_dim),
135
  nn.ReLU(),
136
  nn.Linear(hidden_dim, 2))
137
 
138
- def forward(self, ast_data, code_embedding):
139
- ast_embed = self.ast_gnn(ast_data)
140
- combined = torch.cat([ast_embed, code_embedding.squeeze(0)], dim=0)
141
- return self.classifier(combined.unsqueeze(0))
142
-
143
- # Load Models
144
- def load_models():
145
- ast_encoder = ASTNodeEncoder()
146
- ast_encoder.fit(['MethodDeclaration', 'VariableDeclaration', 'IfStatement'])
147
-
148
- tokenizer = AutoTokenizer.from_pretrained("microsoft/codebert-base")
149
- code_model = AutoModel.from_pretrained("microsoft/codebert-base").to(DEVICE)
150
-
151
- model = HybridCloneDetector(EMBEDDING_DIM, EMBEDDING_DIM).to(DEVICE)
152
- if Path('model.pth').exists():
153
- model.load_state_dict(torch.load('model.pth', map_location=DEVICE))
154
-
155
- return ast_encoder, tokenizer, code_model, model
156
 
157
- ast_encoder, tokenizer, code_model, model = load_models()
158
 
159
- # Prediction Function
160
- def predict_clone(code1, code2):
161
  try:
162
- # Process first code
163
- norm_code1 = normalize_java_code(code1)
164
- tokens1 = list(javalang.tokenizer.tokenize(norm_code1))
165
- parser = javalang.parser.Parser(tokens1)
166
- ast1 = parser.parse()
167
- ast_data1 = ast_to_graph_data(ast1, ast_encoder)
168
 
169
- inputs1 = tokenizer(norm_code1, return_tensors="pt", truncation=True).to(DEVICE)
 
170
  with torch.no_grad():
171
- code_embed1 = code_model(**inputs1).last_hidden_state.mean(dim=1)
172
-
173
- # Process second code
174
- norm_code2 = normalize_java_code(code2)
175
- tokens2 = list(javalang.tokenizer.tokenize(norm_code2))
176
- parser = javalang.parser.Parser(tokens2)
177
- ast2 = parser.parse()
178
- ast_data2 = ast_to_graph_data(ast2, ast_encoder)
179
-
180
- inputs2 = tokenizer(norm_code2, return_tensors="pt", truncation=True).to(DEVICE)
181
- with torch.no_grad():
182
- code_embed2 = code_model(**inputs2).last_hidden_state.mean(dim=1)
183
 
184
- # Predict
185
  with torch.no_grad():
186
- logits1 = model(ast_data1.to(DEVICE), code_embed1)
187
- logits2 = model(ast_data2.to(DEVICE), code_embed2)
188
- sim_score = F.cosine_similarity(logits1, logits2).item()
189
 
190
  return {
191
- "Similarity": f"{sim_score:.3f}",
192
- "Clone": "Yes" if sim_score > 0.7 else "No"
 
193
  }
194
  except Exception as e:
195
  return {"Error": str(e)}
@@ -201,7 +75,7 @@ demo = gr.Interface(
201
  gr.Textbox(label="First Java Code", lines=10),
202
  gr.Textbox(label="Second Java Code", lines=10)
203
  ],
204
- outputs=gr.JSON(label="Prediction"),
205
  examples=[
206
  ["""public class Hello {
207
  public static void main(String[] args) {
@@ -221,7 +95,7 @@ demo = gr.Interface(
221
  }"""]
222
  ],
223
  title="Java Code Clone Detector",
224
- description="Detect code clones between two Java code snippets using AST and neural embeddings"
225
  )
226
 
227
  if __name__ == "__main__":
 
2
  import torch
3
  import torch.nn as nn
4
  import torch.nn.functional as F
 
5
  import re
6
  import gradio as gr
7
  from transformers import AutoTokenizer, AutoModel
 
9
 
10
  # Configuration
11
  MAX_FILE_SIZE = 5000
 
12
  EMBEDDING_DIM = 128
13
  DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
14
 
15
+ # Initialize models once at startup
16
+ tokenizer = AutoTokenizer.from_pretrained("microsoft/codebert-base")
17
+ code_model = AutoModel.from_pretrained("microsoft/codebert-base").to(DEVICE)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
+ # Simplified model architecture
20
+ class CloneDetector(nn.Module):
21
+ def __init__(self, hidden_dim):
22
  super().__init__()
 
23
  self.classifier = nn.Sequential(
24
  nn.Linear(hidden_dim * 2, hidden_dim),
25
  nn.ReLU(),
26
  nn.Linear(hidden_dim, 2))
27
 
28
+ def forward(self, emb1, emb2):
29
+ combined = torch.cat([emb1, emb2], dim=-1)
30
+ return self.classifier(combined)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
+ model = CloneDetector(768).to(DEVICE) # 768 is CodeBERT's hidden size
33
 
34
+ def get_code_embedding(code):
35
+ """Get embedding for a single code snippet"""
36
  try:
37
+ # Normalize code
38
+ code = re.sub(r'//.*', '', code)
39
+ code = re.sub(r'/\*.*?\*/', '', code, flags=re.DOTALL)
40
+ code = ' '.join(code.split())
 
 
41
 
42
+ # Tokenize and get embedding
43
+ inputs = tokenizer(code, return_tensors="pt", truncation=True, max_length=512).to(DEVICE)
44
  with torch.no_grad():
45
+ outputs = code_model(**inputs)
46
+ return outputs.last_hidden_state.mean(dim=1) # Pooled representation
47
+ except Exception:
48
+ return torch.zeros(1, 768).to(DEVICE)
49
+
50
+ def predict_clone(code1, code2):
51
+ """Compare two code snippets"""
52
+ try:
53
+ # Get embeddings
54
+ emb1 = get_code_embedding(code1)
55
+ emb2 = get_code_embedding(code2)
 
56
 
57
+ # Calculate similarity
58
  with torch.no_grad():
59
+ sim_score = F.cosine_similarity(emb1, emb2).item()
60
+ logits = model(emb1, emb2)
61
+ prob = F.softmax(logits, dim=-1)[0, 1].item()
62
 
63
  return {
64
+ "Similarity Score": f"{sim_score:.3f}",
65
+ "Clone Probability": f"{prob:.3f}",
66
+ "Prediction": "Clone" if prob > 0.5 else "Not Clone"
67
  }
68
  except Exception as e:
69
  return {"Error": str(e)}
 
75
  gr.Textbox(label="First Java Code", lines=10),
76
  gr.Textbox(label="Second Java Code", lines=10)
77
  ],
78
+ outputs=gr.JSON(label="Results"),
79
  examples=[
80
  ["""public class Hello {
81
  public static void main(String[] args) {
 
95
  }"""]
96
  ],
97
  title="Java Code Clone Detector",
98
+ description="Compare two Java code snippets to detect potential clones"
99
  )
100
 
101
  if __name__ == "__main__":