File size: 5,450 Bytes
2a21e9f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
from fastapi import FastAPI, Request, Form
from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse
from pydantic import BaseModel
from typing import List
from clearml import Model
import torch
from configs import add_args
from models import build_or_load_gen_model
import argparse
from argparse import Namespace
import os
from peft import PeftModel, PeftConfig, get_peft_model, LoraConfig

MAX_SOURCE_LENGTH = 512

def pad_assert(tokenizer, source_ids):
    source_ids = source_ids[:MAX_SOURCE_LENGTH - 2]
    source_ids = [tokenizer.bos_id] + source_ids + [tokenizer.eos_id]
    pad_len = MAX_SOURCE_LENGTH - len(source_ids)
    source_ids += [tokenizer.pad_id] * pad_len
    assert len(source_ids) == MAX_SOURCE_LENGTH, "Not equal length."
    return source_ids

# Encode code content and comment into model input
def encode_diff(tokenizer, code, comment):
    # Tokenize code file content
    code_ids = tokenizer.encode(code, max_length=MAX_SOURCE_LENGTH, truncation=True)[1:-1]
    # Tokenize comment
    comment_ids = tokenizer.encode(comment, max_length=MAX_SOURCE_LENGTH, truncation=True)[1:-1]
    # Concatenate: [BOS] + code + [EOS] + [msg_id] + comment
    source_ids = [tokenizer.bos_id] + code_ids + [tokenizer.eos_id]
    source_ids += [tokenizer.msg_id] + comment_ids
    # Pad/truncate to fixed length
    source_ids = source_ids[:MAX_SOURCE_LENGTH - 2]
    source_ids = [tokenizer.bos_id] + source_ids + [tokenizer.eos_id]
    pad_len = MAX_SOURCE_LENGTH - len(source_ids)
    source_ids += [tokenizer.pad_id] * pad_len
    assert len(source_ids) == MAX_SOURCE_LENGTH, "Not equal length."
    return source_ids

# Load base model architecture and tokenizer from HuggingFace
BASE_MODEL_NAME = "microsoft/codereviewer"
args = Namespace(
    model_name_or_path=BASE_MODEL_NAME,
    load_model_path=None,
    # Add other necessary default arguments if build_or_load_gen_model requires them
)
print(f"Loading base model architecture and tokenizer from: {BASE_MODEL_NAME}")
config, base_model, tokenizer = build_or_load_gen_model(args)
print("Base model architecture and tokenizer loaded.")

# Download the fine-tuned weights from ClearML
CLEARML_MODEL_ID = "34e25deb24c64b74b29c8519ed15fe3e"
model_obj = Model(model_id=CLEARML_MODEL_ID)
finetuned_weights_path = model_obj.get_local_copy()
adapter_dir = os.path.dirname(finetuned_weights_path)

print(f"Fine-tuned adapter weights downloaded to directory: {adapter_dir}")

# Create LoRA configuration matching the fine-tuned checkpoint
lora_cfg = LoraConfig(
    r=64,
    lora_alpha=128,
    target_modules=["q", "wo", "wi", "v", "o", "k"],
    lora_dropout=0.05,
    bias="none",
    task_type="SEQ_2_SEQ_LM"
)
# Wrap base model with PEFT LoRA
peft_model = get_peft_model(base_model, lora_cfg)
# Load adapter-only weights and merge into base
adapter_state = torch.load(finetuned_weights_path, map_location="cpu")
peft_model.load_state_dict(adapter_state, strict=False)
model = peft_model.merge_and_unload()
print("Merged base model with LoRA adapters.")

model.to("cpu")
model.eval()
print("Model ready for inference.")

app = FastAPI()

last_payload = {"comment": "", "files": []}
last_infer_result = {"generated_code": ""}

class FileContent(BaseModel):
    filename: str
    content: str

class PRPayload(BaseModel):
    comment: str
    files: List[FileContent]

class InferenceRequest(BaseModel):
    comment: str
    files: List[FileContent]


@app.get("/")
def root():
    return {"message": "FastAPI PR comment service is running"}

@app.post("/pr-comments")
async def receive_pr_comment(payload: PRPayload):
    global last_payload
    last_payload = payload.dict()
    # Return the received payload as JSON and also redirect to /show
    return JSONResponse(content={"status": "received", "payload": last_payload, "redirect": "/show"})

@app.get("/show", response_class=HTMLResponse)
def show_last_comment():
    html = f"<h2>Received Comment</h2><p>{last_payload['comment']}</p><hr>"
    for file in last_payload["files"]:
        html += f"<h3>{file['filename']}</h3><pre>{file['content']}</pre><hr>"
    return html

@app.post("/infer")
async def infer(request: InferenceRequest):
    global last_infer_result
    print("[DEBUG] Received /infer request with:", request.dict())
    
    code = request.files[0].content if request.files else ""
    source_ids = encode_diff(tokenizer, code, request.comment)
    # print("[DEBUG] source_ids:", source_ids)
    #tokens = [tokenizer.decode([sid], skip_special_tokens=False) for sid in source_ids]
    #print("[DEBUG] tokens:", tokens)
    inputs = torch.tensor([source_ids], dtype=torch.long)
    inputs_mask = inputs.ne(tokenizer.pad_id)
    
    preds = model.generate(
        inputs,
        attention_mask=inputs_mask,
        use_cache=True,
        num_beams=5,
        early_stopping=True,
        max_length=100,
        num_return_sequences=1
    )
    
    pred = preds[0].cpu().numpy()
    pred_nl = tokenizer.decode(pred[2:], skip_special_tokens=True, clean_up_tokenization_spaces=False)
    last_infer_result = {"generated_code": pred_nl}
    return last_infer_result

@app.get("/show-infer", response_class=HTMLResponse)
def show_infer_result():
    html = f"<h2>Generated Message</h2><pre>{last_infer_result['generated_code']}</pre>"
    return html

if __name__ == "__main__":
    # Place any CLI/training logic here if needed
    # This block is NOT executed when running with uvicorn
    pass