File size: 3,088 Bytes
20e20dc
 
 
 
 
 
 
 
 
 
 
 
 
 
75f22f0
 
 
20e20dc
75f22f0
 
 
20e20dc
 
 
 
 
 
75f22f0
 
 
20e20dc
75f22f0
20e20dc
75f22f0
 
20e20dc
75f22f0
 
 
 
20e20dc
75f22f0
 
20e20dc
75f22f0
 
 
 
 
 
 
 
 
 
 
20e20dc
75f22f0
20e20dc
75f22f0
 
 
20e20dc
 
75f22f0
 
 
20e20dc
 
 
 
 
 
 
 
 
 
 
75f22f0
20e20dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75f22f0
20e20dc
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import javalang
import torch
import torch.nn as nn
import torch.nn.functional as F
import re
import gradio as gr
from transformers import AutoTokenizer, AutoModel
from pathlib import Path

# Configuration
MAX_FILE_SIZE = 5000
EMBEDDING_DIM = 128
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize models once at startup
tokenizer = AutoTokenizer.from_pretrained("microsoft/codebert-base")
code_model = AutoModel.from_pretrained("microsoft/codebert-base").to(DEVICE)

# Simplified model architecture
class CloneDetector(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 2))
        
    def forward(self, emb1, emb2):
        combined = torch.cat([emb1, emb2], dim=-1)
        return self.classifier(combined)

model = CloneDetector(768).to(DEVICE)  # 768 is CodeBERT's hidden size

def get_code_embedding(code):
    """Get embedding for a single code snippet"""
    try:
        # Normalize code
        code = re.sub(r'//.*', '', code)
        code = re.sub(r'/\*.*?\*/', '', code, flags=re.DOTALL)
        code = ' '.join(code.split())
        
        # Tokenize and get embedding
        inputs = tokenizer(code, return_tensors="pt", truncation=True, max_length=512).to(DEVICE)
        with torch.no_grad():
            outputs = code_model(**inputs)
        return outputs.last_hidden_state.mean(dim=1)  # Pooled representation
    except Exception:
        return torch.zeros(1, 768).to(DEVICE)

def predict_clone(code1, code2):
    """Compare two code snippets"""
    try:
        # Get embeddings
        emb1 = get_code_embedding(code1)
        emb2 = get_code_embedding(code2)
        
        # Calculate similarity
        with torch.no_grad():
            sim_score = F.cosine_similarity(emb1, emb2).item()
            logits = model(emb1, emb2)
            prob = F.softmax(logits, dim=-1)[0, 1].item()
        
        return {
            "Similarity Score": f"{sim_score:.3f}",
            "Clone Probability": f"{prob:.3f}",
            "Prediction": "Clone" if prob > 0.5 else "Not Clone"
        }
    except Exception as e:
        return {"Error": str(e)}

# Gradio Interface
demo = gr.Interface(
    fn=predict_clone,
    inputs=[
        gr.Textbox(label="First Java Code", lines=10),
        gr.Textbox(label="Second Java Code", lines=10)
    ],
    outputs=gr.JSON(label="Results"),
    examples=[
        ["""public class Hello {
    public static void main(String[] args) {
        System.out.println("Hello, World!");
    }
}""", 
"""public class Greet {
    public static void main(String[] args) {
        System.out.println("Hello, World!");
    }
}"""],
        ["""public int add(int a, int b) {
    return a + b;
}""", 
"""public int sum(int x, int y) {
    return x + y;
}"""]
    ],
    title="Java Code Clone Detector",
    description="Compare two Java code snippets to detect potential clones"
)

if __name__ == "__main__":
    demo.launch()