vijayvizag commited on
Commit
f360edc
·
verified ·
1 Parent(s): a02d9ee

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +78 -354
  2. codet5_summarizer.py +183 -0
app.py CHANGED
@@ -1,360 +1,84 @@
1
  import streamlit as st
2
- import torch
3
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
- import re
5
- import time
6
 
7
- # Model constants
8
- CODET5_MODEL = "Salesforce/codet5-base-multi-sum"
9
 
10
- class CodeT5Summarizer:
11
- def __init__(self, device=None):
12
- """Initialize CodeT5 summarization model."""
13
- self.device = device if device else ('cuda' if torch.cuda.is_available() else 'cpu')
14
-
15
- # Initialize model and tokenizer
16
- with st.spinner("Loading CodeT5 model... this may take a minute..."):
17
- self.tokenizer = AutoTokenizer.from_pretrained(CODET5_MODEL)
18
- self.model = AutoModelForSeq2SeqLM.from_pretrained(CODET5_MODEL).to(self.device)
19
-
20
- def preprocess_code(self, code):
21
- """Clean and preprocess the Python code."""
22
- # Remove empty lines
23
- code = re.sub(r'\n\s*\n', '\n', code)
24
-
25
- # Remove excessive comments (keeping docstrings)
26
- code_lines = []
27
- in_docstring = False
28
- docstring_delimiter = None
29
-
30
- for line in code.split('\n'):
31
- # Check for docstring delimiters
32
- if '"""' in line or "'''" in line:
33
- delimiter = '"""' if '"""' in line else "'''"
34
- if not in_docstring:
35
- in_docstring = True
36
- docstring_delimiter = delimiter
37
- elif docstring_delimiter == delimiter:
38
- in_docstring = False
39
- docstring_delimiter = None
40
-
41
- # Keep docstrings and non-comment lines
42
- if in_docstring or not line.strip().startswith('#'):
43
- code_lines.append(line)
44
-
45
- processed_code = '\n'.join(code_lines)
46
-
47
- # Normalize whitespace
48
- processed_code = re.sub(r' +', ' ', processed_code)
49
-
50
- return processed_code
51
-
52
- def extract_functions(self, code):
53
- """Extract individual functions for summarization"""
54
- # Simple regex to find function definitions
55
- function_pattern = r'def\s+([a-zA-Z_][a-zA-Z0-9_]*)\s*\(.*?\).*?:'
56
- function_matches = re.finditer(function_pattern, code, re.DOTALL)
57
-
58
- functions = []
59
- for match in function_matches:
60
- start_pos = match.start()
61
- # Find the function body
62
- function_name = match.group(1)
63
- lines = code[start_pos:].split('\n')
64
-
65
- # Skip the function definition line
66
- body_start = 1
67
- while body_start < len(lines) and not lines[body_start].strip():
68
- body_start += 1
69
-
70
- if body_start < len(lines):
71
- # Get the indentation of the function body
72
- body_indent = len(lines[body_start]) - len(lines[body_start].lstrip())
73
-
74
- # Gather all lines with at least this indentation
75
- function_body = [lines[0]] # The function definition
76
- i = 1
77
- while i < len(lines):
78
- line = lines[i]
79
- if line.strip() and (len(line) - len(line.lstrip())) < body_indent and not line.strip().startswith('#'):
80
- break
81
- function_body.append(line)
82
- i += 1
83
-
84
- function_code = '\n'.join(function_body)
85
- functions.append((function_name, function_code))
86
-
87
- # Simple regex to find class methods
88
- class_pattern = r'class\s+([a-zA-Z_][a-zA-Z0-9_]*)'
89
- class_matches = re.finditer(class_pattern, code, re.DOTALL)
90
-
91
- for match in class_matches:
92
- class_name = match.group(1)
93
- start_pos = match.start()
94
-
95
- # Find class methods using the function pattern
96
- class_code = code[start_pos:]
97
- method_matches = re.finditer(function_pattern, class_code, re.DOTALL)
98
-
99
- for method_match in method_matches:
100
- method_name = method_match.group(1)
101
- # Skip if this is not a method (i.e., it's a function outside the class)
102
- if method_match.start() > 200: # Simple heuristic to check if method is within class scope
103
- break
104
-
105
- # Get the full method code
106
- method_start = method_match.start()
107
- method_lines = class_code[method_start:].split('\n')
108
-
109
- # Skip the method definition line
110
- body_start = 1
111
- while body_start < len(method_lines) and not method_lines[body_start].strip():
112
- body_start += 1
113
-
114
- if body_start < len(method_lines):
115
- # Get the indentation of the method body
116
- body_indent = len(method_lines[body_start]) - len(method_lines[body_start].lstrip())
117
-
118
- # Gather all lines with at least this indentation
119
- method_body = [method_lines[0]] # The method definition
120
- i = 1
121
- while i < len(method_lines):
122
- line = method_lines[i]
123
- if line.strip() and (len(line) - len(line.lstrip())) < body_indent and not line.strip().startswith('#'):
124
- break
125
- method_body.append(line)
126
- i += 1
127
-
128
- method_code = '\n'.join(method_body)
129
- functions.append((f"{class_name}.{method_name}", method_code))
130
-
131
- return functions
132
-
133
- def extract_classes(self, code):
134
- """Extract class definitions for summarization"""
135
- class_pattern = r'class\s+([a-zA-Z_][a-zA-Z0-9_]*)'
136
- class_matches = re.finditer(class_pattern, code, re.DOTALL)
137
-
138
- classes = []
139
- for match in class_matches:
140
- class_name = match.group(1)
141
- start_pos = match.start()
142
-
143
- # Extract class body
144
- class_lines = code[start_pos:].split('\n')
145
-
146
- # Skip the class definition line
147
- body_start = 1
148
- while body_start < len(class_lines) and not class_lines[body_start].strip():
149
- body_start += 1
150
-
151
- if body_start < len(class_lines):
152
- # Get the indentation of the class body
153
- body_indent = len(class_lines[body_start]) - len(class_lines[body_start].lstrip())
154
-
155
- # Gather all lines with at least this indentation
156
- class_body = [class_lines[0]] # The class definition
157
- i = 1
158
- while i < len(class_lines):
159
- line = class_lines[i]
160
- if line.strip() and (len(line) - len(line.lstrip())) < body_indent:
161
- break
162
- class_body.append(line)
163
- i += 1
164
-
165
- class_code = '\n'.join(class_body)
166
- classes.append((class_name, class_code))
167
-
168
- return classes
169
-
170
- def summarize(self, code, max_length=50):
171
- """Generate summary using CodeT5."""
172
- # Truncate input if needed
173
- max_input_length = 512 # CodeT5 typically accepts up to 512 tokens
174
- tokenized_code = self.tokenizer(code, truncation=True, max_length=max_input_length, return_tensors="pt").to(self.device)
175
-
176
- with torch.no_grad():
177
- generated_ids = self.model.generate(
178
- tokenized_code["input_ids"],
179
- max_length=max_length,
180
- num_beams=4,
181
- early_stopping=True
182
- )
183
-
184
- summary = self.tokenizer.decode(generated_ids[0], skip_special_tokens=True)
185
- return summary
186
-
187
- def summarize_code(self, code, summarize_functions=True, summarize_classes=True):
188
- """
189
- Generate full file summary and optionally function/class level summaries.
190
- Returns a dictionary with summaries.
191
- """
192
- preprocessed_code = self.preprocess_code(code)
193
-
194
- results = {
195
- "file_summary": None,
196
- "function_summaries": {},
197
- "class_summaries": {}
198
- }
199
-
200
- # Generate file-level summary
201
- try:
202
- file_summary = self.summarize(preprocessed_code)
203
- results["file_summary"] = file_summary
204
- except Exception as e:
205
- results["file_summary"] = f"Error generating file summary: {str(e)}"
206
-
207
- # Generate function-level summaries if requested
208
- if summarize_functions:
209
- functions = self.extract_functions(preprocessed_code)
210
-
211
- for function_name, function_code in functions:
212
- try:
213
- summary = self.summarize(function_code)
214
- results["function_summaries"][function_name] = summary
215
- except Exception as e:
216
- results["function_summaries"][function_name] = f"Error: {str(e)}"
217
-
218
- # Generate class-level summaries if requested
219
- if summarize_classes:
220
- classes = self.extract_classes(preprocessed_code)
221
-
222
- for class_name, class_code in classes:
223
- try:
224
- summary = self.summarize(class_code)
225
- results["class_summaries"][class_name] = summary
226
- except Exception as e:
227
- results["class_summaries"][class_name] = f"Error: {str(e)}"
228
-
229
- return results
230
 
