Spaces:
Runtime error
Runtime error
File size: 9,570 Bytes
84d4d13 c4ca5b1 84d4d13 c4ca5b1 84d4d13 c4ca5b1 84d4d13 c4ca5b1 84d4d13 c4ca5b1 84d4d13 c4ca5b1 84d4d13 c4ca5b1 84d4d13 c4ca5b1 84d4d13 c4ca5b1 84d4d13 c4ca5b1 84d4d13 c4ca5b1 84d4d13 c4ca5b1 84d4d13 c4ca5b1 84d4d13 c4ca5b1 84d4d13 c4ca5b1 84d4d13 c4ca5b1 84d4d13 c4ca5b1 84d4d13 c4ca5b1 84d4d13 c4ca5b1 84d4d13 c4ca5b1 84d4d13 c4ca5b1 84d4d13 c4ca5b1 84d4d13 c4ca5b1 84d4d13 c4ca5b1 84d4d13 c4ca5b1 84d4d13 c4ca5b1 84d4d13 c4ca5b1 84d4d13 c4ca5b1 84d4d13 c4ca5b1 84d4d13 c4ca5b1 84d4d13 c4ca5b1 84d4d13 c4ca5b1 84d4d13 c4ca5b1 84d4d13 c4ca5b1 84d4d13 c4ca5b1 84d4d13 c4ca5b1 84d4d13 c4ca5b1 84d4d13 c4ca5b1 84d4d13 c4ca5b1 84d4d13 c4ca5b1 84d4d13 c4ca5b1 84d4d13 c4ca5b1 84d4d13 c4ca5b1 84d4d13 c4ca5b1 84d4d13 c4ca5b1 84d4d13 c4ca5b1 84d4d13 c4ca5b1 84d4d13 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 |
import os
import streamlit as st
import javalang
import torch
import torch.nn as nn
import torch.nn.functional as F
import re
import numpy as np
import networkx as nx
from transformers import AutoTokenizer, AutoModel, AutoConfig
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
import warnings
import pandas as pd
import zipfile
from collections import defaultdict
# Configuration
os.environ["TOKENIZERS_PARALLELISM"] = "false"
warnings.filterwarnings("ignore")
# Constants
MODEL_NAME = "microsoft/codebert-base"
MAX_LENGTH = 512
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DATASET_PATH = "ijadataset2-1.zip"
CACHE_DIR = "./model_cache"
# Set up page config
st.set_page_config(
page_title="Advanced Java Code Clone Detector",
page_icon="π",
layout="wide"
)
# Model Definitions
class RNNModel(nn.Module):
def __init__(self, input_size, hidden_size, num_layers):
super().__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_size, 1)
def forward(self, x):
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(DEVICE)
out, _ = self.rnn(x, h0)
return self.fc(out[:, -1, :])
class GNNModel(nn.Module):
def __init__(self, node_features):
super().__init__()
self.conv1 = GCNConv(node_features, 128)
self.conv2 = GCNConv(128, 64)
self.fc = nn.Linear(64, 1)
def forward(self, data):
x, edge_index = data.x, data.edge_index
x = F.relu(self.conv1(x, edge_index))
x = F.dropout(x, training=self.training)
x = self.conv2(x, edge_index)
return torch.sigmoid(self.fc(x).mean())
# Model Loading with Cache
@st.cache_resource(show_spinner=False)
def load_models():
try:
with st.spinner('Loading models (first run may take a few minutes)...'):
config = AutoConfig.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
model = AutoModel.from_pretrained(MODEL_NAME, config=config, cache_dir=CACHE_DIR).to(DEVICE)
rnn_model = RNNModel(input_size=768, hidden_size=256, num_layers=2).to(DEVICE)
gnn_model = GNNModel(node_features=128).to(DEVICE)
return tokenizer, model, rnn_model, gnn_model
except Exception as e:
st.error(f"Model loading failed: {str(e)}")
return None, None, None, None
# Dataset Loading
@st.cache_resource
def load_dataset():
try:
if not os.path.exists("Diverse_100K_Dataset"):
with zipfile.ZipFile(DATASET_PATH, 'r') as zip_ref:
zip_ref.extractall(".")
clone_pairs = []
base_path = "Diverse_100K_Dataset/Subject_CloneTypes_Directories"
for clone_type in ["Clone_Type1", "Clone_Type2", "Clone_Type3 - ST", "Clone_Type4"]:
type_path = os.path.join(base_path, clone_type)
if os.path.exists(type_path):
for root, _, files in os.walk(type_path):
if files and len(files) >= 2:
with open(os.path.join(root, files[0]), 'r', encoding='utf-8') as f1, \
open(os.path.join(root, files[1]), 'r', encoding='utf-8') as f2:
clone_pairs.append({
"type": clone_type,
"code1": f1.read(),
"code2": f2.read()
})
break
return clone_pairs[:10]
except Exception as e:
st.error(f"Dataset error: {str(e)}")
return []
# AST Processing
def parse_ast(code):
try:
return javalang.parse.parse(code)
except:
return None
def build_ast_graph(ast_tree):
if not ast_tree: return None
G = nx.DiGraph()
node_id = 0
def traverse(node, parent=None):
nonlocal node_id
current = node_id
G.add_node(current, type=type(node).__name__)
if parent is not None:
G.add_edge(parent, current)
node_id += 1
for child in getattr(node, 'children', []):
if isinstance(child, javalang.ast.Node):
traverse(child, current)
elif isinstance(child, (list, tuple)):
for item in child:
if isinstance(item, javalang.ast.Node):
traverse(item, current)
traverse(ast_tree)
return G
def ast_to_pyg_data(ast_graph):
if not ast_graph: return None
node_types = list(nx.get_node_attributes(ast_graph, 'type').values())
unique_types = list(set(node_types))
type_to_idx = {t: i for i, t in enumerate(unique_types)}
x = torch.zeros(len(node_types), len(unique_types))
for i, t in enumerate(node_types):
x[i, type_to_idx[t]] = 1
edge_index = torch.tensor(list(ast_graph.edges())).t().contiguous()
return Data(x=x.to(DEVICE), edge_index=edge_index.to(DEVICE))
# Feature Extraction
def normalize_code(code):
code = re.sub(r'//.*?$', '', code, flags=re.MULTILINE)
code = re.sub(r'/\*.*?\*/', '', code, flags=re.DOTALL)
return re.sub(r'\s+', ' ', code).strip()
def get_embedding(code, tokenizer, model):
try:
inputs = tokenizer(
normalize_code(code),
return_tensors="pt",
truncation=True,
max_length=MAX_LENGTH,
padding='max_length'
).to(DEVICE)
with torch.no_grad():
return model(**inputs).last_hidden_state.mean(dim=1)
except:
return None
# Similarity Calculations
def calculate_similarities(code1, code2, models):
tokenizer, code_model, rnn_model, gnn_model = models
# Get embeddings
emb1 = get_embedding(code1, tokenizer, code_model)
emb2 = get_embedding(code2, tokenizer, code_model)
# Parse ASTs
ast1 = build_ast_graph(parse_ast(code1))
ast2 = build_ast_graph(parse_ast(code2))
# Calculate similarities
codebert_sim = F.cosine_similarity(emb1, emb2).item() if emb1 is not None and emb2 is not None else 0
rnn_sim = 0
if emb1 is not None and emb2 is not None:
with torch.no_grad():
rnn_input = torch.stack([emb1.squeeze(), emb2.squeeze()])
rnn_sim = torch.sigmoid(rnn_model(rnn_input.unsqueeze(0))).item()
gnn_sim = 0
if ast1 and ast2:
data1 = ast_to_pyg_data(ast1)
data2 = ast_to_pyg_data(ast2)
if data1 and data2:
with torch.no_grad():
gnn_sim = F.cosine_similarity(
gnn_model(data1).unsqueeze(0),
gnn_model(data2).unsqueeze(0)
).item()
return {
'codebert': codebert_sim,
'rnn': rnn_sim,
'gnn': gnn_sim,
'combined': 0.4*codebert_sim + 0.3*rnn_sim + 0.3*gnn_sim
}
# UI Components
def main():
st.title("π Advanced Java Code Clone Detector")
st.markdown("Detect all clone types (1-4) using hybrid analysis")
# Load resources
models = load_models()
dataset_pairs = load_dataset()
# Code input
selected_pair = None
if dataset_pairs:
pair_options = {f"{i+1}: {pair['type']}": pair for i, pair in enumerate(dataset_pairs)}
selected_option = st.selectbox("Select example pair:", list(pair_options.keys()))
selected_pair = pair_options[selected_option]
col1, col2 = st.columns(2)
with col1:
code1 = st.text_area("Code 1", height=300, value=selected_pair["code1"] if selected_pair else "")
with col2:
code2 = st.text_area("Code 2", height=300, value=selected_pair["code2"] if selected_pair else "")
# Thresholds
st.subheader("Detection Thresholds")
cols = st.columns(3)
with cols[0]:
t1 = st.slider("Type 1/2", 0.85, 1.0, 0.95)
with cols[1]:
t3 = st.slider("Type 3", 0.7, 0.9, 0.8)
with cols[2]:
t4 = st.slider("Type 4", 0.5, 0.8, 0.65)
# Analysis
if st.button("Analyze", type="primary") and models[0]:
with st.spinner("Analyzing..."):
sims = calculate_similarities(code1, code2, models)
# Determine clone type
clone_type = "No Clone"
if sims['combined'] >= t1:
clone_type = "Type 1/2 Clone"
elif sims['combined'] >= t3:
clone_type = "Type 3 Clone"
elif sims['combined'] >= t4:
clone_type = "Type 4 Clone"
# Display results
st.subheader("Results")
cols = st.columns(4)
cols[0].metric("Combined", f"{sims['combined']:.2f}")
cols[1].metric("CodeBERT", f"{sims['codebert']:.2f}")
cols[2].metric("RNN", f"{sims['rnn']:.2f}")
cols[3].metric("GNN", f"{sims['gnn']:.2f}")
st.progress(sims['combined'])
st.metric("Detection Result", clone_type)
# Show details
with st.expander("Details"):
st.json(sims)
st.code(f"Normalized Code 1:\n{normalize_code(code1)}")
st.code(f"Normalized Code 2:\n{normalize_code(code2)}")
if __name__ == "__main__":
main() |