import os import streamlit as st import javalang import torch import torch.nn as nn import torch.nn.functional as F import re import numpy as np import networkx as nx from transformers import AutoTokenizer, AutoModel import warnings import pandas as pd from collections import defaultdict # Configuration os.environ["TOKENIZERS_PARALLELISM"] = "false" warnings.filterwarnings("ignore") # Constants MODEL_NAME = "microsoft/codebert-base" MAX_LENGTH = 512 DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Set up page config st.set_page_config( page_title="Java Code Clone Detector", page_icon="🔍", layout="wide" ) # Simplified RNN Model (for Hugging Face compatibility) class SimpleRNN(nn.Module): def __init__(self, input_size=768, hidden_size=128): super().__init__() self.rnn = nn.RNN(input_size, hidden_size, batch_first=True) self.fc = nn.Linear(hidden_size, 1) def forward(self, x): out, _ = self.rnn(x) return torch.sigmoid(self.fc(out[:, -1])) # Model Loading with caching @st.cache_resource(show_spinner=False) def load_models(): try: with st.spinner('Loading models (first run may take a few minutes)...'): # Load CodeBERT tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) code_model = AutoModel.from_pretrained(MODEL_NAME).to(DEVICE) # Initialize simple RNN rnn_model = SimpleRNN().to(DEVICE) return tokenizer, code_model, rnn_model except Exception as e: st.error(f"Model loading failed: {str(e)}") return None, None, None # AST Processing (simplified for Hugging Face) def parse_ast(code): try: return javalang.parse.parse(code) except: return None def build_simple_ast_features(ast_tree): if not ast_tree: return {} features = defaultdict(int) def traverse(node): features[type(node).__name__] += 1 for child in getattr(node, 'children', []): if isinstance(child, javalang.ast.Node): traverse(child) elif isinstance(child, (list, tuple)): for item in child: if isinstance(item, javalang.ast.Node): traverse(item) traverse(ast_tree) return dict(features) # Feature Extraction def normalize_code(code): code = re.sub(r'//.*?$', '', code, flags=re.MULTILINE) code = re.sub(r'/\*.*?\*/', '', code, flags=re.DOTALL) return re.sub(r'\s+', ' ', code).strip() def get_embedding(code, tokenizer, model): try: inputs = tokenizer( normalize_code(code), return_tensors="pt", truncation=True, max_length=MAX_LENGTH, padding='max_length' ).to(DEVICE) with torch.no_grad(): return model(**inputs).last_hidden_state.mean(dim=1) except: return None # Similarity Calculations (optimized for Hugging Face) def calculate_similarities(code1, code2, models): tokenizer, code_model, rnn_model = models # Get embeddings emb1 = get_embedding(code1, tokenizer, code_model) emb2 = get_embedding(code2, tokenizer, code_model) # Get AST features ast1 = parse_ast(code1) ast2 = parse_ast(code2) ast_features1 = build_simple_ast_features(ast1) ast_features2 = build_simple_ast_features(ast2) # Calculate similarities codebert_sim = 0 if emb1 is not None and emb2 is not None: codebert_sim = F.cosine_similarity(emb1, emb2).item() rnn_sim = 0 if emb1 is not None and emb2 is not None: with torch.no_grad(): rnn_input = torch.cat([emb1, emb2]).unsqueeze(0) rnn_sim = rnn_model(rnn_input).item() # Simple AST similarity (count matching node types) ast_sim = 0 if ast_features1 and ast_features2: common_keys = set(ast_features1.keys()) & set(ast_features2.keys()) total_keys = set(ast_features1.keys()) | set(ast_features2.keys()) ast_sim = len(common_keys) / len(total_keys) if total_keys else 0 return { 'codebert': codebert_sim, 'rnn': rnn_sim, 'ast': ast_sim, 'combined': 0.5*codebert_sim + 0.3*rnn_sim + 0.2*ast_sim } # Main UI def main(): st.title("🔍 Java Code Clone Detector (IJaDataset 2.1)") st.markdown("Detect Type 1-4 clones using hybrid analysis") # Load models models = load_models() if None in models: st.error("Failed to load required models. Please check the logs.") return # Example code pairs example_pairs = { "Type 1 Example": { "code1": "public class Test { public static void main(String[] args) { System.out.println(\"Hello\"); } }", "code2": "public class Test { public static void main(String[] args) { System.out.println(\"Hello\"); } }" }, "Type 2 Example": { "code1": "public class Test { public static void main(String[] args) { System.out.println(\"Hello\"); } }", "code2": "public class Example { public static void main(String[] args) { System.out.println(\"Hello\"); } }" }, "Type 3 Example": { "code1": "public class Test { public static void main(String[] args) { for(int i=0;i<10;i++) System.out.println(i); } }", "code2": "public class Example { public static void run(String[] params) { for(int j=0;j<10;j++) System.out.println(j); } }" } } # Code input selected_example = st.selectbox("Select example pair:", list(example_pairs.keys())) col1, col2 = st.columns(2) with col1: code1 = st.text_area( "Code 1", height=300, value=example_pairs[selected_example]["code1"] ) with col2: code2 = st.text_area( "Code 2", height=300, value=example_pairs[selected_example]["code2"] ) # Thresholds st.subheader("Detection Thresholds") cols = st.columns(3) with cols[0]: t1 = st.slider("Type 1/2", 0.85, 1.0, 0.95) with cols[1]: t3 = st.slider("Type 3", 0.7, 0.9, 0.8) with cols[2]: t4 = st.slider("Type 4", 0.5, 0.8, 0.65) # Analysis button if st.button("Analyze Code", type="primary"): with st.spinner("Analyzing code..."): sims = calculate_similarities(code1, code2, models) # Determine clone type clone_type = "No Clone" if sims['combined'] >= t1: clone_type = "Type 1/2 Clone (Exact/Near-Exact)" elif sims['combined'] >= t3: clone_type = "Type 3 Clone (Near-Miss)" elif sims['combined'] >= t4: clone_type = "Type 4 Clone (Semantic)" # Display results st.subheader("Results") # Metrics cols = st.columns(4) cols[0].metric("Combined", f"{sims['combined']:.2f}") cols[1].metric("CodeBERT", f"{sims['codebert']:.2f}") cols[2].metric("RNN", f"{sims['rnn']:.2f}") cols[3].metric("AST", f"{sims['ast']:.2f}") # Progress bar st.progress(sims['combined']) # Final result st.metric("Detection Result", clone_type) # Show details with st.expander("Advanced Details"): st.json(sims) st.code(f"Normalized Code 1:\n{normalize_code(code1)}") st.code(f"Normalized Code 2:\n{normalize_code(code2)}") if __name__ == "__main__": main()