File size: 8,505 Bytes
f360edc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
# =============================
# πŸ“„ codet5_summarizer.py (Updated)
# =============================
import torch
import re
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM
import os 
MODEL_OPTIONS = {
    "CodeT5 Base (multi-sum)": "Salesforce/codet5-base-multi-sum",
    "CodeT5 Base": "Salesforce/codet5-base",
    "CodeT5 Small (Python-specific)": "stmnk/codet5-small-code-summarization-python",
    "Gemini (describeai)": "describeai/gemini",
    "Mistral 7B Instruct (v0.2)": "mistralai/Mistral-7B-Instruct-v0.2",
}

class CodeT5Summarizer:
    def __init__(self, model_name=None):
        model_name = model_name or MODEL_OPTIONS["CodeT5 Base (multi-sum)"]
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        hf_token = os.getenv('HF_TOKEN')
        if hf_token is None:
            raise ValueError("Hugging Face token must be set in the environment variable 'HF_TOKEN'.")

        self.tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token)

        # Use causal model for decoder-only (e.g., Mistral), otherwise Seq2Seq
        try:
            self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name, token=hf_token).to(self.device)
        except:
            self.model = AutoModelForCausalLM.from_pretrained(model_name, token=hf_token).to(self.device)

        self.is_encoder_decoder = self.model.config.is_encoder_decoder if hasattr(self.model.config, "is_encoder_decoder") else False

    def preprocess_code(self, code):
        code = re.sub(r'\n\s*\n', '\n', code)
        lines = code.split('\n')
        clean = []
        docstring = False
        for line in lines:
            if '"""' in line or "'''" in line:
                docstring = not docstring
            if docstring or not line.strip().startswith('#'):
                clean.append(line)
        return re.sub(r' +', ' ', '\n'.join(clean))

    def extract_functions(self, code):
        function_pattern = r'def\s+([a-zA-Z_][a-zA-Z0-9_]*)\s*\(.*?\).*?:'
        function_matches = re.finditer(function_pattern, code, re.DOTALL)
        functions = []
        for match in function_matches:
            start_pos = match.start()
            function_name = match.group(1)
            lines = code[start_pos:].split('\n')
            body_start = 1
            while body_start < len(lines) and not lines[body_start].strip():
                body_start += 1
            if body_start < len(lines):
                body_indent = len(lines[body_start]) - len(lines[body_start].lstrip())
                function_body = [lines[0]]
                i = 1
                while i < len(lines):
                    line = lines[i]
                    if line.strip() and (len(line) - len(line.lstrip())) < body_indent and not line.strip().startswith('#'):
                        break
                    function_body.append(line)
                    i += 1
                function_code = '\n'.join(function_body)
                functions.append((function_name, function_code))

        # Class method detection
        class_pattern = r'class\s+([a-zA-Z_][a-zA-Z0-9_]*)'
        class_matches = re.finditer(class_pattern, code, re.DOTALL)
        for match in class_matches:
            class_name = match.group(1)
            start_pos = match.start()
            class_code = code[start_pos:]
            method_matches = re.finditer(function_pattern, class_code, re.DOTALL)
            for method_match in method_matches:
                if method_match.start() > 200:  # Only near the top of the class
                    break
                method_name = method_match.group(1)
                method_start = method_match.start()
                method_lines = class_code[method_start:].split('\n')
                body_start = 1
                while body_start < len(method_lines) and not method_lines[body_start].strip():
                    body_start += 1
                if body_start < len(method_lines):
                    body_indent = len(method_lines[body_start]) - len(method_lines[body_start].lstrip())
                    method_body = [method_lines[0]]
                    i = 1
                    while i < len(method_lines):
                        line = method_lines[i]
                        if line.strip() and (len(line) - len(line.lstrip())) < body_indent and not line.strip().startswith('#'):
                            break
                        method_body.append(line)
                        i += 1
                    method_code = '\n'.join(method_body)
                    functions.append((f"{class_name}.{method_name}", method_code))
        return functions

    def extract_classes(self, code):
        class_pattern = r'class\s+([a-zA-Z_][a-zA-Z0-9_]*)'
        class_matches = re.finditer(class_pattern, code, re.DOTALL)
        classes = []
        for match in class_matches:
            class_name = match.group(1)
            start_pos = match.start()
            class_lines = code[start_pos:].split('\n')
            body_start = 1
            while body_start < len(class_lines) and not class_lines[body_start].strip():
                body_start += 1
            if body_start < len(class_lines):
                body_indent = len(class_lines[body_start]) - len(class_lines[body_start].lstrip())
                class_body = [class_lines[0]]
                i = 1
                while i < len(class_lines):
                    line = class_lines[i]
                    if line.strip() and (len(line) - len(line.lstrip())) < body_indent:
                        break
                    class_body.append(line)
                    i += 1
                class_code = '\n'.join(class_body)
                classes.append((class_name, class_code))
        return classes

    def summarize(self, code, max_length=512):
        inputs = self.tokenizer(code, return_tensors="pt", truncation=True, max_length=512).to(self.device)
        with torch.no_grad():
            if self.is_encoder_decoder:
                output = self.model.generate(
                    inputs["input_ids"],
                    attention_mask=inputs["attention_mask"],  # Optional but good to include

                    max_new_tokens=max_length,
                    num_beams=4,
                    early_stopping=True
                )
                return self.tokenizer.decode(output[0], skip_special_tokens=True)
            else:
                input_ids = inputs["input_ids"]
                attention_mask = inputs["attention_mask"]
    
                output = self.model.generate(
                    input_ids=input_ids,
                    attention_mask=attention_mask,  # βœ… Add this line
        
                    max_new_tokens=max_length,
                    do_sample=False,
                    num_beams=4,
                    early_stopping=True,
                    pad_token_id=self.tokenizer.eos_token_id
                )
                return self.tokenizer.decode(output[0], skip_special_tokens=True)

    def summarize_code(self, code, summarize_functions=True, summarize_classes=True):
        preprocessed_code = self.preprocess_code(code)
        results = {
            "file_summary": None,
            "function_summaries": {},
            "class_summaries": {}
        }
        try:
            results["file_summary"] = self.summarize(preprocessed_code)
        except Exception as e:
            results["file_summary"] = f"Error generating file summary: {str(e)}"

        if summarize_functions:
            for function_name, function_code in self.extract_functions(preprocessed_code):
                try:
                    summary = self.summarize(function_code)
                    results["function_summaries"][function_name] = summary
                except Exception as e:
                    results["function_summaries"][function_name] = f"Error: {str(e)}"

        if summarize_classes:
            for class_name, class_code in self.extract_classes(preprocessed_code):
                try:
                    summary = self.summarize(class_code)
                    results["class_summaries"][class_name] = summary
                except Exception as e:
                    results["class_summaries"][class_name] = f"Error: {str(e)}"

        return results