CCD / app.py
rahideer's picture
Update app.py
75f22f0 verified
raw
history blame
3.09 kB
import javalang
import torch
import torch.nn as nn
import torch.nn.functional as F
import re
import gradio as gr
from transformers import AutoTokenizer, AutoModel
from pathlib import Path
# Configuration
MAX_FILE_SIZE = 5000
EMBEDDING_DIM = 128
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Initialize models once at startup
tokenizer = AutoTokenizer.from_pretrained("microsoft/codebert-base")
code_model = AutoModel.from_pretrained("microsoft/codebert-base").to(DEVICE)
# Simplified model architecture
class CloneDetector(nn.Module):
def __init__(self, hidden_dim):
super().__init__()
self.classifier = nn.Sequential(
nn.Linear(hidden_dim * 2, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 2))
def forward(self, emb1, emb2):
combined = torch.cat([emb1, emb2], dim=-1)
return self.classifier(combined)
model = CloneDetector(768).to(DEVICE) # 768 is CodeBERT's hidden size
def get_code_embedding(code):
"""Get embedding for a single code snippet"""
try:
# Normalize code
code = re.sub(r'//.*', '', code)
code = re.sub(r'/\*.*?\*/', '', code, flags=re.DOTALL)
code = ' '.join(code.split())
# Tokenize and get embedding
inputs = tokenizer(code, return_tensors="pt", truncation=True, max_length=512).to(DEVICE)
with torch.no_grad():
outputs = code_model(**inputs)
return outputs.last_hidden_state.mean(dim=1) # Pooled representation
except Exception:
return torch.zeros(1, 768).to(DEVICE)
def predict_clone(code1, code2):
"""Compare two code snippets"""
try:
# Get embeddings
emb1 = get_code_embedding(code1)
emb2 = get_code_embedding(code2)
# Calculate similarity
with torch.no_grad():
sim_score = F.cosine_similarity(emb1, emb2).item()
logits = model(emb1, emb2)
prob = F.softmax(logits, dim=-1)[0, 1].item()
return {
"Similarity Score": f"{sim_score:.3f}",
"Clone Probability": f"{prob:.3f}",
"Prediction": "Clone" if prob > 0.5 else "Not Clone"
}
except Exception as e:
return {"Error": str(e)}
# Gradio Interface
demo = gr.Interface(
fn=predict_clone,
inputs=[
gr.Textbox(label="First Java Code", lines=10),
gr.Textbox(label="Second Java Code", lines=10)
],
outputs=gr.JSON(label="Results"),
examples=[
["""public class Hello {
public static void main(String[] args) {
System.out.println("Hello, World!");
}
}""",
"""public class Greet {
public static void main(String[] args) {
System.out.println("Hello, World!");
}
}"""],
["""public int add(int a, int b) {
return a + b;
}""",
"""public int sum(int x, int y) {
return x + y;
}"""]
],
title="Java Code Clone Detector",
description="Compare two Java code snippets to detect potential clones"
)
if __name__ == "__main__":
demo.launch()