from fastapi import FastAPI, Request, Form from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse from pydantic import BaseModel from typing import List from clearml import Model import torch from configs import add_args from models import build_or_load_gen_model import argparse from argparse import Namespace import os from peft import PeftModel, PeftConfig, get_peft_model, LoraConfig MAX_SOURCE_LENGTH = 512 def pad_assert(tokenizer, source_ids): source_ids = source_ids[:MAX_SOURCE_LENGTH - 2] source_ids = [tokenizer.bos_id] + source_ids + [tokenizer.eos_id] pad_len = MAX_SOURCE_LENGTH - len(source_ids) source_ids += [tokenizer.pad_id] * pad_len assert len(source_ids) == MAX_SOURCE_LENGTH, "Not equal length." return source_ids # Encode code content and comment into model input def encode_diff(tokenizer, code, comment): # Tokenize code file content code_ids = tokenizer.encode(code, max_length=MAX_SOURCE_LENGTH, truncation=True)[1:-1] # Tokenize comment comment_ids = tokenizer.encode(comment, max_length=MAX_SOURCE_LENGTH, truncation=True)[1:-1] # Concatenate: [BOS] + code + [EOS] + [msg_id] + comment source_ids = [tokenizer.bos_id] + code_ids + [tokenizer.eos_id] source_ids += [tokenizer.msg_id] + comment_ids # Pad/truncate to fixed length source_ids = source_ids[:MAX_SOURCE_LENGTH - 2] source_ids = [tokenizer.bos_id] + source_ids + [tokenizer.eos_id] pad_len = MAX_SOURCE_LENGTH - len(source_ids) source_ids += [tokenizer.pad_id] * pad_len assert len(source_ids) == MAX_SOURCE_LENGTH, "Not equal length." return source_ids # Load base model architecture and tokenizer from HuggingFace BASE_MODEL_NAME = "microsoft/codereviewer" args = Namespace( model_name_or_path=BASE_MODEL_NAME, load_model_path=None, # Add other necessary default arguments if build_or_load_gen_model requires them ) print(f"Loading base model architecture and tokenizer from: {BASE_MODEL_NAME}") config, base_model, tokenizer = build_or_load_gen_model(args) print("Base model architecture and tokenizer loaded.") # Download the fine-tuned weights from ClearML CLEARML_MODEL_ID = "34e25deb24c64b74b29c8519ed15fe3e" model_obj = Model(model_id=CLEARML_MODEL_ID) finetuned_weights_path = model_obj.get_local_copy() adapter_dir = os.path.dirname(finetuned_weights_path) print(f"Fine-tuned adapter weights downloaded to directory: {adapter_dir}") # Create LoRA configuration matching the fine-tuned checkpoint lora_cfg = LoraConfig( r=64, lora_alpha=128, target_modules=["q", "wo", "wi", "v", "o", "k"], lora_dropout=0.05, bias="none", task_type="SEQ_2_SEQ_LM" ) # Wrap base model with PEFT LoRA peft_model = get_peft_model(base_model, lora_cfg) # Load adapter-only weights and merge into base adapter_state = torch.load(finetuned_weights_path, map_location="cpu") peft_model.load_state_dict(adapter_state, strict=False) model = peft_model.merge_and_unload() print("Merged base model with LoRA adapters.") model.to("cpu") model.eval() print("Model ready for inference.") app = FastAPI() last_payload = {"comment": "", "files": []} last_infer_result = {"generated_code": ""} class FileContent(BaseModel): filename: str content: str class PRPayload(BaseModel): comment: str files: List[FileContent] class InferenceRequest(BaseModel): comment: str files: List[FileContent] @app.get("/") def root(): return {"message": "FastAPI PR comment service is running"} @app.post("/pr-comments") async def receive_pr_comment(payload: PRPayload): global last_payload last_payload = payload.dict() # Return the received payload as JSON and also redirect to /show return JSONResponse(content={"status": "received", "payload": last_payload, "redirect": "/show"}) @app.get("/show", response_class=HTMLResponse) def show_last_comment(): html = f"

Received Comment

{last_payload['comment']}


" for file in last_payload["files"]: html += f"

{file['filename']}

{file['content']}

" return html @app.post("/infer") async def infer(request: InferenceRequest): global last_infer_result print("[DEBUG] Received /infer request with:", request.dict()) code = request.files[0].content if request.files else "" source_ids = encode_diff(tokenizer, code, request.comment) # print("[DEBUG] source_ids:", source_ids) #tokens = [tokenizer.decode([sid], skip_special_tokens=False) for sid in source_ids] #print("[DEBUG] tokens:", tokens) inputs = torch.tensor([source_ids], dtype=torch.long) inputs_mask = inputs.ne(tokenizer.pad_id) preds = model.generate( inputs, attention_mask=inputs_mask, use_cache=True, num_beams=5, early_stopping=True, max_length=100, num_return_sequences=1 ) pred = preds[0].cpu().numpy() pred_nl = tokenizer.decode(pred[2:], skip_special_tokens=True, clean_up_tokenization_spaces=False) last_infer_result = {"generated_code": pred_nl} return last_infer_result @app.get("/show-infer", response_class=HTMLResponse) def show_infer_result(): html = f"

Generated Message

{last_infer_result['generated_code']}
" return html if __name__ == "__main__": # Place any CLI/training logic here if needed # This block is NOT executed when running with uvicorn pass