Aranwer's picture
Update app.py
1c3d984 verified
raw
history blame
8.92 kB
import os
import re
import time
import random
import zipfile
import javalang
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric
from torch_geometric.data import Data, Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_recall_fscore_support
from tqdm import tqdm
import networkx as nx
# ---- Utility functions ----
def unzip_dataset(zip_path, extract_to):
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(extract_to)
def normalize_java_code(code):
# Remove single-line comments
code = re.sub(r'//.*?\n', '', code)
# Remove multi-line comments
code = re.sub(r'/\*.*?\*/', '', code, flags=re.DOTALL)
# Remove extra spaces and blank lines
code = re.sub(r'\s+', ' ', code)
return code.strip()
def safe_parse_java(code):
try:
tokens = list(javalang.tokenizer.tokenize(code))
parser = javalang.parser.Parser(tokens)
tree = parser.parse()
return tree
except Exception:
return None
def ast_to_graph(ast):
graph = nx.DiGraph()
def dfs(node, parent_id=None):
node_id = len(graph)
graph.add_node(node_id, label=type(node).__name__)
if parent_id is not None:
graph.add_edge(parent_id, node_id)
for child in getattr(node, 'children', []):
if isinstance(child, (list, tuple)):
for item in child:
if isinstance(item, javalang.ast.Node):
dfs(item, node_id)
elif isinstance(child, javalang.ast.Node):
dfs(child, node_id)
dfs(ast)
return graph
def tokenize_java_code(code):
try:
tokens = list(javalang.tokenizer.tokenize(code))
token_list = [token.value for token in tokens]
return token_list
except:
return []
# ---- Data Preprocessing ----
class CloneDataset(Dataset):
def __init__(self, root_dir, transform=None):
super().__init__()
self.data_list = []
self.labels = []
self.skipped_files = 0
self.max_tokens = 5000
clone_dirs = {
"Clone_Type1": 1,
"Clone_Type2": 1,
"Clone_Type3 - ST": 1,
"Clone_Type3 - VST": 1,
"Clone_Type3 - MT": 0 # Assuming MT = Not Clone
}
for clone_type, label in clone_dirs.items():
clone_path = os.path.join(root_dir, 'Subject_CloneTypes_Directories', clone_type)
for root, _, files in os.walk(clone_path):
for file in files:
if file.endswith(".java"):
file_path = os.path.join(root, file)
with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
code = f.read()
code = normalize_java_code(code)
if len(code.split()) > self.max_tokens:
self.skipped_files += 1
continue
ast = safe_parse_java(code)
if ast is None:
self.skipped_files += 1
continue
graph = ast_to_graph(ast)
tokens = tokenize_java_code(code)
if not tokens:
self.skipped_files += 1
continue
data = {
'graph': graph,
'tokens': tokens,
'label': label
}
self.data_list.append(data)
def len(self):
return len(self.data_list)
def get(self, idx):
data_item = self.data_list[idx]
graph = data_item['graph']
tokens = data_item['tokens']
label = data_item['label']
# Graph processing
edge_index = torch.tensor(list(graph.edges)).t().contiguous()
node_features = torch.arange(graph.number_of_nodes()).unsqueeze(1).float()
# Token processing
token_indices = torch.tensor([hash(t) % 5000 for t in tokens], dtype=torch.long)
return edge_index, node_features, token_indices, torch.tensor(label, dtype=torch.long)
# ---- Models ----
class GNNEncoder(nn.Module):
def __init__(self, in_channels=1, hidden_dim=64):
super().__init__()
self.conv1 = torch_geometric.nn.GCNConv(in_channels, hidden_dim)
self.conv2 = torch_geometric.nn.GCNConv(hidden_dim, hidden_dim)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index)
x = F.relu(x)
x = self.conv2(x, edge_index)
x = F.relu(x)
return torch.mean(x, dim=0) # Graph-level embedding
class RNNEncoder(nn.Module):
def __init__(self, vocab_size=5000, embedding_dim=64, hidden_dim=64):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim)
self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
def forward(self, tokens):
embeds = self.embedding(tokens)
_, (hidden, _) = self.lstm(embeds)
return hidden.squeeze(0)
class HybridClassifier(nn.Module):
def __init__(self):
super().__init__()
self.gnn = GNNEncoder()
self.rnn = RNNEncoder()
self.fc = nn.Linear(128, 2)
def forward(self, edge_index, node_features, tokens):
gnn_out = self.gnn(node_features, edge_index)
rnn_out = self.rnn(tokens)
combined = torch.cat([gnn_out, rnn_out], dim=-1)
out = self.fc(combined)
return out
# ---- Training and Evaluation ----
def train(model, optimizer, loader, device):
model.train()
total_loss = 0
for edge_index, node_features, tokens, labels in loader:
edge_index = edge_index.to(device)
node_features = node_features.to(device)
tokens = tokens.to(device)
labels = labels.to(device)
optimizer.zero_grad()
outputs = model(edge_index, node_features, tokens)
loss = F.cross_entropy(outputs.unsqueeze(0), labels.unsqueeze(0))
loss.backward()
optimizer.step()
total_loss += loss.item()
return total_loss / len(loader)
def evaluate(model, loader, device):
model.eval()
preds, labels_all = [], []
with torch.no_grad():
for edge_index, node_features, tokens, labels in loader:
edge_index = edge_index.to(device)
node_features = node_features.to(device)
tokens = tokens.to(device)
labels = labels.to(device)
outputs = model(edge_index, node_features, tokens)
pred = outputs.argmax(dim=-1)
preds.append(pred.cpu().numpy())
labels_all.append(labels.cpu().numpy())
preds = np.concatenate(preds)
labels_all = np.concatenate(labels_all)
precision, recall, f1, _ = precision_recall_fscore_support(labels_all, preds, average='binary')
return precision, recall, f1
# ---- Main Execution ----
if __name__ == "__main__":
import numpy as np
dataset_root = 'archive (1)'
unzip_dataset('archive (1).zip', dataset_root)
dataset = CloneDataset(dataset_root)
print(f"Total valid samples: {dataset.len()}")
print(f"Total skipped files: {dataset.skipped_files}")
indices = list(range(dataset.len()))
train_idx, temp_idx = train_test_split(indices, test_size=0.2, random_state=42)
val_idx, test_idx = train_test_split(temp_idx, test_size=0.5, random_state=42)
train_set = torch.utils.data.Subset(dataset, train_idx)
val_set = torch.utils.data.Subset(dataset, val_idx)
test_set = torch.utils.data.Subset(dataset, test_idx)
batch_size = 1 # small because of variable graph sizes
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_set, batch_size=batch_size)
test_loader = DataLoader(test_set, batch_size=batch_size)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = HybridClassifier().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
epochs = 5
start_time = time.time()
for epoch in range(epochs):
train_loss = train(model, optimizer, train_loader, device)
precision, recall, f1 = evaluate(model, val_loader, device)
print(f"Epoch {epoch+1}: Loss={train_loss:.4f}, Precision={precision:.4f}, Recall={recall:.4f}, F1={f1:.4f}")
precision, recall, f1 = evaluate(model, test_loader, device)
total_time = time.time() - start_time
print(f"Test Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}")
print(f"Total execution time: {total_time:.2f} seconds")