Aranwer's picture
Update app.py
84d4d13 verified
raw
history blame
9.57 kB
import os
import streamlit as st
import javalang
import torch
import torch.nn as nn
import torch.nn.functional as F
import re
import numpy as np
import networkx as nx
from transformers import AutoTokenizer, AutoModel, AutoConfig
from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
import warnings
import pandas as pd
import zipfile
from collections import defaultdict
# Configuration
os.environ["TOKENIZERS_PARALLELISM"] = "false"
warnings.filterwarnings("ignore")
# Constants
MODEL_NAME = "microsoft/codebert-base"
MAX_LENGTH = 512
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DATASET_PATH = "ijadataset2-1.zip"
CACHE_DIR = "./model_cache"
# Set up page config
st.set_page_config(
page_title="Advanced Java Code Clone Detector",
page_icon="πŸ”",
layout="wide"
)
# Model Definitions
class RNNModel(nn.Module):
def __init__(self, input_size, hidden_size, num_layers):
super().__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)
self.fc = nn.Linear(hidden_size, 1)
def forward(self, x):
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(DEVICE)
out, _ = self.rnn(x, h0)
return self.fc(out[:, -1, :])
class GNNModel(nn.Module):
def __init__(self, node_features):
super().__init__()
self.conv1 = GCNConv(node_features, 128)
self.conv2 = GCNConv(128, 64)
self.fc = nn.Linear(64, 1)
def forward(self, data):
x, edge_index = data.x, data.edge_index
x = F.relu(self.conv1(x, edge_index))
x = F.dropout(x, training=self.training)
x = self.conv2(x, edge_index)
return torch.sigmoid(self.fc(x).mean())
# Model Loading with Cache
@st.cache_resource(show_spinner=False)
def load_models():
try:
with st.spinner('Loading models (first run may take a few minutes)...'):
config = AutoConfig.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
model = AutoModel.from_pretrained(MODEL_NAME, config=config, cache_dir=CACHE_DIR).to(DEVICE)
rnn_model = RNNModel(input_size=768, hidden_size=256, num_layers=2).to(DEVICE)
gnn_model = GNNModel(node_features=128).to(DEVICE)
return tokenizer, model, rnn_model, gnn_model
except Exception as e:
st.error(f"Model loading failed: {str(e)}")
return None, None, None, None
# Dataset Loading
@st.cache_resource
def load_dataset():
try:
if not os.path.exists("Diverse_100K_Dataset"):
with zipfile.ZipFile(DATASET_PATH, 'r') as zip_ref:
zip_ref.extractall(".")
clone_pairs = []
base_path = "Diverse_100K_Dataset/Subject_CloneTypes_Directories"
for clone_type in ["Clone_Type1", "Clone_Type2", "Clone_Type3 - ST", "Clone_Type4"]:
type_path = os.path.join(base_path, clone_type)
if os.path.exists(type_path):
for root, _, files in os.walk(type_path):
if files and len(files) >= 2:
with open(os.path.join(root, files[0]), 'r', encoding='utf-8') as f1, \
open(os.path.join(root, files[1]), 'r', encoding='utf-8') as f2:
clone_pairs.append({
"type": clone_type,
"code1": f1.read(),
"code2": f2.read()
})
break
return clone_pairs[:10]
except Exception as e:
st.error(f"Dataset error: {str(e)}")
return []
# AST Processing
def parse_ast(code):
try:
return javalang.parse.parse(code)
except:
return None
def build_ast_graph(ast_tree):
if not ast_tree: return None
G = nx.DiGraph()
node_id = 0
def traverse(node, parent=None):
nonlocal node_id
current = node_id
G.add_node(current, type=type(node).__name__)
if parent is not None:
G.add_edge(parent, current)
node_id += 1
for child in getattr(node, 'children', []):
if isinstance(child, javalang.ast.Node):
traverse(child, current)
elif isinstance(child, (list, tuple)):
for item in child:
if isinstance(item, javalang.ast.Node):
traverse(item, current)
traverse(ast_tree)
return G
def ast_to_pyg_data(ast_graph):
if not ast_graph: return None
node_types = list(nx.get_node_attributes(ast_graph, 'type').values())
unique_types = list(set(node_types))
type_to_idx = {t: i for i, t in enumerate(unique_types)}
x = torch.zeros(len(node_types), len(unique_types))
for i, t in enumerate(node_types):
x[i, type_to_idx[t]] = 1
edge_index = torch.tensor(list(ast_graph.edges())).t().contiguous()
return Data(x=x.to(DEVICE), edge_index=edge_index.to(DEVICE))
# Feature Extraction
def normalize_code(code):
code = re.sub(r'//.*?$', '', code, flags=re.MULTILINE)
code = re.sub(r'/\*.*?\*/', '', code, flags=re.DOTALL)
return re.sub(r'\s+', ' ', code).strip()
def get_embedding(code, tokenizer, model):
try:
inputs = tokenizer(
normalize_code(code),
return_tensors="pt",
truncation=True,
max_length=MAX_LENGTH,
padding='max_length'
).to(DEVICE)
with torch.no_grad():
return model(**inputs).last_hidden_state.mean(dim=1)
except:
return None
# Similarity Calculations
def calculate_similarities(code1, code2, models):
tokenizer, code_model, rnn_model, gnn_model = models
# Get embeddings
emb1 = get_embedding(code1, tokenizer, code_model)
emb2 = get_embedding(code2, tokenizer, code_model)
# Parse ASTs
ast1 = build_ast_graph(parse_ast(code1))
ast2 = build_ast_graph(parse_ast(code2))
# Calculate similarities
codebert_sim = F.cosine_similarity(emb1, emb2).item() if emb1 is not None and emb2 is not None else 0
rnn_sim = 0
if emb1 is not None and emb2 is not None:
with torch.no_grad():
rnn_input = torch.stack([emb1.squeeze(), emb2.squeeze()])
rnn_sim = torch.sigmoid(rnn_model(rnn_input.unsqueeze(0))).item()
gnn_sim = 0
if ast1 and ast2:
data1 = ast_to_pyg_data(ast1)
data2 = ast_to_pyg_data(ast2)
if data1 and data2:
with torch.no_grad():
gnn_sim = F.cosine_similarity(
gnn_model(data1).unsqueeze(0),
gnn_model(data2).unsqueeze(0)
).item()
return {
'codebert': codebert_sim,
'rnn': rnn_sim,
'gnn': gnn_sim,
'combined': 0.4*codebert_sim + 0.3*rnn_sim + 0.3*gnn_sim
}
# UI Components
def main():
st.title("πŸ” Advanced Java Code Clone Detector")
st.markdown("Detect all clone types (1-4) using hybrid analysis")
# Load resources
models = load_models()
dataset_pairs = load_dataset()
# Code input
selected_pair = None
if dataset_pairs:
pair_options = {f"{i+1}: {pair['type']}": pair for i, pair in enumerate(dataset_pairs)}
selected_option = st.selectbox("Select example pair:", list(pair_options.keys()))
selected_pair = pair_options[selected_option]
col1, col2 = st.columns(2)
with col1:
code1 = st.text_area("Code 1", height=300, value=selected_pair["code1"] if selected_pair else "")
with col2:
code2 = st.text_area("Code 2", height=300, value=selected_pair["code2"] if selected_pair else "")
# Thresholds
st.subheader("Detection Thresholds")
cols = st.columns(3)
with cols[0]:
t1 = st.slider("Type 1/2", 0.85, 1.0, 0.95)
with cols[1]:
t3 = st.slider("Type 3", 0.7, 0.9, 0.8)
with cols[2]:
t4 = st.slider("Type 4", 0.5, 0.8, 0.65)
# Analysis
if st.button("Analyze", type="primary") and models[0]:
with st.spinner("Analyzing..."):
sims = calculate_similarities(code1, code2, models)
# Determine clone type
clone_type = "No Clone"
if sims['combined'] >= t1:
clone_type = "Type 1/2 Clone"
elif sims['combined'] >= t3:
clone_type = "Type 3 Clone"
elif sims['combined'] >= t4:
clone_type = "Type 4 Clone"
# Display results
st.subheader("Results")
cols = st.columns(4)
cols[0].metric("Combined", f"{sims['combined']:.2f}")
cols[1].metric("CodeBERT", f"{sims['codebert']:.2f}")
cols[2].metric("RNN", f"{sims['rnn']:.2f}")
cols[3].metric("GNN", f"{sims['gnn']:.2f}")
st.progress(sims['combined'])
st.metric("Detection Result", clone_type)
# Show details
with st.expander("Details"):
st.json(sims)
st.code(f"Normalized Code 1:\n{normalize_code(code1)}")
st.code(f"Normalized Code 2:\n{normalize_code(code2)}")
if __name__ == "__main__":
main()