Update app.py
Browse files
app.py
CHANGED
@@ -2,7 +2,6 @@ import javalang
|
|
2 |
import torch
|
3 |
import torch.nn as nn
|
4 |
import torch.nn.functional as F
|
5 |
-
from torch_geometric.data import Data
|
6 |
import re
|
7 |
import gradio as gr
|
8 |
from transformers import AutoTokenizer, AutoModel
|
@@ -10,186 +9,61 @@ from pathlib import Path
|
|
10 |
|
11 |
# Configuration
|
12 |
MAX_FILE_SIZE = 5000
|
13 |
-
MAX_AST_DEPTH = 50
|
14 |
EMBEDDING_DIM = 128
|
15 |
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
16 |
|
17 |
-
#
|
18 |
-
|
19 |
-
|
20 |
-
self.node_types = set()
|
21 |
-
self.type_to_idx = {}
|
22 |
-
|
23 |
-
def fit(self, ast_nodes):
|
24 |
-
self.node_types.update(ast_nodes)
|
25 |
-
self.type_to_idx = {t: i for i, t in enumerate(sorted(self.node_types))}
|
26 |
-
|
27 |
-
def encode(self, node_type):
|
28 |
-
if node_type not in self.type_to_idx:
|
29 |
-
return torch.zeros(EMBEDDING_DIM)
|
30 |
-
idx = self.type_to_idx[node_type]
|
31 |
-
embedding = torch.zeros(EMBEDDING_DIM)
|
32 |
-
embedding[idx % EMBEDDING_DIM] = 1
|
33 |
-
embedding += torch.randn(EMBEDDING_DIM) * 0.1
|
34 |
-
return embedding
|
35 |
-
|
36 |
-
# Code Normalization
|
37 |
-
def normalize_java_code(code):
|
38 |
-
code = re.sub(r'//.*', '', code)
|
39 |
-
code = re.sub(r'/\*.*?\*/', '', code, flags=re.DOTALL)
|
40 |
-
code = re.sub(r'"[^"]*"', '"<STRING>"', code)
|
41 |
-
code = re.sub(r"'[^']*'", "'<CHAR>'", code)
|
42 |
-
return ' '.join(code.split())
|
43 |
-
|
44 |
-
# AST Processing
|
45 |
-
def extract_ast_paths(node, encoder, current_path=None, paths=None, depth=0):
|
46 |
-
if current_path is None:
|
47 |
-
current_path = []
|
48 |
-
if paths is None:
|
49 |
-
paths = []
|
50 |
-
|
51 |
-
if depth > MAX_AST_DEPTH:
|
52 |
-
return paths
|
53 |
-
|
54 |
-
node_type = str(type(node).__name__)
|
55 |
-
node_embedding = encoder.encode(node_type)
|
56 |
-
current_path.append(node_embedding)
|
57 |
-
|
58 |
-
if not hasattr(node, 'children') or depth == MAX_AST_DEPTH:
|
59 |
-
paths.append(torch.stack(current_path))
|
60 |
-
current_path.pop()
|
61 |
-
return paths
|
62 |
-
|
63 |
-
for child in node.children:
|
64 |
-
if isinstance(child, (javalang.ast.Node, list, tuple)):
|
65 |
-
if isinstance(child, (list, tuple)):
|
66 |
-
for c in child:
|
67 |
-
if isinstance(c, javalang.ast.Node):
|
68 |
-
extract_ast_paths(c, encoder, current_path, paths, depth+1)
|
69 |
-
elif isinstance(child, javalang.ast.Node):
|
70 |
-
extract_ast_paths(child, encoder, current_path, paths, depth+1)
|
71 |
-
|
72 |
-
current_path.pop()
|
73 |
-
return paths
|
74 |
-
|
75 |
-
def ast_to_graph_data(ast, encoder):
|
76 |
-
paths = extract_ast_paths(ast, encoder)
|
77 |
-
if not paths:
|
78 |
-
return None
|
79 |
-
|
80 |
-
edge_index = []
|
81 |
-
node_features = []
|
82 |
-
node_counter = 0
|
83 |
-
node_mapping = {}
|
84 |
-
|
85 |
-
for path in paths:
|
86 |
-
for i in range(len(path) - 1):
|
87 |
-
for j in [i, i+1]:
|
88 |
-
node_key = tuple(path[j].tolist())
|
89 |
-
if node_key not in node_mapping:
|
90 |
-
node_mapping[node_key] = node_counter
|
91 |
-
node_features.append(path[j])
|
92 |
-
node_counter += 1
|
93 |
-
|
94 |
-
src = node_mapping[tuple(path[i].tolist())]
|
95 |
-
dst = node_mapping[tuple(path[i+1].tolist())]
|
96 |
-
edge_index.append([src, dst])
|
97 |
-
|
98 |
-
if not edge_index:
|
99 |
-
return None
|
100 |
-
|
101 |
-
node_features = torch.stack(node_features)
|
102 |
-
edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
|
103 |
-
return Data(x=node_features, edge_index=edge_index)
|
104 |
-
|
105 |
-
# Model Architecture
|
106 |
-
class ASTGNN(nn.Module):
|
107 |
-
def __init__(self, input_dim, hidden_dim):
|
108 |
-
super().__init__()
|
109 |
-
self.conv1 = nn.Sequential(
|
110 |
-
nn.Linear(input_dim, hidden_dim),
|
111 |
-
nn.ReLU(),
|
112 |
-
nn.Linear(hidden_dim, hidden_dim)
|
113 |
-
)
|
114 |
-
self.conv2 = nn.Sequential(
|
115 |
-
nn.Linear(hidden_dim, hidden_dim),
|
116 |
-
nn.ReLU(),
|
117 |
-
nn.Linear(hidden_dim, hidden_dim)
|
118 |
-
)
|
119 |
-
self.pool = nn.AdaptiveMaxPool1d(1)
|
120 |
-
|
121 |
-
def forward(self, data):
|
122 |
-
x, edge_index = data.x.to(DEVICE), data.edge_index.to(DEVICE)
|
123 |
-
x = self.conv1(x)
|
124 |
-
x = self.conv2(x)
|
125 |
-
x = x.t().unsqueeze(0)
|
126 |
-
x = self.pool(x)
|
127 |
-
return x.squeeze(0).squeeze(-1)
|
128 |
|
129 |
-
|
130 |
-
|
|
|
131 |
super().__init__()
|
132 |
-
self.ast_gnn = ASTGNN(ast_input_dim, hidden_dim)
|
133 |
self.classifier = nn.Sequential(
|
134 |
nn.Linear(hidden_dim * 2, hidden_dim),
|
135 |
nn.ReLU(),
|
136 |
nn.Linear(hidden_dim, 2))
|
137 |
|
138 |
-
def forward(self,
|
139 |
-
|
140 |
-
|
141 |
-
return self.classifier(combined.unsqueeze(0))
|
142 |
-
|
143 |
-
# Load Models
|
144 |
-
def load_models():
|
145 |
-
ast_encoder = ASTNodeEncoder()
|
146 |
-
ast_encoder.fit(['MethodDeclaration', 'VariableDeclaration', 'IfStatement'])
|
147 |
-
|
148 |
-
tokenizer = AutoTokenizer.from_pretrained("microsoft/codebert-base")
|
149 |
-
code_model = AutoModel.from_pretrained("microsoft/codebert-base").to(DEVICE)
|
150 |
-
|
151 |
-
model = HybridCloneDetector(EMBEDDING_DIM, EMBEDDING_DIM).to(DEVICE)
|
152 |
-
if Path('model.pth').exists():
|
153 |
-
model.load_state_dict(torch.load('model.pth', map_location=DEVICE))
|
154 |
-
|
155 |
-
return ast_encoder, tokenizer, code_model, model
|
156 |
|
157 |
-
|
158 |
|
159 |
-
|
160 |
-
|
161 |
try:
|
162 |
-
#
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
ast1 = parser.parse()
|
167 |
-
ast_data1 = ast_to_graph_data(ast1, ast_encoder)
|
168 |
|
169 |
-
|
|
|
170 |
with torch.no_grad():
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
code_embed2 = code_model(**inputs2).last_hidden_state.mean(dim=1)
|
183 |
|
184 |
-
#
|
185 |
with torch.no_grad():
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
|
190 |
return {
|
191 |
-
"Similarity": f"{sim_score:.3f}",
|
192 |
-
"Clone": "
|
|
|
193 |
}
|
194 |
except Exception as e:
|
195 |
return {"Error": str(e)}
|
@@ -201,7 +75,7 @@ demo = gr.Interface(
|
|
201 |
gr.Textbox(label="First Java Code", lines=10),
|
202 |
gr.Textbox(label="Second Java Code", lines=10)
|
203 |
],
|
204 |
-
outputs=gr.JSON(label="
|
205 |
examples=[
|
206 |
["""public class Hello {
|
207 |
public static void main(String[] args) {
|
@@ -221,7 +95,7 @@ demo = gr.Interface(
|
|
221 |
}"""]
|
222 |
],
|
223 |
title="Java Code Clone Detector",
|
224 |
-
description="
|
225 |
)
|
226 |
|
227 |
if __name__ == "__main__":
|
|
|
2 |
import torch
|
3 |
import torch.nn as nn
|
4 |
import torch.nn.functional as F
|
|
|
5 |
import re
|
6 |
import gradio as gr
|
7 |
from transformers import AutoTokenizer, AutoModel
|
|
|
9 |
|
10 |
# Configuration
|
11 |
MAX_FILE_SIZE = 5000
|
|
|
12 |
EMBEDDING_DIM = 128
|
13 |
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
14 |
|
15 |
+
# Initialize models once at startup
|
16 |
+
tokenizer = AutoTokenizer.from_pretrained("microsoft/codebert-base")
|
17 |
+
code_model = AutoModel.from_pretrained("microsoft/codebert-base").to(DEVICE)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
+
# Simplified model architecture
|
20 |
+
class CloneDetector(nn.Module):
|
21 |
+
def __init__(self, hidden_dim):
|
22 |
super().__init__()
|
|
|
23 |
self.classifier = nn.Sequential(
|
24 |
nn.Linear(hidden_dim * 2, hidden_dim),
|
25 |
nn.ReLU(),
|
26 |
nn.Linear(hidden_dim, 2))
|
27 |
|
28 |
+
def forward(self, emb1, emb2):
|
29 |
+
combined = torch.cat([emb1, emb2], dim=-1)
|
30 |
+
return self.classifier(combined)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
|
32 |
+
model = CloneDetector(768).to(DEVICE) # 768 is CodeBERT's hidden size
|
33 |
|
34 |
+
def get_code_embedding(code):
|
35 |
+
"""Get embedding for a single code snippet"""
|
36 |
try:
|
37 |
+
# Normalize code
|
38 |
+
code = re.sub(r'//.*', '', code)
|
39 |
+
code = re.sub(r'/\*.*?\*/', '', code, flags=re.DOTALL)
|
40 |
+
code = ' '.join(code.split())
|
|
|
|
|
41 |
|
42 |
+
# Tokenize and get embedding
|
43 |
+
inputs = tokenizer(code, return_tensors="pt", truncation=True, max_length=512).to(DEVICE)
|
44 |
with torch.no_grad():
|
45 |
+
outputs = code_model(**inputs)
|
46 |
+
return outputs.last_hidden_state.mean(dim=1) # Pooled representation
|
47 |
+
except Exception:
|
48 |
+
return torch.zeros(1, 768).to(DEVICE)
|
49 |
+
|
50 |
+
def predict_clone(code1, code2):
|
51 |
+
"""Compare two code snippets"""
|
52 |
+
try:
|
53 |
+
# Get embeddings
|
54 |
+
emb1 = get_code_embedding(code1)
|
55 |
+
emb2 = get_code_embedding(code2)
|
|
|
56 |
|
57 |
+
# Calculate similarity
|
58 |
with torch.no_grad():
|
59 |
+
sim_score = F.cosine_similarity(emb1, emb2).item()
|
60 |
+
logits = model(emb1, emb2)
|
61 |
+
prob = F.softmax(logits, dim=-1)[0, 1].item()
|
62 |
|
63 |
return {
|
64 |
+
"Similarity Score": f"{sim_score:.3f}",
|
65 |
+
"Clone Probability": f"{prob:.3f}",
|
66 |
+
"Prediction": "Clone" if prob > 0.5 else "Not Clone"
|
67 |
}
|
68 |
except Exception as e:
|
69 |
return {"Error": str(e)}
|
|
|
75 |
gr.Textbox(label="First Java Code", lines=10),
|
76 |
gr.Textbox(label="Second Java Code", lines=10)
|
77 |
],
|
78 |
+
outputs=gr.JSON(label="Results"),
|
79 |
examples=[
|
80 |
["""public class Hello {
|
81 |
public static void main(String[] args) {
|
|
|
95 |
}"""]
|
96 |
],
|
97 |
title="Java Code Clone Detector",
|
98 |
+
description="Compare two Java code snippets to detect potential clones"
|
99 |
)
|
100 |
|
101 |
if __name__ == "__main__":
|