Spaces:
Runtime error
Runtime error
File size: 8,505 Bytes
f360edc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 |
# =============================
# π codet5_summarizer.py (Updated)
# =============================
import torch
import re
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM
import os
MODEL_OPTIONS = {
"CodeT5 Base (multi-sum)": "Salesforce/codet5-base-multi-sum",
"CodeT5 Base": "Salesforce/codet5-base",
"CodeT5 Small (Python-specific)": "stmnk/codet5-small-code-summarization-python",
"Gemini (describeai)": "describeai/gemini",
"Mistral 7B Instruct (v0.2)": "mistralai/Mistral-7B-Instruct-v0.2",
}
class CodeT5Summarizer:
def __init__(self, model_name=None):
model_name = model_name or MODEL_OPTIONS["CodeT5 Base (multi-sum)"]
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
hf_token = os.getenv('HF_TOKEN')
if hf_token is None:
raise ValueError("Hugging Face token must be set in the environment variable 'HF_TOKEN'.")
self.tokenizer = AutoTokenizer.from_pretrained(model_name, token=hf_token)
# Use causal model for decoder-only (e.g., Mistral), otherwise Seq2Seq
try:
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name, token=hf_token).to(self.device)
except:
self.model = AutoModelForCausalLM.from_pretrained(model_name, token=hf_token).to(self.device)
self.is_encoder_decoder = self.model.config.is_encoder_decoder if hasattr(self.model.config, "is_encoder_decoder") else False
def preprocess_code(self, code):
code = re.sub(r'\n\s*\n', '\n', code)
lines = code.split('\n')
clean = []
docstring = False
for line in lines:
if '"""' in line or "'''" in line:
docstring = not docstring
if docstring or not line.strip().startswith('#'):
clean.append(line)
return re.sub(r' +', ' ', '\n'.join(clean))
def extract_functions(self, code):
function_pattern = r'def\s+([a-zA-Z_][a-zA-Z0-9_]*)\s*\(.*?\).*?:'
function_matches = re.finditer(function_pattern, code, re.DOTALL)
functions = []
for match in function_matches:
start_pos = match.start()
function_name = match.group(1)
lines = code[start_pos:].split('\n')
body_start = 1
while body_start < len(lines) and not lines[body_start].strip():
body_start += 1
if body_start < len(lines):
body_indent = len(lines[body_start]) - len(lines[body_start].lstrip())
function_body = [lines[0]]
i = 1
while i < len(lines):
line = lines[i]
if line.strip() and (len(line) - len(line.lstrip())) < body_indent and not line.strip().startswith('#'):
break
function_body.append(line)
i += 1
function_code = '\n'.join(function_body)
functions.append((function_name, function_code))
# Class method detection
class_pattern = r'class\s+([a-zA-Z_][a-zA-Z0-9_]*)'
class_matches = re.finditer(class_pattern, code, re.DOTALL)
for match in class_matches:
class_name = match.group(1)
start_pos = match.start()
class_code = code[start_pos:]
method_matches = re.finditer(function_pattern, class_code, re.DOTALL)
for method_match in method_matches:
if method_match.start() > 200: # Only near the top of the class
break
method_name = method_match.group(1)
method_start = method_match.start()
method_lines = class_code[method_start:].split('\n')
body_start = 1
while body_start < len(method_lines) and not method_lines[body_start].strip():
body_start += 1
if body_start < len(method_lines):
body_indent = len(method_lines[body_start]) - len(method_lines[body_start].lstrip())
method_body = [method_lines[0]]
i = 1
while i < len(method_lines):
line = method_lines[i]
if line.strip() and (len(line) - len(line.lstrip())) < body_indent and not line.strip().startswith('#'):
break
method_body.append(line)
i += 1
method_code = '\n'.join(method_body)
functions.append((f"{class_name}.{method_name}", method_code))
return functions
def extract_classes(self, code):
class_pattern = r'class\s+([a-zA-Z_][a-zA-Z0-9_]*)'
class_matches = re.finditer(class_pattern, code, re.DOTALL)
classes = []
for match in class_matches:
class_name = match.group(1)
start_pos = match.start()
class_lines = code[start_pos:].split('\n')
body_start = 1
while body_start < len(class_lines) and not class_lines[body_start].strip():
body_start += 1
if body_start < len(class_lines):
body_indent = len(class_lines[body_start]) - len(class_lines[body_start].lstrip())
class_body = [class_lines[0]]
i = 1
while i < len(class_lines):
line = class_lines[i]
if line.strip() and (len(line) - len(line.lstrip())) < body_indent:
break
class_body.append(line)
i += 1
class_code = '\n'.join(class_body)
classes.append((class_name, class_code))
return classes
def summarize(self, code, max_length=512):
inputs = self.tokenizer(code, return_tensors="pt", truncation=True, max_length=512).to(self.device)
with torch.no_grad():
if self.is_encoder_decoder:
output = self.model.generate(
inputs["input_ids"],
attention_mask=inputs["attention_mask"], # Optional but good to include
max_new_tokens=max_length,
num_beams=4,
early_stopping=True
)
return self.tokenizer.decode(output[0], skip_special_tokens=True)
else:
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
output = self.model.generate(
input_ids=input_ids,
attention_mask=attention_mask, # β
Add this line
max_new_tokens=max_length,
do_sample=False,
num_beams=4,
early_stopping=True,
pad_token_id=self.tokenizer.eos_token_id
)
return self.tokenizer.decode(output[0], skip_special_tokens=True)
def summarize_code(self, code, summarize_functions=True, summarize_classes=True):
preprocessed_code = self.preprocess_code(code)
results = {
"file_summary": None,
"function_summaries": {},
"class_summaries": {}
}
try:
results["file_summary"] = self.summarize(preprocessed_code)
except Exception as e:
results["file_summary"] = f"Error generating file summary: {str(e)}"
if summarize_functions:
for function_name, function_code in self.extract_functions(preprocessed_code):
try:
summary = self.summarize(function_code)
results["function_summaries"][function_name] = summary
except Exception as e:
results["function_summaries"][function_name] = f"Error: {str(e)}"
if summarize_classes:
for class_name, class_code in self.extract_classes(preprocessed_code):
try:
summary = self.summarize(class_code)
results["class_summaries"][class_name] = summary
except Exception as e:
results["class_summaries"][class_name] = f"Error: {str(e)}"
return results
|