import torch import torch.nn as nn import yaml import sys sys.path.append(".") sys.path.append("submodules") sys.path.append("submodules/mast3r") from mast3r.model import AsymmetricMASt3R from src.ptv3 import PTV3 from src.gaussian_head import GaussianHead from src.utils.points_process import merge_points from src.losses import GaussianLoss from src.lseg import LSegFeatureExtractor import argparse class LSM_MASt3R(nn.Module): def __init__(self, mast3r_config, point_transformer_config, gaussian_head_config, lseg_config, ): super().__init__() # self.config self.config = { 'mast3r_config': mast3r_config, 'point_transformer_config': point_transformer_config, 'gaussian_head_config': gaussian_head_config, 'lseg_config': lseg_config } # Initialize AsymmetricMASt3R self.mast3r = AsymmetricMASt3R.from_pretrained(**mast3r_config) # Freeze MASt3R parameters for param in self.mast3r.parameters(): param.requires_grad = False self.mast3r.eval() # Initialize PointTransformerV3 self.point_transformer = PTV3(**point_transformer_config) # Initialize the gaussian head self.gaussian_head = GaussianHead(**gaussian_head_config) # Initialize the lseg feature extractor self.lseg_feature_extractor = LSegFeatureExtractor.from_pretrained(**lseg_config) for param in self.lseg_feature_extractor.parameters(): param.requires_grad = False self.lseg_feature_extractor.eval() # Define two linear layers d_gs_feats = gaussian_head_config.get('d_gs_feats', 32) self.feature_reduction = nn.Sequential( nn.Conv2d(512, d_gs_feats, kernel_size=1), nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) ) # (b, 512, h//2, w//2) -> (b, d_features, h, w) self.feature_expansion = nn.Sequential( nn.Conv2d(d_gs_feats, 512, kernel_size=1), nn.Upsample(scale_factor=0.5, mode='bilinear', align_corners=True) ) # (b, d_features, h, w) -> (b, 512, h//2, w//2) def forward(self, view1, view2): # AsymmetricMASt3R forward pass mast3r_output = self.mast3r(view1, view2) # merge points from two views data_dict = merge_points(mast3r_output, view1, view2) # PointTransformerV3 forward pass point_transformer_output = self.point_transformer(data_dict) # extract lseg features lseg_features = self.extract_lseg_features(view1, view2) # Gaussian head forward pass final_output = self.gaussian_head(point_transformer_output, lseg_features) return final_output def extract_lseg_features(self, view1, view2): # concat view1 and view2 img = torch.cat([view1['img'], view2['img']], dim=0) # (v*b, 3, h, w) # extract features lseg_features = self.lseg_feature_extractor.extract_features(img) # (v*b, 512, h//2, w//2) # reduce dimensions lseg_features = self.feature_reduction(lseg_features) # (v*b, d_features, h, w) return lseg_features @staticmethod def from_pretrained(checkpoint_path, device='cuda'): # Load the checkpoint ckpt = torch.load(checkpoint_path, map_location='cpu') # Extract the configuration from the checkpoint config = ckpt['args'] # Create a new instance of LSM_MASt3R model = eval(config.model) # Load the state dict model.load_state_dict(ckpt['model']) # Move the model to the specified device model = model.to(device) return model def state_dict(self, destination=None, prefix='', keep_vars=False): # 获取所有参数的state_dict full_state_dict = super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) # 只保留需要训练的参数 trainable_state_dict = { k: v for k, v in full_state_dict.items() if not (k.startswith('mast3r.') or k.startswith('lseg_feature_extractor.')) } return trainable_state_dict def load_state_dict(self, state_dict, strict=True): # 获取当前模型的完整state_dict model_state = super().state_dict() # 只更新需要训练的参数 for k in list(state_dict.keys()): if k in model_state and not (k.startswith('mast3r.') or k.startswith('lseg_feature_extractor.')): model_state[k] = state_dict[k] # 使用更新后的state_dict加载模型 super().load_state_dict(model_state, strict=False) if __name__ == "__main__": from torch.utils.data import DataLoader import argparse parser = argparse.ArgumentParser() parser.add_argument('--checkpoint', type=str) args = parser.parse_args() # Load config with open("configs/model_config.yaml", "r") as f: config = yaml.safe_load(f) # Initialize model if args.checkpoint is not None: model = LSM_MASt3R.from_pretrained(args.checkpoint, device='cuda') else: model = LSM_MASt3R(**config).to('cuda') model.eval() # Print model print(model) # Load dataset from src.datasets.scannet import Scannet dataset = Scannet(split='train', ROOT="data/scannet_processed", resolution=[(512, 384)]) # Print dataset print(dataset) # Test model data_loader = DataLoader(dataset, batch_size=3, shuffle=True) data = next(iter(data_loader)) # move data to cuda for view in data: view['img'] = view['img'].to('cuda') view['depthmap'] = view['depthmap'].to('cuda') view['camera_pose'] = view['camera_pose'].to('cuda') view['camera_intrinsics'] = view['camera_intrinsics'].to('cuda') # Forward pass output = model(*data[:2]) # Loss loss = GaussianLoss() loss_value = loss(*data, *output, model) print(loss_value)