231
- def main():
232
- st.set_page_config(
233
- page_title="Python Code Summarizer",
234
- page_icon="📝",
235
- layout="wide"
236
- )
237
-
238
- st.title("📝 Python Code Summarizer using CodeT5")
239
- st.markdown("""
240
- Upload a Python file or paste code directly to generate summaries.
241
- This app uses CodeT5, a pretrained model for code understanding and generation.
242
- """)
243
-
244
- # Initialize session state
245
- if 'summarizer' not in st.session_state:
246
- st.session_state.summarizer = None
247
-
248
- # Load model if not already loaded
249
- if st.session_state.summarizer is None:
250
- st.session_state.summarizer = CodeT5Summarizer()
251
-
252
- # Create tabs for different input methods
253
- tab1, tab2 = st.tabs(["Upload Python File", "Paste Code"])
254
-
255
- with tab1:
256
- uploaded_file = st.file_uploader("Choose a Python file", type=['py'])
257
- if uploaded_file is not None:
258
- code = uploaded_file.getvalue().decode('utf-8')
259
- with st.expander("View Uploaded Code", expanded=False):
260
- st.code(code, language='python')
261
-
262
- # Add summarization options
263
- st.subheader("Summarization Options")
264
- col1, col2 = st.columns(2)
265
- with col1:
266
- summarize_functions = st.checkbox("Generate function summaries", value=True)
267
- with col2:
268
- summarize_classes = st.checkbox("Generate class summaries", value=True)
269
-
270
- if st.button("Summarize Code", key="summarize_file"):
271
- with st.spinner("Generating summaries..."):
272
- start_time = time.time()
273
- summaries = st.session_state.summarizer.summarize_code(
274
- code,
275
- summarize_functions=summarize_functions,
276
- summarize_classes=summarize_classes
277
- )
278
- end_time = time.time()
279
-
280
- # Display summaries
281
- st.success(f"Summarization completed in {end_time - start_time:.2f} seconds!")
282
-
283
- # File summary
284
- st.subheader("File Summary")
285
- st.write(summaries["file_summary"])
286
-
287
- # Function summaries
288
- if summarize_functions and summaries["function_summaries"]:
289
- st.subheader("Function Summaries")
290
- for func_name, summary in summaries["function_summaries"].items():
291
- with st.expander(f"Function: {func_name}"):
292
- st.write(summary)
293
-
294
- # Class summaries
295
- if summarize_classes and summaries["class_summaries"]:
296
- st.subheader("Class Summaries")
297
- for class_name, summary in summaries["class_summaries"].items():
298
- with st.expander(f"Class: {class_name}"):
299
- st.write(summary)
300
-
301
- with tab2:
302
- code = st.text_area("Paste Python code here", height=300)
303
- if code:
304
- # Add summarization options
305
- st.subheader("Summarization Options")
306
- col1, col2 = st.columns(2)
307
- with col1:
308
- summarize_functions = st.checkbox("Generate function summaries", value=True, key="func_paste")
309
- with col2:
310
- summarize_classes = st.checkbox("Generate class summaries", value=True, key="class_paste")
311
-
312
- if st.button("Summarize Code", key="summarize_paste"):
313
- with st.spinner("Generating summaries..."):
314
- start_time = time.time()
315
- summaries = st.session_state.summarizer.summarize_code(
316
- code,
317
- summarize_functions=summarize_functions,
318
- summarize_classes=summarize_classes
319
- )
320
- end_time = time.time()
321
-
322
- # Display summaries
323
- st.success(f"Summarization completed in {end_time - start_time:.2f} seconds!")
324
-
325
- # File summary
326
- st.subheader("File Summary")
327
- st.write(summaries["file_summary"])
328
-
329
- # Function summaries
330
- if summarize_functions and summaries["function_summaries"]:
331
- st.subheader("Function Summaries")
332
- for func_name, summary in summaries["function_summaries"].items():
333
- with st.expander(f"Function: {func_name}"):
334
- st.write(summary)
335
-
336
- # Class summaries
337
- if summarize_classes and summaries["class_summaries"]:
338
- st.subheader("Class Summaries")
339
- for class_name, summary in summaries["class_summaries"].items():
340
- with st.expander(f"Class: {class_name}"):
341
- st.write(summary)
342
-
343
  st.markdown("---")
344
- st.markdown("### About")
345
- st.markdown("""
346
- This app uses the CodeT5 model to generate summaries of Python code. The model is trained on a large corpus of code and documentation.
347
-
348
- **Features:**
349
- - File-level summaries
350
- - Function-level summaries
351
- - Class-level summaries
352
-
353
- **Limitations:**
354
- - Summaries may not always be accurate
355
- - Long files may be truncated
356
- - Complex code structures might not be properly understood
357
- """)
 
 
 
 
 
 
 
 
 
358
 
359
- if __name__ == "__main__":
360
- main()
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ from codet5_summarizer import CodeT5Summarizer, MODEL_OPTIONS
3
+ import textwrap
4
+ import os
5
+ import base64
6
 
7
+ st.set_page_config(page_title="Code Summarizer & Report Generator", layout="wide")
 
8
 
9
+ st.title("📄 Code Summarizer & Report Generator")
10
+ st.markdown("""
11
+ Upload a Python code file to get a high-level summary and a report structure with editable sections.
12
+ You can choose from various models including Mistral, CodeT5, and Gemini.
13
+ """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
+ # Model selection
16
+ model_label = st.selectbox("Select Model", list(MODEL_OPTIONS.keys()), index=0)
17
+ summarizer = CodeT5Summarizer(model_name=MODEL_OPTIONS[model_label])
18
+
19
+ # Upload code file
20
+ uploaded_file = st.file_uploader("Upload a .py file", type="py")
21
+ if uploaded_file:
22
+ code = uploaded_file.read().decode("utf-8")
23
+ st.code(code, language="python")
24
+
25
+ st.markdown("---")
26
+ st.subheader("🔍 Generating Summary...")
27
+
28
+ if "Mistral" in model_label or "Gemini" in model_label:
29
+ summary = summarizer.summarize(code)
30
+ function_summaries = None
31
+ class_summaries = None
32
+ else:
33
+ results = summarizer.summarize_code(code)
34
+ summary = results["file_summary"]
35
+ function_summaries = results["function_summaries"]
36
+ class_summaries = results["class_summaries"]
37
+
38
+ st.text_area("Summary", summary, height=200)
39
+
40
+ if function_summaries:
41
+ st.subheader("🧩 Function Summaries")
42
+ for func, summ in function_summaries.items():
43
+ st.text_area(f"Function: {func}", summ, height=100)
44
+
45
+ if class_summaries:
46
+ st.subheader("🏗️ Class Summaries")
47
+ for cls, summ in class_summaries.items():
48
+ st.text_area(f"Class: {cls}", summ, height=100)
49
+
50
+ # Report generation section
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  st.markdown("---")
52
+ st.subheader("📘 Generate Report")
53
+
54
+ default_sections = [
55
+ "Abstract", "Introduction", "Literature Review", "Methodology",
56
+ "Modules", "Software & Hardware Requirements", "Architecture & UML Diagrams",
57
+ "References", "Conclusion"
58
+ ]
59
+
60
+ sections = st.multiselect("Select Sections", default_sections, default=default_sections)
61
+
62
+ report = ""
63
+ for section in sections:
64
+ content = st.text_area(f"✏️ {section} Content", value=f"{section} description goes here...", height=150)
65
+ report += f"\n## {section}\n\n{textwrap.dedent(content)}\n"
66
+
67
+ # Export format
68
+ st.markdown("---")
69
+ st.subheader("📤 Export Report")
70
+ export_format = st.radio("Select Export Format", ["Markdown", "Text", "HTML"])
71
+
72
+ def generate_download_link(content, filename):
73
+ b64 = base64.b64encode(content.encode()).decode()
74
+ return f'<a href="data:file/txt;base64,{b64}" download="{filename}">📥 Download {filename}</a>'
75
 
76
+ if st.button("Generate Export File"):
77
+ filename = uploaded_file.name.replace(".py", "")
78
+ if export_format == "Markdown":
79
+ st.markdown(generate_download_link(report, f"{filename}_report.md"), unsafe_allow_html=True)
80
+ elif export_format == "Text":
81
+ st.markdown(generate_download_link(report, f"{filename}_report.txt"), unsafe_allow_html=True)
82
+ else:
83
+ html_content = f"<html><body>{report.replace('\n', '<br>')}</body></html>"
84
+ st.markdown(generate_download_link(html_content, f"{filename}_report.html"), unsafe_allow_html=True)
codet5_summarizer.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # =============================
2
+ # 📄 codet5_summarizer.py (Updated)
3
+ # =============================
4
+ import torch
5
+ import re
6
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM
7
+ import os
8
+ MODEL_OPTIONS = {
9
+ "CodeT5 Base (multi-sum)": "Salesforce/codet5-base-multi-sum",
10
+ "CodeT5 Base": "Salesforce/codet5-base",
11
+ "CodeT5 Small (Python-specific)": "stmnk/codet5-small-code-summarization-python",
12
+ "Gemini (describeai)": "describeai/gemini",
13
+ "Mistral 7B Instruct (v0.2)": "mistralai/Mistral-7B-Instruct-v0.2",
14
+ }
15
+
16
+ class CodeT5Summarizer:
17
+ def __init__(self, model_name=None):
18
+ model_name = model_name or MODEL_OPTIONS["CodeT5 Base (multi-sum)"]
19
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
20
+ hf_token = os.getenv('HF_TOKEN')
21
+ if hf_token is None:
22
+ raise ValueError("Hugging Face token must be set in the environment variable 'HF_TOKEN'.")
23
+
24
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token)
25
+
26
+ # Use causal model for decoder-only (e.g., Mistral), otherwise Seq2Seq
27
+ try:
28
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name, token=hf_token).to(self.device)
29
+ except:
30
+ self.model = AutoModelForCausalLM.from_pretrained(model_name, token=hf_token).to(self.device)
31
+
32
+ self.is_encoder_decoder = self.model.config.is_encoder_decoder if hasattr(self.model.config, "is_encoder_decoder") else False
33
+
34
+ def preprocess_code(self, code):
35
+ code = re.sub(r'\n\s*\n', '\n', code)
36
+ lines = code.split('\n')
37
+ clean = []
38
+ docstring = False
39
+ for line in lines:
40
+ if '"""' in line or "'''" in line:
41
+ docstring = not docstring
42
+ if docstring or not line.strip().startswith('#'):
43
+ clean.append(line)
44
+ return re.sub(r' +', ' ', '\n'.join(clean))
45
+
46
+ def extract_functions(self, code):
47
+ function_pattern = r'def\s+([a-zA-Z_][a-zA-Z0-9_]*)\s*\(.*?\).*?:'
48
+ function_matches = re.finditer(function_pattern, code, re.DOTALL)
49
+ functions = []
50
+ for match in function_matches:
51
+ start_pos = match.start()
52
+ function_name = match.group(1)
53
+ lines = code[start_pos:].split('\n')
54
+ body_start = 1
55
+ while body_start < len(lines) and not lines[body_start].strip():
56
+ body_start += 1
57
+ if body_start < len(lines):
58
+ body_indent = len(lines[body_start]) - len(lines[body_start].lstrip())
59
+ function_body = [lines[0]]
60
+ i = 1
61
+ while i < len(lines):
62
+ line = lines[i]
63
+ if line.strip() and (len(line) - len(line.lstrip())) < body_indent and not line.strip().startswith('#'):
64
+ break
65
+ function_body.append(line)
66
+ i += 1
67
+ function_code = '\n'.join(function_body)
68
+ functions.append((function_name, function_code))
69
+
70
+ # Class method detection
71
+ class_pattern = r'class\s+([a-zA-Z_][a-zA-Z0-9_]*)'
72
+ class_matches = re.finditer(class_pattern, code, re.DOTALL)
73
+ for match in class_matches:
74
+ class_name = match.group(1)
75
+ start_pos = match.start()
76
+ class_code = code[start_pos:]
77
+ method_matches = re.finditer(function_pattern, class_code, re.DOTALL)
78
+ for method_match in method_matches:
79
+ if method_match.start() > 200: # Only near the top of the class
80
+ break
81
+ method_name = method_match.group(1)
82
+ method_start = method_match.start()
83
+ method_lines = class_code[method_start:].split('\n')
84
+ body_start = 1
85
+ while body_start < len(method_lines) and not method_lines[body_start].strip():
86
+ body_start += 1
87
+ if body_start < len(method_lines):
88
+ body_indent = len(method_lines[body_start]) - len(method_lines[body_start].lstrip())
89
+ method_body = [method_lines[0]]
90
+ i = 1
91
+ while i < len(method_lines):
92
+ line = method_lines[i]
93
+ if line.strip() and (len(line) - len(line.lstrip())) < body_indent and not line.strip().startswith('#'):
94
+ break
95
+ method_body.append(line)
96
+ i += 1
97
+ method_code = '\n'.join(method_body)
98
+ functions.append((f"{class_name}.{method_name}", method_code))
99
+ return functions
100
+
101
+ def extract_classes(self, code):
102
+ class_pattern = r'class\s+([a-zA-Z_][a-zA-Z0-9_]*)'
103
+ class_matches = re.finditer(class_pattern, code, re.DOTALL)
104
+ classes = []
105
+ for match in class_matches:
106
+ class_name = match.group(1)
107
+ start_pos = match.start()
108
+ class_lines = code[start_pos:].split('\n')
109
+ body_start = 1
110
+ while body_start < len(class_lines) and not class_lines[body_start].strip():
111
+ body_start += 1
112
+ if body_start < len(class_lines):
113
+ body_indent = len(class_lines[body_start]) - len(class_lines[body_start].lstrip())
114
+ class_body = [class_lines[0]]
115
+ i = 1
116
+ while i < len(class_lines):
117
+ line = class_lines[i]
118
+ if line.strip() and (len(line) - len(line.lstrip())) < body_indent:
119
+ break
120
+ class_body.append(line)
121
+ i += 1
122
+ class_code = '\n'.join(class_body)
123
+ classes.append((class_name, class_code))
124
+ return classes
125
+
126
+ def summarize(self, code, max_length=512):
127
+ inputs = self.tokenizer(code, return_tensors="pt", truncation=True, max_length=512).to(self.device)
128
+ with torch.no_grad():
129
+ if self.is_encoder_decoder:
130
+ output = self.model.generate(
131
+ inputs["input_ids"],
132
+ attention_mask=inputs["attention_mask"], # Optional but good to include
133
+
134
+ max_new_tokens=max_length,
135
+ num_beams=4,
136
+ early_stopping=True
137
+ )
138
+ return self.tokenizer.decode(output[0], skip_special_tokens=True)
139
+ else:
140
+ input_ids = inputs["input_ids"]
141
+ attention_mask = inputs["attention_mask"]
142
+
143
+ output = self.model.generate(
144
+ input_ids=input_ids,
145
+ attention_mask=attention_mask, # ✅ Add this line
146
+
147
+ max_new_tokens=max_length,
148
+ do_sample=False,
149
+ num_beams=4,
150
+ early_stopping=True,
151
+ pad_token_id=self.tokenizer.eos_token_id
152
+ )
153
+ return self.tokenizer.decode(output[0], skip_special_tokens=True)
154
+
155
+ def summarize_code(self, code, summarize_functions=True, summarize_classes=True):
156
+ preprocessed_code = self.preprocess_code(code)
157
+ results = {
158
+ "file_summary": None,
159
+ "function_summaries": {},
160
+ "class_summaries": {}
161
+ }
162
+ try:
163
+ results["file_summary"] = self.summarize(preprocessed_code)
164
+ except Exception as e:
165
+ results["file_summary"] = f"Error generating file summary: {str(e)}"
166
+
167
+ if summarize_functions:
168
+ for function_name, function_code in self.extract_functions(preprocessed_code):
169
+ try:
170
+ summary = self.summarize(function_code)
171
+ results["function_summaries"][function_name] = summary
172
+ except Exception as e:
173
+ results["function_summaries"][function_name] = f"Error: {str(e)}"
174
+
175
+ if summarize_classes:
176
+ for class_name, class_code in self.extract_classes(preprocessed_code):
177
+ try:
178
+ summary = self.summarize(class_code)
179
+ results["class_summaries"][class_name] = summary
180
+ except Exception as e:
181
+ results["class_summaries"][class_name] = f"Error: {str(e)}"
182
+
183
+ return results