Spaces:
Running
Running
from clearml import Model | |
import torch | |
import os | |
# Import needed classes for local loading and LoRA construction | |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
from peft import LoraConfig, get_peft_model | |
# 1. Download the LoRA checkpoint artifact from ClearML | |
CLEARML_MODEL_ID = "34e25deb24c64b74b29c8519ed15fe3e" | |
model_obj = Model(model_id=CLEARML_MODEL_ID) | |
checkpoint_path = model_obj.get_local_copy() | |
adapter_dir = os.path.dirname(checkpoint_path) | |
print(f"LoRA checkpoint downloaded to: {checkpoint_path}") | |
# 2. Load the base pretrained CodeT5 model and tokenizer from local config.json directory | |
BASE_MODEL_PATH = "microsoft/codereviewer" | |
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_PATH) | |
base_model = AutoModelForSeq2SeqLM.from_pretrained(BASE_MODEL_PATH) | |
# Print all base model parameters and their shapes | |
print("\nBase model parameters:") | |
for name, param in base_model.named_parameters(): | |
print(f"{name}: {tuple(param.shape)}") | |
# 3. Reconstruct and attach LoRA adapters | |
lora_config = LoraConfig( | |
r=64, | |
lora_alpha=128, | |
target_modules=["q", "k", "v", "o", "wi", "wo"], | |
lora_dropout=0.05, | |
bias="none", | |
task_type="SEQ_2_SEQ_LM" | |
) | |
model = get_peft_model(base_model, lora_config) | |
# 4. Load LoRA adapter weights from ClearML checkpoint | |
adapter_state = torch.load(checkpoint_path, map_location="cpu") | |
model.load_state_dict(adapter_state, strict=False) | |
# 5. Move to CPU and set evaluation mode | |
model.to("cpu").eval() | |
print("Model with LoRA adapters loaded and ready for inference.") | |
# Print out all LoRA adapter parameter names and shapes as before | |
print("\nFinetuned (LoRA adapter) parameters:") | |
for name, param in model.named_parameters(): | |
if "lora_" in name: | |
print(f"{name}: {tuple(param.shape)}") |