CCD / app.py
rahideer's picture
Update app.py
e39d081 verified
import streamlit as st
import javalang
import torch
import torch.nn.functional as F
import re
from transformers import AutoTokenizer, AutoModel
import warnings
# Set up page config
st.set_page_config(
page_title="Java Code Clone Detector",
page_icon="πŸ”",
layout="wide"
)
# Suppress warnings
warnings.filterwarnings("ignore")
# Constants
MODEL_NAME = "microsoft/codebert-base"
MAX_LENGTH = 512
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Initialize models with caching
@st.cache_resource
def load_models():
try:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME).to(DEVICE)
return tokenizer, model
except Exception as e:
st.error(f"Failed to load models: {str(e)}")
return None, None
tokenizer, code_model = load_models()
# UI Elements
st.title("πŸ” Java Code Clone Detector")
st.markdown("""
Compare two Java code snippets to detect potential clones using CodeBERT embeddings.
The similarity score ranges from 0 (completely different) to 1 (identical).
""")
# Example code
EXAMPLE_1 = """public class Hello {
public static void main(String[] args) {
System.out.println("Hello, World!");
}
}"""
EXAMPLE_2 = """public class Greet {
public static void main(String[] args) {
System.out.println("Hello, World!");
}
}"""
# Layout
col1, col2 = st.columns(2)
with col1:
code1 = st.text_area(
"First Java Code",
height=300,
value=EXAMPLE_1,
help="Enter the first Java code snippet"
)
with col2:
code2 = st.text_area(
"Second Java Code",
height=300,
value=EXAMPLE_2,
help="Enter the second Java code snippet"
)
# Threshold slider
threshold = st.slider(
"Clone Detection Threshold",
min_value=0.5,
max_value=1.0,
value=0.85,
step=0.01,
help="Adjust the similarity threshold for clone detection"
)
# Normalization function
def normalize_code(code):
try:
code = re.sub(r'//.*', '', code) # Remove single-line comments
code = re.sub(r'/\*.*?\*/', '', code, flags=re.DOTALL) # Multi-line comments
code = re.sub(r'\s+', ' ', code).strip() # Normalize whitespace
return code
except Exception:
return code
# Embedding generation
def get_embedding(code):
try:
code = normalize_code(code)
inputs = tokenizer(
code,
return_tensors="pt",
truncation=True,
max_length=MAX_LENGTH,
padding='max_length'
).to(DEVICE)
with torch.no_grad():
outputs = code_model(**inputs)
return outputs.last_hidden_state.mean(dim=1) # Pooled embedding
except Exception as e:
st.error(f"Error processing code: {str(e)}")
return None
# Comparison function
def compare_code(code1, code2):
if not code1 or not code2:
return None
with st.spinner('Analyzing code...'):
emb1 = get_embedding(code1)
emb2 = get_embedding(code2)
if emb1 is None or emb2 is None:
return None
with torch.no_grad():
similarity = F.cosine_similarity(emb1, emb2).item()
return similarity
# Compare button
if st.button("Compare Code", type="primary"):
if tokenizer is None or code_model is None:
st.error("Models failed to load. Please check the logs.")
else:
similarity = compare_code(code1, code2)
if similarity is not None:
# Display results
st.subheader("Results")
# Progress bar for visualization
st.progress(similarity)
# Metrics columns
col1, col2, col3 = st.columns(3)
with col1:
st.metric("Similarity Score", f"{similarity:.3f}")
with col2:
st.metric("Threshold", f"{threshold:.3f}")
with col3:
is_clone = similarity >= threshold
st.metric(
"Clone Detection",
"βœ… Clone" if is_clone else "❌ Not a Clone",
delta=f"{similarity-threshold:+.3f}"
)
# Interpretation
if similarity > 0.95:
st.success("The code snippets are nearly identical (potential Type-1 clone)")
elif similarity > 0.85:
st.success("The code snippets are very similar (potential Type-2 clone)")
elif similarity > 0.7:
st.warning("The code snippets show some similarity (potential Type-3 clone)")
else:
st.info("The code snippets are significantly different")
# Show normalized code for debugging
with st.expander("Show normalized code"):
tab1, tab2 = st.tabs(["First Code", "Second Code"])
with tab1:
st.code(normalize_code(code1))
with tab2:
st.code(normalize_code(code2))
# Footer
st.markdown("---")
st.markdown("""
**How it works**:
1. Code is normalized (comments removed, whitespace standardized)
2. CodeBERT generates embeddings for each snippet
3. Cosine similarity is calculated between embeddings
4. Results are compared against your threshold
""")