# ============================= # 📄 codet5_summarizer.py (Updated) # ============================= import torch import re from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM import os MODEL_OPTIONS = { "CodeT5 Base (multi-sum)": "Salesforce/codet5-base-multi-sum", "CodeT5 Base": "Salesforce/codet5-base", "CodeT5 Small (Python-specific)": "stmnk/codet5-small-code-summarization-python", "Gemini (describeai)": "describeai/gemini", "Mistral 7B Instruct (v0.2)": "mistralai/Mistral-7B-Instruct-v0.2", } class CodeT5Summarizer: def __init__(self, model_name=None): model_name = model_name or MODEL_OPTIONS["CodeT5 Base (multi-sum)"] self.device = 'cuda' if torch.cuda.is_available() else 'cpu' hf_token = os.getenv('HF_TOKEN') if hf_token is None: raise ValueError("Hugging Face token must be set in the environment variable 'HF_TOKEN'.") self.tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token) # Use causal model for decoder-only (e.g., Mistral), otherwise Seq2Seq try: self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name, token=hf_token).to(self.device) except: self.model = AutoModelForCausalLM.from_pretrained(model_name, token=hf_token).to(self.device) self.is_encoder_decoder = self.model.config.is_encoder_decoder if hasattr(self.model.config, "is_encoder_decoder") else False def preprocess_code(self, code): code = re.sub(r'\n\s*\n', '\n', code) lines = code.split('\n') clean = [] docstring = False for line in lines: if '"""' in line or "'''" in line: docstring = not docstring if docstring or not line.strip().startswith('#'): clean.append(line) return re.sub(r' +', ' ', '\n'.join(clean)) def extract_functions(self, code): function_pattern = r'def\s+([a-zA-Z_][a-zA-Z0-9_]*)\s*\(.*?\).*?:' function_matches = re.finditer(function_pattern, code, re.DOTALL) functions = [] for match in function_matches: start_pos = match.start() function_name = match.group(1) lines = code[start_pos:].split('\n') body_start = 1 while body_start < len(lines) and not lines[body_start].strip(): body_start += 1 if body_start < len(lines): body_indent = len(lines[body_start]) - len(lines[body_start].lstrip()) function_body = [lines[0]] i = 1 while i < len(lines): line = lines[i] if line.strip() and (len(line) - len(line.lstrip())) < body_indent and not line.strip().startswith('#'): break function_body.append(line) i += 1 function_code = '\n'.join(function_body) functions.append((function_name, function_code)) # Class method detection class_pattern = r'class\s+([a-zA-Z_][a-zA-Z0-9_]*)' class_matches = re.finditer(class_pattern, code, re.DOTALL) for match in class_matches: class_name = match.group(1) start_pos = match.start() class_code = code[start_pos:] method_matches = re.finditer(function_pattern, class_code, re.DOTALL) for method_match in method_matches: if method_match.start() > 200: # Only near the top of the class break method_name = method_match.group(1) method_start = method_match.start() method_lines = class_code[method_start:].split('\n') body_start = 1 while body_start < len(method_lines) and not method_lines[body_start].strip(): body_start += 1 if body_start < len(method_lines): body_indent = len(method_lines[body_start]) - len(method_lines[body_start].lstrip()) method_body = [method_lines[0]] i = 1 while i < len(method_lines): line = method_lines[i] if line.strip() and (len(line) - len(line.lstrip())) < body_indent and not line.strip().startswith('#'): break method_body.append(line) i += 1 method_code = '\n'.join(method_body) functions.append((f"{class_name}.{method_name}", method_code)) return functions def extract_classes(self, code): class_pattern = r'class\s+([a-zA-Z_][a-zA-Z0-9_]*)' class_matches = re.finditer(class_pattern, code, re.DOTALL) classes = [] for match in class_matches: class_name = match.group(1) start_pos = match.start() class_lines = code[start_pos:].split('\n') body_start = 1 while body_start < len(class_lines) and not class_lines[body_start].strip(): body_start += 1 if body_start < len(class_lines): body_indent = len(class_lines[body_start]) - len(class_lines[body_start].lstrip()) class_body = [class_lines[0]] i = 1 while i < len(class_lines): line = class_lines[i] if line.strip() and (len(line) - len(line.lstrip())) < body_indent: break class_body.append(line) i += 1 class_code = '\n'.join(class_body) classes.append((class_name, class_code)) return classes def summarize(self, code, max_length=512): inputs = self.tokenizer(code, return_tensors="pt", truncation=True, max_length=512).to(self.device) with torch.no_grad(): if self.is_encoder_decoder: output = self.model.generate( inputs["input_ids"], attention_mask=inputs["attention_mask"], # Optional but good to include max_new_tokens=max_length, num_beams=4, early_stopping=True ) return self.tokenizer.decode(output[0], skip_special_tokens=True) else: input_ids = inputs["input_ids"] attention_mask = inputs["attention_mask"] output = self.model.generate( input_ids=input_ids, attention_mask=attention_mask, # ✅ Add this line max_new_tokens=max_length, do_sample=False, num_beams=4, early_stopping=True, pad_token_id=self.tokenizer.eos_token_id ) return self.tokenizer.decode(output[0], skip_special_tokens=True) def summarize_code(self, code, summarize_functions=True, summarize_classes=True): preprocessed_code = self.preprocess_code(code) results = { "file_summary": None, "function_summaries": {}, "class_summaries": {} } try: results["file_summary"] = self.summarize(preprocessed_code) except Exception as e: results["file_summary"] = f"Error generating file summary: {str(e)}" if summarize_functions: for function_name, function_code in self.extract_functions(preprocessed_code): try: summary = self.summarize(function_code) results["function_summaries"][function_name] = summary except Exception as e: results["function_summaries"][function_name] = f"Error: {str(e)}" if summarize_classes: for class_name, class_code in self.extract_classes(preprocessed_code): try: summary = self.summarize(class_code) results["class_summaries"][class_name] = summary except Exception as e: results["class_summaries"][class_name] = f"Error: {str(e)}" return results