|
import javalang |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import re |
|
import gradio as gr |
|
from transformers import AutoTokenizer, AutoModel |
|
from pathlib import Path |
|
|
|
|
|
MAX_FILE_SIZE = 5000 |
|
EMBEDDING_DIM = 128 |
|
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("microsoft/codebert-base") |
|
code_model = AutoModel.from_pretrained("microsoft/codebert-base").to(DEVICE) |
|
|
|
|
|
class CloneDetector(nn.Module): |
|
def __init__(self, hidden_dim): |
|
super().__init__() |
|
self.classifier = nn.Sequential( |
|
nn.Linear(hidden_dim * 2, hidden_dim), |
|
nn.ReLU(), |
|
nn.Linear(hidden_dim, 2)) |
|
|
|
def forward(self, emb1, emb2): |
|
combined = torch.cat([emb1, emb2], dim=-1) |
|
return self.classifier(combined) |
|
|
|
model = CloneDetector(768).to(DEVICE) |
|
|
|
def get_code_embedding(code): |
|
"""Get embedding for a single code snippet""" |
|
try: |
|
|
|
code = re.sub(r'//.*', '', code) |
|
code = re.sub(r'/\*.*?\*/', '', code, flags=re.DOTALL) |
|
code = ' '.join(code.split()) |
|
|
|
|
|
inputs = tokenizer(code, return_tensors="pt", truncation=True, max_length=512).to(DEVICE) |
|
with torch.no_grad(): |
|
outputs = code_model(**inputs) |
|
return outputs.last_hidden_state.mean(dim=1) |
|
except Exception: |
|
return torch.zeros(1, 768).to(DEVICE) |
|
|
|
def predict_clone(code1, code2): |
|
"""Compare two code snippets""" |
|
try: |
|
|
|
emb1 = get_code_embedding(code1) |
|
emb2 = get_code_embedding(code2) |
|
|
|
|
|
with torch.no_grad(): |
|
sim_score = F.cosine_similarity(emb1, emb2).item() |
|
logits = model(emb1, emb2) |
|
prob = F.softmax(logits, dim=-1)[0, 1].item() |
|
|
|
return { |
|
"Similarity Score": f"{sim_score:.3f}", |
|
"Clone Probability": f"{prob:.3f}", |
|
"Prediction": "Clone" if prob > 0.5 else "Not Clone" |
|
} |
|
except Exception as e: |
|
return {"Error": str(e)} |
|
|
|
|
|
demo = gr.Interface( |
|
fn=predict_clone, |
|
inputs=[ |
|
gr.Textbox(label="First Java Code", lines=10), |
|
gr.Textbox(label="Second Java Code", lines=10) |
|
], |
|
outputs=gr.JSON(label="Results"), |
|
examples=[ |
|
["""public class Hello { |
|
public static void main(String[] args) { |
|
System.out.println("Hello, World!"); |
|
} |
|
}""", |
|
"""public class Greet { |
|
public static void main(String[] args) { |
|
System.out.println("Hello, World!"); |
|
} |
|
}"""], |
|
["""public int add(int a, int b) { |
|
return a + b; |
|
}""", |
|
"""public int sum(int x, int y) { |
|
return x + y; |
|
}"""] |
|
], |
|
title="Java Code Clone Detector", |
|
description="Compare two Java code snippets to detect potential clones" |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |