code-to-doc-streamlit / codet5_summarizer.py
vijayvizag's picture
Upload 2 files
f360edc verified
# =============================
# πŸ“„ codet5_summarizer.py (Updated)
# =============================
import torch
import re
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM
import os
MODEL_OPTIONS = {
"CodeT5 Base (multi-sum)": "Salesforce/codet5-base-multi-sum",
"CodeT5 Base": "Salesforce/codet5-base",
"CodeT5 Small (Python-specific)": "stmnk/codet5-small-code-summarization-python",
"Gemini (describeai)": "describeai/gemini",
"Mistral 7B Instruct (v0.2)": "mistralai/Mistral-7B-Instruct-v0.2",
}
class CodeT5Summarizer:
def __init__(self, model_name=None):
model_name = model_name or MODEL_OPTIONS["CodeT5 Base (multi-sum)"]
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
hf_token = os.getenv('HF_TOKEN')
if hf_token is None:
raise ValueError("Hugging Face token must be set in the environment variable 'HF_TOKEN'.")
self.tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token)
# Use causal model for decoder-only (e.g., Mistral), otherwise Seq2Seq
try:
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name, token=hf_token).to(self.device)
except:
self.model = AutoModelForCausalLM.from_pretrained(model_name, token=hf_token).to(self.device)
self.is_encoder_decoder = self.model.config.is_encoder_decoder if hasattr(self.model.config, "is_encoder_decoder") else False
def preprocess_code(self, code):
code = re.sub(r'\n\s*\n', '\n', code)
lines = code.split('\n')
clean = []
docstring = False
for line in lines:
if '"""' in line or "'''" in line:
docstring = not docstring
if docstring or not line.strip().startswith('#'):
clean.append(line)
return re.sub(r' +', ' ', '\n'.join(clean))
def extract_functions(self, code):
function_pattern = r'def\s+([a-zA-Z_][a-zA-Z0-9_]*)\s*\(.*?\).*?:'
function_matches = re.finditer(function_pattern, code, re.DOTALL)
functions = []
for match in function_matches:
start_pos = match.start()
function_name = match.group(1)
lines = code[start_pos:].split('\n')
body_start = 1
while body_start < len(lines) and not lines[body_start].strip():
body_start += 1
if body_start < len(lines):
body_indent = len(lines[body_start]) - len(lines[body_start].lstrip())
function_body = [lines[0]]
i = 1
while i < len(lines):
line = lines[i]
if line.strip() and (len(line) - len(line.lstrip())) < body_indent and not line.strip().startswith('#'):
break
function_body.append(line)
i += 1
function_code = '\n'.join(function_body)
functions.append((function_name, function_code))
# Class method detection
class_pattern = r'class\s+([a-zA-Z_][a-zA-Z0-9_]*)'
class_matches = re.finditer(class_pattern, code, re.DOTALL)
for match in class_matches:
class_name = match.group(1)
start_pos = match.start()
class_code = code[start_pos:]
method_matches = re.finditer(function_pattern, class_code, re.DOTALL)
for method_match in method_matches:
if method_match.start() > 200: # Only near the top of the class
break
method_name = method_match.group(1)
method_start = method_match.start()
method_lines = class_code[method_start:].split('\n')
body_start = 1
while body_start < len(method_lines) and not method_lines[body_start].strip():
body_start += 1
if body_start < len(method_lines):
body_indent = len(method_lines[body_start]) - len(method_lines[body_start].lstrip())
method_body = [method_lines[0]]
i = 1
while i < len(method_lines):
line = method_lines[i]
if line.strip() and (len(line) - len(line.lstrip())) < body_indent and not line.strip().startswith('#'):
break
method_body.append(line)
i += 1
method_code = '\n'.join(method_body)
functions.append((f"{class_name}.{method_name}", method_code))
return functions
def extract_classes(self, code):
class_pattern = r'class\s+([a-zA-Z_][a-zA-Z0-9_]*)'
class_matches = re.finditer(class_pattern, code, re.DOTALL)
classes = []
for match in class_matches:
class_name = match.group(1)
start_pos = match.start()
class_lines = code[start_pos:].split('\n')
body_start = 1
while body_start < len(class_lines) and not class_lines[body_start].strip():
body_start += 1
if body_start < len(class_lines):
body_indent = len(class_lines[body_start]) - len(class_lines[body_start].lstrip())
class_body = [class_lines[0]]
i = 1
while i < len(class_lines):
line = class_lines[i]
if line.strip() and (len(line) - len(line.lstrip())) < body_indent:
break
class_body.append(line)
i += 1
class_code = '\n'.join(class_body)
classes.append((class_name, class_code))
return classes
def summarize(self, code, max_length=512):
inputs = self.tokenizer(code, return_tensors="pt", truncation=True, max_length=512).to(self.device)
with torch.no_grad():
if self.is_encoder_decoder:
output = self.model.generate(
inputs["input_ids"],
attention_mask=inputs["attention_mask"], # Optional but good to include
max_new_tokens=max_length,
num_beams=4,
early_stopping=True
)
return self.tokenizer.decode(output[0], skip_special_tokens=True)
else:
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
output = self.model.generate(
input_ids=input_ids,
attention_mask=attention_mask, # βœ… Add this line
max_new_tokens=max_length,
do_sample=False,
num_beams=4,
early_stopping=True,
pad_token_id=self.tokenizer.eos_token_id
)
return self.tokenizer.decode(output[0], skip_special_tokens=True)
def summarize_code(self, code, summarize_functions=True, summarize_classes=True):
preprocessed_code = self.preprocess_code(code)
results = {
"file_summary": None,
"function_summaries": {},
"class_summaries": {}
}
try:
results["file_summary"] = self.summarize(preprocessed_code)
except Exception as e:
results["file_summary"] = f"Error generating file summary: {str(e)}"
if summarize_functions:
for function_name, function_code in self.extract_functions(preprocessed_code):
try:
summary = self.summarize(function_code)
results["function_summaries"][function_name] = summary
except Exception as e:
results["function_summaries"][function_name] = f"Error: {str(e)}"
if summarize_classes:
for class_name, class_code in self.extract_classes(preprocessed_code):
try:
summary = self.summarize(class_code)
results["class_summaries"][class_name] = summary
except Exception as e:
results["class_summaries"][class_name] = f"Error: {str(e)}"
return results