Spaces:
Running
Running
from fastapi import FastAPI, Request, Form | |
from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse | |
from pydantic import BaseModel | |
from typing import List | |
from clearml import Model | |
import torch | |
from configs import add_args | |
from models import build_or_load_gen_model | |
import argparse | |
from argparse import Namespace | |
import os | |
from peft import PeftModel, PeftConfig, get_peft_model, LoraConfig | |
MAX_SOURCE_LENGTH = 512 | |
def pad_assert(tokenizer, source_ids): | |
source_ids = source_ids[:MAX_SOURCE_LENGTH - 2] | |
source_ids = [tokenizer.bos_id] + source_ids + [tokenizer.eos_id] | |
pad_len = MAX_SOURCE_LENGTH - len(source_ids) | |
source_ids += [tokenizer.pad_id] * pad_len | |
assert len(source_ids) == MAX_SOURCE_LENGTH, "Not equal length." | |
return source_ids | |
# Encode code content and comment into model input | |
def encode_diff(tokenizer, code, comment): | |
# Tokenize code file content | |
code_ids = tokenizer.encode(code, max_length=MAX_SOURCE_LENGTH, truncation=True)[1:-1] | |
# Tokenize comment | |
comment_ids = tokenizer.encode(comment, max_length=MAX_SOURCE_LENGTH, truncation=True)[1:-1] | |
# Concatenate: [BOS] + code + [EOS] + [msg_id] + comment | |
source_ids = [tokenizer.bos_id] + code_ids + [tokenizer.eos_id] | |
source_ids += [tokenizer.msg_id] + comment_ids | |
# Pad/truncate to fixed length | |
source_ids = source_ids[:MAX_SOURCE_LENGTH - 2] | |
source_ids = [tokenizer.bos_id] + source_ids + [tokenizer.eos_id] | |
pad_len = MAX_SOURCE_LENGTH - len(source_ids) | |
source_ids += [tokenizer.pad_id] * pad_len | |
assert len(source_ids) == MAX_SOURCE_LENGTH, "Not equal length." | |
return source_ids | |
# Load base model architecture and tokenizer from HuggingFace | |
BASE_MODEL_NAME = "microsoft/codereviewer" | |
args = Namespace( | |
model_name_or_path=BASE_MODEL_NAME, | |
load_model_path=None, | |
# Add other necessary default arguments if build_or_load_gen_model requires them | |
) | |
print(f"Loading base model architecture and tokenizer from: {BASE_MODEL_NAME}") | |
config, base_model, tokenizer = build_or_load_gen_model(args) | |
print("Base model architecture and tokenizer loaded.") | |
# Download the fine-tuned weights from ClearML | |
CLEARML_MODEL_ID = "34e25deb24c64b74b29c8519ed15fe3e" | |
model_obj = Model(model_id=CLEARML_MODEL_ID) | |
finetuned_weights_path = model_obj.get_local_copy() | |
adapter_dir = os.path.dirname(finetuned_weights_path) | |
print(f"Fine-tuned adapter weights downloaded to directory: {adapter_dir}") | |
# Create LoRA configuration matching the fine-tuned checkpoint | |
lora_cfg = LoraConfig( | |
r=64, | |
lora_alpha=128, | |
target_modules=["q", "wo", "wi", "v", "o", "k"], | |
lora_dropout=0.05, | |
bias="none", | |
task_type="SEQ_2_SEQ_LM" | |
) | |
# Wrap base model with PEFT LoRA | |
peft_model = get_peft_model(base_model, lora_cfg) | |
# Load adapter-only weights and merge into base | |
adapter_state = torch.load(finetuned_weights_path, map_location="cpu") | |
peft_model.load_state_dict(adapter_state, strict=False) | |
model = peft_model.merge_and_unload() | |
print("Merged base model with LoRA adapters.") | |
model.to("cpu") | |
model.eval() | |
print("Model ready for inference.") | |
app = FastAPI() | |
last_payload = {"comment": "", "files": []} | |
last_infer_result = {"generated_code": ""} | |
class FileContent(BaseModel): | |
filename: str | |
content: str | |
class PRPayload(BaseModel): | |
comment: str | |
files: List[FileContent] | |
class InferenceRequest(BaseModel): | |
comment: str | |
files: List[FileContent] | |
def root(): | |
return {"message": "FastAPI PR comment service is running"} | |
async def receive_pr_comment(payload: PRPayload): | |
global last_payload | |
last_payload = payload.dict() | |
# Return the received payload as JSON and also redirect to /show | |
return JSONResponse(content={"status": "received", "payload": last_payload, "redirect": "/show"}) | |
def show_last_comment(): | |
html = f"<h2>Received Comment</h2><p>{last_payload['comment']}</p><hr>" | |
for file in last_payload["files"]: | |
html += f"<h3>{file['filename']}</h3><pre>{file['content']}</pre><hr>" | |
return html | |
async def infer(request: InferenceRequest): | |
global last_infer_result | |
print("[DEBUG] Received /infer request with:", request.dict()) | |
code = request.files[0].content if request.files else "" | |
source_ids = encode_diff(tokenizer, code, request.comment) | |
# print("[DEBUG] source_ids:", source_ids) | |
#tokens = [tokenizer.decode([sid], skip_special_tokens=False) for sid in source_ids] | |
#print("[DEBUG] tokens:", tokens) | |
inputs = torch.tensor([source_ids], dtype=torch.long) | |
inputs_mask = inputs.ne(tokenizer.pad_id) | |
preds = model.generate( | |
inputs, | |
attention_mask=inputs_mask, | |
use_cache=True, | |
num_beams=5, | |
early_stopping=True, | |
max_length=100, | |
num_return_sequences=1 | |
) | |
pred = preds[0].cpu().numpy() | |
pred_nl = tokenizer.decode(pred[2:], skip_special_tokens=True, clean_up_tokenization_spaces=False) | |
last_infer_result = {"generated_code": pred_nl} | |
return last_infer_result | |
def show_infer_result(): | |
html = f"<h2>Generated Message</h2><pre>{last_infer_result['generated_code']}</pre>" | |
return html | |
if __name__ == "__main__": | |
# Place any CLI/training logic here if needed | |
# This block is NOT executed when running with uvicorn | |
pass |