codereviewer / fastapi_app.py
shekkari21's picture
Deploy to HF Space
2a21e9f
from fastapi import FastAPI, Request, Form
from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse
from pydantic import BaseModel
from typing import List
from clearml import Model
import torch
from configs import add_args
from models import build_or_load_gen_model
import argparse
from argparse import Namespace
import os
from peft import PeftModel, PeftConfig, get_peft_model, LoraConfig
MAX_SOURCE_LENGTH = 512
def pad_assert(tokenizer, source_ids):
source_ids = source_ids[:MAX_SOURCE_LENGTH - 2]
source_ids = [tokenizer.bos_id] + source_ids + [tokenizer.eos_id]
pad_len = MAX_SOURCE_LENGTH - len(source_ids)
source_ids += [tokenizer.pad_id] * pad_len
assert len(source_ids) == MAX_SOURCE_LENGTH, "Not equal length."
return source_ids
# Encode code content and comment into model input
def encode_diff(tokenizer, code, comment):
# Tokenize code file content
code_ids = tokenizer.encode(code, max_length=MAX_SOURCE_LENGTH, truncation=True)[1:-1]
# Tokenize comment
comment_ids = tokenizer.encode(comment, max_length=MAX_SOURCE_LENGTH, truncation=True)[1:-1]
# Concatenate: [BOS] + code + [EOS] + [msg_id] + comment
source_ids = [tokenizer.bos_id] + code_ids + [tokenizer.eos_id]
source_ids += [tokenizer.msg_id] + comment_ids
# Pad/truncate to fixed length
source_ids = source_ids[:MAX_SOURCE_LENGTH - 2]
source_ids = [tokenizer.bos_id] + source_ids + [tokenizer.eos_id]
pad_len = MAX_SOURCE_LENGTH - len(source_ids)
source_ids += [tokenizer.pad_id] * pad_len
assert len(source_ids) == MAX_SOURCE_LENGTH, "Not equal length."
return source_ids
# Load base model architecture and tokenizer from HuggingFace
BASE_MODEL_NAME = "microsoft/codereviewer"
args = Namespace(
model_name_or_path=BASE_MODEL_NAME,
load_model_path=None,
# Add other necessary default arguments if build_or_load_gen_model requires them
)
print(f"Loading base model architecture and tokenizer from: {BASE_MODEL_NAME}")
config, base_model, tokenizer = build_or_load_gen_model(args)
print("Base model architecture and tokenizer loaded.")
# Download the fine-tuned weights from ClearML
CLEARML_MODEL_ID = "34e25deb24c64b74b29c8519ed15fe3e"
model_obj = Model(model_id=CLEARML_MODEL_ID)
finetuned_weights_path = model_obj.get_local_copy()
adapter_dir = os.path.dirname(finetuned_weights_path)
print(f"Fine-tuned adapter weights downloaded to directory: {adapter_dir}")
# Create LoRA configuration matching the fine-tuned checkpoint
lora_cfg = LoraConfig(
r=64,
lora_alpha=128,
target_modules=["q", "wo", "wi", "v", "o", "k"],
lora_dropout=0.05,
bias="none",
task_type="SEQ_2_SEQ_LM"
)
# Wrap base model with PEFT LoRA
peft_model = get_peft_model(base_model, lora_cfg)
# Load adapter-only weights and merge into base
adapter_state = torch.load(finetuned_weights_path, map_location="cpu")
peft_model.load_state_dict(adapter_state, strict=False)
model = peft_model.merge_and_unload()
print("Merged base model with LoRA adapters.")
model.to("cpu")
model.eval()
print("Model ready for inference.")
app = FastAPI()
last_payload = {"comment": "", "files": []}
last_infer_result = {"generated_code": ""}
class FileContent(BaseModel):
filename: str
content: str
class PRPayload(BaseModel):
comment: str
files: List[FileContent]
class InferenceRequest(BaseModel):
comment: str
files: List[FileContent]
@app.get("/")
def root():
return {"message": "FastAPI PR comment service is running"}
@app.post("/pr-comments")
async def receive_pr_comment(payload: PRPayload):
global last_payload
last_payload = payload.dict()
# Return the received payload as JSON and also redirect to /show
return JSONResponse(content={"status": "received", "payload": last_payload, "redirect": "/show"})
@app.get("/show", response_class=HTMLResponse)
def show_last_comment():
html = f"<h2>Received Comment</h2><p>{last_payload['comment']}</p><hr>"
for file in last_payload["files"]:
html += f"<h3>{file['filename']}</h3><pre>{file['content']}</pre><hr>"
return html
@app.post("/infer")
async def infer(request: InferenceRequest):
global last_infer_result
print("[DEBUG] Received /infer request with:", request.dict())
code = request.files[0].content if request.files else ""
source_ids = encode_diff(tokenizer, code, request.comment)
# print("[DEBUG] source_ids:", source_ids)
#tokens = [tokenizer.decode([sid], skip_special_tokens=False) for sid in source_ids]
#print("[DEBUG] tokens:", tokens)
inputs = torch.tensor([source_ids], dtype=torch.long)
inputs_mask = inputs.ne(tokenizer.pad_id)
preds = model.generate(
inputs,
attention_mask=inputs_mask,
use_cache=True,
num_beams=5,
early_stopping=True,
max_length=100,
num_return_sequences=1
)
pred = preds[0].cpu().numpy()
pred_nl = tokenizer.decode(pred[2:], skip_special_tokens=True, clean_up_tokenization_spaces=False)
last_infer_result = {"generated_code": pred_nl}
return last_infer_result
@app.get("/show-infer", response_class=HTMLResponse)
def show_infer_result():
html = f"<h2>Generated Message</h2><pre>{last_infer_result['generated_code']}</pre>"
return html
if __name__ == "__main__":
# Place any CLI/training logic here if needed
# This block is NOT executed when running with uvicorn
pass