Aranwer's picture
Update app.py
daf072c verified
raw
history blame
7.79 kB
import os
import streamlit as st
import javalang
import torch
import torch.nn as nn
import torch.nn.functional as F
import re
import numpy as np
import networkx as nx
from transformers import AutoTokenizer, AutoModel
import warnings
import pandas as pd
from collections import defaultdict
# Configuration
os.environ["TOKENIZERS_PARALLELISM"] = "false"
warnings.filterwarnings("ignore")
# Constants
MODEL_NAME = "microsoft/codebert-base"
MAX_LENGTH = 512
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Set up page config
st.set_page_config(
page_title="Java Code Clone Detector",
page_icon="πŸ”",
layout="wide"
)
# Simplified RNN Model (for Hugging Face compatibility)
class SimpleRNN(nn.Module):
def __init__(self, input_size=768, hidden_size=128):
super().__init__()
self.rnn = nn.RNN(input_size, hidden_size, batch_first=True)
self.fc = nn.Linear(hidden_size, 1)
def forward(self, x):
out, _ = self.rnn(x)
return torch.sigmoid(self.fc(out[:, -1]))
# Model Loading with caching
@st.cache_resource(show_spinner=False)
def load_models():
try:
with st.spinner('Loading models (first run may take a few minutes)...'):
# Load CodeBERT
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
code_model = AutoModel.from_pretrained(MODEL_NAME).to(DEVICE)
# Initialize simple RNN
rnn_model = SimpleRNN().to(DEVICE)
return tokenizer, code_model, rnn_model
except Exception as e:
st.error(f"Model loading failed: {str(e)}")
return None, None, None
# AST Processing (simplified for Hugging Face)
def parse_ast(code):
try:
return javalang.parse.parse(code)
except:
return None
def build_simple_ast_features(ast_tree):
if not ast_tree: return {}
features = defaultdict(int)
def traverse(node):
features[type(node).__name__] += 1
for child in getattr(node, 'children', []):
if isinstance(child, javalang.ast.Node):
traverse(child)
elif isinstance(child, (list, tuple)):
for item in child:
if isinstance(item, javalang.ast.Node):
traverse(item)
traverse(ast_tree)
return dict(features)
# Feature Extraction
def normalize_code(code):
code = re.sub(r'//.*?$', '', code, flags=re.MULTILINE)
code = re.sub(r'/\*.*?\*/', '', code, flags=re.DOTALL)
return re.sub(r'\s+', ' ', code).strip()
def get_embedding(code, tokenizer, model):
try:
inputs = tokenizer(
normalize_code(code),
return_tensors="pt",
truncation=True,
max_length=MAX_LENGTH,
padding='max_length'
).to(DEVICE)
with torch.no_grad():
return model(**inputs).last_hidden_state.mean(dim=1)
except:
return None
# Similarity Calculations (optimized for Hugging Face)
def calculate_similarities(code1, code2, models):
tokenizer, code_model, rnn_model = models
# Get embeddings
emb1 = get_embedding(code1, tokenizer, code_model)
emb2 = get_embedding(code2, tokenizer, code_model)
# Get AST features
ast1 = parse_ast(code1)
ast2 = parse_ast(code2)
ast_features1 = build_simple_ast_features(ast1)
ast_features2 = build_simple_ast_features(ast2)
# Calculate similarities
codebert_sim = 0
if emb1 is not None and emb2 is not None:
codebert_sim = F.cosine_similarity(emb1, emb2).item()
rnn_sim = 0
if emb1 is not None and emb2 is not None:
with torch.no_grad():
rnn_input = torch.cat([emb1, emb2]).unsqueeze(0)
rnn_sim = rnn_model(rnn_input).item()
# Simple AST similarity (count matching node types)
ast_sim = 0
if ast_features1 and ast_features2:
common_keys = set(ast_features1.keys()) & set(ast_features2.keys())
total_keys = set(ast_features1.keys()) | set(ast_features2.keys())
ast_sim = len(common_keys) / len(total_keys) if total_keys else 0
return {
'codebert': codebert_sim,
'rnn': rnn_sim,
'ast': ast_sim,
'combined': 0.5*codebert_sim + 0.3*rnn_sim + 0.2*ast_sim
}
# Main UI
def main():
st.title("πŸ” Java Code Clone Detector (IJaDataset 2.1)")
st.markdown("Detect Type 1-4 clones using hybrid analysis")
# Load models
models = load_models()
if None in models:
st.error("Failed to load required models. Please check the logs.")
return
# Example code pairs
example_pairs = {
"Type 1 Example": {
"code1": "public class Test { public static void main(String[] args) { System.out.println(\"Hello\"); } }",
"code2": "public class Test { public static void main(String[] args) { System.out.println(\"Hello\"); } }"
},
"Type 2 Example": {
"code1": "public class Test { public static void main(String[] args) { System.out.println(\"Hello\"); } }",
"code2": "public class Example { public static void main(String[] args) { System.out.println(\"Hello\"); } }"
},
"Type 3 Example": {
"code1": "public class Test { public static void main(String[] args) { for(int i=0;i<10;i++) System.out.println(i); } }",
"code2": "public class Example { public static void run(String[] params) { for(int j=0;j<10;j++) System.out.println(j); } }"
}
}
# Code input
selected_example = st.selectbox("Select example pair:", list(example_pairs.keys()))
col1, col2 = st.columns(2)
with col1:
code1 = st.text_area(
"Code 1",
height=300,
value=example_pairs[selected_example]["code1"]
)
with col2:
code2 = st.text_area(
"Code 2",
height=300,
value=example_pairs[selected_example]["code2"]
)
# Thresholds
st.subheader("Detection Thresholds")
cols = st.columns(3)
with cols[0]:
t1 = st.slider("Type 1/2", 0.85, 1.0, 0.95)
with cols[1]:
t3 = st.slider("Type 3", 0.7, 0.9, 0.8)
with cols[2]:
t4 = st.slider("Type 4", 0.5, 0.8, 0.65)
# Analysis button
if st.button("Analyze Code", type="primary"):
with st.spinner("Analyzing code..."):
sims = calculate_similarities(code1, code2, models)
# Determine clone type
clone_type = "No Clone"
if sims['combined'] >= t1:
clone_type = "Type 1/2 Clone (Exact/Near-Exact)"
elif sims['combined'] >= t3:
clone_type = "Type 3 Clone (Near-Miss)"
elif sims['combined'] >= t4:
clone_type = "Type 4 Clone (Semantic)"
# Display results
st.subheader("Results")
# Metrics
cols = st.columns(4)
cols[0].metric("Combined", f"{sims['combined']:.2f}")
cols[1].metric("CodeBERT", f"{sims['codebert']:.2f}")
cols[2].metric("RNN", f"{sims['rnn']:.2f}")
cols[3].metric("AST", f"{sims['ast']:.2f}")
# Progress bar
st.progress(sims['combined'])
# Final result
st.metric("Detection Result", clone_type)
# Show details
with st.expander("Advanced Details"):
st.json(sims)
st.code(f"Normalized Code 1:\n{normalize_code(code1)}")
st.code(f"Normalized Code 2:\n{normalize_code(code2)}")
if __name__ == "__main__":
main()