diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..fea7e72b5ee96b7254a67fa1f4fa4accf07b6c0c
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,4 @@
+.idea
+temp
+temp.py
+weight
diff --git a/README.md b/README.md
index 7687583e0bd7016eb51e87be3a6317b5491ece8c..4782cf6b514b3346a38c6301fcf2e1f9216eee18 100644
--- a/README.md
+++ b/README.md
@@ -1,13 +1,256 @@
----
-title: MapLocNet
-emoji: 📈
-colorFrom: gray
-colorTo: green
-sdk: gradio
-sdk_version: 3.40.1
-app_file: app.py
-pinned: false
-license: bsd
----
-
-Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
+
+
OrienterNet
Visual Localization in 2D Public Maps
with Neural Matching
+
+ Paul-Edouard Sarlin
+ ·
+ Daniel DeTone
+ ·
+ Tsun-Yi Yang
+ ·
+ Armen Avetisyan
+ ·
+ Julian Straub
+
+ Tomasz Malisiewicz
+ ·
+ Samuel Rota Bulo
+ ·
+ Richard Newcombe
+ ·
+ Peter Kontschieder
+ ·
+ Vasileios Balntas
+
+ CVPR 2023
+
+
+
+
+
+
+ OrienterNet is a deep neural network that can accurately localize an image
using the same 2D semantic maps that humans use to orient themselves.
+
+
+##
+
+This repository hosts the source code for OrienterNet, a research project by Meta Reality Labs. OrienterNet leverages the power of deep learning to provide accurate positioning of images using free and globally-available maps from OpenStreetMap. As opposed to complex existing algorithms that rely on 3D point clouds, OrienterNet estimates a position and orientation by matching a neural Bird's-Eye-View with 2D maps.
+
+## Installation
+
+OrienterNet requires Python >= 3.8 and [PyTorch](https://pytorch.org/). To run the demo, clone this repo and install the minimal requirements:
+
+```bash
+git clone https://github.com/facebookresearch/OrienterNet
+python -m pip install -r requirements/requirements.txt
+```
+
+To run the evaluation and training, install the full requirements:
+
+```bash
+python -m pip install -r requirements/full.txt
+```
+
+## Demo ➡️ [](https://sarlinpe-orienternet.hf.space) [](https://colab.research.google.com/drive/1zH_2mzdB18BnJVq48ZvJhMorcRjrWAXI?usp=sharing)
+
+Try our minimal demo - take a picture with your phone in any city and find its exact location in a few seconds!
+- [Web demo with Gradio and Huggingface Spaces](https://sarlinpe-orienternet.hf.space)
+- [Cloud demo with Google Colab](https://colab.research.google.com/drive/1zH_2mzdB18BnJVq48ZvJhMorcRjrWAXI?usp=sharing)
+- Local demo with Jupyter nobook [`demo.ipynb`](./demo.ipynb)
+
+
+
+
+ OrienterNet positions any image within a large area - try it with your own images!
+
+
+## Evaluation
+
+#### Mapillary Geo-Localization dataset
+
+
+[Click to expand]
+
+To obtain the dataset:
+
+1. Create a developper account at [mapillary.com](https://www.mapillary.com/dashboard/developers) and obtain a free access token.
+2. Run the following script to download the data from Mapillary and prepare it:
+
+```bash
+python -m maploc.data.mapillary.prepare --token $YOUR_ACCESS_TOKEN
+```
+
+By default the data is written to the directory `./datasets/MGL/`. Then run the evaluation with the pre-trained model:
+
+```bash
+python -m maploc.evaluation.mapillary --experiment OrienterNet_MGL model.num_rotations=256
+```
+
+This downloads the pre-trained models if necessary. The results should be close to the following:
+
+```
+Recall xy_max_error: [14.37, 48.69, 61.7] at (1, 3, 5) m/°
+Recall yaw_max_error: [20.95, 54.96, 70.17] at (1, 3, 5) m/°
+```
+
+This requires a GPU with 11GB of memory. If you run into OOM issues, consider reducing the number of rotations (the default is 256):
+
+```bash
+python -m maploc.evaluation.mapillary [...] model.num_rotations=128
+```
+
+To export visualizations for the first 100 examples:
+
+```bash
+python -m maploc.evaluation.mapillary [...] --output_dir ./viz_MGL/ --num 100
+```
+
+To run the evaluation in sequential mode:
+
+```bash
+python -m maploc.evaluation.mapillary --experiment OrienterNet_MGL --sequential model.num_rotations=256
+```
+The results should be close to the following:
+```
+Recall xy_seq_error: [29.73, 73.25, 91.17] at (1, 3, 5) m/°
+Recall yaw_seq_error: [46.55, 88.3, 96.45] at (1, 3, 5) m/°
+```
+The sequential evaluation uses 10 frames by default. To increase this number, add:
+```bash
+python -m maploc.evaluation.mapillary [...] chunking.max_length=20
+```
+
+
+
+
+#### KITTI dataset
+
+
+[Click to expand]
+
+1. Download and prepare the dataset to `./datasets/kitti/`:
+
+```bash
+python -m maploc.data.kitti.prepare
+```
+
+2. Run the evaluation with the model trained on MGL:
+
+```bash
+python -m maploc.evaluation.kitti --experiment OrienterNet_MGL model.num_rotations=256
+```
+
+You should expect the following results:
+
+```
+Recall directional_error: [[50.33, 85.18, 92.73], [24.38, 56.13, 67.98]] at (1, 3, 5) m/°
+Recall yaw_max_error: [29.22, 68.2, 84.49] at (1, 3, 5) m/°
+```
+
+You can similarly export some visual examples:
+
+```bash
+python -m maploc.evaluation.kitti [...] --output_dir ./viz_KITTI/ --num 100
+```
+
+To run in sequential mode:
+```bash
+python -m maploc.evaluation.kitti --experiment OrienterNet_MGL --sequential model.num_rotations=256
+```
+with results:
+```
+Recall directional_seq_error: [[81.94, 97.35, 98.67], [52.57, 95.6, 97.35]] at (1, 3, 5) m/°
+Recall yaw_seq_error: [82.7, 98.63, 99.06] at (1, 3, 5) m/°
+```
+
+
+
+#### Aria Detroit & Seattle
+
+We are currently unable to release the dataset used to evaluate OrienterNet in the CVPR 2023 paper.
+
+## Training
+
+#### MGL dataset
+
+We trained the model on the MGL dataset using 3x 3090 GPUs (24GB VRAM each) and a total batch size of 12 for 340k iterations (about 3-4 days) with the following command:
+
+```bash
+python -m maploc.train experiment.name=OrienterNet_MGL_reproduce
+```
+
+Feel free to use any other experiment name. Configurations are managed by [Hydra](https://hydra.cc/) and [OmegaConf](https://omegaconf.readthedocs.io) so any entry can be overridden from the command line. You may thus reduce the number of GPUs and the batch size via:
+
+```bash
+python -m maploc.train experiment.name=OrienterNet_MGL_reproduce
+ experiment.gpus=1 data.loading.train.batch_size=4
+```
+
+Be aware that this can reduce the overall performance. The checkpoints are written to `./experiments/experiment_name/`. Then run the evaluation:
+
+```bash
+# the best checkpoint:
+python -m maploc.evaluation.mapillary --experiment OrienterNet_MGL_reproduce
+# a specific checkpoint:
+python -m maploc.evaluation.mapillary \
+ --experiment OrienterNet_MGL_reproduce/checkpoint-step=340000.ckpt
+```
+
+#### KITTI
+
+To fine-tune a trained model on the KITTI dataset:
+
+```bash
+python -m maploc.train experiment.name=OrienterNet_MGL_kitti data=kitti \
+ training.finetune_from_checkpoint='"experiments/OrienterNet_MGL_reproduce/checkpoint-step=340000.ckpt"'
+```
+
+## Interactive development
+
+We provide several visualization notebooks:
+
+- [Visualize predictions on the MGL dataset](./notebooks/visualize_predictions_mgl.ipynb)
+- [Visualize predictions on the KITTI dataset](./notebooks/visualize_predictions_kitti.ipynb)
+- [Visualize sequential predictions](./notebooks/visualize_predictions_sequences.ipynb)
+
+## OpenStreetMap data
+
+
+[Click to expand]
+
+To make sure that the results are consistent over time, we used OSM data downloaded from [Geofabrik](https://download.geofabrik.de/) in November 2021. By default, the dataset scripts `maploc.data.[mapillary,kitti].prepare` download pre-generated raster tiles. If you wish to use different OSM classes, you can pass `--generate_tiles`, which will download and use our prepared raw `.osm` XML files.
+
+You may alternatively download more recent files from [Geofabrik](https://download.geofabrik.de/). Download either compressed XML files as `.osm.bz2` or binary files `.osm.pbf`, which need to be converted to XML files `.osm`, for example using Osmium: ` osmium cat xx.osm.pbf -o xx.osm`.
+
+
+
+## License
+
+The MGL dataset is made available under the [CC-BY-SA](https://creativecommons.org/licenses/by-sa/4.0/) license following the data available on the Mapillary platform. The model implementation and the pre-trained weights follow a [CC-BY-NC](https://creativecommons.org/licenses/by-nc/2.0/) license. [OpenStreetMap data](https://www.openstreetmap.org/copyright) is licensed under the [Open Data Commons Open Database License](https://opendatacommons.org/licenses/odbl/).
+
+## BibTex citation
+
+Please consider citing our work if you use any code from this repo or ideas presented in the paper:
+```
+@inproceedings{sarlin2023orienternet,
+ author = {Paul-Edouard Sarlin and
+ Daniel DeTone and
+ Tsun-Yi Yang and
+ Armen Avetisyan and
+ Julian Straub and
+ Tomasz Malisiewicz and
+ Samuel Rota Bulo and
+ Richard Newcombe and
+ Peter Kontschieder and
+ Vasileios Balntas},
+ title = {{OrienterNet: Visual Localization in 2D Public Maps with Neural Matching}},
+ booktitle = {CVPR},
+ year = {2023},
+}
+```
+
diff --git a/conf/maplocnet.yaml b/conf/maplocnet.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..df0be9a09848e49803bb2049c2cffeffde771084
--- /dev/null
+++ b/conf/maplocnet.yaml
@@ -0,0 +1,100 @@
+data:
+ root: '/root/DATASET/UAV2MAP/UAV/'
+ train_citys:
+ - Paris
+ - Berlin
+ - London
+ - Tokyo
+ - NewYork
+ val_citys:
+ - Toronto
+ image_size: 256
+ train:
+ batch_size: 12
+ num_workers: 4
+ val:
+ batch_size: ${..train.batch_size}
+ num_workers: ${.batch_size}
+ num_classes:
+ areas: 7
+ ways: 10
+ nodes: 33
+ pixel_per_meter: 1
+ crop_size_meters: 64
+ max_init_error: 48
+ add_map_mask: true
+ resize_image: 512
+ pad_to_square: true
+ rectify_pitch: true
+ augmentation:
+ rot90: true
+ flip: true
+ image:
+ apply: true
+ brightness: 0.5
+ contrast: 0.4
+ saturation: 0.4
+ hue": 0.5/3.14
+model:
+ image_size: ${data.image_size}
+ latent_dim: 128
+ val_citys: ${data.val_citys}
+ image_encoder:
+ name: feature_extractor_v2
+ backbone:
+ encoder: resnet50
+ pretrained: true
+ output_dim: 8
+ num_downsample: null
+ remove_stride_from_first_conv: false
+ name: orienternet
+ matching_dim: 8
+ z_max: 32
+ x_max: 32
+ pixel_per_meter: 1
+ num_scale_bins: 33
+ num_rotations: 64
+ map_encoder:
+ embedding_dim: 16
+ output_dim: 8
+ num_classes:
+ areas: 7
+ ways: 10
+ nodes: 33
+ backbone:
+ encoder: vgg19
+ pretrained: false
+ output_scales:
+ - 0
+ num_downsample: 3
+ decoder:
+ - 128
+ - 64
+ - 64
+ padding: replicate
+ unary_prior: false
+ bev_net:
+ num_blocks: 4
+ latent_dim: 128
+ output_dim: 8
+ confidence: true
+experiment:
+ name: maplocanet_0906_diffhight
+ gpus: 6
+ seed: 0
+training:
+ lr: 0.0001
+ lr_scheduler: null
+ finetune_from_checkpoint: null
+ trainer:
+ val_check_interval: 1000
+ log_every_n_steps: 100
+# limit_val_batches: 1000
+ max_steps: 200000
+ devices: ${experiment.gpus}
+ checkpointing:
+ monitor: "loss/total/val"
+ save_top_k: 10
+ mode: min
+
+# filename: '{epoch}-{step}-{loss_SanFrancisco:.2f}'
\ No newline at end of file
diff --git a/dataset/UAV/dataset.py b/dataset/UAV/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..1dbc8121bb355dcbee2ed7f2bc68e85e25b2b808
--- /dev/null
+++ b/dataset/UAV/dataset.py
@@ -0,0 +1,116 @@
+import torch
+from torch.utils.data import Dataset
+import os
+import cv2
+# @Time : 2023-02-13 22:56
+# @Author : Wang Zhen
+# @Email : frozenzhencola@163.com
+# @File : SatelliteTool.py
+# @Project : TGRS_seqmatch_2023_1
+import numpy as np
+import random
+from utils.geo import BoundaryBox, Projection
+from osm.tiling import TileManager,MapTileManager
+from pathlib import Path
+from torchvision import transforms
+from torch.utils.data import DataLoader
+
+class UavMapPair(Dataset):
+ def __init__(
+ self,
+ root: Path,
+ city:str,
+ training:bool,
+ transform
+ ):
+ super().__init__()
+
+ # self.root = root
+
+ # city = 'Manhattan'
+ # root = '/root/DATASET/CrossModel/'
+ # root=Path(root)
+ self.uav_image_path = root/city/'uav'
+ self.map_path = root/city/'map'
+ self.map_vis = root / city / 'map_vis'
+ info_path = root / city / 'info.csv'
+
+ self.info = np.loadtxt(str(info_path), dtype=str, delimiter=",", skiprows=1)
+
+ self.transform=transform
+ self.training=training
+
+ def random_center_crop(self,image):
+ height, width = image.shape[:2]
+
+ # 随机生成剪裁尺寸
+ crop_size = random.randint(min(height, width) // 2, min(height, width))
+
+ # 计算剪裁的起始坐标
+ start_x = (width - crop_size) // 2
+ start_y = (height - crop_size) // 2
+
+ # 进行剪裁
+ cropped_image = image[start_y:start_y + crop_size, start_x:start_x + crop_size]
+
+ return cropped_image
+ def __getitem__(self, index: int):
+ id, uav_name, map_name, \
+ uav_long, uav_lat, \
+ map_long, map_lat, \
+ tile_size_meters, pixel_per_meter, \
+ u, v, yaw,dis=self.info[index]
+
+
+ uav_image=cv2.imread(str(self.uav_image_path/uav_name))
+ if self.training:
+ uav_image =self.random_center_crop(uav_image)
+ uav_image=cv2.cvtColor(uav_image,cv2.COLOR_BGR2RGB)
+ if self.transform:
+ uav_image=self.transform(uav_image)
+ map=np.load(str(self.map_path/map_name))
+
+ return {
+ 'map':torch.from_numpy(np.ascontiguousarray(map)).long(),
+ 'image':torch.tensor(uav_image),
+ 'roll_pitch_yaw':torch.tensor((0, 0, float(yaw))).float(),
+ 'pixels_per_meter':torch.tensor(float(pixel_per_meter)).float(),
+ "uv":torch.tensor([float(u), float(v)]).float(),
+ }
+ def __len__(self):
+ return len(self.info)
+if __name__ == '__main__':
+
+ root=Path('/root/DATASET/OrienterNet/UavMap/')
+ city='NewYork'
+
+ transform = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Resize(256),
+ transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
+ ])
+
+ dataset=UavMapPair(
+ root=root,
+ city=city,
+ transform=transform
+ )
+ datasetloder = DataLoader(dataset, batch_size=3)
+ for batch, i in enumerate(datasetloder):
+ pass
+ # 将PyTorch张量转换为PIL图像
+ # pil_image = Image.fromarray(i['uav_image'][0].permute(1, 2, 0).byte().numpy())
+
+ # 显示图像
+ # 将PyTorch张量转换为NumPy数组
+ # numpy_array = i['uav_image'][0].numpy()
+ #
+ # # 显示图像
+ # plt.imshow(numpy_array.transpose(1, 2, 0))
+ # plt.axis('off')
+ # plt.show()
+ #
+ # map_viz, label = Colormap.apply(i['map'][0])
+ # map_viz = map_viz * 255
+ # map_viz = map_viz.astype(np.uint8)
+ # plot_images([map_viz], titles=["OpenStreetMap raster"])
diff --git a/dataset/UAV/prepara_dataset.py b/dataset/UAV/prepara_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ed607eafd6e8678688570db999755d136f381db
--- /dev/null
+++ b/dataset/UAV/prepara_dataset.py
@@ -0,0 +1,270 @@
+import torch
+from torch.utils.data import Dataset
+import os
+import cv2
+# @Time : 2023-02-13 22:56
+# @Author : Wang Zhen
+# @Email : frozenzhencola@163.com
+# @File : SatelliteTool.py
+# @Project : TGRS_seqmatch_2023_1
+import numpy as np
+import random
+from utils.geo import BoundaryBox, Projection
+from osm.tiling import TileManager,MapTileManager
+from pathlib import Path
+from torchvision import transforms
+from tqdm import tqdm
+import time
+import math
+import random
+from geopy import Point, distance
+from osm.viz import Colormap, plot_nodes
+
+def generate_random_coordinate(latitude, longitude, dis):
+ # 生成一个随机方向角
+ random_angle = random.uniform(0, 360)
+ # print("random_angle",random_angle)
+ # 计算目标点的经纬度
+ start_point = Point(latitude, longitude)
+ destination = distance.distance(kilometers=dis/1000).destination(start_point, random_angle)
+
+ return destination.latitude, destination.longitude
+
+def rotate_corp(src,angle):
+ # 原图的高、宽 以及通道数
+ rows, cols, channel = src.shape
+
+ # 绕图像的中心旋转
+ # 参数:旋转中心 旋转度数 scale
+ M = cv2.getRotationMatrix2D((cols / 2, rows / 2), angle, 1)
+ # rows, cols=700,700
+ # 自适应图片边框大小
+ cos = np.abs(M[0, 0])
+ sin = np.abs(M[0, 1])
+ new_w = rows * sin + cols * cos
+ new_h = rows * cos + cols * sin
+ M[0, 2] += (new_w - cols) * 0.5
+ M[1, 2] += (new_h - rows) * 0.5
+ w = int(np.round(new_w))
+ h = int(np.round(new_h))
+ rotated = cv2.warpAffine(src, M, (w, h))
+
+ # rotated = cv2.warpAffine(src, M, (cols, rows))
+
+ c=int(w / 2)
+ w=int(rows*math.sqrt(2)/4)
+ rotated2=rotated[c-w:c+w,c-w:c+w,:]
+ return rotated2
+
+class SatelliteGeoTools:
+ """
+ 用于读取卫星图tfw文件,执行 像素坐标-Mercator-GPS坐标 的转化
+ """
+ def __init__(self, tfw_path):
+ self.SatelliteParameter=self.Parsetfw(tfw_path)
+ def Parsetfw(self, tfw_path):
+ info = []
+ f = open(tfw_path)
+ for _ in range(6):
+ line = f.readline()
+ line = line.strip('\n')
+ info.append(float(line))
+ f.close()
+ return info
+ def Pix2Geo(self, x, y):
+ A, D, B, E, C, F = self.SatelliteParameter
+ x1 = A * x + B * y + C
+ y1 = D * x + E * y + F
+ # print(x1,y1)
+ s_long, s_lat = self.MercatorTolonlat(x1, y1)
+ return s_long, s_lat
+
+ def Geo2Pix(self, lon, lat):
+ """
+ https://baike.baidu.com/item/TFW%E6%A0%BC%E5%BC%8F/6273151?fr=aladdin
+ x'=Ax+By+C
+ y'=Dx+Ey+F
+ :return:
+ """
+ x1, y1 = self.LonlatToMercator(lon, lat)
+ A, D, B, E, C, F = self.SatelliteParameter
+ M = np.array([[A, B, C],
+ [D, E, F],
+ [0, 0, 1]])
+ M_INV = np.linalg.inv(M)
+ XY = np.matmul(M_INV, np.array([x1, y1, 1]).T)
+ return int(XY[0]), int(XY[1])
+ def MercatorTolonlat(self,mx,my):
+ x = mx/20037508.3427892*180
+ y = my/20037508.3427892*180
+ # y= 180/math.pi*(2*math.atan(math.exp(y*math.pi/180))-math.pi/2)
+ y = 180.0 / np.pi * (2.0 * np.arctan(np.exp(y * np.pi / 180.0)) - np.pi / 2.0)
+ return x,y
+ def LonlatToMercator(self,lon, lat):
+ x = lon * 20037508.342789 / 180
+ y = np.log(np.tan((90 + lat) * np.pi / 360)) / (np.pi / 180)
+ y = y * 20037508.34789 / 180
+ return x, y
+
+def geodistance(lng1, lat1, lng2, lat2):
+ lng1, lat1, lng2, lat2 = map(np.radians, [lng1, lat1, lng2, lat2])
+ dlon = lng2 - lng1
+ dlat = lat2 - lat1
+ a = np.sin(dlat / 2) ** 2 + np.cos(lat1) * np.cos(lat2) * np.sin(dlon / 2) ** 2
+ distance = 2 * np.arcsin(np.sqrt(a)) * 6371 * 1000 # 地球平均半径,6371km
+ return distance
+
+class PreparaDataset:
+ def __init__(
+ self,
+ root: Path,
+ city:str,
+ patch_size:int,
+ tile_size_meters:float
+ ):
+ super().__init__()
+
+ # self.root = root
+
+ # city = 'Manhattan'
+ # root = '/root/DATASET/CrossModel/'
+ imagepath = root/city/ '{}.tif'.format(city)
+ tfwpath = root/city/'{}.tfw'.format(city)
+
+ self.osmpath = root/city/'{}.osm'.format(city)
+
+ self.TileManager=MapTileManager(self.osmpath)
+ image = cv2.imread(str(imagepath))
+ self.image=cv2.cvtColor(image,cv2.COLOR_BGR2RGB)
+
+ self.ST = SatelliteGeoTools(str(tfwpath))
+
+ self.patch_size=patch_size
+ self.tile_size_meters=tile_size_meters
+
+
+
+ def get_osm(self,prior_latlon,uav_latlon):
+ latlon = np.array(prior_latlon)
+ proj = Projection(*latlon)
+ center = proj.project(latlon)
+
+ uav_latlon=np.array(uav_latlon)
+
+ XY=proj.project(uav_latlon)
+ # tile_size_meters = 128
+ bbox = BoundaryBox(center, center) + self.tile_size_meters
+ # bbox= BoundaryBox(center, center)
+ # Query OpenStreetMap for this area
+ self.pixel_per_meter = 1
+ start_time = time.time()
+ canvas = self.TileManager.from_bbox(proj, bbox, self.pixel_per_meter)
+ end_time = time.time()
+ execution_time = end_time - start_time
+ # print("方法执行时间:", execution_time, "秒")
+ # canvas = tiler.query(bbox)
+ XY=[XY[0]+self.tile_size_meters,-XY[1]+self.tile_size_meters]
+ return canvas,XY
+ def random_corp(self):
+
+ # 根据随机裁剪尺寸计算出裁剪区域的左上角坐标
+ x = random.randint(1000, self.image.shape[1] - self.patch_size-1000)
+ y = random.randint(1000, self.image.shape[0] - self.patch_size-1000)
+ x1 = x + self.patch_size
+ y1 = y + self.patch_size
+ return x,x1,y,y1
+
+ def generate(self):
+ x,x1,y,y1 = self.random_corp()
+ uav_center_x,uav_center_y=int((x+x1)//2),int((y+y1)//2)
+ uav_center_long,uav_center_lat=self.ST.Pix2Geo(uav_center_x,uav_center_y)
+ # print(uav_center_long,uav_center_lat)
+ self.image_patch = self.image[y:y1, x:x1]
+
+ map_center_lat, map_center_long = generate_random_coordinate(uav_center_lat, uav_center_long, self.tile_size_meters)
+ map,XY=self.get_osm([map_center_lat,map_center_long],[uav_center_lat, uav_center_long])
+
+
+ yaw=np.random.random()*360
+ self.image_patch=rotate_corp(self.image_patch,yaw)
+ # return self.image_patch,self.osm_patch
+ # XY=[X+self.tile_size_meters
+ return {
+ 'uav_image':self.image_patch,
+ 'uav_long_lat':[uav_center_long,uav_center_lat],
+ 'map_long_lat': [map_center_long,map_center_lat],
+ 'tile_size_meters': map.raster.shape[1],
+ 'pixel_per_meter':self.pixel_per_meter,
+ 'yaw':yaw,
+ 'map':map.raster,
+ "uv":XY
+ }
+if __name__ == '__main__':
+
+ import argparse
+
+ parser = argparse.ArgumentParser(description='manual to this script')
+ parser.add_argument('--city', type=str, default=None,required=True)
+ parser.add_argument('--num', type=int, default=10000)
+ args = parser.parse_args()
+
+
+ root=Path('/root/DATASET/OrienterNet/UavMap/')
+ city=args.city
+ dataset = PreparaDataset(
+ root=root,
+ city=city,
+ patch_size=512,
+ tile_size_meters=128,
+ )
+
+ uav_path=root/city/'uav'
+ if not uav_path.exists():
+ uav_path.mkdir(parents=True)
+
+ map_path = root / city / 'map'
+ if not map_path.exists():
+ map_path.mkdir(parents=True)
+
+ map_vis_path = root / city / 'map_vis'
+ if not map_vis_path.exists():
+ map_vis_path.mkdir(parents=True)
+
+ info_path = root / city / 'info.csv'
+
+ # num=1000
+ num = args.num
+ info=[['id','uav_name','map_name','uav_long','uav_lat','map_long','map_lat','tile_size_meters','pixel_per_meter','u','v','yaw']]
+ # info =[]
+ for i in tqdm(range(num)):
+ data=dataset.generate()
+ # print(str(uav_path/"{:05d}.jpg".format(i)))
+
+ cv2.imwrite(str(uav_path/"{:05d}.jpg".format(i)),cv2.cvtColor(data['uav_image'],cv2.COLOR_RGB2BGR))
+
+ np.save(str(map_path/"{:05d}.npy".format(i)),data['map'])
+
+ map_viz, label = Colormap.apply(data['map'])
+ map_viz = map_viz * 255
+ map_viz = map_viz.astype(np.uint8)
+ cv2.imwrite(str(map_vis_path / "{:05d}.jpg".format(i)), cv2.cvtColor(map_viz, cv2.COLOR_RGB2BGR))
+
+
+ uav_center_long, uav_center_lat=data['uav_long_lat']
+ map_center_long, map_center_lat = data['map_long_lat']
+ info.append([
+ i,
+ "{:05d}.jpg".format(i),
+ "{:05d}.npy".format(i),
+ uav_center_long,
+ uav_center_lat,
+ map_center_long,
+ map_center_lat,
+ data["tile_size_meters"],
+ data["pixel_per_meter"],
+ data['uv'][0],
+ data['uv'][1],
+ data['yaw']
+ ])
+ # print(info)
+ np.savetxt(info_path,info,delimiter=',',fmt="%s")
\ No newline at end of file
diff --git a/dataset/__init__.py b/dataset/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b51bfd2ed3eaf38719d8ee102df779d53d1ffa4
--- /dev/null
+++ b/dataset/__init__.py
@@ -0,0 +1,4 @@
+# from .UAV.dataset import UavMapPair
+from .dataset import UavMapDatasetModule
+
+# modules = {"UAV": UavMapPair}
diff --git a/dataset/dataset.py b/dataset/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..ec79057f559178a4c7508ab2970bb963dcf4e09a
--- /dev/null
+++ b/dataset/dataset.py
@@ -0,0 +1,93 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+from copy import deepcopy
+from pathlib import Path
+from typing import Any, Dict, List
+# from logger import logger
+import numpy as np
+# import torch
+# import torch.utils.data as torchdata
+# import torchvision.transforms as tvf
+from omegaconf import DictConfig, OmegaConf
+import pytorch_lightning as pl
+from dataset.UAV.dataset import UavMapPair
+# from torch.utils.data import Dataset, DataLoader
+# from torchvision import transforms
+from torch.utils.data import Dataset, ConcatDataset
+from torch.utils.data import Dataset, DataLoader, random_split
+import torchvision.transforms as tvf
+
+# 自定义数据模块类,继承自pl.LightningDataModule
+class UavMapDatasetModule(pl.LightningDataModule):
+
+
+ def __init__(self, cfg: Dict[str, Any]):
+ super().__init__()
+
+ # default_cfg = OmegaConf.create(self.default_cfg)
+ # OmegaConf.set_struct(default_cfg, True) # cannot add new keys
+ # self.cfg = OmegaConf.merge(default_cfg, cfg)
+ self.cfg=cfg
+ # self.transform = tvf.Compose([
+ # tvf.ToTensor(),
+ # tvf.Resize(self.cfg.image_size),
+ # tvf.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
+ # ])
+
+ tfs = []
+ tfs.append(tvf.ToTensor())
+ tfs.append(tvf.Resize(self.cfg.image_size))
+ self.val_tfs = tvf.Compose(tfs)
+
+ # transforms.Resize(self.cfg.image_size),
+ if cfg.augmentation.image.apply:
+ args = OmegaConf.masked_copy(
+ cfg.augmentation.image, ["brightness", "contrast", "saturation", "hue"]
+ )
+ tfs.append(tvf.ColorJitter(**args))
+ self.train_tfs = tvf.Compose(tfs)
+
+ # self.train_tfs=self.transform
+ # self.val_tfs = self.transform
+ self.init()
+ def init(self):
+ self.train_dataset = ConcatDataset([
+ UavMapPair(root=Path(self.cfg.root),city=city,training=True,transform=self.train_tfs)
+ for city in self.cfg.train_citys
+ ])
+
+ self.val_dataset = ConcatDataset([
+ UavMapPair(root=Path(self.cfg.root),city=city,training=False,transform=self.val_tfs)
+ for city in self.cfg.val_citys
+ ])
+
+ # self.val_datasets = {
+ # city:UavMapPair(root=Path(self.cfg.root),city=city,transform=self.val_tfs)
+ # for city in self.cfg.val_citys
+ # }
+ # logger.info("train data len:{},val data len:{}".format(len(self.train_dataset),len(self.val_dataset)))
+ # # 定义分割比例
+ # train_ratio = 0.8 # 训练集比例
+ # # 计算分割的样本数量
+ # train_size = int(len(self.dataset) * train_ratio)
+ # val_size = len(self.dataset) - train_size
+ # self.train_dataset, self.val_dataset = random_split(self.dataset, [train_size, val_size])
+ def train_dataloader(self):
+ train_loader = DataLoader(self.train_dataset,
+ batch_size=self.cfg.train.batch_size,
+ num_workers=self.cfg.train.num_workers,
+ shuffle=True,pin_memory = True)
+ return train_loader
+
+ def val_dataloader(self):
+ val_loader = DataLoader(self.val_dataset,
+ batch_size=self.cfg.val.batch_size,
+ num_workers=self.cfg.val.num_workers,
+ shuffle=True,pin_memory = True)
+ #
+ # my_dict = {k: v for k, v in self.val_datasets}
+ # val_loaders={city: DataLoader(dataset,
+ # batch_size=self.cfg.val.batch_size,
+ # num_workers=self.cfg.val.num_workers,
+ # shuffle=False,pin_memory = True) for city, dataset in self.val_datasets.items()}
+ return val_loader
diff --git a/dataset/image.py b/dataset/image.py
new file mode 100644
index 0000000000000000000000000000000000000000..75b3dc68cc2481150c5ff938483ae640956bcf0d
--- /dev/null
+++ b/dataset/image.py
@@ -0,0 +1,140 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+from typing import Callable, Optional, Union, Sequence
+
+import numpy as np
+import torch
+import torchvision.transforms.functional as tvf
+import collections
+from scipy.spatial.transform import Rotation
+
+from utils.geometry import from_homogeneous, to_homogeneous
+from utils.wrappers import Camera
+
+
+def rectify_image(
+ image: torch.Tensor,
+ cam: Camera,
+ roll: float,
+ pitch: Optional[float] = None,
+ valid: Optional[torch.Tensor] = None,
+):
+ *_, h, w = image.shape
+ grid = torch.meshgrid(
+ [torch.arange(w, device=image.device), torch.arange(h, device=image.device)],
+ indexing="xy",
+ )
+ grid = torch.stack(grid, -1).to(image.dtype)
+
+ if pitch is not None:
+ args = ("ZX", (roll, pitch))
+ else:
+ args = ("Z", roll)
+ R = Rotation.from_euler(*args, degrees=True).as_matrix()
+ R = torch.from_numpy(R).to(image)
+
+ grid_rect = to_homogeneous(cam.normalize(grid)) @ R.T
+ grid_rect = cam.denormalize(from_homogeneous(grid_rect))
+ grid_norm = (grid_rect + 0.5) / grid.new_tensor([w, h]) * 2 - 1
+ rectified = torch.nn.functional.grid_sample(
+ image[None],
+ grid_norm[None],
+ align_corners=False,
+ mode="bilinear",
+ ).squeeze(0)
+ if valid is None:
+ valid = torch.all((grid_norm >= -1) & (grid_norm <= 1), -1)
+ else:
+ valid = (
+ torch.nn.functional.grid_sample(
+ valid[None, None].float(),
+ grid_norm[None],
+ align_corners=False,
+ mode="nearest",
+ )[0, 0]
+ > 0
+ )
+ return rectified, valid
+
+
+def resize_image(
+ image: torch.Tensor,
+ size: Union[int, Sequence, np.ndarray],
+ fn: Optional[Callable] = None,
+ camera: Optional[Camera] = None,
+ valid: np.ndarray = None,
+):
+ """Resize an image to a fixed size, or according to max or min edge."""
+ *_, h, w = image.shape
+ if fn is not None:
+ assert isinstance(size, int)
+ scale = size / fn(h, w)
+ h_new, w_new = int(round(h * scale)), int(round(w * scale))
+ scale = (scale, scale)
+ else:
+ if isinstance(size, (collections.abc.Sequence, np.ndarray)):
+ w_new, h_new = size
+ elif isinstance(size, int):
+ w_new = h_new = size
+ else:
+ raise ValueError(f"Incorrect new size: {size}")
+ scale = (w_new / w, h_new / h)
+ if (w, h) != (w_new, h_new):
+ mode = tvf.InterpolationMode.BILINEAR
+ image = tvf.resize(image, (h_new, w_new), interpolation=mode, antialias=True)
+ image.clip_(0, 1)
+ if camera is not None:
+ camera = camera.scale(scale)
+ if valid is not None:
+ valid = tvf.resize(
+ valid.unsqueeze(0),
+ (h_new, w_new),
+ interpolation=tvf.InterpolationMode.NEAREST,
+ ).squeeze(0)
+ ret = [image, scale]
+ if camera is not None:
+ ret.append(camera)
+ if valid is not None:
+ ret.append(valid)
+ return ret
+
+
+def pad_image(
+ image: torch.Tensor,
+ size: Union[int, Sequence, np.ndarray],
+ camera: Optional[Camera] = None,
+ valid: torch.Tensor = None,
+ crop_and_center: bool = False,
+):
+ if isinstance(size, int):
+ w_new = h_new = size
+ elif isinstance(size, (collections.abc.Sequence, np.ndarray)):
+ w_new, h_new = size
+ else:
+ raise ValueError(f"Incorrect new size: {size}")
+ *c, h, w = image.shape
+ if crop_and_center:
+ diff = np.array([w - w_new, h - h_new])
+ left, top = left_top = np.round(diff / 2).astype(int)
+ right, bottom = diff - left_top
+ else:
+ assert h <= h_new
+ assert w <= w_new
+ top = bottom = left = right = 0
+ slice_out = np.s_[..., : min(h, h_new), : min(w, w_new)]
+ slice_in = np.s_[
+ ..., max(top, 0) : h - max(bottom, 0), max(left, 0) : w - max(right, 0)
+ ]
+ if (w, h) == (w_new, h_new):
+ out = image
+ else:
+ out = torch.zeros((*c, h_new, w_new), dtype=image.dtype)
+ out[slice_out] = image[slice_in]
+ if camera is not None:
+ camera = camera.crop((max(left, 0), max(top, 0)), (w_new, h_new))
+ out_valid = torch.zeros((h_new, w_new), dtype=torch.bool)
+ out_valid[slice_out] = True if valid is None else valid[slice_in]
+ if camera is not None:
+ return out, out_valid, camera
+ else:
+ return out, out_valid
diff --git a/dataset/torch.py b/dataset/torch.py
new file mode 100644
index 0000000000000000000000000000000000000000..9547ca149c606a345e5b8916591e43c26031022c
--- /dev/null
+++ b/dataset/torch.py
@@ -0,0 +1,111 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+import collections
+import os
+
+import torch
+from torch.utils.data import get_worker_info
+from torch.utils.data._utils.collate import (
+ default_collate_err_msg_format,
+ np_str_obj_array_pattern,
+)
+from lightning_fabric.utilities.seed import pl_worker_init_function
+from lightning_utilities.core.apply_func import apply_to_collection
+from lightning_fabric.utilities.apply_func import move_data_to_device
+
+
+def collate(batch):
+ """Difference with PyTorch default_collate: it can stack other tensor-like objects.
+ Adapted from PixLoc, Paul-Edouard Sarlin, ETH Zurich
+ https://github.com/cvg/pixloc
+ Released under the Apache License 2.0
+ """
+ if not isinstance(batch, list): # no batching
+ return batch
+ elem = batch[0]
+ elem_type = type(elem)
+ if isinstance(elem, torch.Tensor):
+ out = None
+ if torch.utils.data.get_worker_info() is not None:
+ # If we're in a background process, concatenate directly into a
+ # shared memory tensor to avoid an extra copy
+ numel = sum(x.numel() for x in batch)
+ storage = elem.storage()._new_shared(numel, device=elem.device)
+ out = elem.new(storage).resize_(len(batch), *list(elem.size()))
+ return torch.stack(batch, 0, out=out)
+ elif (
+ elem_type.__module__ == "numpy"
+ and elem_type.__name__ != "str_"
+ and elem_type.__name__ != "string_"
+ ):
+ if elem_type.__name__ == "ndarray" or elem_type.__name__ == "memmap":
+ # array of string classes and object
+ if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
+ raise TypeError(default_collate_err_msg_format.format(elem.dtype))
+
+ return collate([torch.as_tensor(b) for b in batch])
+ elif elem.shape == (): # scalars
+ return torch.as_tensor(batch)
+ elif isinstance(elem, float):
+ return torch.tensor(batch, dtype=torch.float64)
+ elif isinstance(elem, int):
+ return torch.tensor(batch)
+ elif isinstance(elem, (str, bytes)):
+ return batch
+ elif isinstance(elem, collections.abc.Mapping):
+ return {key: collate([d[key] for d in batch]) for key in elem}
+ elif isinstance(elem, tuple) and hasattr(elem, "_fields"): # namedtuple
+ return elem_type(*(collate(samples) for samples in zip(*batch)))
+ elif isinstance(elem, collections.abc.Sequence):
+ # check to make sure that the elements in batch have consistent size
+ it = iter(batch)
+ elem_size = len(next(it))
+ if not all(len(elem) == elem_size for elem in it):
+ raise RuntimeError("each element in list of batch should be of equal size")
+ transposed = zip(*batch)
+ return [collate(samples) for samples in transposed]
+ else:
+ # try to stack anyway in case the object implements stacking.
+ try:
+ return torch.stack(batch, 0)
+ except TypeError as e:
+ if "expected Tensor as element" in str(e):
+ return batch
+ else:
+ raise e
+
+
+def set_num_threads(nt):
+ """Force numpy and other libraries to use a limited number of threads."""
+ try:
+ import mkl
+ except ImportError:
+ pass
+ else:
+ mkl.set_num_threads(nt)
+ torch.set_num_threads(1)
+ os.environ["IPC_ENABLE"] = "1"
+ for o in [
+ "OPENBLAS_NUM_THREADS",
+ "NUMEXPR_NUM_THREADS",
+ "OMP_NUM_THREADS",
+ "MKL_NUM_THREADS",
+ ]:
+ os.environ[o] = str(nt)
+
+
+def worker_init_fn(i):
+ info = get_worker_info()
+ pl_worker_init_function(info.id)
+ num_threads = info.dataset.cfg.get("num_threads")
+ if num_threads is not None:
+ set_num_threads(num_threads)
+
+
+def unbatch_to_device(data, device="cpu"):
+ data = move_data_to_device(data, device)
+ data = apply_to_collection(data, torch.Tensor, lambda x: x.squeeze(0))
+ data = apply_to_collection(
+ data, list, lambda x: x[0] if len(x) == 1 and isinstance(x[0], str) else x
+ )
+ return data
diff --git a/demo.py b/demo.py
new file mode 100644
index 0000000000000000000000000000000000000000..32b54b259de97aa0253b5a81379cd080897ace29
--- /dev/null
+++ b/demo.py
@@ -0,0 +1,354 @@
+
+import matplotlib.pyplot as plt
+# from demo import Demo, read_input_image,read_input_image_test
+from evaluation.viz import plot_example_single
+from dataset.torch import unbatch_to_device
+import matplotlib.pyplot as plt
+from typing import Optional, Tuple
+import cv2
+import torch
+import numpy as np
+import time
+from logger import logger
+from evaluation.run import resolve_checkpoint_path, pretrained_models
+from models.maplocnet import MapLocNet
+from models.voting import fuse_gps, argmax_xyr
+# from data.image import resize_image, pad_image, rectify_image
+from osm.raster import Canvas
+from utils.wrappers import Camera
+from utils.io import read_image
+from utils.geo import BoundaryBox, Projection
+from utils.exif import EXIF
+import requests
+from pathlib import Path
+from utils.exif import EXIF
+from dataset.image import resize_image, pad_image, rectify_image
+# from maploc.demo import Demo, read_input_image
+from dataset import UavMapDatasetModule
+import torchvision.transforms as tvf
+import matplotlib.pyplot as plt
+import numpy as np
+from sklearn.decomposition import PCA
+from PIL import Image
+# import pyproj
+# Query OpenStreetMap for this area
+from osm.tiling import TileManager
+from utils.viz_localization import (
+ likelihood_overlay,
+ plot_dense_rotations,
+ add_circle_inset,
+)
+# Show the inputs to the model: image and raster map
+from osm.viz import Colormap, plot_nodes
+from utils.viz_2d import plot_images
+
+from utils.viz_2d import features_to_RGB
+import random
+from geopy.distance import geodesic
+
+
+def vis_image_feature(F):
+ def normalize(x):
+ return x / np.linalg.norm(x, axis=-1, keepdims=True)
+
+ # F=neural_map.numpy()
+ F = F[:, 0:180, 0:180]
+ flatten = []
+ c, h, w = F.shape
+ print(F.shape)
+ F = np.rollaxis(F, 0, 3)
+ F_flat = F.reshape(-1, c)
+ flatten.append(F_flat)
+ flatten = normalize(flatten)[0]
+
+ flatten = np.nan_to_num(flatten, nan=0)
+ pca = PCA(n_components=3)
+
+ print(flatten.shape)
+ flatten = pca.fit_transform(flatten)
+ flatten = (normalize(flatten) + 1) / 2
+
+ # h, w = F.shape[-2:]
+ F_rgb, flatten = np.split(flatten, [h * w], axis=0)
+ F_rgb = F_rgb.reshape((h, w, 3))
+ return F_rgb
+def distance(lat1, lon1, lat2, lon2):
+ point1 = (lat1, lon1)
+ point2 = (lat2, lon2)
+ distance_km = geodesic(point1, point2).meters
+ return distance_km
+
+# # 示例
+# lat1, lon1 = 39.9, 116.4 # 北京的经纬度
+# lat2, lon2 = 31.2, 121.5 # 上海的经纬度
+
+# distance_km = distance(lat1, lon1, lat2, lon2)
+# print(distance_km)
+def show_result(map_vis_image, pre_uv, pre_yaw):
+ # 创建一个和原始图片大小相同的灰色蒙版图像
+ gray_mask = np.zeros_like(map_vis_image)
+ gray_mask.fill(128) # 填充灰色
+
+ # 将灰色蒙版图像与原始图像进行融合
+ image = cv2.addWeighted(map_vis_image, 1, gray_mask, 0, 0)
+ # 绘制真实值
+
+ # 绘制预测值
+ u, v = pre_uv
+ x1, y1 = int(u), int(v) # 替换为实际的起点坐标
+ angle = pre_yaw - 90 # 替换为实际的箭头角度
+ # 计算箭头的终点坐标
+ length = 20
+ x2 = int(x1 + length * np.cos(np.radians(angle)))
+ y2 = int(y1 + length * np.sin(np.radians(angle)))
+ # 在图像上画出箭头
+ cv2.arrowedLine(image, (x1, y1), (x2, y2), (0, 0, 0), 2, 5, 0, 0.3)
+ # cv2.circle(image, (x1, y1), radius=2, color=(255, 0, 255), thickness=-1)
+ return image
+
+
+def xyz_to_latlon(x, y, z):
+ # 定义WGS84投影
+ wgs84 = pyproj.CRS('EPSG:4326')
+
+ # 定义XYZ投影
+ xyz = pyproj.CRS(f'+proj=geocent +datum=WGS84 +units=m +no_defs')
+
+ # 创建坐标转换器
+ transformer = pyproj.Transformer.from_crs(xyz, wgs84)
+
+ # 转换坐标
+ lon, lat, _ = transformer.transform(x, y, z)
+
+ return lat, lon
+
+
+class Demo:
+ def __init__(
+ self,
+ experiment_or_path: Optional[str] = "OrienterNet_MGL",
+ device=None,
+ **kwargs
+ ):
+ if experiment_or_path in pretrained_models:
+ experiment_or_path, _ = pretrained_models[experiment_or_path]
+ path = resolve_checkpoint_path(experiment_or_path)
+ ckpt = torch.load(path, map_location=(lambda storage, loc: storage))
+ config = ckpt["hyper_parameters"]
+ config.model.update(kwargs)
+ config.model.image_encoder.backbone.pretrained = False
+
+ model = MapLocNet(config.model).eval()
+ state = {k[len("model."):]: v for k, v in ckpt["state_dict"].items()}
+ model.load_state_dict(state, strict=True)
+ if device is None:
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ model = model.to(device)
+
+ self.model = model
+ self.config = config
+ self.device = device
+
+ def prepare_data(
+ self,
+ image: np.ndarray,
+ camera: Camera,
+ canvas: Canvas,
+ roll_pitch: Optional[Tuple[float]] = None,
+ ):
+
+ image = torch.from_numpy(image).permute(2, 0, 1).float().div_(255)
+
+ return {
+ 'map': torch.from_numpy(canvas.raster).long(),
+ 'image': image,
+ # 'roll_pitch_yaw':torch.tensor((0, 0, float(yaw))).float().unsqueeze(0),
+ # 'pixels_per_meter':torch.tensor(float(pixel_per_meter)).float().unsqueeze(0),
+ # "uv":torch.tensor([float(u), float(v)]).float().unsqueeze(0),
+ }
+ # return dict(
+ # image=image,
+ # map=torch.from_numpy(canvas.raster).long(),
+ # camera=camera.float(),
+ # valid=valid,
+ # )
+
+ def localize(self, image: np.ndarray, camera: Camera, canvas: Canvas, **kwargs):
+
+ data = self.prepare_data(image, camera, canvas, **kwargs)
+ data_ = {k: v.to(self.device)[None] for k, v in data.items()}
+ # data_np = {k: v.cpu().numpy()[None] for k, v in data.items()}
+ # logger.info(data_)
+ # np.save(data_np, 'data_.npy')
+ start = time.time()
+ with torch.no_grad():
+ pred = self.model(data_)
+
+ end = time.time()
+ xy_gps = canvas.bbox.center
+ uv_gps = torch.from_numpy(canvas.to_uv(xy_gps))
+
+ lp_xyr = pred["log_probs"].squeeze(0)
+ # tile_size = canvas.bbox.size.min() / 2
+ # sigma = tile_size - 20 # 20 meters margin
+ # lp_xyr = fuse_gps(
+ # lp_xyr,
+ # uv_gps.to(lp_xyr),
+ # self.config.model.pixel_per_meter,
+ # sigma=sigma,
+ # )
+ xyr = argmax_xyr(lp_xyr).cpu()
+
+ prob = lp_xyr.exp().cpu()
+ neural_map = pred["map"]["map_features"][0].squeeze(0).cpu()
+ print('total time:', start - end)
+ return xyr[:2], xyr[2], prob, neural_map, data["image"], data_, pred
+
+
+def load_test_data(
+ root: Path,
+ city: str,
+ index: int,
+):
+ uav_image_path = root / city / 'uav'
+ map_path = root / city / 'map'
+ map_vis = root / city / 'map_vis'
+ info_path = root / city / 'info.csv'
+ osm_path = root / city / '{}.osm'.format(city)
+
+ info = np.loadtxt(str(info_path), dtype=str, delimiter=",", skiprows=1)
+
+ id, uav_name, map_name, \
+ uav_long, uav_lat, \
+ map_long, map_lat, \
+ tile_size_meters, pixel_per_meter, \
+ u, v, yaw, dis = info[index]
+ print(info[index])
+ uav_image_rgb = cv2.imread(str(uav_image_path / uav_name))
+ uav_image_rgb = cv2.cvtColor(uav_image_rgb, cv2.COLOR_BGR2RGB)
+
+ # w,h,c=uav_image_rgb.shape
+ # # 指定裁剪区域的坐标
+ # x = w//2 # 起始横坐标
+ # y = h//2 # 起始纵坐标
+ # w = 150 # 宽度
+ # h = 150 # 高度
+
+ # # 裁剪图像
+ # uav_image_rgb = uav_image_rgb[y-h:y+h, x-w:x+w]
+
+ map_vis_image = cv2.imread(str(map_vis / uav_name))
+ map_vis_image = cv2.cvtColor(map_vis_image, cv2.COLOR_BGR2RGB)
+
+ map = np.load(str(map_path / map_name))
+
+ tfs = []
+ tfs.append(tvf.ToTensor())
+ tfs.append(tvf.Resize(256))
+ val_tfs = tvf.Compose(tfs)
+
+ uav_image = val_tfs(uav_image_rgb)
+ # print(id, uav_name, map_name, \
+ # uav_long, uav_lat, \
+ # map_long, map_lat, \
+ # tile_size_meters, pixel_per_meter, \
+ # u, v, yaw,dis)
+ uav_path = str(uav_image_path / uav_name)
+ return {
+ 'map': torch.from_numpy(np.ascontiguousarray(map)).long().unsqueeze(0),
+ 'image': torch.tensor(uav_image).unsqueeze(0),
+ 'roll_pitch_yaw': torch.tensor((0, 0, float(yaw))).float().unsqueeze(0),
+ 'pixels_per_meter': torch.tensor(float(pixel_per_meter)).float().unsqueeze(0),
+ "uv": torch.tensor([float(u), float(v)]).float().unsqueeze(0),
+ }, uav_image_rgb, map_vis_image, uav_path, [float(map_lat), float(map_long)]
+
+
+def crop_image(image, width, height):
+ # 计算剪裁区域的起始点坐标
+ x = int((image.shape[1] - width) / 2)
+ y = int((image.shape[0] - height) / 2)
+
+ # 剪裁图像
+ cropped_image = image[y:y + height, x:x + width]
+ return cropped_image
+
+
+def crop_square(image):
+ # 获取图像的宽度和高度
+ height, width = image.shape[:2]
+
+ # 确定最小边的长度
+ min_length = min(height, width)
+
+ # 计算剪裁区域的坐标
+ top = (height - min_length) // 2
+ bottom = top + min_length
+ left = (width - min_length) // 2
+ right = left + min_length
+
+ # 剪裁图像为正方形
+ cropped_image = image[top:bottom, left:right]
+
+ return cropped_image
+def read_input_image_test(
+ image,
+ prior_latlon,
+ tile_size_meters,
+):
+ # image = read_image(image_path)
+ # # 剪裁图像
+ # # 指定剪裁的宽度和高度
+ # width = 1080*2
+ # height =1080*2
+ # image = crop_square(image)
+ # # print("input image:",image.shape)
+ # image = crop_image(image, width, height)
+ # # print("crop_image:",image.shape)
+ image = cv2.resize(image,(256,256))
+ roll_pitch = None
+
+
+ latlon = None
+ if prior_latlon is not None:
+ latlon = prior_latlon
+ logger.info("Using prior latlon %s.", prior_latlon)
+
+ if latlon is None:
+ with open(image_path, "rb") as fid:
+ exif = EXIF(fid, lambda: image.shape[:2])
+ geo = exif.extract_geo()
+ if geo:
+ alt = geo.get("altitude", 0) # read if available
+ latlon = (geo["latitude"], geo["longitude"], alt)
+ logger.info("Using prior location from EXIF.")
+ # print(latlon)
+ else:
+ logger.info("Could not find any prior location in the image EXIF metadata.")
+
+ latlon = np.array(latlon)
+
+ proj = Projection(*latlon)
+ center = proj.project(latlon)
+ bbox = BoundaryBox(center, center) + float(tile_size_meters)
+ camera=None
+ image=cv2.resize(image,(256,256))
+ return image, camera, roll_pitch, proj, bbox, latlon
+if __name__ == '__main__':
+ experiment_or_path = "weight/last-step-checkpointing.ckpt"
+ # experiment_or_path="experiments/maplocanet_0906_diffhight/last-step-checkpointing.ckpt"
+ image_path='images/00000.jpg'
+ prior_latlon=(37.75704325989902,-122.435941445631)
+ tile_size_meters=128
+ demo = Demo(experiment_or_path=experiment_or_path, num_rotations=128, device='cpu')
+ image, camera, gravity, proj, bbox, true_prior_latlon = read_input_image_test(
+ image_path,
+ prior_latlon=prior_latlon,
+ tile_size_meters=tile_size_meters, # try 64, 256, etc.
+ )
+ tiler = TileManager.from_bbox(projection=proj, bbox=bbox + 10,ppm=1, tile_size=tile_size_meters)
+ # tiler = TileManager.from_bbox(projection=proj, bbox=bbox + 10,ppm=1,path=root/city/'{}.osm'.format(city), tile_size=1)
+ canvas = tiler.query(bbox)
+ uv, yaw, prob, neural_map, image_rectified, data_, pred = demo.localize(
+ image, camera, canvas)
+ prior_latlon_pred = proj.unproject(canvas.to_xy(uv))
+ pass
\ No newline at end of file
diff --git a/evaluation/kitti.py b/evaluation/kitti.py
new file mode 100644
index 0000000000000000000000000000000000000000..e91da069f307a533b0471a3fb43f8622cadc60db
--- /dev/null
+++ b/evaluation/kitti.py
@@ -0,0 +1,89 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+import argparse
+from pathlib import Path
+from typing import Optional, Tuple
+
+from omegaconf import OmegaConf, DictConfig
+
+from .. import logger
+from ..data import KittiDataModule
+from .run import evaluate
+
+
+default_cfg_single = OmegaConf.create({})
+# For the sequential evaluation, we need to center the map around the GT location,
+# since random offsets would accumulate and leave only the GT location with a valid mask.
+# This should not have much impact on the results.
+default_cfg_sequential = OmegaConf.create(
+ {
+ "data": {
+ "mask_radius": KittiDataModule.default_cfg["max_init_error"],
+ "prior_range_rotation": KittiDataModule.default_cfg[
+ "max_init_error_rotation"
+ ]
+ + 1,
+ "max_init_error": 0,
+ "max_init_error_rotation": 0,
+ },
+ "chunking": {
+ "max_length": 100, # about 10s?
+ },
+ }
+)
+
+
+def run(
+ split: str,
+ experiment: str,
+ cfg: Optional[DictConfig] = None,
+ sequential: bool = False,
+ thresholds: Tuple[int] = (1, 3, 5),
+ **kwargs,
+):
+ cfg = cfg or {}
+ if isinstance(cfg, dict):
+ cfg = OmegaConf.create(cfg)
+ default = default_cfg_sequential if sequential else default_cfg_single
+ cfg = OmegaConf.merge(default, cfg)
+ dataset = KittiDataModule(cfg.get("data", {}))
+
+ metrics = evaluate(
+ experiment,
+ cfg,
+ dataset,
+ split=split,
+ sequential=sequential,
+ viz_kwargs=dict(show_dir_error=True, show_masked_prob=False),
+ **kwargs,
+ )
+
+ keys = ["directional_error", "yaw_max_error"]
+ if sequential:
+ keys += ["directional_seq_error", "yaw_seq_error"]
+ for k in keys:
+ rec = metrics[k].recall(thresholds).double().numpy().round(2).tolist()
+ logger.info("Recall %s: %s at %s m/°", k, rec, thresholds)
+ return metrics
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--experiment", type=str, required=True)
+ parser.add_argument(
+ "--split", type=str, default="test", choices=["test", "val", "train"]
+ )
+ parser.add_argument("--sequential", action="store_true")
+ parser.add_argument("--output_dir", type=Path)
+ parser.add_argument("--num", type=int)
+ parser.add_argument("dotlist", nargs="*")
+ args = parser.parse_args()
+ cfg = OmegaConf.from_cli(args.dotlist)
+ run(
+ args.split,
+ args.experiment,
+ cfg,
+ args.sequential,
+ output_dir=args.output_dir,
+ num=args.num,
+ )
diff --git a/evaluation/mapillary.py b/evaluation/mapillary.py
new file mode 100644
index 0000000000000000000000000000000000000000..c45b845bacd6d6a9c995d3b8d7ee9cd9ec2a9f78
--- /dev/null
+++ b/evaluation/mapillary.py
@@ -0,0 +1,111 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+import argparse
+from pathlib import Path
+from typing import Optional, Tuple
+
+from omegaconf import OmegaConf, DictConfig
+
+from .. import logger
+from ..conf import data as conf_data_dir
+from ..data import MapillaryDataModule
+from .run import evaluate
+
+
+split_overrides = {
+ "val": {
+ "scenes": [
+ "sanfrancisco_soma",
+ "sanfrancisco_hayes",
+ "amsterdam",
+ "berlin",
+ "lemans",
+ "montrouge",
+ "toulouse",
+ "nantes",
+ "vilnius",
+ "avignon",
+ "helsinki",
+ "milan",
+ "paris",
+ ],
+ },
+}
+data_cfg_train = OmegaConf.load(Path(conf_data_dir.__file__).parent / "mapillary.yaml")
+data_cfg = OmegaConf.merge(
+ data_cfg_train,
+ {
+ "return_gps": True,
+ "add_map_mask": True,
+ "max_init_error": 32,
+ "loading": {"val": {"batch_size": 1, "num_workers": 0}},
+ },
+)
+default_cfg_single = OmegaConf.create({"data": data_cfg})
+default_cfg_sequential = OmegaConf.create(
+ {
+ **default_cfg_single,
+ "chunking": {
+ "max_length": 10,
+ },
+ }
+)
+
+
+def run(
+ split: str,
+ experiment: str,
+ cfg: Optional[DictConfig] = None,
+ sequential: bool = False,
+ thresholds: Tuple[int] = (1, 3, 5),
+ **kwargs,
+):
+ cfg = cfg or {}
+ if isinstance(cfg, dict):
+ cfg = OmegaConf.create(cfg)
+ default = default_cfg_sequential if sequential else default_cfg_single
+ default = OmegaConf.merge(default, split_overrides[split])
+ cfg = OmegaConf.merge(default, cfg)
+ dataset = MapillaryDataModule(cfg.get("data", {}))
+
+ metrics = evaluate(experiment, cfg, dataset, split, sequential=sequential, **kwargs)
+
+ keys = [
+ "xy_max_error",
+ "xy_gps_error",
+ "yaw_max_error",
+ ]
+ if sequential:
+ keys += [
+ "xy_seq_error",
+ "xy_gps_seq_error",
+ "yaw_seq_error",
+ "yaw_gps_seq_error",
+ ]
+ for k in keys:
+ if k not in metrics:
+ logger.warning("Key %s not in metrics.", k)
+ continue
+ rec = metrics[k].recall(thresholds).double().numpy().round(2).tolist()
+ logger.info("Recall %s: %s at %s m/°", k, rec, thresholds)
+ return metrics
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--experiment", type=str, required=True)
+ parser.add_argument("--split", type=str, default="val", choices=["val"])
+ parser.add_argument("--sequential", action="store_true")
+ parser.add_argument("--output_dir", type=Path)
+ parser.add_argument("--num", type=int)
+ parser.add_argument("dotlist", nargs="*")
+ args = parser.parse_args()
+ cfg = OmegaConf.from_cli(args.dotlist)
+ run(
+ args.split,
+ args.experiment,
+ cfg,
+ args.sequential,
+ output_dir=args.output_dir,
+ num=args.num,
+ )
diff --git a/evaluation/run.py b/evaluation/run.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff29688454358303643f3a26f7900486c19dbf22
--- /dev/null
+++ b/evaluation/run.py
@@ -0,0 +1,252 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+import functools
+from itertools import islice
+from typing import Callable, Dict, Optional, Tuple
+from pathlib import Path
+
+import numpy as np
+import torch
+from omegaconf import DictConfig, OmegaConf
+from torchmetrics import MetricCollection
+from pytorch_lightning import seed_everything
+from tqdm import tqdm
+
+from logger import logger, EXPERIMENTS_PATH
+from dataset.torch import collate, unbatch_to_device
+from models.voting import argmax_xyr, fuse_gps
+from models.metrics import AngleError, LateralLongitudinalError, Location2DError
+# from models.sequential import GPSAligner, RigidAligner
+from module import GenericModule
+from utils.io import download_file, DATA_URL
+from evaluation.viz import plot_example_single, plot_example_sequential
+from evaluation.utils import write_dump
+
+
+pretrained_models = dict(
+ OrienterNet_MGL=("orienternet_mgl.ckpt", dict(num_rotations=256)),
+)
+
+
+def resolve_checkpoint_path(experiment_or_path: str) -> Path:
+ path = Path(experiment_or_path)
+ if not path.exists():
+ # provided name of experiment
+ path = Path(EXPERIMENTS_PATH, *experiment_or_path.split("/"))
+ if not path.exists():
+ if experiment_or_path in set(p for p, _ in pretrained_models.values()):
+ download_file(f"{DATA_URL}/{experiment_or_path}", path)
+ else:
+ raise FileNotFoundError(path)
+ if path.is_file():
+ return path
+ # provided only the experiment name
+ maybe_path = path / "last-step-v1.ckpt"
+ if not maybe_path.exists():
+ maybe_path = path / "last.ckpt"
+ if not maybe_path.exists():
+ raise FileNotFoundError(f"Could not find any checkpoint in {path}.")
+ return maybe_path
+
+
+@torch.no_grad()
+def evaluate_single_image(
+ dataloader: torch.utils.data.DataLoader,
+ model: GenericModule,
+ num: Optional[int] = None,
+ callback: Optional[Callable] = None,
+ progress: bool = True,
+ mask_index: Optional[Tuple[int]] = None,
+ has_gps: bool = False,
+):
+ ppm = model.model.conf.pixel_per_meter
+ metrics = MetricCollection(model.model.metrics())
+ metrics["directional_error"] = LateralLongitudinalError(ppm)
+ if has_gps:
+ metrics["xy_gps_error"] = Location2DError("uv_gps", ppm)
+ metrics["xy_fused_error"] = Location2DError("uv_fused", ppm)
+ metrics["yaw_fused_error"] = AngleError("yaw_fused")
+ metrics = metrics.to(model.device)
+
+ for i, batch_ in enumerate(
+ islice(tqdm(dataloader, total=num, disable=not progress), num)
+ ):
+ batch = model.transfer_batch_to_device(batch_, model.device, i)
+ # Ablation: mask semantic classes
+ if mask_index is not None:
+ mask = batch["map"][0, mask_index[0]] == (mask_index[1] + 1)
+ batch["map"][0, mask_index[0]][mask] = 0
+ pred = model(batch)
+
+ if has_gps:
+ (uv_gps,) = pred["uv_gps"] = batch["uv_gps"]
+ pred["log_probs_fused"] = fuse_gps(
+ pred["log_probs"], uv_gps, ppm, sigma=batch["accuracy_gps"]
+ )
+ uvt_fused = argmax_xyr(pred["log_probs_fused"])
+ pred["uv_fused"] = uvt_fused[..., :2]
+ pred["yaw_fused"] = uvt_fused[..., -1]
+ del uv_gps, uvt_fused
+
+ results = metrics(pred, batch)
+ if callback is not None:
+ callback(
+ i, model, unbatch_to_device(pred), unbatch_to_device(batch_), results
+ )
+ del batch_, batch, pred, results
+
+ return metrics.cpu()
+
+
+@torch.no_grad()
+def evaluate_sequential(
+ dataset: torch.utils.data.Dataset,
+ chunk2idx: Dict,
+ model: GenericModule,
+ num: Optional[int] = None,
+ shuffle: bool = False,
+ callback: Optional[Callable] = None,
+ progress: bool = True,
+ num_rotations: int = 512,
+ mask_index: Optional[Tuple[int]] = None,
+ has_gps: bool = True,
+):
+ chunk_keys = list(chunk2idx)
+ if shuffle:
+ chunk_keys = [chunk_keys[i] for i in torch.randperm(len(chunk_keys))]
+ if num is not None:
+ chunk_keys = chunk_keys[:num]
+ lengths = [len(chunk2idx[k]) for k in chunk_keys]
+ logger.info(
+ "Min/max/med lengths: %d/%d/%d, total number of images: %d",
+ min(lengths),
+ np.median(lengths),
+ max(lengths),
+ sum(lengths),
+ )
+ viz = callback is not None
+
+ metrics = MetricCollection(model.model.metrics())
+ ppm = model.model.conf.pixel_per_meter
+ metrics["directional_error"] = LateralLongitudinalError(ppm)
+ metrics["xy_seq_error"] = Location2DError("uv_seq", ppm)
+ metrics["yaw_seq_error"] = AngleError("yaw_seq")
+ metrics["directional_seq_error"] = LateralLongitudinalError(ppm, key="uv_seq")
+ if has_gps:
+ metrics["xy_gps_error"] = Location2DError("uv_gps", ppm)
+ metrics["xy_gps_seq_error"] = Location2DError("uv_gps_seq", ppm)
+ metrics["yaw_gps_seq_error"] = AngleError("yaw_gps_seq")
+ metrics = metrics.to(model.device)
+
+ keys_save = ["uvr_max", "uv_max", "yaw_max", "uv_expectation"]
+ if has_gps:
+ keys_save.append("uv_gps")
+ if viz:
+ keys_save.append("log_probs")
+
+ for chunk_index, key in enumerate(tqdm(chunk_keys, disable=not progress)):
+ indices = chunk2idx[key]
+ aligner = RigidAligner(track_priors=viz, num_rotations=num_rotations)
+ if has_gps:
+ aligner_gps = GPSAligner(track_priors=viz, num_rotations=num_rotations)
+ batches = []
+ preds = []
+ for i in indices:
+ data = dataset[i]
+ data = model.transfer_batch_to_device(data, model.device, 0)
+ pred = model(collate([data]))
+
+ canvas = data["canvas"]
+ data["xy_geo"] = xy = canvas.to_xy(data["uv"].double())
+ data["yaw"] = yaw = data["roll_pitch_yaw"][-1].double()
+ aligner.update(pred["log_probs"][0], canvas, xy, yaw)
+
+ if has_gps:
+ (uv_gps) = pred["uv_gps"] = data["uv_gps"][None]
+ xy_gps = canvas.to_xy(uv_gps.double())
+ aligner_gps.update(xy_gps, data["accuracy_gps"], canvas, xy, yaw)
+
+ if not viz:
+ data.pop("image")
+ data.pop("map")
+ batches.append(data)
+ preds.append({k: pred[k][0] for k in keys_save})
+ del pred
+
+ xy_gt = torch.stack([b["xy_geo"] for b in batches])
+ yaw_gt = torch.stack([b["yaw"] for b in batches])
+ aligner.compute()
+ xy_seq, yaw_seq = aligner.transform(xy_gt, yaw_gt)
+ if has_gps:
+ aligner_gps.compute()
+ xy_gps_seq, yaw_gps_seq = aligner_gps.transform(xy_gt, yaw_gt)
+ results = []
+ for i in range(len(indices)):
+ preds[i]["uv_seq"] = batches[i]["canvas"].to_uv(xy_seq[i]).float()
+ preds[i]["yaw_seq"] = yaw_seq[i].float()
+ if has_gps:
+ preds[i]["uv_gps_seq"] = (
+ batches[i]["canvas"].to_uv(xy_gps_seq[i]).float()
+ )
+ preds[i]["yaw_gps_seq"] = yaw_gps_seq[i].float()
+ results.append(metrics(preds[i], batches[i]))
+ if viz:
+ callback(chunk_index, model, batches, preds, results, aligner)
+ del aligner, preds, batches, results
+ return metrics.cpu()
+
+
+def evaluate(
+ experiment: str,
+ cfg: DictConfig,
+ dataset,
+ split: str,
+ sequential: bool = False,
+ output_dir: Optional[Path] = None,
+ callback: Optional[Callable] = None,
+ num_workers: int = 1,
+ viz_kwargs=None,
+ **kwargs,
+):
+ if experiment in pretrained_models:
+ experiment, cfg_override = pretrained_models[experiment]
+ cfg = OmegaConf.merge(OmegaConf.create(dict(model=cfg_override)), cfg)
+
+ logger.info("Evaluating model %s with config %s", experiment, cfg)
+ checkpoint_path = resolve_checkpoint_path(experiment)
+ model = GenericModule.load_from_checkpoint(
+ checkpoint_path, cfg=cfg, find_best=not experiment.endswith(".ckpt")
+ )
+ model = model.eval()
+ if torch.cuda.is_available():
+ model = model.cuda()
+
+ dataset.prepare_data()
+ dataset.setup()
+
+ if output_dir is not None:
+ output_dir.mkdir(exist_ok=True, parents=True)
+ if callback is None:
+ if sequential:
+ callback = plot_example_sequential
+ else:
+ callback = plot_example_single
+ callback = functools.partial(
+ callback, out_dir=output_dir, **(viz_kwargs or {})
+ )
+ kwargs = {**kwargs, "callback": callback}
+
+ seed_everything(dataset.cfg.seed)
+ if sequential:
+ dset, chunk2idx = dataset.sequence_dataset(split, **cfg.chunking)
+ metrics = evaluate_sequential(dset, chunk2idx, model, **kwargs)
+ else:
+ loader = dataset.dataloader(split, shuffle=True, num_workers=num_workers)
+ metrics = evaluate_single_image(loader, model, **kwargs)
+
+ results = metrics.compute()
+ logger.info("All results: %s", results)
+ if output_dir is not None:
+ write_dump(output_dir, experiment, cfg, results, metrics)
+ logger.info("Outputs have been written to %s.", output_dir)
+ return metrics
diff --git a/evaluation/utils.py b/evaluation/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..6cc954ed20557c351965cad89aa2e249986986ee
--- /dev/null
+++ b/evaluation/utils.py
@@ -0,0 +1,40 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+import numpy as np
+from omegaconf import OmegaConf
+
+from utils.io import write_json
+
+
+def compute_recall(errors):
+ num_elements = len(errors)
+ sort_idx = np.argsort(errors)
+ errors = np.array(errors.copy())[sort_idx]
+ recall = (np.arange(num_elements) + 1) / num_elements
+ recall = np.r_[0, recall]
+ errors = np.r_[0, errors]
+ return errors, recall
+
+
+def compute_auc(errors, recall, thresholds):
+ aucs = []
+ for t in thresholds:
+ last_index = np.searchsorted(errors, t, side="right")
+ r = np.r_[recall[:last_index], recall[last_index - 1]]
+ e = np.r_[errors[:last_index], t]
+ auc = np.trapz(r, x=e) / t
+ aucs.append(auc * 100)
+ return aucs
+
+
+def write_dump(output_dir, experiment, cfg, results, metrics):
+ dump = {
+ "experiment": experiment,
+ "cfg": OmegaConf.to_container(cfg),
+ "results": results,
+ "errors": {},
+ }
+ for k, m in metrics.items():
+ if hasattr(m, "get_errors"):
+ dump["errors"][k] = m.get_errors().numpy()
+ write_json(output_dir / "log.json", dump)
diff --git a/evaluation/viz.py b/evaluation/viz.py
new file mode 100644
index 0000000000000000000000000000000000000000..9cd9f7dfd0f2103f2ebdda8cfe8022ad5a2e719b
--- /dev/null
+++ b/evaluation/viz.py
@@ -0,0 +1,178 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+import numpy as np
+import torch
+import matplotlib.pyplot as plt
+
+from utils.io import write_torch_image
+from utils.viz_2d import plot_images, features_to_RGB, save_plot
+from utils.viz_localization import (
+ likelihood_overlay,
+ plot_pose,
+ plot_dense_rotations,
+ add_circle_inset,
+)
+from osm.viz import Colormap, plot_nodes
+
+
+def plot_example_single(
+ idx,
+ model,
+ pred,
+ data,
+ results,
+ plot_bev=True,
+ out_dir=None,
+ fig_for_paper=False,
+ show_gps=False,
+ show_fused=False,
+ show_dir_error=False,
+ show_masked_prob=False,
+):
+ scene, name, rasters, uv_gt = (data[k] for k in ("scene", "name", "map", "uv"))
+ uv_gps = data.get("uv_gps")
+ yaw_gt = data["roll_pitch_yaw"][-1].numpy()
+ image = data["image"].permute(1, 2, 0)
+ if "valid" in data:
+ image = image.masked_fill(~data["valid"].unsqueeze(-1), 0.3)
+
+ lp_uvt = lp_uv = pred["log_probs"]
+ if show_fused and "log_probs_fused" in pred:
+ lp_uvt = lp_uv = pred["log_probs_fused"]
+ elif not show_masked_prob and "scores_unmasked" in pred:
+ lp_uvt = lp_uv = pred["scores_unmasked"]
+ has_rotation = lp_uvt.ndim == 3
+ if has_rotation:
+ lp_uv = lp_uvt.max(-1).values
+ if lp_uv.min() > -np.inf:
+ lp_uv = lp_uv.clip(min=np.percentile(lp_uv, 1))
+ prob = lp_uv.exp()
+ uv_p, yaw_p = pred["uv_max"], pred.get("yaw_max")
+ if show_fused and "uv_fused" in pred:
+ uv_p, yaw_p = pred["uv_fused"], pred.get("yaw_fused")
+ feats_map = pred["map"]["map_features"][0]
+ (feats_map_rgb,) = features_to_RGB(feats_map.numpy())
+
+ text1 = rf'$\Delta xy$: {results["xy_max_error"]:.1f}m'
+ if has_rotation:
+ text1 += rf', $\Delta\theta$: {results["yaw_max_error"]:.1f}°'
+ if show_fused and "xy_fused_error" in results:
+ text1 += rf', $\Delta xy_{{fused}}$: {results["xy_fused_error"]:.1f}m'
+ text1 += rf', $\Delta\theta_{{fused}}$: {results["yaw_fused_error"]:.1f}°'
+ if show_dir_error and "directional_error" in results:
+ err_lat, err_lon = results["directional_error"]
+ text1 += rf", $\Delta$lateral/longitundinal={err_lat:.1f}m/{err_lon:.1f}m"
+ if "xy_gps_error" in results:
+ text1 += rf', $\Delta xy_{{GPS}}$: {results["xy_gps_error"]:.1f}m'
+
+ map_viz = Colormap.apply(rasters)
+ overlay = likelihood_overlay(prob.numpy(), map_viz.mean(-1, keepdims=True))
+ plot_images(
+ [image, map_viz, overlay, feats_map_rgb],
+ titles=[text1, "map", "likelihood", "neural map"],
+ dpi=75,
+ cmaps="jet",
+ )
+ fig = plt.gcf()
+ axes = fig.axes
+ axes[1].images[0].set_interpolation("none")
+ axes[2].images[0].set_interpolation("none")
+ Colormap.add_colorbar()
+ plot_nodes(1, rasters[2])
+
+ if show_gps and uv_gps is not None:
+ plot_pose([1], uv_gps, c="blue")
+ plot_pose([1], uv_gt, yaw_gt, c="red")
+ plot_pose([1], uv_p, yaw_p, c="k")
+ plot_dense_rotations(2, lp_uvt.exp())
+ inset_center = pred["uv_max"] if results["xy_max_error"] < 5 else uv_gt
+ axins = add_circle_inset(axes[2], inset_center)
+ axins.scatter(*uv_gt, lw=1, c="red", ec="k", s=50, zorder=15)
+ axes[0].text(
+ 0.003,
+ 0.003,
+ f"{scene}/{name}",
+ transform=axes[0].transAxes,
+ fontsize=3,
+ va="bottom",
+ ha="left",
+ color="w",
+ )
+ plt.show()
+ if out_dir is not None:
+ name_ = name.replace("/", "_")
+ p = str(out_dir / f"{scene}_{name_}_{{}}.pdf")
+ save_plot(p.format("pred"))
+ plt.close()
+
+ if fig_for_paper:
+ # !cp ../datasets/MGL/{scene}/images/{name}.jpg {out_dir}/{scene}_{name}.jpg
+ plot_images([map_viz])
+ plt.gca().images[0].set_interpolation("none")
+ plot_nodes(0, rasters[2])
+ plot_pose([0], uv_gt, yaw_gt, c="red")
+ plot_pose([0], pred["uv_max"], pred["yaw_max"], c="k")
+ save_plot(p.format("map"))
+ plt.close()
+ plot_images([lp_uv], cmaps="jet")
+ plot_dense_rotations(0, lp_uvt.exp())
+ save_plot(p.format("loglikelihood"), dpi=100)
+ plt.close()
+ plot_images([overlay])
+ plt.gca().images[0].set_interpolation("none")
+ axins = add_circle_inset(plt.gca(), inset_center)
+ axins.scatter(*uv_gt, lw=1, c="red", ec="k", s=50)
+ save_plot(p.format("likelihood"))
+ plt.close()
+ write_torch_image(
+ p.format("neuralmap").replace("pdf", "jpg"), feats_map_rgb
+ )
+ write_torch_image(p.format("image").replace("pdf", "jpg"), image.numpy())
+
+ if not plot_bev:
+ return
+
+ feats_q = pred["features_bev"]
+ mask_bev = pred["valid_bev"]
+ prior = None
+ if "log_prior" in pred["map"]:
+ prior = pred["map"]["log_prior"][0].sigmoid()
+ if "bev" in pred and "confidence" in pred["bev"]:
+ conf_q = pred["bev"]["confidence"]
+ else:
+ conf_q = torch.norm(feats_q, dim=0)
+ conf_q = conf_q.masked_fill(~mask_bev, np.nan)
+ (feats_q_rgb,) = features_to_RGB(feats_q.numpy(), masks=[mask_bev.numpy()])
+ # feats_map_rgb, feats_q_rgb, = features_to_RGB(
+ # feats_map.numpy(), feats_q.numpy(), masks=[None, mask_bev])
+ norm_map = torch.norm(feats_map, dim=0)
+
+ plot_images(
+ [conf_q, feats_q_rgb, norm_map] + ([] if prior is None else [prior]),
+ titles=["BEV confidence", "BEV features", "map norm"]
+ + ([] if prior is None else ["map prior"]),
+ dpi=50,
+ cmaps="jet",
+ )
+ plt.show()
+
+ if out_dir is not None:
+ save_plot(p.format("bev"))
+ plt.close()
+
+
+def plot_example_sequential(
+ idx,
+ model,
+ pred,
+ data,
+ results,
+ plot_bev=True,
+ out_dir=None,
+ fig_for_paper=False,
+ show_gps=False,
+ show_fused=False,
+ show_dir_error=False,
+ show_masked_prob=False,
+):
+ return
diff --git a/flagged/inp/10d2e4a8712491181c2f48b61f5003b216d2b9f9/tmp48n9eoyh.png b/flagged/inp/10d2e4a8712491181c2f48b61f5003b216d2b9f9/tmp48n9eoyh.png
new file mode 100644
index 0000000000000000000000000000000000000000..61bdcaeabd4ec399fa036511592c0c9e3f8628b7
Binary files /dev/null and b/flagged/inp/10d2e4a8712491181c2f48b61f5003b216d2b9f9/tmp48n9eoyh.png differ
diff --git a/flagged/inp/e1b18d44d9e381d586209f73a015fed7f688822b/tmp86ith_2q.png b/flagged/inp/e1b18d44d9e381d586209f73a015fed7f688822b/tmp86ith_2q.png
new file mode 100644
index 0000000000000000000000000000000000000000..61bdcaeabd4ec399fa036511592c0c9e3f8628b7
Binary files /dev/null and b/flagged/inp/e1b18d44d9e381d586209f73a015fed7f688822b/tmp86ith_2q.png differ
diff --git a/flagged/log.csv b/flagged/log.csv
new file mode 100644
index 0000000000000000000000000000000000000000..61a8dd134c6a6e7a5cd88ccf2ef430e489e8d4b4
--- /dev/null
+++ b/flagged/log.csv
@@ -0,0 +1,3 @@
+inp,longitude,latitude,Area,output,flag,username,timestamp
+E:\MapLocNetDemo\Demo\flagged\inp\10d2e4a8712491181c2f48b61f5003b216d2b9f9\tmp48n9eoyh.png,70.1,40,256,E:\MapLocNetDemo\Demo\flagged\output\tmp59657zop.json,,,2023-09-22 10:07:17.488625
+E:\MapLocNetDemo\Demo\flagged\inp\e1b18d44d9e381d586209f73a015fed7f688822b\tmp86ith_2q.png,70.1,40,256,E:\MapLocNetDemo\Demo\flagged\output\tmpbs17s28d.json,,,2023-09-22 10:07:21.485967
diff --git a/flagged/output/tmp59657zop.json b/flagged/output/tmp59657zop.json
new file mode 100644
index 0000000000000000000000000000000000000000..6da7282f99d84615c8e174ce435ecd85765184f8
--- /dev/null
+++ b/flagged/output/tmp59657zop.json
@@ -0,0 +1 @@
+{"label": "bull mastiff\n", "confidences": [{"label": "bull mastiff\n", "confidence": 0.24759389460086823}, {"label": "pug\n", "confidence": 0.0916372761130333}, {"label": "Great Dane\n", "confidence": 0.08652031421661377}]}
\ No newline at end of file
diff --git a/flagged/output/tmpbs17s28d.json b/flagged/output/tmpbs17s28d.json
new file mode 100644
index 0000000000000000000000000000000000000000..6da7282f99d84615c8e174ce435ecd85765184f8
--- /dev/null
+++ b/flagged/output/tmpbs17s28d.json
@@ -0,0 +1 @@
+{"label": "bull mastiff\n", "confidences": [{"label": "bull mastiff\n", "confidence": 0.24759389460086823}, {"label": "pug\n", "confidence": 0.0916372761130333}, {"label": "Great Dane\n", "confidence": 0.08652031421661377}]}
\ No newline at end of file
diff --git a/images/00000.jpg b/images/00000.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..c3340fd6671c91ec138a9cef129df4f9ce5adbd6
Binary files /dev/null and b/images/00000.jpg differ
diff --git a/images/00011.jpg b/images/00011.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..3c9201b98433b985bafd85793bd7992c8e7f55c6
Binary files /dev/null and b/images/00011.jpg differ
diff --git a/images/00022.jpg b/images/00022.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..1d23d874d54e1f2608d6907fd1ef3416ac6e0716
Binary files /dev/null and b/images/00022.jpg differ
diff --git a/images/00033.jpg b/images/00033.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..fe8b2a1dac78db32ab979df3ca482a37108d5698
Binary files /dev/null and b/images/00033.jpg differ
diff --git a/images/cat_dog.png b/images/cat_dog.png
new file mode 100644
index 0000000000000000000000000000000000000000..61bdcaeabd4ec399fa036511592c0c9e3f8628b7
Binary files /dev/null and b/images/cat_dog.png differ
diff --git a/logger.py b/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..56b4d718091ad22d84d59387ca76628aa242555e
--- /dev/null
+++ b/logger.py
@@ -0,0 +1,28 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+from pathlib import Path
+import logging
+
+import pytorch_lightning # noqa: F401
+
+
+formatter = logging.Formatter(
+ fmt="[%(asctime)s %(name)s %(levelname)s] %(message)s",
+ datefmt="%Y-%m-%d %H:%M:%S",
+)
+handler = logging.StreamHandler()
+handler.setFormatter(formatter)
+handler.setLevel(logging.INFO)
+
+logger = logging.getLogger("maploc")
+logger.setLevel(logging.INFO)
+logger.addHandler(handler)
+logger.propagate = False
+
+pl_logger = logging.getLogger("pytorch_lightning")
+if len(pl_logger.handlers):
+ pl_logger.handlers[0].setFormatter(formatter)
+
+repo_dir = Path(__file__).parent
+EXPERIMENTS_PATH = repo_dir / "experiments/"
+DATASETS_PATH = repo_dir / "datasets/"
diff --git a/main.py b/main.py
new file mode 100644
index 0000000000000000000000000000000000000000..311d7687bba3cb830615cfb9e89644e25df3d2ee
--- /dev/null
+++ b/main.py
@@ -0,0 +1,98 @@
+import gradio as gr
+import cv2
+import gradio as gr
+import torch
+from torchvision import transforms
+import requests
+from PIL import Image
+from demo import Demo,read_input_image_test,show_result,vis_image_feature
+from osm.tiling import TileManager
+from osm.viz import Colormap, plot_nodes
+from utils.viz_2d import plot_images
+import numpy as np
+from utils.viz_2d import features_to_RGB
+from utils.viz_localization import (
+ likelihood_overlay,
+ plot_dense_rotations,
+ add_circle_inset,
+)
+from osm.viz import GeoPlotter
+import matplotlib.pyplot as plt
+import random
+from geopy.distance import geodesic
+
+experiment_or_path = "weight/last-step-checkpointing.ckpt"
+# experiment_or_path="experiments/maplocanet_0906_diffhight/last-step-checkpointing.ckpt"
+image_path = 'images/00000.jpg'
+
+# prior_latlon = (37.75704325989902, -122.435941445631)
+# tile_size_meters = 128
+model = Demo(experiment_or_path=experiment_or_path, num_rotations=128, device='cpu')
+
+def demo_localize(image,long,lat,tile_size_meters):
+ # inp = Image.fromarray(inp.astype('uint8'), 'RGB')
+ # inp = transforms.ToTensor()(inp).unsqueeze(0)
+ prior_latlon=(lat,long)
+ image, camera, gravity, proj, bbox, true_prior_latlon = read_input_image_test(
+ image,
+ prior_latlon=prior_latlon,
+ tile_size_meters=tile_size_meters, # try 64, 256, etc.
+ )
+ tiler = TileManager.from_bbox(projection=proj, bbox=bbox, ppm=1, tile_size=tile_size_meters)
+ # tiler = TileManager.from_bbox(projection=proj, bbox=bbox + 10,ppm=1,path=root/city/'{}.osm'.format(city), tile_size=1)
+ canvas = tiler.query(bbox)
+ uv, yaw, prob, neural_map, image_rectified, data_, pred = model.localize(
+ image, camera, canvas)
+ prior_latlon_pred = proj.unproject(canvas.to_xy(uv))
+
+ map_viz = Colormap.apply(canvas.raster)
+ map_vis_image_result = map_viz * 255
+ map_vis_image_result =show_result(map_vis_image_result.astype(np.uint8), uv, yaw)
+ # map_vis_image_result = show_result(map_vis_image_result.astype(np.uint8), True_uv,
+ # uv,
+ # 90.0 - yaw_T,
+ # yaw)
+ # return prior_latlon_pred
+ uab_feature_rgb = vis_image_feature(pred['features_image'][0].cpu().numpy())
+ map_viz = cv2.resize(map_viz, (prob.numpy().shape[0], prob.numpy().shape[1]))
+ overlay = likelihood_overlay(prob.numpy().max(-1), map_viz.mean(-1, keepdims=True))
+ (neural_map_rgb,) = features_to_RGB(neural_map.numpy())
+ fig=plot_images([image, map_vis_image_result / 255, overlay, uab_feature_rgb, neural_map_rgb],
+ titles=["UAV image", "map","likelihood","UAV feature","map feature"])
+ # plot_images([overlay, neural_map_rgb], titles=["prediction", "neural map"])
+ # ax = plt.gcf().axes[2]
+ # ax.scatter(*canvas.to_uv(bbox.center), s=5, c="red")
+ # plot_dense_rotations(ax, prob, w=0.005, s=1 / 25)
+ # add_circle_inset(ax, uv)
+
+ # Plot as interactive figure
+ bbox_latlon = proj.unproject(canvas.bbox)
+ plot2 = GeoPlotter(zoom=16.5)
+ plot2.raster(map_viz, bbox_latlon, opacity=0.5)
+ plot2.raster(likelihood_overlay(prob.numpy().max(-1)), proj.unproject(bbox))
+ plot2.points(prior_latlon[:2], "red", name="location prior", size=10)
+ plot2.points(proj.unproject(canvas.to_xy(uv)), "black", name="argmax", size=10)
+ plot2.bbox(bbox_latlon, "blue", name="map tile")
+ # plot2.fig.show()
+ return fig,plot2.fig,str(prior_latlon_pred)
+# model = torch.hub.load('pytorch/vision:v0.6.0', 'resnet18', pretrained=True).eval()
+#标题
+title = "MapLocNet"
+#标题下的描述,支持md格式
+description = "UAV Vision-based Geo-Localization Using Vectorized Maps"
+
+# outputs = gr.outputs.Label(num_top_classes=3)
+outputs = gr.Plot()
+interface = gr.Interface(fn=demo_localize,
+ inputs=["image",
+ gr.Number(label="Prior location-longitude)"),
+ gr.Number(label="Prior location-longitude)"),
+ gr.Radio([64, 128, 256], label="Search radius (meters)", info="vectorized map size"),
+ # gr.inputs.RadioGroup(label="Search radius (meters)",["English", "French", "Spanish"]),
+ # gr.Slider(64, 512,label='Search radius (meters)')
+ ],
+ outputs=["plot","plot","text"],
+ title=title,
+ description=description,
+ examples=[['images/00000.jpg',-122.435941445631,37.75704325989902,128]])
+interface.launch(share=True)
\ No newline at end of file
diff --git a/models/__init__.py b/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..02c1f950d5d3f84b18ba4178e2549fc328479d3f
--- /dev/null
+++ b/models/__init__.py
@@ -0,0 +1,34 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+# Adapted from PixLoc, Paul-Edouard Sarlin, ETH Zurich
+# https://github.com/cvg/pixloc
+# Released under the Apache License 2.0
+
+import inspect
+
+from .base import BaseModel
+
+
+def get_class(mod_name, base_path, BaseClass):
+ """Get the class object which inherits from BaseClass and is defined in
+ the module named mod_name, child of base_path.
+ """
+ mod_path = "{}.{}".format(base_path, mod_name)
+ mod = __import__(mod_path, fromlist=[""])
+ classes = inspect.getmembers(mod, inspect.isclass)
+ # Filter classes defined in the module
+ classes = [c for c in classes if c[1].__module__ == mod_path]
+ # Filter classes inherited from BaseModel
+ classes = [c for c in classes if issubclass(c[1], BaseClass)]
+ assert len(classes) == 1, classes
+ return classes[0][1]
+
+
+def get_model(name):
+ if name == "localizer":
+ name = "localizer_basic"
+ elif name == "rotation_localizer":
+ name = "localizer_basic_rotation"
+ elif name == "bev_localizer":
+ name = "localizer_bev_plane"
+ return get_class(name, __name__, BaseModel)
diff --git a/models/base.py b/models/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9978eabb3c32fcd98f12399347f4c864e463494
--- /dev/null
+++ b/models/base.py
@@ -0,0 +1,123 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+# Adapted from PixLoc, Paul-Edouard Sarlin, ETH Zurich
+# https://github.com/cvg/pixloc
+# Released under the Apache License 2.0
+
+"""
+Base class for trainable models.
+"""
+
+from abc import ABCMeta, abstractmethod
+from copy import copy
+
+import omegaconf
+from omegaconf import OmegaConf
+from torch import nn
+
+
+class BaseModel(nn.Module, metaclass=ABCMeta):
+ """
+ What the child model is expect to declare:
+ default_conf: dictionary of the default configuration of the model.
+ It recursively updates the default_conf of all parent classes, and
+ it is updated by the user-provided configuration passed to __init__.
+ Configurations can be nested.
+
+ required_data_keys: list of expected keys in the input data dictionary.
+
+ strict_conf (optional): boolean. If false, BaseModel does not raise
+ an error when the user provides an unknown configuration entry.
+
+ _init(self, conf): initialization method, where conf is the final
+ configuration object (also accessible with `self.conf`). Accessing
+ unknown configuration entries will raise an error.
+
+ _forward(self, data): method that returns a dictionary of batched
+ prediction tensors based on a dictionary of batched input data tensors.
+
+ loss(self, pred, data): method that returns a dictionary of losses,
+ computed from model predictions and input data. Each loss is a batch
+ of scalars, i.e. a torch.Tensor of shape (B,).
+ The total loss to be optimized has the key `'total'`.
+
+ metrics(self, pred, data): method that returns a dictionary of metrics,
+ each as a batch of scalars.
+ """
+
+ base_default_conf = {
+ "name": None,
+ "trainable": True, # if false: do not optimize this model parameters
+ "freeze_batch_normalization": False, # use test-time statistics
+ }
+ default_conf = {}
+ required_data_keys = []
+ strict_conf = True
+
+ def __init__(self, conf):
+ """Perform some logic and call the _init method of the child model."""
+ super().__init__()
+ default_conf = OmegaConf.merge(
+ self.base_default_conf, OmegaConf.create(self.default_conf)
+ )
+ if self.strict_conf:
+ OmegaConf.set_struct(default_conf, True)
+
+ # fixme: backward compatibility
+ if "pad" in conf and "pad" not in default_conf: # backward compat.
+ with omegaconf.read_write(conf):
+ with omegaconf.open_dict(conf):
+ conf["interpolation"] = {"pad": conf.pop("pad")}
+
+ if isinstance(conf, dict):
+ conf = OmegaConf.create(conf)
+ self.conf = conf = OmegaConf.merge(default_conf, conf)
+ OmegaConf.set_readonly(conf, True)
+ OmegaConf.set_struct(conf, True)
+ self.required_data_keys = copy(self.required_data_keys)
+ self._init(conf)
+
+ if not conf.trainable:
+ for p in self.parameters():
+ p.requires_grad = False
+
+ def train(self, mode=True):
+ super().train(mode)
+
+ def freeze_bn(module):
+ if isinstance(module, nn.modules.batchnorm._BatchNorm):
+ module.eval()
+
+ if self.conf.freeze_batch_normalization:
+ self.apply(freeze_bn)
+
+ return self
+
+ def forward(self, data):
+ """Check the data and call the _forward method of the child model."""
+
+ def recursive_key_check(expected, given):
+ for key in expected:
+ assert key in given, f"Missing key {key} in data"
+ if isinstance(expected, dict):
+ recursive_key_check(expected[key], given[key])
+
+ recursive_key_check(self.required_data_keys, data)
+ return self._forward(data)
+
+ @abstractmethod
+ def _init(self, conf):
+ """To be implemented by the child class."""
+ raise NotImplementedError
+
+ @abstractmethod
+ def _forward(self, data):
+ """To be implemented by the child class."""
+ raise NotImplementedError
+
+ def loss(self, pred, data):
+ """To be implemented by the child class."""
+ raise NotImplementedError
+
+ def metrics(self):
+ return {} # no metrics
diff --git a/models/feature_extractor.py b/models/feature_extractor.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a06066c8ebce96b859d20fa444833d2b884a7ed
--- /dev/null
+++ b/models/feature_extractor.py
@@ -0,0 +1,231 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+# Adapted from PixLoc, Paul-Edouard Sarlin, ETH Zurich
+# https://github.com/cvg/pixloc
+# Released under the Apache License 2.0
+
+"""
+Flexible UNet model which takes any Torchvision backbone as encoder.
+Predicts multi-level feature and makes sure that they are well aligned.
+"""
+
+import torch
+import torch.nn as nn
+import torchvision
+
+from .base import BaseModel
+from .utils import checkpointed
+
+
+class DecoderBlock(nn.Module):
+ def __init__(
+ self, previous, skip, out, num_convs=1, norm=nn.BatchNorm2d, padding="zeros"
+ ):
+ super().__init__()
+
+ self.upsample = nn.Upsample(
+ scale_factor=2, mode="bilinear", align_corners=False
+ )
+
+ layers = []
+ for i in range(num_convs):
+ conv = nn.Conv2d(
+ previous + skip if i == 0 else out,
+ out,
+ kernel_size=3,
+ padding=1,
+ bias=norm is None,
+ padding_mode=padding,
+ )
+ layers.append(conv)
+ if norm is not None:
+ layers.append(norm(out))
+ layers.append(nn.ReLU(inplace=True))
+ self.layers = nn.Sequential(*layers)
+
+ def forward(self, previous, skip):
+ upsampled = self.upsample(previous)
+ # If the shape of the input map `skip` is not a multiple of 2,
+ # it will not match the shape of the upsampled map `upsampled`.
+ # If the downsampling uses ceil_mode=False, we nedd to crop `skip`.
+ # If it uses ceil_mode=True (not supported here), we should pad it.
+ _, _, hu, wu = upsampled.shape
+ _, _, hs, ws = skip.shape
+ assert (hu <= hs) and (wu <= ws), "Using ceil_mode=True in pooling?"
+ # assert (hu == hs) and (wu == ws), 'Careful about padding'
+ skip = skip[:, :, :hu, :wu]
+ return self.layers(torch.cat([upsampled, skip], dim=1))
+
+
+class AdaptationBlock(nn.Sequential):
+ def __init__(self, inp, out):
+ conv = nn.Conv2d(inp, out, kernel_size=1, padding=0, bias=True)
+ super().__init__(conv)
+
+
+class FeatureExtractor(BaseModel):
+ default_conf = {
+ "pretrained": True,
+ "input_dim": 3,
+ "output_scales": [0, 2, 4], # what scales to adapt and output
+ "output_dim": 128, # # of channels in output feature maps
+ "encoder": "vgg16", # string (torchvision net) or list of channels
+ "num_downsample": 4, # how many downsample block (if VGG-style net)
+ "decoder": [64, 64, 64, 64], # list of channels of decoder
+ "decoder_norm": "nn.BatchNorm2d", # normalization ind decoder blocks
+ "do_average_pooling": False,
+ "checkpointed": False, # whether to use gradient checkpointing
+ "padding": "zeros",
+ }
+ mean = [0.485, 0.456, 0.406]
+ std = [0.229, 0.224, 0.225]
+
+ def build_encoder(self, conf):
+ assert isinstance(conf.encoder, str)
+ if conf.pretrained:
+ assert conf.input_dim == 3
+ Encoder = getattr(torchvision.models, conf.encoder)
+ encoder = Encoder(weights="DEFAULT" if conf.pretrained else None)
+ Block = checkpointed(torch.nn.Sequential, do=conf.checkpointed)
+ assert max(conf.output_scales) <= conf.num_downsample
+
+ if conf.encoder.startswith("vgg"):
+ # Parse the layers and pack them into downsampling blocks
+ # It's easy for VGG-style nets because of their linear structure.
+ # This does not handle strided convs and residual connections
+ skip_dims = []
+ previous_dim = None
+ blocks = [[]]
+ for i, layer in enumerate(encoder.features):
+ if isinstance(layer, torch.nn.Conv2d):
+ # Change the first conv layer if the input dim mismatches
+ if i == 0 and conf.input_dim != layer.in_channels:
+ args = {k: getattr(layer, k) for k in layer.__constants__}
+ args.pop("output_padding")
+ layer = torch.nn.Conv2d(
+ **{**args, "in_channels": conf.input_dim}
+ )
+ previous_dim = layer.out_channels
+ elif isinstance(layer, torch.nn.MaxPool2d):
+ assert previous_dim is not None
+ skip_dims.append(previous_dim)
+ if (conf.num_downsample + 1) == len(blocks):
+ break
+ blocks.append([]) # start a new block
+ if conf.do_average_pooling:
+ assert layer.dilation == 1
+ layer = torch.nn.AvgPool2d(
+ kernel_size=layer.kernel_size,
+ stride=layer.stride,
+ padding=layer.padding,
+ ceil_mode=layer.ceil_mode,
+ count_include_pad=False,
+ )
+ blocks[-1].append(layer)
+ encoder = [Block(*b) for b in blocks]
+ elif conf.encoder.startswith("resnet"):
+ # Manually define the ResNet blocks such that the downsampling comes first
+ assert conf.encoder[len("resnet") :] in ["18", "34", "50", "101"]
+ assert conf.input_dim == 3, "Unsupported for now."
+ block1 = torch.nn.Sequential(encoder.conv1, encoder.bn1, encoder.relu)
+ block2 = torch.nn.Sequential(encoder.maxpool, encoder.layer1)
+ block3 = encoder.layer2
+ block4 = encoder.layer3
+ block5 = encoder.layer4
+ blocks = [block1, block2, block3, block4, block5]
+ # Extract the output dimension of each block
+ skip_dims = [encoder.conv1.out_channels]
+ for i in range(1, 5):
+ modules = getattr(encoder, f"layer{i}")[-1]._modules
+ conv = sorted(k for k in modules if k.startswith("conv"))[-1]
+ skip_dims.append(modules[conv].out_channels)
+ # Add a dummy block such that the first one does not downsample
+ encoder = [torch.nn.Identity()] + [Block(b) for b in blocks]
+ skip_dims = [3] + skip_dims
+ # Trim based on the requested encoder size
+ encoder = encoder[: conf.num_downsample + 1]
+ skip_dims = skip_dims[: conf.num_downsample + 1]
+ else:
+ raise NotImplementedError(conf.encoder)
+
+ assert (conf.num_downsample + 1) == len(encoder)
+ encoder = nn.ModuleList(encoder)
+
+ return encoder, skip_dims
+
+ def _init(self, conf):
+ # Encoder
+ self.encoder, skip_dims = self.build_encoder(conf)
+ self.skip_dims = skip_dims
+
+ def update_padding(module):
+ if isinstance(module, nn.Conv2d):
+ module.padding_mode = conf.padding
+
+ if conf.padding != "zeros":
+ self.encoder.apply(update_padding)
+
+ # Decoder
+ if conf.decoder is not None:
+ assert len(conf.decoder) == (len(skip_dims) - 1)
+ Block = checkpointed(DecoderBlock, do=conf.checkpointed)
+ norm = eval(conf.decoder_norm) if conf.decoder_norm else None # noqa
+
+ previous = skip_dims[-1]
+ decoder = []
+ for out, skip in zip(conf.decoder, skip_dims[:-1][::-1]):
+ decoder.append(
+ Block(previous, skip, out, norm=norm, padding=conf.padding)
+ )
+ previous = out
+ self.decoder = nn.ModuleList(decoder)
+
+ # Adaptation layers
+ adaptation = []
+ for idx, i in enumerate(conf.output_scales):
+ if conf.decoder is None or i == (len(self.encoder) - 1):
+ input_ = skip_dims[i]
+ else:
+ input_ = conf.decoder[-1 - i]
+
+ # out_dim can be an int (same for all scales) or a list (per scale)
+ dim = conf.output_dim
+ if not isinstance(dim, int):
+ dim = dim[idx]
+
+ block = AdaptationBlock(input_, dim)
+ adaptation.append(block)
+ self.adaptation = nn.ModuleList(adaptation)
+ self.scales = [2**s for s in conf.output_scales]
+
+ def _forward(self, data):
+ image = data["image"]
+ if self.conf.pretrained:
+ mean, std = image.new_tensor(self.mean), image.new_tensor(self.std)
+ image = (image - mean[:, None, None]) / std[:, None, None]
+
+ skip_features = []
+ features = image
+ for block in self.encoder:
+ features = block(features)
+ skip_features.append(features)
+
+ if self.conf.decoder:
+ pre_features = [skip_features[-1]]
+ for block, skip in zip(self.decoder, skip_features[:-1][::-1]):
+ pre_features.append(block(pre_features[-1], skip))
+ pre_features = pre_features[::-1] # fine to coarse
+ else:
+ pre_features = skip_features
+
+ out_features = []
+ for adapt, i in zip(self.adaptation, self.conf.output_scales):
+ out_features.append(adapt(pre_features[i]))
+ pred = {"feature_maps": out_features, "skip_features": skip_features}
+ return pred
+
+ def loss(self, pred, data):
+ raise NotImplementedError
+
+ def metrics(self, pred, data):
+ raise NotImplementedError
diff --git a/models/feature_extractor_v2.py b/models/feature_extractor_v2.py
new file mode 100644
index 0000000000000000000000000000000000000000..38c910651f63e5394214ea0e2b1909537948da54
--- /dev/null
+++ b/models/feature_extractor_v2.py
@@ -0,0 +1,192 @@
+import logging
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torchvision
+from torchvision.models.feature_extraction import create_feature_extractor
+
+from .base import BaseModel
+
+logger = logging.getLogger(__name__)
+
+
+class DecoderBlock(nn.Module):
+ def __init__(
+ self, previous, out, ksize=3, num_convs=1, norm=nn.BatchNorm2d, padding="zeros"
+ ):
+ super().__init__()
+ layers = []
+ for i in range(num_convs):
+ conv = nn.Conv2d(
+ previous if i == 0 else out,
+ out,
+ kernel_size=ksize,
+ padding=ksize // 2,
+ bias=norm is None,
+ padding_mode=padding,
+ )
+ layers.append(conv)
+ if norm is not None:
+ layers.append(norm(out))
+ layers.append(nn.ReLU(inplace=True))
+ self.layers = nn.Sequential(*layers)
+
+ def forward(self, previous, skip):
+ _, _, hp, wp = previous.shape
+ _, _, hs, ws = skip.shape
+ scale = 2 ** np.round(np.log2(np.array([hs / hp, ws / wp])))
+ upsampled = nn.functional.interpolate(
+ previous, scale_factor=scale.tolist(), mode="bilinear", align_corners=False
+ )
+ # If the shape of the input map `skip` is not a multiple of 2,
+ # it will not match the shape of the upsampled map `upsampled`.
+ # If the downsampling uses ceil_mode=False, we nedd to crop `skip`.
+ # If it uses ceil_mode=True (not supported here), we should pad it.
+ _, _, hu, wu = upsampled.shape
+ _, _, hs, ws = skip.shape
+ if (hu <= hs) and (wu <= ws):
+ skip = skip[:, :, :hu, :wu]
+ elif (hu >= hs) and (wu >= ws):
+ skip = nn.functional.pad(skip, [0, wu - ws, 0, hu - hs])
+ else:
+ raise ValueError(
+ f"Inconsistent skip vs upsampled shapes: {(hs, ws)}, {(hu, wu)}"
+ )
+
+ return self.layers(skip) + upsampled
+
+
+class FPN(nn.Module):
+ def __init__(self, in_channels_list, out_channels, **kw):
+ super().__init__()
+ self.first = nn.Conv2d(
+ in_channels_list[-1], out_channels, 1, padding=0, bias=True
+ )
+ self.blocks = nn.ModuleList(
+ [
+ DecoderBlock(c, out_channels, ksize=1, **kw)
+ for c in in_channels_list[::-1][1:]
+ ]
+ )
+ self.out = nn.Sequential(
+ nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
+ nn.BatchNorm2d(out_channels),
+ nn.ReLU(inplace=True),
+ )
+
+ def forward(self, layers):
+ feats = None
+ for idx, x in enumerate(reversed(layers.values())):
+ if feats is None:
+ feats = self.first(x)
+ else:
+ feats = self.blocks[idx - 1](feats, x)
+ out = self.out(feats)
+ return out
+
+
+def remove_conv_stride(conv):
+ conv_new = nn.Conv2d(
+ conv.in_channels,
+ conv.out_channels,
+ conv.kernel_size,
+ bias=conv.bias is not None,
+ stride=1,
+ padding=conv.padding,
+ )
+ conv_new.weight = conv.weight
+ conv_new.bias = conv.bias
+ return conv_new
+
+
+class FeatureExtractor(BaseModel):
+ default_conf = {
+ "pretrained": True,
+ "input_dim": 3,
+ "output_dim": 128, # # of channels in output feature maps
+ "encoder": "resnet50", # torchvision net as string
+ "remove_stride_from_first_conv": False,
+ "num_downsample": None, # how many downsample block
+ "decoder_norm": "nn.BatchNorm2d", # normalization ind decoder blocks
+ "do_average_pooling": False,
+ "checkpointed": False, # whether to use gradient checkpointing
+ }
+ mean = [0.485, 0.456, 0.406]
+ std = [0.229, 0.224, 0.225]
+
+ def build_encoder(self, conf):
+ assert isinstance(conf.encoder, str)
+ if conf.pretrained:
+ assert conf.input_dim == 3
+ Encoder = getattr(torchvision.models, conf.encoder)
+
+ kw = {}
+ if conf.encoder.startswith("resnet"):
+ layers = ["relu", "layer1", "layer2", "layer3", "layer4"]
+ kw["replace_stride_with_dilation"] = [False, False, False]
+ elif conf.encoder == "vgg13":
+ layers = [
+ "features.3",
+ "features.8",
+ "features.13",
+ "features.18",
+ "features.23",
+ ]
+ elif conf.encoder == "vgg16":
+ layers = [
+ "features.3",
+ "features.8",
+ "features.15",
+ "features.22",
+ "features.29",
+ ]
+ else:
+ raise NotImplementedError(conf.encoder)
+
+ if conf.num_downsample is not None:
+ layers = layers[: conf.num_downsample]
+ encoder = Encoder(weights="DEFAULT" if conf.pretrained else None, **kw)
+ encoder = create_feature_extractor(encoder, return_nodes=layers)
+ if conf.encoder.startswith("resnet") and conf.remove_stride_from_first_conv:
+ encoder.conv1 = remove_conv_stride(encoder.conv1)
+
+ if conf.do_average_pooling:
+ raise NotImplementedError
+ if conf.checkpointed:
+ raise NotImplementedError
+
+ return encoder, layers
+
+ def _init(self, conf):
+ # Preprocessing
+ self.register_buffer("mean_", torch.tensor(self.mean), persistent=False)
+ self.register_buffer("std_", torch.tensor(self.std), persistent=False)
+
+ # Encoder
+ self.encoder, self.layers = self.build_encoder(conf)
+ s = 128
+ inp = torch.zeros(1, 3, s, s)
+ features = list(self.encoder(inp).values())
+ self.skip_dims = [x.shape[1] for x in features]
+ self.layer_strides = [s / f.shape[-1] for f in features]
+ self.scales = [self.layer_strides[0]]
+
+ # Decoder
+ norm = eval(conf.decoder_norm) if conf.decoder_norm else None # noqa
+ self.decoder = FPN(self.skip_dims, out_channels=conf.output_dim, norm=norm)
+
+ logger.debug(
+ "Built feature extractor with layers {name:dim:stride}:\n"
+ f"{list(zip(self.layers, self.skip_dims, self.layer_strides))}\n"
+ f"and output scales {self.scales}."
+ )
+
+ def _forward(self, data):
+ image = data["image"]
+ image = (image - self.mean_[:, None, None]) / self.std_[:, None, None]
+
+ skip_features = self.encoder(image)
+ output = self.decoder(skip_features)
+ pred = {"feature_maps": [output], "skip_features": skip_features}
+ return pred
diff --git a/models/map_encoder.py b/models/map_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..e54db926df6cc16d9082826ffee2a8b838dbed21
--- /dev/null
+++ b/models/map_encoder.py
@@ -0,0 +1,67 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+import torch
+import torch.nn as nn
+
+from .base import BaseModel
+from .feature_extractor import FeatureExtractor
+
+
+class MapEncoder(BaseModel):
+ default_conf = {
+ "embedding_dim": "???",
+ "output_dim": None,
+ "num_classes": "???",
+ "backbone": "???",
+ "unary_prior": False,
+ }
+
+ def _init(self, conf):
+ self.embeddings = torch.nn.ModuleDict(
+ {
+ k: torch.nn.Embedding(n + 1, conf.embedding_dim)
+ for k, n in conf.num_classes.items()
+ }
+ )
+ #num_calsses:{'areas': 7, 'ways': 10, 'nodes': 33}
+ input_dim = len(conf.num_classes) * conf.embedding_dim
+ output_dim = conf.output_dim
+ if output_dim is None:
+ output_dim = conf.backbone.output_dim
+ if conf.unary_prior:
+ output_dim += 1
+ if conf.backbone is None:
+ self.encoder = nn.Conv2d(input_dim, output_dim, 1)
+ elif conf.backbone == "simple":
+ self.encoder = nn.Sequential(
+ nn.Conv2d(input_dim, 128, 3, padding=1),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(128, 128, 3, padding=1),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(128, output_dim, 3, padding=1),
+ )
+ else:
+ self.encoder = FeatureExtractor(
+ {
+ **conf.backbone,
+ "input_dim": input_dim,
+ "output_dim": output_dim,
+ }
+ )
+
+ def _forward(self, data):
+ embeddings = [
+ self.embeddings[k](data["map"][:, i])
+ for i, k in enumerate(("areas", "ways", "nodes"))
+ ]
+ embeddings = torch.cat(embeddings, dim=-1).permute(0, 3, 1, 2)
+ if isinstance(self.encoder, BaseModel):
+ features = self.encoder({"image": embeddings})["feature_maps"]
+ else:
+ features = [self.encoder(embeddings)]
+ pred = {}
+ if self.conf.unary_prior:
+ pred["log_prior"] = [f[:, -1] for f in features]
+ features = [f[:, :-1] for f in features]
+ pred["map_features"] = features
+ return pred
diff --git a/models/maplocnet.py b/models/maplocnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d32dc4b78bac1d0c1eb23827be875598489b447
--- /dev/null
+++ b/models/maplocnet.py
@@ -0,0 +1,204 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+import numpy as np
+import torch
+from torch.nn.functional import normalize
+
+from . import get_model
+from models.base import BaseModel
+# from models.bev_net import BEVNet
+# from models.bev_projection import CartesianProjection, PolarProjectionDepth
+from models.voting import (
+ argmax_xyr,
+ conv2d_fft_batchwise,
+ expectation_xyr,
+ log_softmax_spatial,
+ mask_yaw_prior,
+ nll_loss_xyr,
+ nll_loss_xyr_smoothed,
+ TemplateSampler,
+ UAVTemplateSampler,
+ UAVTemplateSamplerFast
+)
+from .map_encoder import MapEncoder
+from .metrics import AngleError, AngleRecall, Location2DError, Location2DRecall
+
+
+class MapLocNet(BaseModel):
+ default_conf = {
+ "image_size": "???",
+ "val_citys":"???",
+ "image_encoder": "???",
+ "map_encoder": "???",
+ "bev_net": "???",
+ "latent_dim": "???",
+ "matching_dim": "???",
+ "scale_range": [0, 9],
+ "num_scale_bins": "???",
+ "z_min": None,
+ "z_max": "???",
+ "x_max": "???",
+ "pixel_per_meter": "???",
+ "num_rotations": "???",
+ "add_temperature": False,
+ "normalize_features": False,
+ "padding_matching": "replicate",
+ "apply_map_prior": True,
+ "do_label_smoothing": False,
+ "sigma_xy": 1,
+ "sigma_r": 2,
+ # depcreated
+ "depth_parameterization": "scale",
+ "norm_depth_scores": False,
+ "normalize_scores_by_dim": False,
+ "normalize_scores_by_num_valid": True,
+ "prior_renorm": True,
+ "retrieval_dim": None,
+ }
+
+ def _init(self, conf):
+ assert not self.conf.norm_depth_scores
+ assert self.conf.depth_parameterization == "scale"
+ assert not self.conf.normalize_scores_by_dim
+ assert self.conf.normalize_scores_by_num_valid
+ assert self.conf.prior_renorm
+
+ Encoder = get_model(conf.image_encoder.get("name", "feature_extractor_v2"))
+ self.image_encoder = Encoder(conf.image_encoder.backbone)
+ self.map_encoder = MapEncoder(conf.map_encoder)
+ # self.bev_net = None if conf.bev_net is None else BEVNet(conf.bev_net)
+
+ ppm = conf.pixel_per_meter
+ # self.projection_polar = PolarProjectionDepth(
+ # conf.z_max,
+ # ppm,
+ # conf.scale_range,
+ # conf.z_min,
+ # )
+ # self.projection_bev = CartesianProjection(
+ # conf.z_max, conf.x_max, ppm, conf.z_min
+ # )
+ # self.template_sampler = TemplateSampler(
+ # self.projection_bev.grid_xz, ppm, conf.num_rotations
+ # )
+ # self.template_sampler = UAVTemplateSamplerFast(conf.num_rotations,w=conf.image_size//2)
+ self.template_sampler = UAVTemplateSampler(conf.num_rotations)
+ # self.scale_classifier = torch.nn.Linear(conf.latent_dim, conf.num_scale_bins)
+ # if conf.bev_net is None:
+ # self.feature_projection = torch.nn.Linear(
+ # conf.latent_dim, conf.matching_dim
+ # )
+ if conf.add_temperature:
+ temperature = torch.nn.Parameter(torch.tensor(0.0))
+ self.register_parameter("temperature", temperature)
+
+ def exhaustive_voting(self, f_bev, f_map):
+ if self.conf.normalize_features:
+ f_bev = normalize(f_bev, dim=1)
+ f_map = normalize(f_map, dim=1)
+
+ # Build the templates and exhaustively match against the map.
+ # if confidence_bev is not None:
+ # f_bev = f_bev * confidence_bev.unsqueeze(1)
+ # f_bev = f_bev.masked_fill(~valid_bev.unsqueeze(1), 0.0)
+ # torch.save(f_bev, 'f_bev.pt')
+ # torch.save(f_map, 'f_map.pt')
+
+ templates = self.template_sampler(f_bev)#[batch,256,8,129,129]
+ # torch.save(templates, 'templates.pt')
+ with torch.autocast("cuda", enabled=False):
+ scores = conv2d_fft_batchwise(
+ f_map.float(),
+ templates.float(),
+ padding_mode=self.conf.padding_matching,
+ )
+ if self.conf.add_temperature:
+ scores = scores * torch.exp(self.temperature)
+
+ # Reweight the different rotations based on the number of valid pixels
+ # in each template. Axis-aligned rotation have the maximum number of valid pixels.
+ # valid_templates = self.template_sampler(valid_bev.float()[None]) > (1 - 1e-4)
+ # num_valid = valid_templates.float().sum((-3, -2, -1))
+ # scores = scores / num_valid[..., None, None]
+ return scores
+
+ def _forward(self, data):
+ pred = {}
+ pred_map = pred["map"] = self.map_encoder(data)
+ f_map = pred_map["map_features"][0]#[batch,8,256,256]
+
+ # Extract image features.
+ level = 0
+ f_image = self.image_encoder(data)["feature_maps"][level]#[batch,128,128,176]
+ # print("f_map:",f_map.shape)
+
+ scores = self.exhaustive_voting(f_image, f_map)#f_bev:[batch,8,64,129] f_map:[batch,8,256,256] confidence:[1,64,129]
+ scores = scores.moveaxis(1, -1) # B,H,W,N
+ if "log_prior" in pred_map and self.conf.apply_map_prior:
+ scores = scores + pred_map["log_prior"][0].unsqueeze(-1)
+ # pred["scores_unmasked"] = scores.clone()
+ if "map_mask" in data:
+ scores.masked_fill_(~data["map_mask"][..., None], -np.inf)
+ if "yaw_prior" in data:
+ mask_yaw_prior(scores, data["yaw_prior"], self.conf.num_rotations)
+ log_probs = log_softmax_spatial(scores)
+ # torch.save(scores, 'scores.pt')
+ with torch.no_grad():
+ uvr_max = argmax_xyr(scores).to(scores)
+ uvr_avg, _ = expectation_xyr(log_probs.exp())
+
+ return {
+ **pred,
+ "scores": scores,
+ "log_probs": log_probs,
+ "uvr_max": uvr_max,
+ "uv_max": uvr_max[..., :2],
+ "yaw_max": uvr_max[..., 2],
+ "uvr_expectation": uvr_avg,
+ "uv_expectation": uvr_avg[..., :2],
+ "yaw_expectation": uvr_avg[..., 2],
+ "features_image": f_image,
+ }
+
+ def loss(self, pred, data):
+ xy_gt = data["uv"]
+ yaw_gt = data["roll_pitch_yaw"][..., -1]
+ if self.conf.do_label_smoothing:
+ nll = nll_loss_xyr_smoothed(
+ pred["log_probs"],
+ xy_gt,
+ yaw_gt,
+ self.conf.sigma_xy / self.conf.pixel_per_meter,
+ self.conf.sigma_r,
+ mask=data.get("map_mask"),
+ )
+ else:
+ nll = nll_loss_xyr(pred["log_probs"], xy_gt, yaw_gt)
+ loss = {"total": nll, "nll": nll}
+ if self.training and self.conf.add_temperature:
+ loss["temperature"] = self.temperature.expand(len(nll))
+ return loss
+
+ def metrics(self):
+ return {
+ "xy_max_error": Location2DError("uv_max", self.conf.pixel_per_meter),
+ "xy_expectation_error": Location2DError(
+ "uv_expectation", self.conf.pixel_per_meter
+ ),
+ "yaw_max_error": AngleError("yaw_max"),
+ "xy_recall_1m": Location2DRecall(1.0, self.conf.pixel_per_meter, "uv_max"),
+ "xy_recall_3m": Location2DRecall(3.0, self.conf.pixel_per_meter, "uv_max"),
+ "xy_recall_5m": Location2DRecall(5.0, self.conf.pixel_per_meter, "uv_max"),
+
+ # "x_recall_1m": Location2DRecall(1.0, self.conf.pixel_per_meter, "uv_max"),
+ # "x_recall_3m": Location2DRecall(3.0, self.conf.pixel_per_meter, "uv_max"),
+ # "x_recall_5m": Location2DRecall(5.0, self.conf.pixel_per_meter, "uv_max"),
+ #
+ # "y_recall_1m": Location2DRecall(1.0, self.conf.pixel_per_meter, "uv_max"),
+ # "y_recall_3m": Location2DRecall(3.0, self.conf.pixel_per_meter, "uv_max"),
+ # "y_recall_5m": Location2DRecall(5.0, self.conf.pixel_per_meter, "uv_max"),
+
+ "yaw_recall_1°": AngleRecall(1.0, "yaw_max"),
+ "yaw_recall_3°": AngleRecall(3.0, "yaw_max"),
+ "yaw_recall_5°": AngleRecall(5.0, "yaw_max"),
+ }
diff --git a/models/metrics.py b/models/metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..b50a724f4719853b476d693db25ddbba562a3a51
--- /dev/null
+++ b/models/metrics.py
@@ -0,0 +1,118 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+import torch
+import torchmetrics
+from torchmetrics.utilities.data import dim_zero_cat
+
+from .utils import deg2rad, rotmat2d
+
+
+def location_error(uv, uv_gt, ppm=1):
+ return torch.norm(uv - uv_gt.to(uv), dim=-1) / ppm
+
+def location_error_single(uv, uv_gt, ppm=1):
+ return torch.norm(uv - uv_gt.to(uv), dim=-1) / ppm
+
+def angle_error(t, t_gt):
+ error = torch.abs(t % 360 - t_gt.to(t) % 360)
+ error = torch.minimum(error, 360 - error)
+ return error
+
+
+class Location2DRecall(torchmetrics.MeanMetric):
+ def __init__(self, threshold, pixel_per_meter, key="uv_max", *args, **kwargs):
+ self.threshold = threshold
+ self.ppm = pixel_per_meter
+ self.key = key
+ super().__init__(*args, **kwargs)
+
+ def update(self, pred, data):
+ self.cuda()
+ error = location_error(pred[self.key], data["uv"], self.ppm)
+ # print(error,self.threshold)
+ super().update((error <= torch.tensor(self.threshold,device=error.device)).float())
+
+class Location1DRecall(torchmetrics.MeanMetric):
+ def __init__(self, threshold, pixel_per_meter, key="uv_max", *args, **kwargs):
+ self.threshold = threshold
+ self.ppm = pixel_per_meter
+ self.key = key
+ super().__init__(*args, **kwargs)
+
+ def update(self, pred, data):
+ self.cuda()
+ error = location_error(pred[self.key], data["uv"], self.ppm)
+ # print(error,self.threshold)
+ super().update((error <= torch.tensor(self.threshold,device=error.device)).float())
+class AngleRecall(torchmetrics.MeanMetric):
+ def __init__(self, threshold, key="yaw_max", *args, **kwargs):
+ self.threshold = threshold
+ self.key = key
+
+ super().__init__(*args, **kwargs)
+
+ def update(self, pred, data):
+ self.cuda()
+ error = angle_error(pred[self.key], data["roll_pitch_yaw"][..., -1])
+ super().update((error <= self.threshold).float())
+
+
+class MeanMetricWithRecall(torchmetrics.Metric):
+ full_state_update = True
+
+ def __init__(self):
+ super().__init__()
+ self.add_state("value", default=[], dist_reduce_fx="cat")
+ def compute(self):
+ return dim_zero_cat(self.value).mean(0)
+
+ def get_errors(self):
+ return dim_zero_cat(self.value)
+
+ def recall(self, thresholds):
+ self.cuda()
+ error = self.get_errors()
+ thresholds = error.new_tensor(thresholds)
+ return (error.unsqueeze(-1) < thresholds).float().mean(0) * 100
+
+
+class AngleError(MeanMetricWithRecall):
+ def __init__(self, key):
+ super().__init__()
+ self.key = key
+
+ def update(self, pred, data):
+ self.cuda()
+ value = angle_error(pred[self.key], data["roll_pitch_yaw"][..., -1])
+ if value.numel():
+ self.value.append(value)
+
+
+class Location2DError(MeanMetricWithRecall):
+ def __init__(self, key, pixel_per_meter):
+ super().__init__()
+ self.key = key
+ self.ppm = pixel_per_meter
+
+ def update(self, pred, data):
+ self.cuda()
+ value = location_error(pred[self.key], data["uv"], self.ppm)
+ if value.numel():
+ self.value.append(value)
+
+
+class LateralLongitudinalError(MeanMetricWithRecall):
+ def __init__(self, pixel_per_meter, key="uv_max"):
+ super().__init__()
+ self.ppm = pixel_per_meter
+ self.key = key
+
+ def update(self, pred, data):
+ self.cuda()
+ yaw = deg2rad(data["roll_pitch_yaw"][..., -1])
+ shift = (pred[self.key] - data["uv"]) * yaw.new_tensor([-1, 1])
+ shift = (rotmat2d(yaw) @ shift.unsqueeze(-1)).squeeze(-1)
+ error = torch.abs(shift) / self.ppm
+ value = error.view(-1, 2)
+ if value.numel():
+ self.value.append(value)
diff --git a/models/utils.py b/models/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ec246b7d0adf04cc9307475867650523f67a5063
--- /dev/null
+++ b/models/utils.py
@@ -0,0 +1,87 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+import math
+from typing import Optional
+
+import torch
+
+
+def checkpointed(cls, do=True):
+ """Adapted from the DISK implementation of Michał Tyszkiewicz."""
+ assert issubclass(cls, torch.nn.Module)
+
+ class Checkpointed(cls):
+ def forward(self, *args, **kwargs):
+ super_fwd = super(Checkpointed, self).forward
+ if any((torch.is_tensor(a) and a.requires_grad) for a in args):
+ return torch.utils.checkpoint.checkpoint(super_fwd, *args, **kwargs)
+ else:
+ return super_fwd(*args, **kwargs)
+
+ return Checkpointed if do else cls
+
+
+class GlobalPooling(torch.nn.Module):
+ def __init__(self, kind):
+ super().__init__()
+ if kind == "mean":
+ self.fn = torch.nn.Sequential(
+ torch.nn.Flatten(2), torch.nn.AdaptiveAvgPool1d(1), torch.nn.Flatten()
+ )
+ elif kind == "max":
+ self.fn = torch.nn.Sequential(
+ torch.nn.Flatten(2), torch.nn.AdaptiveMaxPool1d(1), torch.nn.Flatten()
+ )
+ else:
+ raise ValueError(f"Unknown pooling type {kind}.")
+
+ def forward(self, x):
+ return self.fn(x)
+
+
+@torch.jit.script
+def make_grid(
+ w: float,
+ h: float,
+ step_x: float = 1.0,
+ step_y: float = 1.0,
+ orig_x: float = 0,
+ orig_y: float = 0,
+ y_up: bool = False,
+ device: Optional[torch.device] = None,
+) -> torch.Tensor:
+ x, y = torch.meshgrid(
+ [
+ torch.arange(orig_x, w + orig_x, step_x, device=device),
+ torch.arange(orig_y, h + orig_y, step_y, device=device),
+ ],
+ indexing="xy",
+ )
+ if y_up:
+ y = y.flip(-2)
+ grid = torch.stack((x, y), -1)
+ return grid
+
+
+@torch.jit.script
+def rotmat2d(angle: torch.Tensor) -> torch.Tensor:
+ c = torch.cos(angle)
+ s = torch.sin(angle)
+ R = torch.stack([c, -s, s, c], -1).reshape(angle.shape + (2, 2))
+ return R
+
+
+@torch.jit.script
+def rotmat2d_grad(angle: torch.Tensor) -> torch.Tensor:
+ c = torch.cos(angle)
+ s = torch.sin(angle)
+ R = torch.stack([-s, -c, c, -s], -1).reshape(angle.shape + (2, 2))
+ return R
+
+
+def deg2rad(x):
+ return x * math.pi / 180
+
+
+def rad2deg(x):
+ return x * 180 / math.pi
diff --git a/models/voting.py b/models/voting.py
new file mode 100644
index 0000000000000000000000000000000000000000..b57bc1e86d6f738c060f7ef0fea3698f4fc13dd6
--- /dev/null
+++ b/models/voting.py
@@ -0,0 +1,365 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+from typing import Optional, Tuple
+
+import numpy as np
+import torch
+from torch.fft import irfftn, rfftn
+from torch.nn.functional import grid_sample, log_softmax, pad
+
+from .metrics import angle_error
+from .utils import make_grid, rotmat2d
+from torchvision.transforms.functional import rotate
+
+class UAVTemplateSamplerFast(torch.nn.Module):
+ def __init__(self, num_rotations,w=128,optimize=True):
+ super().__init__()
+
+ h, w = w,w
+ grid_xy = make_grid(
+ w=w,
+ h=h,
+ step_x=1,
+ step_y=1,
+ orig_y=-h//2,
+ orig_x=-h//2,
+ y_up=True,
+ ).cuda()
+
+ if optimize:
+ assert (num_rotations % 4) == 0
+ angles = torch.arange(
+ 0, 90, 90 / (num_rotations // 4)
+ ).cuda()
+ else:
+ angles = torch.arange(
+ 0, 360, 360 / num_rotations, device=grid_xz_bev.device
+ )
+ rotmats = rotmat2d(angles / 180 * np.pi)
+ grid_xy_rot = torch.einsum("...nij,...hwj->...nhwi", rotmats, grid_xy)
+
+ grid_ij_rot = (grid_xy_rot - grid_xy[..., :1, :1, :]) * grid_xy.new_tensor(
+ [1, -1]
+ )
+ grid_ij_rot = grid_ij_rot
+ grid_norm = (grid_ij_rot + 0.5) / grid_ij_rot.new_tensor([w, h]) * 2 - 1
+
+ self.optimize = optimize
+ self.num_rots = num_rotations
+ self.register_buffer("angles", angles, persistent=False)
+ self.register_buffer("grid_norm", grid_norm, persistent=False)
+
+ def forward(self, image_bev):
+ grid = self.grid_norm
+ b, c = image_bev.shape[:2]
+ n, h, w = grid.shape[:3]
+ grid = grid[None].repeat_interleave(b, 0).reshape(b * n, h, w, 2)
+ image = (
+ image_bev[:, None]
+ .repeat_interleave(n, 1)
+ .reshape(b * n, *image_bev.shape[1:])
+ )
+ # print(image.shape,grid.shape,self.grid_norm.shape)
+ kernels = grid_sample(image, grid.to(image.dtype), align_corners=False).reshape(
+ b, n, c, h, w
+ )
+
+ if self.optimize: # we have computed only the first quadrant
+ kernels_quad234 = [torch.rot90(kernels, -i, (-2, -1)) for i in (1, 2, 3)]
+ kernels = torch.cat([kernels] + kernels_quad234, 1)
+
+ return kernels
+class UAVTemplateSampler(torch.nn.Module):
+ def __init__(self, num_rotations):
+ super().__init__()
+
+ self.num_rotations = num_rotations
+
+ def Template(self, input_features):
+ # 角度数量
+ num_angles = self.num_rotations
+ # 扩展第二个维度为旋转角度数量
+ input_shape = torch.tensor(input_features.shape)
+ output_shape = torch.cat((input_shape[:1], torch.tensor([num_angles]), input_shape[1:])).tolist()
+ expanded_features = torch.zeros(output_shape,device=input_features.device)
+
+ # 生成旋转角度序列
+ rotation_angles = torch.linspace(360, 0, 64 + 1)[:-1]
+ # rotation_angles=torch.flip(rotation_angles, dims=[0])
+ # 对扩展后的特征应用不同的旋转角度
+ rotated_features = []
+ # print(len(rotation_angles))
+ for i in range(len(rotation_angles)):
+ # print(rotation_angles[i].item())
+ rotated_feature = rotate(input_features, rotation_angles[i].item(), fill=0)
+ expanded_features[:, i, :, :, :] = rotated_feature
+
+ # 将所有旋转后的特征堆叠起来形成最终的输出向量
+ # output_features = torch.stack(rotated_features, dim=1)
+
+ # 输出向量的维度
+ # output_size = [3, num_angles, 8, 128, 128]
+ return expanded_features # 输出调试信息,验证输出向量的维度是否正确
+ def forward(self, image_bev):
+
+ kernels=self.Template(image_bev)
+
+ return kernels
+class TemplateSampler(torch.nn.Module):
+ def __init__(self, grid_xz_bev, ppm, num_rotations, optimize=True):
+ super().__init__()
+
+ Δ = 1 / ppm
+ h, w = grid_xz_bev.shape[:2]
+ ksize = max(w, h * 2 + 1)
+ radius = ksize * Δ
+ grid_xy = make_grid(
+ radius,
+ radius,
+ step_x=Δ,
+ step_y=Δ,
+ orig_y=(Δ - radius) / 2,
+ orig_x=(Δ - radius) / 2,
+ y_up=True,
+ )
+
+ if optimize:
+ assert (num_rotations % 4) == 0
+ angles = torch.arange(
+ 0, 90, 90 / (num_rotations // 4), device=grid_xz_bev.device
+ )
+ else:
+ angles = torch.arange(
+ 0, 360, 360 / num_rotations, device=grid_xz_bev.device
+ )
+ rotmats = rotmat2d(angles / 180 * np.pi)
+ grid_xy_rot = torch.einsum("...nij,...hwj->...nhwi", rotmats, grid_xy)
+
+ grid_ij_rot = (grid_xy_rot - grid_xz_bev[..., :1, :1, :]) * grid_xy.new_tensor(
+ [1, -1]
+ )
+ grid_ij_rot = grid_ij_rot / Δ
+ grid_norm = (grid_ij_rot + 0.5) / grid_ij_rot.new_tensor([w, h]) * 2 - 1
+
+ self.optimize = optimize
+ self.num_rots = num_rotations
+ self.register_buffer("angles", angles, persistent=False)
+ self.register_buffer("grid_norm", grid_norm, persistent=False)
+
+ def forward(self, image_bev):
+ grid = self.grid_norm
+ b, c = image_bev.shape[:2]
+ n, h, w = grid.shape[:3]
+ grid = grid[None].repeat_interleave(b, 0).reshape(b * n, h, w, 2)
+ image = (
+ image_bev[:, None]
+ .repeat_interleave(n, 1)
+ .reshape(b * n, *image_bev.shape[1:])
+ )
+ kernels = grid_sample(image, grid.to(image.dtype), align_corners=False).reshape(
+ b, n, c, h, w
+ )
+
+ if self.optimize: # we have computed only the first quadrant
+ kernels_quad234 = [torch.rot90(kernels, -i, (-2, -1)) for i in (1, 2, 3)]
+ kernels = torch.cat([kernels] + kernels_quad234, 1)
+
+ return kernels
+
+
+def conv2d_fft_batchwise(signal, kernel, padding="same", padding_mode="constant"):
+ if padding == "same":
+ padding = [i // 2 for i in kernel.shape[-2:]]
+ padding_signal = [p for p in padding[::-1] for _ in range(2)]
+ signal = pad(signal, padding_signal, mode=padding_mode)
+ assert signal.size(-1) % 2 == 0
+
+ padding_kernel = [
+ pad for i in [1, 2] for pad in [0, signal.size(-i) - kernel.size(-i)]
+ ]
+ kernel_padded = pad(kernel, padding_kernel)
+
+ signal_fr = rfftn(signal, dim=(-1, -2))
+ kernel_fr = rfftn(kernel_padded, dim=(-1, -2))
+
+ kernel_fr.imag *= -1 # flip the kernel
+ output_fr = torch.einsum("bc...,bdc...->bd...", signal_fr, kernel_fr)
+ output = irfftn(output_fr, dim=(-1, -2))
+
+ crop_slices = [slice(0, output.size(0)), slice(0, output.size(1))] + [
+ slice(0, (signal.size(i) - kernel.size(i) + 1)) for i in [-2, -1]
+ ]
+ output = output[crop_slices].contiguous()
+
+ return output
+
+
+class SparseMapSampler(torch.nn.Module):
+ def __init__(self, num_rotations):
+ super().__init__()
+ angles = torch.arange(0, 360, 360 / self.conf.num_rotations)
+ rotmats = rotmat2d(angles / 180 * np.pi)
+ self.num_rotations = num_rotations
+ self.register_buffer("rotmats", rotmats, persistent=False)
+
+ def forward(self, image_map, p2d_bev):
+ h, w = image_map.shape[-2:]
+ locations = make_grid(w, h, device=p2d_bev.device)
+ p2d_candidates = torch.einsum(
+ "kji,...i,->...kj", self.rotmats.to(p2d_bev), p2d_bev
+ )
+ p2d_candidates = p2d_candidates[..., None, None, :, :] + locations.unsqueeze(-1)
+ # ... x N x W x H x K x 2
+
+ p2d_norm = (p2d_candidates / (image_map.new_tensor([w, h]) - 1)) * 2 - 1
+ valid = torch.all((p2d_norm >= -1) & (p2d_norm <= 1), -1)
+ value = grid_sample(
+ image_map, p2d_norm.flatten(-4, -2), align_corners=True, mode="bilinear"
+ )
+ value = value.reshape(image_map.shape[:2] + valid.shape[-4])
+ return valid, value
+
+
+def sample_xyr(volume, xy_grid, angle_grid, nearest_for_inf=False):
+ # (B, C, H, W, N) to (B, C, H, W, N+1)
+ volume_padded = pad(volume, [0, 1, 0, 0, 0, 0], mode="circular")
+
+ size = xy_grid.new_tensor(volume.shape[-3:-1][::-1])
+ xy_norm = xy_grid / (size - 1) # align_corners=True
+ angle_norm = (angle_grid / 360) % 1
+ grid = torch.concat([angle_norm.unsqueeze(-1), xy_norm], -1)
+ grid_norm = grid * 2 - 1
+
+ valid = torch.all((grid_norm >= -1) & (grid_norm <= 1), -1)
+ value = grid_sample(volume_padded, grid_norm, align_corners=True, mode="bilinear")
+
+ # if one of the values used for linear interpolation is infinite,
+ # we fallback to nearest to avoid propagating inf
+ if nearest_for_inf:
+ value_nearest = grid_sample(
+ volume_padded, grid_norm, align_corners=True, mode="nearest"
+ )
+ value = torch.where(~torch.isfinite(value) & valid, value_nearest, value)
+
+ return value, valid
+
+
+def nll_loss_xyr(log_probs, xy, angle):
+ log_prob, _ = sample_xyr(
+ log_probs.unsqueeze(1), xy[:, None, None, None], angle[:, None, None, None]
+ )
+ nll = -log_prob.reshape(-1) # remove C,H,W,N
+ return nll
+
+
+def nll_loss_xyr_smoothed(log_probs, xy, angle, sigma_xy, sigma_r, mask=None):
+ *_, nx, ny, nr = log_probs.shape
+ grid_x = torch.arange(nx, device=log_probs.device, dtype=torch.float)
+ dx = (grid_x - xy[..., None, 0]) / sigma_xy
+ grid_y = torch.arange(ny, device=log_probs.device, dtype=torch.float)
+ dy = (grid_y - xy[..., None, 1]) / sigma_xy
+ dr = (
+ torch.arange(0, 360, 360 / nr, device=log_probs.device, dtype=torch.float)
+ - angle[..., None]
+ ) % 360
+ dr = torch.minimum(dr, 360 - dr) / sigma_r
+ diff = (
+ dx[..., None, :, None] ** 2
+ + dy[..., :, None, None] ** 2
+ + dr[..., None, None, :] ** 2
+ )
+ pdf = torch.exp(-diff / 2)
+ if mask is not None:
+ pdf.masked_fill_(~mask[..., None], 0)
+ log_probs = log_probs.masked_fill(~mask[..., None], 0)
+ pdf /= pdf.sum((-1, -2, -3), keepdim=True)
+ return -torch.sum(pdf * log_probs.to(torch.float), dim=(-1, -2, -3))
+
+
+def log_softmax_spatial(x, dims=3):
+ return log_softmax(x.flatten(-dims), dim=-1).reshape(x.shape)
+
+
+@torch.jit.script
+def argmax_xy(scores: torch.Tensor) -> torch.Tensor:
+ indices = scores.flatten(-2).max(-1).indices
+ width = scores.shape[-1]
+ x = indices % width
+ y = torch.div(indices, width, rounding_mode="floor")
+ return torch.stack((x, y), -1)
+
+
+@torch.jit.script
+def expectation_xy(prob: torch.Tensor) -> torch.Tensor:
+ h, w = prob.shape[-2:]
+ grid = make_grid(float(w), float(h), device=prob.device).to(prob)
+ return torch.einsum("...hw,hwd->...d", prob, grid)
+
+
+@torch.jit.script
+def expectation_xyr(
+ prob: torch.Tensor, covariance: bool = False
+) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ h, w, num_rotations = prob.shape[-3:]
+ x, y = torch.meshgrid(
+ [
+ torch.arange(w, device=prob.device, dtype=prob.dtype),
+ torch.arange(h, device=prob.device, dtype=prob.dtype),
+ ],
+ indexing="xy",
+ )
+ grid_xy = torch.stack((x, y), -1)
+ xy_mean = torch.einsum("...hwn,hwd->...d", prob, grid_xy)
+
+ angles = torch.arange(0, 1, 1 / num_rotations, device=prob.device, dtype=prob.dtype)
+ angles = angles * 2 * np.pi
+ grid_cs = torch.stack([torch.cos(angles), torch.sin(angles)], -1)
+ cs_mean = torch.einsum("...hwn,nd->...d", prob, grid_cs)
+ angle = torch.atan2(cs_mean[..., 1], cs_mean[..., 0])
+ angle = (angle * 180 / np.pi) % 360
+
+ if covariance:
+ xy_cov = torch.einsum("...hwn,...hwd,...hwk->...dk", prob, grid_xy, grid_xy)
+ xy_cov = xy_cov - torch.einsum("...d,...k->...dk", xy_mean, xy_mean)
+ else:
+ xy_cov = None
+
+ xyr_mean = torch.cat((xy_mean, angle.unsqueeze(-1)), -1)
+ return xyr_mean, xy_cov
+
+
+@torch.jit.script
+def argmax_xyr(scores: torch.Tensor) -> torch.Tensor:
+ indices = scores.flatten(-3).max(-1).indices
+ width, num_rotations = scores.shape[-2:]
+ wr = width * num_rotations
+ y = torch.div(indices, wr, rounding_mode="floor")
+ x = torch.div(indices % wr, num_rotations, rounding_mode="floor")
+ angle_index = indices % num_rotations
+ angle = angle_index * 360 / num_rotations
+ xyr = torch.stack((x, y, angle), -1)
+ return xyr
+
+
+@torch.jit.script
+def mask_yaw_prior(
+ scores: torch.Tensor, yaw_prior: torch.Tensor, num_rotations: int
+) -> torch.Tensor:
+ step = 360 / num_rotations
+ step_2 = step / 2
+ angles = torch.arange(step_2, 360 + step_2, step, device=scores.device)
+ yaw_init, yaw_range = yaw_prior.chunk(2, dim=-1)
+ rot_mask = angle_error(angles, yaw_init) < yaw_range
+ return scores.masked_fill_(~rot_mask[:, None, None], -np.inf)
+
+
+def fuse_gps(log_prob, uv_gps, ppm, sigma=10, gaussian=False):
+ grid = make_grid(*log_prob.shape[-3:-1][::-1]).to(log_prob)
+ dist = torch.sum((grid - uv_gps) ** 2, -1)
+ sigma_pixel = sigma * ppm
+ if gaussian:
+ gps_log_prob = -1 / 2 * dist / sigma_pixel**2
+ else:
+ gps_log_prob = torch.where(dist < sigma_pixel**2, 1, -np.inf)
+ log_prob_fused = log_softmax_spatial(log_prob + gps_log_prob.unsqueeze(-1))
+ return log_prob_fused
diff --git a/module.py b/module.py
new file mode 100644
index 0000000000000000000000000000000000000000..47bc341f83aba111638a5de1fb8c6d88ed900df7
--- /dev/null
+++ b/module.py
@@ -0,0 +1,171 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+from pathlib import Path
+
+import pytorch_lightning as pl
+import torch
+from omegaconf import DictConfig, OmegaConf, open_dict
+from torchmetrics import MeanMetric, MetricCollection
+
+import logger
+from models import get_model
+
+
+class AverageKeyMeter(MeanMetric):
+ def __init__(self, key, *args, **kwargs):
+ self.key = key
+ super().__init__(*args, **kwargs)
+
+ def update(self, dict):
+ value = dict[self.key]
+ value = value[torch.isfinite(value)]
+ return super().update(value)
+
+
+class GenericModule(pl.LightningModule):
+ def __init__(self, cfg):
+ super().__init__()
+ name = cfg.model.get("name")
+ name = "orienternet" if name in ("localizer_bev_depth", None) else name
+ self.model = get_model(name)(cfg.model)
+ self.cfg = cfg
+ self.save_hyperparameters(cfg)
+
+
+
+ self.metrics_val = MetricCollection(self.model.metrics(), prefix="val/")
+ self.losses_val = None # we do not know the loss keys in advance
+
+ # self.citys = self.cfg.data.val_citys
+ # for i in range(len(self.citys)):
+ # city=self.citys[i]
+ # setattr(self, "metric_vals_{}".format(i), MetricCollection(self.model.metrics(), prefix="val_{}/".format(city)))
+ # self.losse_vals = [None for city in self.cfg.data.val_citys]
+
+
+ def forward(self, batch):
+ return self.model(batch)
+
+ def training_step(self, batch):
+ pred = self(batch)
+ losses = self.model.loss(pred, batch)
+ self.log_dict(
+ {f"loss/{k}/train": v.mean() for k, v in losses.items()},
+ prog_bar=True,
+ rank_zero_only=True,
+ )
+ return losses["total"].mean()
+
+ # def validation_step(self, batch, batch_idx,dataloader_idx):
+ # city=self.citys[dataloader_idx]
+ #
+ # pred = self(batch)
+ # losses = self.model.loss(pred, batch)
+ #
+ # if hasattr(self,"losse_val_{}".format(dataloader_idx)) is False:
+ # setattr(self,"losse_val_{}".format(dataloader_idx),MetricCollection(
+ # {k: AverageKeyMeter(k).to(self.device) for k in losses},
+ # prefix="loss_{}/".format(city),
+ # postfix="/val_{}".format(city),
+ # ))
+ #
+ # # print(pred, batch)
+ # getattr(self,"metric_vals_{}".format(dataloader_idx))(pred, batch)
+ # self.log_dict(getattr(self,"metric_vals_{}".format(dataloader_idx))(pred, batch), sync_dist=True)
+ #
+ # getattr(self,"losse_val_{}".format(dataloader_idx)).update(losses)
+ # # print(getattr(self,"losse_val_{}".format(dataloader_idx)))
+ # self.log_dict(getattr(self,"losse_val_{}".format(dataloader_idx)).compute(), sync_dist=True)
+ def validation_step(self, batch, batch_idx):
+ pred = self(batch)
+ losses = self.model.loss(pred, batch)
+ if self.losses_val is None:
+ self.losses_val = MetricCollection(
+ {k: AverageKeyMeter(k).to(self.device) for k in losses},
+ prefix="loss/",
+ postfix="/val",
+ )
+ self.metrics_val(pred, batch)
+ self.log_dict(self.metrics_val, sync_dist=True)
+ self.losses_val.update(losses)
+ self.log_dict(self.losses_val, sync_dist=True)
+
+ def validation_epoch_start(self, batch):
+ self.losses_val = None
+ # self.losse_val = [None for city in self.cfg.data.val_citys]
+
+ def configure_optimizers(self):
+ optimizer = torch.optim.Adam(self.parameters(), lr=self.cfg.training.lr)
+ ret = {"optimizer": optimizer}
+ cfg_scheduler = self.cfg.training.get("lr_scheduler")
+ if cfg_scheduler is not None:
+ scheduler = getattr(torch.optim.lr_scheduler, cfg_scheduler.name)(
+ optimizer=optimizer, **cfg_scheduler.get("args", {})
+ )
+ ret["lr_scheduler"] = {
+ "scheduler": scheduler,
+ "interval": "epoch",
+ "frequency": 1,
+ "monitor": "loss/total/val",
+ "strict": True,
+ "name": "learning_rate",
+ }
+ return ret
+
+ @classmethod
+ def load_from_checkpoint(
+ cls,
+ checkpoint_path,
+ map_location=None,
+ hparams_file=None,
+ strict=True,
+ cfg=None,
+ find_best=False,
+ ):
+ assert hparams_file is None, "hparams are not supported."
+
+ checkpoint = torch.load(
+ checkpoint_path, map_location=map_location or (lambda storage, loc: storage)
+ )
+ if find_best:
+ best_score, best_name = None, None
+ modes = {"min": torch.lt, "max": torch.gt}
+ for key, state in checkpoint["callbacks"].items():
+ if not key.startswith("ModelCheckpoint"):
+ continue
+ mode = eval(key.replace("ModelCheckpoint", ""))["mode"]
+ if best_score is None or modes[mode](
+ state["best_model_score"], best_score
+ ):
+ best_score = state["best_model_score"]
+ best_name = Path(state["best_model_path"]).name
+ logger.info("Loading best checkpoint %s", best_name)
+ if best_name != checkpoint_path:
+ return cls.load_from_checkpoint(
+ Path(checkpoint_path).parent / best_name,
+ map_location,
+ hparams_file,
+ strict,
+ cfg,
+ find_best=False,
+ )
+
+ logger.info(
+ "Using checkpoint %s from epoch %d and step %d.",
+ checkpoint_path.name,
+ checkpoint["epoch"],
+ checkpoint["global_step"],
+ )
+ cfg_ckpt = checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY]
+ if list(cfg_ckpt.keys()) == ["cfg"]: # backward compatibility
+ cfg_ckpt = cfg_ckpt["cfg"]
+ cfg_ckpt = OmegaConf.create(cfg_ckpt)
+
+ if cfg is None:
+ cfg = {}
+ if not isinstance(cfg, DictConfig):
+ cfg = OmegaConf.create(cfg)
+ with open_dict(cfg_ckpt):
+ cfg = OmegaConf.merge(cfg_ckpt, cfg)
+
+ return pl.core.saving._load_state(cls, checkpoint, strict=strict, cfg=cfg)
diff --git a/osm/analysis.py b/osm/analysis.py
new file mode 100644
index 0000000000000000000000000000000000000000..a667c21373a31482f7bbcfb41d4fa14681741260
--- /dev/null
+++ b/osm/analysis.py
@@ -0,0 +1,182 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+from collections import Counter, defaultdict
+from typing import Dict
+
+import matplotlib.pyplot as plt
+import numpy as np
+import plotly.graph_objects as go
+
+from .parser import (
+ filter_area,
+ filter_node,
+ filter_way,
+ match_to_group,
+ parse_area,
+ parse_node,
+ parse_way,
+ Patterns,
+)
+from .reader import OSMData
+
+
+def recover_hierarchy(counter: Counter) -> Dict:
+ """Recover a two-level hierarchy from the flat group labels."""
+ groups = defaultdict(dict)
+ for k, v in sorted(counter.items(), key=lambda x: -x[1]):
+ if ":" in k:
+ prefix, group = k.split(":")
+ if prefix in groups and isinstance(groups[prefix], int):
+ groups[prefix] = {}
+ groups[prefix][prefix] = groups[prefix]
+ groups[prefix] = {}
+ groups[prefix][group] = v
+ else:
+ groups[k] = v
+ return dict(groups)
+
+
+def bar_autolabel(rects, fontsize):
+ """Attach a text label above each bar in *rects*, displaying its height."""
+ for rect in rects:
+ width = rect.get_width()
+ plt.gca().annotate(
+ f"{width}",
+ xy=(width, rect.get_y() + rect.get_height() / 2),
+ xytext=(3, 0), # 3 points vertical offset
+ textcoords="offset points",
+ ha="left",
+ va="center",
+ fontsize=fontsize,
+ )
+
+
+def plot_histogram(counts, fontsize, dpi):
+ fig, ax = plt.subplots(dpi=dpi, figsize=(8, 20))
+
+ labels = []
+ for k, v in counts.items():
+ if isinstance(v, dict):
+ labels += list(v.keys())
+ v = list(v.values())
+ else:
+ labels.append(k)
+ v = [v]
+ bars = plt.barh(
+ len(labels) + -len(v) + np.arange(len(v)), v, height=0.9, label=k
+ )
+ bar_autolabel(bars, fontsize)
+
+ ax.set_yticklabels(labels, fontsize=fontsize)
+ ax.axes.xaxis.set_ticklabels([])
+ ax.xaxis.tick_top()
+ ax.invert_yaxis()
+ plt.yticks(np.arange(len(labels)))
+ plt.xscale("log")
+ plt.legend(ncol=len(counts), loc="upper center")
+
+
+def count_elements(elems: Dict[int, str], filter_fn, parse_fn) -> Dict:
+ """Count the number of elements in each group."""
+ counts = Counter()
+ for elem in filter(filter_fn, elems.values()):
+ group = parse_fn(elem.tags)
+ if group is None:
+ continue
+ counts[group] += 1
+ counts = recover_hierarchy(counts)
+ return counts
+
+
+def plot_osm_histograms(osm: OSMData, fontsize=8, dpi=150):
+ counts = count_elements(osm.nodes, filter_node, parse_node)
+ plot_histogram(counts, fontsize, dpi)
+ plt.title("nodes")
+
+ counts = count_elements(osm.ways, filter_way, parse_way)
+ plot_histogram(counts, fontsize, dpi)
+ plt.title("ways")
+
+ counts = count_elements(osm.ways, filter_area, parse_area)
+ plot_histogram(counts, fontsize, dpi)
+ plt.title("areas")
+
+
+def plot_sankey_hierarchy(osm: OSMData):
+ triplets = []
+ for node in filter(filter_node, osm.nodes.values()):
+ label = parse_node(node.tags)
+ if label is None:
+ continue
+ group = match_to_group(label, Patterns.nodes)
+ if group is None:
+ group = match_to_group(label, Patterns.ways)
+ if group is None:
+ group = "null"
+ if ":" in label:
+ key, tag = label.split(":")
+ if tag == "yes":
+ tag = key
+ else:
+ key = tag = label
+ triplets.append((key, tag, group))
+ keys, tags, groups = list(zip(*triplets))
+ counts_key_tag = Counter(zip(keys, tags))
+ counts_key_tag_group = Counter(triplets)
+
+ key2tags = defaultdict(set)
+ for k, t in zip(keys, tags):
+ key2tags[k].add(t)
+ key2tags = {k: sorted(t) for k, t in key2tags.items()}
+ keytag2group = dict(zip(zip(keys, tags), groups))
+ key_names = sorted(set(keys))
+ tag_names = [(k, t) for k in key_names for t in key2tags[k]]
+
+ group_names = []
+ for k in key_names:
+ for t in key2tags[k]:
+ g = keytag2group[k, t]
+ if g not in group_names and g != "null":
+ group_names.append(g)
+ group_names += ["null"]
+
+ key2idx = dict(zip(key_names, range(len(key_names))))
+ tag2idx = {kt: i + len(key2idx) for i, kt in enumerate(tag_names)}
+ group2idx = {n: i + len(key2idx) + len(tag2idx) for i, n in enumerate(group_names)}
+
+ key_counts = Counter(keys)
+ key_text = [f"{k} {key_counts[k]}" for k in key_names]
+ tag_counts = Counter(list(zip(keys, tags)))
+ tag_text = [f"{t} {tag_counts[k, t]}" for k, t in tag_names]
+ group_counts = Counter(groups)
+ group_text = [f"{k} {group_counts[k]}" for k in group_names]
+
+ fig = go.Figure(
+ data=[
+ go.Sankey(
+ orientation="h",
+ node=dict(
+ pad=15,
+ thickness=20,
+ line=dict(color="black", width=0.5),
+ label=key_text + tag_text + group_text,
+ x=[0] * len(key_names)
+ + [1] * len(tag_names)
+ + [2] * len(group_names),
+ color="blue",
+ ),
+ arrangement="fixed",
+ link=dict(
+ source=[key2idx[k] for k, _ in counts_key_tag]
+ + [tag2idx[k, t] for k, t, _ in counts_key_tag_group],
+ target=[tag2idx[k, t] for k, t in counts_key_tag]
+ + [group2idx[g] for _, _, g in counts_key_tag_group],
+ value=list(counts_key_tag.values())
+ + list(counts_key_tag_group.values()),
+ ),
+ )
+ ]
+ )
+ fig.update_layout(autosize=False, width=800, height=2000, font_size=10)
+ fig.show()
+ return fig
diff --git a/osm/data.py b/osm/data.py
new file mode 100644
index 0000000000000000000000000000000000000000..dafc568f8ad5ac8c72ea9ffbd096838e0693ad84
--- /dev/null
+++ b/osm/data.py
@@ -0,0 +1,230 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+import logging
+from dataclasses import dataclass, field
+from typing import Dict, List, Optional, Set, Tuple
+
+import numpy as np
+
+from .parser import (
+ filter_area,
+ filter_node,
+ filter_way,
+ match_to_group,
+ parse_area,
+ parse_node,
+ parse_way,
+ Patterns,
+)
+from .reader import OSMData, OSMNode, OSMRelation, OSMWay
+
+
+logger = logging.getLogger(__name__)
+
+
+def glue(ways: List[OSMWay]) -> List[List[OSMNode]]:
+ result: List[List[OSMNode]] = []
+ to_process: Set[Tuple[OSMNode]] = set()
+
+ for way in ways:
+ if way.is_cycle():
+ result.append(way.nodes)
+ else:
+ to_process.add(tuple(way.nodes))
+
+ while to_process:
+ nodes: List[OSMNode] = list(to_process.pop())
+ glued: Optional[List[OSMNode]] = None
+ other_nodes: Optional[Tuple[OSMNode]] = None
+
+ for other_nodes in to_process:
+ glued = try_to_glue(nodes, list(other_nodes))
+ if glued is not None:
+ break
+
+ if glued is not None:
+ to_process.remove(other_nodes)
+ if is_cycle(glued):
+ result.append(glued)
+ else:
+ to_process.add(tuple(glued))
+ else:
+ result.append(nodes)
+
+ return result
+
+
+def is_cycle(nodes: List[OSMNode]) -> bool:
+ """Is way a cycle way or an area boundary."""
+ return nodes[0] == nodes[-1]
+
+
+def try_to_glue(nodes: List[OSMNode], other: List[OSMNode]) -> Optional[List[OSMNode]]:
+ """Create new combined way if ways share endpoints."""
+ if nodes[0] == other[0]:
+ return list(reversed(other[1:])) + nodes
+ if nodes[0] == other[-1]:
+ return other[:-1] + nodes
+ if nodes[-1] == other[-1]:
+ return nodes + list(reversed(other[:-1]))
+ if nodes[-1] == other[0]:
+ return nodes + other[1:]
+ return None
+
+
+def multipolygon_from_relation(rel: OSMRelation, osm: OSMData):
+ inner_ways = []
+ outer_ways = []
+ for member in rel.members:
+ if member.type_ == "way":
+ if member.role == "inner":
+ if member.ref in osm.ways:
+ inner_ways.append(osm.ways[member.ref])
+ elif member.role == "outer":
+ if member.ref in osm.ways:
+ outer_ways.append(osm.ways[member.ref])
+ else:
+ logger.warning(f'Unknown member role "{member.role}".')
+ if outer_ways:
+ inners_path = glue(inner_ways)
+ outers_path = glue(outer_ways)
+ return inners_path, outers_path
+
+
+@dataclass
+class MapElement:
+ id_: int
+ label: str
+ group: str
+ tags: Optional[Dict[str, str]]
+
+
+@dataclass
+class MapNode(MapElement):
+ xy: np.ndarray
+
+ @classmethod
+ def from_osm(cls, node: OSMNode, label: str, group: str):
+ return cls(
+ node.id_,
+ label,
+ group,
+ node.tags,
+ xy=node.xy,
+ )
+
+
+@dataclass
+class MapLine(MapElement):
+ xy: np.ndarray
+
+ @classmethod
+ def from_osm(cls, way: OSMWay, label: str, group: str):
+ xy = np.stack([n.xy for n in way.nodes])
+ return cls(
+ way.id_,
+ label,
+ group,
+ way.tags,
+ xy=xy,
+ )
+
+
+@dataclass
+class MapArea(MapElement):
+ outers: List[np.ndarray]
+ inners: List[np.ndarray] = field(default_factory=list)
+
+ @classmethod
+ def from_relation(cls, rel: OSMRelation, label: str, group: str, osm: OSMData):
+ outers_inners = multipolygon_from_relation(rel, osm)
+ if outers_inners is None:
+ return None
+ outers, inners = outers_inners
+ outers = [np.stack([n.xy for n in way]) for way in outers]
+ inners = [np.stack([n.xy for n in way]) for way in inners]
+ return cls(
+ rel.id_,
+ label,
+ group,
+ rel.tags,
+ outers=outers,
+ inners=inners,
+ )
+
+ @classmethod
+ def from_way(cls, way: OSMWay, label: str, group: str):
+ xy = np.stack([n.xy for n in way.nodes])
+ return cls(
+ way.id_,
+ label,
+ group,
+ way.tags,
+ outers=[xy],
+ )
+
+
+class MapData:
+ def __init__(self):
+ self.nodes: Dict[int, MapNode] = {}
+ self.lines: Dict[int, MapLine] = {}
+ self.areas: Dict[int, MapArea] = {}
+
+ @classmethod
+ def from_osm(cls, osm: OSMData):
+ self = cls()
+
+ for node in filter(filter_node, osm.nodes.values()):
+ label = parse_node(node.tags)
+ if label is None:
+ continue
+ group = match_to_group(label, Patterns.nodes)
+ if group is None:
+ group = match_to_group(label, Patterns.ways)
+ if group is None:
+ continue # missing
+ self.nodes[node.id_] = MapNode.from_osm(node, label, group)
+
+ for way in filter(filter_way, osm.ways.values()):
+ label = parse_way(way.tags)
+ if label is None:
+ continue
+ group = match_to_group(label, Patterns.ways)
+ if group is None:
+ group = match_to_group(label, Patterns.nodes)
+ if group is None:
+ continue # missing
+ self.lines[way.id_] = MapLine.from_osm(way, label, group)
+
+ for area in filter(filter_area, osm.ways.values()):
+ label = parse_area(area.tags)
+ if label is None:
+ continue
+ group = match_to_group(label, Patterns.areas)
+ if group is None:
+ group = match_to_group(label, Patterns.ways)
+ if group is None:
+ group = match_to_group(label, Patterns.nodes)
+ if group is None:
+ continue # missing
+ self.areas[area.id_] = MapArea.from_way(area, label, group)
+
+ for rel in osm.relations.values():
+ if rel.tags.get("type") != "multipolygon":
+ continue
+ label = parse_area(rel.tags)
+ if label is None:
+ continue
+ group = match_to_group(label, Patterns.areas)
+ if group is None:
+ group = match_to_group(label, Patterns.ways)
+ if group is None:
+ group = match_to_group(label, Patterns.nodes)
+ if group is None:
+ continue # missing
+ area = MapArea.from_relation(rel, label, group, osm)
+ assert rel.id_ not in self.areas # not sure if there can be collision
+ if area is not None:
+ self.areas[rel.id_] = area
+
+ return self
diff --git a/osm/download.py b/osm/download.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a188e513aaaff8e73d6ca60a4caf51ff60bc0fa
--- /dev/null
+++ b/osm/download.py
@@ -0,0 +1,118 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+import json
+from pathlib import Path
+from typing import Dict, Optional
+
+import urllib3
+
+
+from utils.geo import BoundaryBox
+import urllib.request
+import requests
+
+def get_osm(
+ boundary_box: BoundaryBox,
+ cache_path: Optional[Path] = None,
+ overwrite: bool = False,
+) -> str:
+ if not overwrite and cache_path is not None and cache_path.is_file():
+ with cache_path.open() as fp:
+ return json.load(fp)
+
+ (bottom, left), (top, right) = boundary_box.min_, boundary_box.max_
+ content: bytes = get_web_data(
+ # "https://api.openstreetmap.org/api/0.6/map.json",
+ "https://openstreetmap.erniubot.live/api/0.6/map.json",
+ # 'https://overpass-api.de/api/map',
+ # 'http://localhost:29505/api/map',
+ # "https://lz4.overpass-api.de/api/interpreter",
+ {"bbox": f"{left},{bottom},{right},{top}"},
+ )
+
+ content_str = content.decode("utf-8")
+ if content_str.startswith("You requested too many nodes"):
+ raise ValueError(content_str)
+
+ if cache_path is not None:
+ with cache_path.open("bw+") as fp:
+ fp.write(content)
+ a=json.loads(content_str)
+ return json.loads(content_str)
+
+
+def get_web_data(address: str, parameters: Dict[str, str]) -> bytes:
+ # logger.info("Getting %s...", address)
+ # proxy_address = "http://107.173.122.186:3128"
+ #
+ # # 设置代理服务器地址和端口
+ # proxies = {
+ # 'http': proxy_address,
+ # 'https': proxy_address
+ # }
+
+ # 发送GET请求并返回响应数据
+ # response = requests.get(address, params=parameters, timeout=100, proxies=proxies)
+ print('url:',address)
+ response = requests.get(address, params=parameters, timeout=100)
+ return response.content
+def get_web_data(address: str, parameters: Dict[str, str]) -> bytes:
+ # logger.info("Getting %s...", address)
+ while True:
+ try:
+ # proxy_address = "http://107.173.122.186:3128"
+ #
+ # # 设置代理服务器地址和端口
+ # proxies = {
+ # 'http': proxy_address,
+ # 'https': proxy_address
+ # }
+ # # 发送GET请求并返回响应数据
+ response = requests.get(address, params=parameters, timeout=100)
+ request = requests.Request('GET', address, params=parameters)
+ prepared_request = request.prepare()
+ # 获取完整URL
+ full_url = prepared_request.url
+ break
+
+ except Exception as e:
+ # 打印错误信息
+ print(f"发生错误: {e}")
+ print("重试...")
+
+ return response.content
+# def get_web_data_2(address: str, parameters: Dict[str, str]) -> bytes:
+# # logger.info("Getting %s...", address)
+# proxy_address="http://107.173.122.186:3128"
+# http = urllib3.PoolManager(proxy_url=proxy_address)
+# result = http.request("GET", address, parameters, timeout=100)
+# return result.data
+#
+#
+# def get_web_data_1(address: str, parameters: Dict[str, str]) -> bytes:
+#
+# # 设置代理服务器地址和端口
+# proxy_address = "http://107.173.122.186:3128"
+#
+# # 创建ProxyHandler对象
+# proxy_handler = urllib.request.ProxyHandler({'http': proxy_address})
+#
+# # 构建查询字符串
+# query_string = urllib.parse.urlencode(parameters)
+#
+# # 构建完整的URL
+# url = address + '?' + query_string
+# print(url)
+# # 创建OpenerDirector对象,并将ProxyHandler对象作为参数传递
+# opener = urllib.request.build_opener(proxy_handler)
+#
+# # 使用OpenerDirector对象发送请求
+# response = opener.open(url)
+#
+# # 发送GET请求
+# # response = urllib.request.urlopen(url, timeout=100)
+#
+# # 读取响应内容
+# data = response.read()
+# print()
+# return data
\ No newline at end of file
diff --git a/osm/parser.py b/osm/parser.py
new file mode 100644
index 0000000000000000000000000000000000000000..d235c71bdbdd22d280d60015b5941d25ba0345e1
--- /dev/null
+++ b/osm/parser.py
@@ -0,0 +1,255 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+import logging
+import re
+from typing import List
+
+from .reader import OSMData, OSMElement, OSMNode, OSMWay
+
+IGNORE_TAGS = {"source", "phone", "entrance", "inscription", "note", "name"}
+
+
+def parse_levels(string: str) -> List[float]:
+ """Parse string representation of level sequence value."""
+ try:
+ cleaned = string.replace(",", ";").replace(" ", "")
+ return list(map(float, cleaned.split(";")))
+ except ValueError:
+ logging.debug("Cannot parse level description from `%s`.", string)
+ return []
+
+
+def filter_level(elem: OSMElement):
+ level = elem.tags.get("level")
+ if level is not None:
+ levels = parse_levels(level)
+ # In the US, ground floor levels are sometimes marked as level=1
+ # so let's be conservative and include it.
+ if not (0 in levels or 1 in levels):
+ return False
+ layer = elem.tags.get("layer")
+ if layer is not None:
+ layer = parse_levels(layer)
+ if len(layer) > 0 and max(layer) < 0:
+ return False
+ return (
+ elem.tags.get("location") != "underground"
+ and elem.tags.get("parking") != "underground"
+ )
+
+
+def filter_node(node: OSMNode):
+ return len(node.tags.keys() - IGNORE_TAGS) > 0 and filter_level(node)
+
+
+def is_area(way: OSMWay):
+ if way.nodes[0] != way.nodes[-1]:
+ return False
+ if way.tags.get("area") == "no":
+ return False
+ filters = [
+ "area",
+ "building",
+ "amenity",
+ "indoor",
+ "landuse",
+ "landcover",
+ "leisure",
+ "public_transport",
+ "shop",
+ ]
+ for f in filters:
+ if f in way.tags and way.tags.get(f) != "no":
+ return True
+ if way.tags.get("natural") in {"wood", "grassland", "water"}:
+ return True
+ return False
+
+
+def filter_area(way: OSMWay):
+ return len(way.tags.keys() - IGNORE_TAGS) > 0 and is_area(way) and filter_level(way)
+
+
+def filter_way(way: OSMWay):
+ return not filter_area(way) and way.tags != {} and filter_level(way)
+
+
+def parse_node(tags):
+ keys = tags.keys()
+ for key in [
+ "amenity",
+ "natural",
+ "highway",
+ "barrier",
+ "shop",
+ "tourism",
+ "public_transport",
+ "emergency",
+ "man_made",
+ ]:
+ if key in keys:
+ if "disused" in tags[key]:
+ continue
+ return f"{key}:{tags[key]}"
+ return None
+
+
+def parse_area(tags):
+ if "building" in tags:
+ group = "building"
+ kind = tags["building"]
+ if kind == "yes":
+ for key in ["amenity", "tourism"]:
+ if key in tags:
+ kind = tags[key]
+ break
+ if kind != "yes":
+ group += f":{kind}"
+ return group
+ if "area:highway" in tags:
+ return f'highway:{tags["area:highway"]}'
+ for key in [
+ "amenity",
+ "landcover",
+ "leisure",
+ "shop",
+ "highway",
+ "tourism",
+ "natural",
+ "waterway",
+ "landuse",
+ ]:
+ if key in tags:
+ return f"{key}:{tags[key]}"
+ return None
+
+
+def parse_way(tags):
+ keys = tags.keys()
+ for key in ["highway", "barrier", "natural"]:
+ if key in keys:
+ return f"{key}:{tags[key]}"
+ return None
+
+
+def match_to_group(label, patterns):
+ for group, pattern in patterns.items():
+ if re.match(pattern, label):
+ return group
+ return None
+
+
+class Patterns:
+ areas = dict(
+ building="building($|:.*?)*",
+ parking="amenity:parking",
+ playground="leisure:(playground|pitch)",
+ grass="(landuse:grass|landcover:grass|landuse:meadow|landuse:flowerbed|natural:grassland)",
+ park="leisure:(park|garden|dog_park)",
+ forest="(landuse:forest|natural:wood)",
+ water="(natural:water|waterway:*)",
+ )
+ # + ways: road, path
+ # + node: fountain, bicycle_parking
+
+ ways = dict(
+ fence="barrier:(fence|yes)",
+ wall="barrier:(wall|retaining_wall)",
+ hedge="barrier:hedge",
+ kerb="barrier:kerb",
+ building_outline="building($|:.*?)*",
+ cycleway="highway:cycleway",
+ path="highway:(pedestrian|footway|steps|path|corridor)",
+ road="highway:(motorway|trunk|primary|secondary|tertiary|service|construction|track|unclassified|residential|.*_link)",
+ busway="highway:busway",
+ tree_row="natural:tree_row", # maybe merge with node?
+ )
+ # + nodes: bollard
+
+ nodes = dict(
+ tree="natural:tree",
+ stone="(natural:stone|barrier:block)",
+ crossing="highway:crossing",
+ lamp="highway:street_lamp",
+ traffic_signal="highway:traffic_signals",
+ bus_stop="highway:bus_stop",
+ stop_sign="highway:stop",
+ junction="highway:motorway_junction",
+ bus_stop_position="public_transport:stop_position",
+ gate="barrier:(gate|lift_gate|swing_gate|cycle_barrier)",
+ bollard="barrier:bollard",
+ shop="(shop.*?|amenity:(bank|post_office))",
+ restaurant="amenity:(restaurant|fast_food)",
+ bar="amenity:(cafe|bar|pub|biergarten)",
+ pharmacy="amenity:pharmacy",
+ fuel="amenity:fuel",
+ bicycle_parking="amenity:(bicycle_parking|bicycle_rental)",
+ charging_station="amenity:charging_station",
+ parking_entrance="amenity:parking_entrance",
+ atm="amenity:atm",
+ toilets="amenity:toilets",
+ vending_machine="amenity:vending_machine",
+ fountain="amenity:fountain",
+ waste_basket="amenity:(waste_basket|waste_disposal)",
+ bench="amenity:bench",
+ post_box="amenity:post_box",
+ artwork="tourism:artwork",
+ recycling="amenity:recycling",
+ give_way="highway:give_way",
+ clock="amenity:clock",
+ fire_hydrant="emergency:fire_hydrant",
+ pole="man_made:(flagpole|utility_pole)",
+ street_cabinet="man_made:street_cabinet",
+ )
+ # + ways: kerb
+
+
+class Groups:
+ areas = list(Patterns.areas)
+ ways = list(Patterns.ways)
+ nodes = list(Patterns.nodes)
+
+
+def group_elements(osm: OSMData):
+ elem2group = {
+ "area": {},
+ "way": {},
+ "node": {},
+ }
+
+ for node in filter(filter_node, osm.nodes.values()):
+ label = parse_node(node.tags)
+ if label is None:
+ continue
+ group = match_to_group(label, Patterns.nodes)
+ if group is None:
+ group = match_to_group(label, Patterns.ways)
+ if group is None:
+ continue # missing
+ elem2group["node"][node.id_] = group
+
+ for way in filter(filter_way, osm.ways.values()):
+ label = parse_way(way.tags)
+ if label is None:
+ continue
+ group = match_to_group(label, Patterns.ways)
+ if group is None:
+ group = match_to_group(label, Patterns.nodes)
+ if group is None:
+ continue # missing
+ elem2group["way"][way.id_] = group
+
+ for area in filter(filter_area, osm.ways.values()):
+ label = parse_area(area.tags)
+ if label is None:
+ continue
+ group = match_to_group(label, Patterns.areas)
+ if group is None:
+ group = match_to_group(label, Patterns.ways)
+ if group is None:
+ group = match_to_group(label, Patterns.nodes)
+ if group is None:
+ continue # missing
+ elem2group["area"][area.id_] = group
+
+ return elem2group
diff --git a/osm/raster.py b/osm/raster.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e203bc50db0d873a9c8544ed0a1533c84a38df9
--- /dev/null
+++ b/osm/raster.py
@@ -0,0 +1,103 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+from typing import Dict, List
+
+import cv2
+import numpy as np
+import torch
+
+from utils.geo import BoundaryBox
+from .data import MapArea, MapLine, MapNode
+from .parser import Groups
+
+
+class Canvas:
+ def __init__(self, bbox: BoundaryBox, ppm: float):
+ self.bbox = bbox
+ self.ppm = ppm
+ self.scaling = bbox.size * ppm
+ self.w, self.h = np.ceil(self.scaling).astype(int)
+ self.clear()
+
+ def clear(self):
+ self.raster = np.zeros((self.h, self.w), np.uint8)
+
+ def to_uv(self, xy: np.ndarray):
+ xy = self.bbox.normalize(xy)
+ xy[..., 1] = 1 - xy[..., 1]
+ s = self.scaling
+ if isinstance(xy, torch.Tensor):
+ s = torch.from_numpy(s).to(xy)
+ return xy * s - 0.5
+
+ def to_xy(self, uv: np.ndarray):
+ s = self.scaling
+ if isinstance(uv, torch.Tensor):
+ s = torch.from_numpy(s).to(uv)
+ xy = (uv + 0.5) / s
+ xy[..., 1] = 1 - xy[..., 1]
+ return self.bbox.unnormalize(xy)
+
+ def draw_polygon(self, xy: np.ndarray):
+ uv = self.to_uv(xy)
+ cv2.fillPoly(self.raster, uv[None].astype(np.int32), 255)
+
+ def draw_multipolygon(self, xys: List[np.ndarray]):
+ uvs = [self.to_uv(xy).round().astype(np.int32) for xy in xys]
+ cv2.fillPoly(self.raster, uvs, 255)
+
+ def draw_line(self, xy: np.ndarray, width: float = 1):
+ uv = self.to_uv(xy)
+ cv2.polylines(
+ self.raster, uv[None].round().astype(np.int32), False, 255, thickness=width
+ )
+
+ def draw_cell(self, xy: np.ndarray):
+ if not self.bbox.contains(xy):
+ return
+ uv = self.to_uv(xy)
+ self.raster[tuple(uv.round().astype(int).T[::-1])] = 255
+
+
+def render_raster_masks(
+ nodes: List[MapNode],
+ lines: List[MapLine],
+ areas: List[MapArea],
+ canvas: Canvas,
+) -> Dict[str, np.ndarray]:
+ all_groups = Groups.areas + Groups.ways + Groups.nodes
+ masks = {k: np.zeros((canvas.h, canvas.w), np.uint8) for k in all_groups}
+
+ for area in areas:
+ canvas.raster = masks[area.group]
+ outlines = area.outers + area.inners
+ canvas.draw_multipolygon(outlines)
+ if area.group == "building":
+ canvas.raster = masks["building_outline"]
+ for line in outlines:
+ canvas.draw_line(line)
+
+ for line in lines:
+ canvas.raster = masks[line.group]
+ canvas.draw_line(line.xy)
+
+ for node in nodes:
+ canvas.raster = masks[node.group]
+ canvas.draw_cell(node.xy)
+
+ return masks
+
+
+def mask_to_idx(group2mask: Dict[str, np.ndarray], groups: List[str]) -> np.ndarray:
+ masks = np.stack([group2mask[k] for k in groups]) > 0
+ void = ~np.any(masks, 0)
+ idx = np.argmax(masks, 0)
+ idx = np.where(void, np.zeros_like(idx), idx + 1) # add background
+ return idx
+
+
+def render_raster_map(masks: Dict[str, np.ndarray]) -> np.ndarray:
+ areas = mask_to_idx(masks, Groups.areas)
+ ways = mask_to_idx(masks, Groups.ways)
+ nodes = mask_to_idx(masks, Groups.nodes)
+ return np.stack([areas, ways, nodes])
diff --git a/osm/reader.py b/osm/reader.py
new file mode 100644
index 0000000000000000000000000000000000000000..793ad1879f8b2068cd5265bed408be14719e9680
--- /dev/null
+++ b/osm/reader.py
@@ -0,0 +1,310 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+import json
+import re
+from dataclasses import dataclass, field
+from pathlib import Path
+from typing import Any, Dict, List, Optional
+
+from lxml import etree
+import numpy as np
+
+from utils.geo import BoundaryBox, Projection
+
+METERS_PATTERN: re.Pattern = re.compile("^(?P\\d*\\.?\\d*)\\s*m$")
+KILOMETERS_PATTERN: re.Pattern = re.compile("^(?P\\d*\\.?\\d*)\\s*km$")
+MILES_PATTERN: re.Pattern = re.compile("^(?P\\d*\\.?\\d*)\\s*mi$")
+
+
+def parse_float(string: str) -> Optional[float]:
+ """Parse string representation of a float or integer value."""
+ try:
+ return float(string)
+ except (TypeError, ValueError):
+ return None
+
+
+@dataclass(eq=False)
+class OSMElement:
+ """
+ Something with tags (string to string mapping).
+ """
+
+ id_: int
+ tags: Dict[str, str]
+
+ def get_float(self, key: str) -> Optional[float]:
+ """Parse float from tag value."""
+ if key in self.tags:
+ return parse_float(self.tags[key])
+ return None
+
+ def get_length(self, key: str) -> Optional[float]:
+ """Get length in meters."""
+ if key not in self.tags:
+ return None
+
+ value: str = self.tags[key]
+
+ float_value: float = parse_float(value)
+ if float_value is not None:
+ return float_value
+
+ for pattern, ratio in [
+ (METERS_PATTERN, 1.0),
+ (KILOMETERS_PATTERN, 1000.0),
+ (MILES_PATTERN, 1609.344),
+ ]:
+ matcher: re.Match = pattern.match(value)
+ if matcher:
+ float_value: float = parse_float(matcher.group("value"))
+ if float_value is not None:
+ return float_value * ratio
+
+ return None
+
+ def __hash__(self) -> int:
+ return self.id_
+
+
+@dataclass(eq=False)
+class OSMNode(OSMElement):
+ """
+ OpenStreetMap node.
+
+ See https://wiki.openstreetmap.org/wiki/Node
+ """
+
+ geo: np.ndarray
+ visible: Optional[str] = None
+ xy: Optional[np.ndarray] = None
+
+ @classmethod
+ def from_dict(cls, structure: Dict[str, Any]) -> "OSMNode":
+ """
+ Parse node from Overpass-like structure.
+
+ :param structure: input structure
+ """
+ return cls(
+ structure["id"],
+ structure.get("tags", {}),
+ geo=np.array((structure["lat"], structure["lon"])),
+ visible=structure.get("visible"),
+ )
+
+
+@dataclass(eq=False)
+class OSMWay(OSMElement):
+ """
+ OpenStreetMap way.
+
+ See https://wiki.openstreetmap.org/wiki/Way
+ """
+
+ nodes: Optional[List[OSMNode]] = field(default_factory=list)
+ visible: Optional[str] = None
+
+ @classmethod
+ def from_dict(
+ cls, structure: Dict[str, Any], nodes: Dict[int, OSMNode]
+ ) -> "OSMWay":
+ """
+ Parse way from Overpass-like structure.
+
+ :param structure: input structure
+ :param nodes: node structure
+ """
+ return cls(
+ structure["id"],
+ structure.get("tags", {}),
+ [nodes[x] for x in structure["nodes"]],
+ visible=structure.get("visible"),
+ )
+
+ def is_cycle(self) -> bool:
+ """Is way a cycle way or an area boundary."""
+ return self.nodes[0] == self.nodes[-1]
+
+ def __repr__(self) -> str:
+ return f"Way <{self.id_}> {self.nodes}"
+
+
+@dataclass
+class OSMMember:
+ """
+ Member of OpenStreetMap relation.
+ """
+
+ type_: str
+ ref: int
+ role: str
+
+
+@dataclass(eq=False)
+class OSMRelation(OSMElement):
+ """
+ OpenStreetMap relation.
+
+ See https://wiki.openstreetmap.org/wiki/Relation
+ """
+
+ members: Optional[List[OSMMember]]
+ visible: Optional[str] = None
+
+ @classmethod
+ def from_dict(cls, structure: Dict[str, Any]) -> "OSMRelation":
+ """
+ Parse relation from Overpass-like structure.
+
+ :param structure: input structure
+ """
+ return cls(
+ structure["id"],
+ structure["tags"],
+ [OSMMember(x["type"], x["ref"], x["role"]) for x in structure["members"]],
+ visible=structure.get("visible"),
+ )
+
+
+class OSMData:
+ """
+ The whole OpenStreetMap information about nodes, ways, and relations.
+ """
+
+ def __init__(self) -> None:
+ self.nodes: Dict[int, OSMNode] = {}
+ self.ways: Dict[int, OSMWay] = {}
+ self.relations: Dict[int, OSMRelation] = {}
+ self.box: BoundaryBox = None
+
+ @classmethod
+ def from_dict(cls, structure: Dict[str, Any]):
+ data = cls()
+ bounds = structure.get("bounds")
+ if bounds is not None:
+ data.box = BoundaryBox(
+ np.array([bounds["minlat"], bounds["minlon"]]),
+ np.array([bounds["maxlat"], bounds["maxlon"]]),
+ )
+
+ for element in structure["elements"]:
+ if element["type"] == "node":
+ node = OSMNode.from_dict(element)
+ data.add_node(node)
+ for element in structure["elements"]:
+ if element["type"] == "way":
+ way = OSMWay.from_dict(element, data.nodes)
+ data.add_way(way)
+ for element in structure["elements"]:
+ if element["type"] == "relation":
+ relation = OSMRelation.from_dict(element)
+ data.add_relation(relation)
+
+ return data
+
+ @classmethod
+ def from_json(cls, path: Path):
+ with path.open(encoding='utf-8') as fid:
+ structure = json.load(fid)
+ return cls.from_dict(structure)
+
+ @classmethod
+ def from_xml(cls, path: Path):
+ root = etree.parse(str(path)).getroot()
+ structure = {"elements": []}
+ from tqdm import tqdm
+
+ for elem in tqdm(root):
+ if elem.tag == "bounds":
+ structure["bounds"] = {
+ k: float(elem.attrib[k])
+ for k in ("minlon", "minlat", "maxlon", "maxlat")
+ }
+ elif elem.tag in {"node", "way", "relation"}:
+ if elem.tag == "node":
+ item = {
+ "id": int(elem.attrib["id"]),
+ "lat": float(elem.attrib["lat"]),
+ "lon": float(elem.attrib["lon"]),
+ "visible": elem.attrib.get("visible"),
+ "tags": {
+ x.attrib["k"]: x.attrib["v"] for x in elem if x.tag == "tag"
+ },
+ }
+ elif elem.tag == "way":
+ item = {
+ "id": int(elem.attrib["id"]),
+ "visible": elem.attrib.get("visible"),
+ "tags": {
+ x.attrib["k"]: x.attrib["v"] for x in elem if x.tag == "tag"
+ },
+ "nodes": [int(x.attrib["ref"]) for x in elem if x.tag == "nd"],
+ }
+ elif elem.tag == "relation":
+ item = {
+ "id": int(elem.attrib["id"]),
+ "visible": elem.attrib.get("visible"),
+ "tags": {
+ x.attrib["k"]: x.attrib["v"] for x in elem if x.tag == "tag"
+ },
+ "members": [
+ {
+ "type": x.attrib["type"],
+ "ref": int(x.attrib["ref"]),
+ "role": x.attrib["role"],
+ }
+ for x in elem
+ if x.tag == "member"
+ ],
+ }
+ item["type"] = elem.tag
+ structure["elements"].append(item)
+ elem.clear()
+ del root
+ return cls.from_dict(structure)
+
+ @classmethod
+ def from_file(cls, path: Path):
+ ext = path.suffix
+ if ext == ".json":
+ return cls.from_json(path)
+ elif ext in {".osm", ".xml"}:
+ return cls.from_xml(path)
+ else:
+ raise ValueError(f"Unknown extension for {path}")
+
+ def add_node(self, node: OSMNode):
+ """Add node and update map parameters."""
+ if node.id_ in self.nodes:
+ raise ValueError(f"Node with duplicate id {node.id_}.")
+ self.nodes[node.id_] = node
+
+ def add_way(self, way: OSMWay):
+ """Add way and update map parameters."""
+ if way.id_ in self.ways:
+ raise ValueError(f"Way with duplicate id {way.id_}.")
+ self.ways[way.id_] = way
+
+ def add_relation(self, relation: OSMRelation):
+ """Add relation and update map parameters."""
+ if relation.id_ in self.relations:
+ raise ValueError(f"Relation with duplicate id {relation.id_}.")
+ self.relations[relation.id_] = relation
+
+ def add_xy_to_nodes(self, proj: Projection):
+ nodes = list(self.nodes.values())
+ if len(nodes) == 0:
+ return
+ geos = np.stack([n.geo for n in nodes], 0)
+ if proj.bounds is not None:
+ # For some reasons few nodes are sometimes very far off the initial bbox.
+ valid = proj.bounds.contains(geos)
+ if valid.mean() < 0.9:
+ print("Many nodes are out of the projection bounds.")
+ xys = np.zeros_like(geos)
+ xys[valid] = proj.project(geos[valid])
+ else:
+ xys = proj.project(geos)
+ for xy, node in zip(xys, nodes):
+ node.xy = xy
diff --git a/osm/tiling.py b/osm/tiling.py
new file mode 100644
index 0000000000000000000000000000000000000000..b610f0dc1eb31376e516deaafa1d33dc496aec2c
--- /dev/null
+++ b/osm/tiling.py
@@ -0,0 +1,310 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+import io
+import pickle
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import numpy as np
+from PIL import Image
+import rtree
+
+from utils.geo import BoundaryBox, Projection
+from .data import MapData
+from .download import get_osm
+from .parser import Groups
+from .raster import Canvas, render_raster_map, render_raster_masks
+from .reader import OSMData, OSMNode, OSMWay
+
+
+class MapIndex:
+ def __init__(
+ self,
+ data: MapData,
+ ):
+ self.index_nodes = rtree.index.Index()
+ for i, node in data.nodes.items():
+ self.index_nodes.insert(i, tuple(node.xy) * 2)
+
+ self.index_lines = rtree.index.Index()
+ for i, line in data.lines.items():
+ bbox = tuple(np.r_[line.xy.min(0), line.xy.max(0)])
+ self.index_lines.insert(i, bbox)
+
+ self.index_areas = rtree.index.Index()
+ for i, area in data.areas.items():
+ xy = np.concatenate(area.outers + area.inners)
+ bbox = tuple(np.r_[xy.min(0), xy.max(0)])
+ self.index_areas.insert(i, bbox)
+
+ self.data = data
+
+ def query(self, bbox: BoundaryBox) -> Tuple[List[OSMNode], List[OSMWay]]:
+ query = tuple(np.r_[bbox.min_, bbox.max_])
+ ret = []
+ for x in ["nodes", "lines", "areas"]:
+ ids = getattr(self, "index_" + x).intersection(query)
+ ret.append([getattr(self.data, x)[i] for i in ids])
+ return tuple(ret)
+
+
+def bbox_to_slice(bbox: BoundaryBox, canvas: Canvas):
+ uv_min = np.ceil(canvas.to_uv(bbox.min_)).astype(int)
+ uv_max = np.ceil(canvas.to_uv(bbox.max_)).astype(int)
+ slice_ = (slice(uv_max[1], uv_min[1]), slice(uv_min[0], uv_max[0]))
+ return slice_
+
+
+def round_bbox(bbox: BoundaryBox, origin: np.ndarray, ppm: int):
+ bbox = bbox.translate(-origin)
+ bbox = BoundaryBox(np.round(bbox.min_ * ppm) / ppm, np.round(bbox.max_ * ppm) / ppm)
+ return bbox.translate(origin)
+
+class MapTileManager:
+ def __init__(
+ self,
+ osmpath:Path,
+ ):
+
+ self.osm = OSMData.from_file(osmpath)
+
+
+ # @classmethod
+ def from_bbox(
+ self,
+ projection: Projection,
+ bbox: BoundaryBox,
+ ppm: int,
+ tile_size: int = 128,
+ ):
+ # bbox_osm = projection.unproject(bbox)
+ # if path is not None and path.is_file():
+ # print(OSMData.from_file)
+ # osm = OSMData.from_file(path)
+ # if osm.box is not None:
+ # assert osm.box.contains(bbox_osm)
+ # else:
+ # osm = OSMData.from_dict(get_osm(bbox_osm, path))
+
+ self.osm.add_xy_to_nodes(projection)
+ map_data = MapData.from_osm(self.osm)
+ map_index = MapIndex(map_data)
+
+ bounds_x, bounds_y = [
+ np.r_[np.arange(min_, max_, tile_size), max_]
+ for min_, max_ in zip(bbox.min_, bbox.max_)
+ ]
+ bbox_tiles = {}
+ for i, xmin in enumerate(bounds_x[:-1]):
+ for j, ymin in enumerate(bounds_y[:-1]):
+ bbox_tiles[i, j] = BoundaryBox(
+ [xmin, ymin], [bounds_x[i + 1], bounds_y[j + 1]]
+ )
+
+ tiles = {}
+ for ij, bbox_tile in bbox_tiles.items():
+ canvas = Canvas(bbox_tile, ppm)
+ nodes, lines, areas = map_index.query(bbox_tile)
+ masks = render_raster_masks(nodes, lines, areas, canvas)
+ canvas.raster = render_raster_map(masks)
+ tiles[ij] = canvas
+
+ groups = {k: v for k, v in vars(Groups).items() if not k.startswith("__")}
+
+ self.origin = bbox.min_
+ self.bbox = bbox
+ self.tiles = tiles
+ self.tile_size = tile_size
+ self.ppm = ppm
+ self.projection = projection
+ self.groups = groups
+ self.map_data = map_data
+
+ return self.query(bbox)
+ # return cls(tiles, bbox, tile_size, ppm, projection, groups, map_data)
+
+ def query(self, bbox: BoundaryBox) -> Canvas:
+ bbox = round_bbox(bbox, self.bbox.min_, self.ppm)
+ canvas = Canvas(bbox, self.ppm)
+ raster = np.zeros((3, canvas.h, canvas.w), np.uint8)
+
+ bbox_all = bbox & self.bbox
+ ij_min = np.floor((bbox_all.min_ - self.origin) / self.tile_size).astype(int)
+ ij_max = np.ceil((bbox_all.max_ - self.origin) / self.tile_size).astype(int) - 1
+ for i in range(ij_min[0], ij_max[0] + 1):
+ for j in range(ij_min[1], ij_max[1] + 1):
+ tile = self.tiles[i, j]
+ bbox_select = tile.bbox & bbox
+ slice_query = bbox_to_slice(bbox_select, canvas)
+ slice_tile = bbox_to_slice(bbox_select, tile)
+ raster[(slice(None),) + slice_query] = tile.raster[
+ (slice(None),) + slice_tile
+ ]
+ canvas.raster = raster
+ return canvas
+
+ def save(self, path: Path):
+ dump = {
+ "bbox": self.bbox.format(),
+ "tile_size": self.tile_size,
+ "ppm": self.ppm,
+ "groups": self.groups,
+ "tiles_bbox": {},
+ "tiles_raster": {},
+ }
+ if self.projection is not None:
+ dump["ref_latlonalt"] = self.projection.latlonalt
+ for ij, canvas in self.tiles.items():
+ dump["tiles_bbox"][ij] = canvas.bbox.format()
+ raster_bytes = io.BytesIO()
+ raster = Image.fromarray(canvas.raster.transpose(1, 2, 0).astype(np.uint8))
+ raster.save(raster_bytes, format="PNG")
+ dump["tiles_raster"][ij] = raster_bytes
+ with open(path, "wb") as fp:
+ pickle.dump(dump, fp)
+
+ @classmethod
+ def load(cls, path: Path):
+ with path.open("rb") as fp:
+ dump = pickle.load(fp)
+ tiles = {}
+ for ij, bbox in dump["tiles_bbox"].items():
+ tiles[ij] = Canvas(BoundaryBox.from_string(bbox), dump["ppm"])
+ raster = np.asarray(Image.open(dump["tiles_raster"][ij]))
+ tiles[ij].raster = raster.transpose(2, 0, 1).copy()
+ projection = Projection(*dump["ref_latlonalt"])
+ return cls(
+ tiles,
+ BoundaryBox.from_string(dump["bbox"]),
+ dump["tile_size"],
+ dump["ppm"],
+ projection,
+ dump["groups"],
+ )
+
+class TileManager:
+ def __init__(
+ self,
+ tiles: Dict,
+ bbox: BoundaryBox,
+ tile_size: int,
+ ppm: int,
+ projection: Projection,
+ groups: Dict[str, List[str]],
+ map_data: Optional[MapData] = None,
+ ):
+ self.origin = bbox.min_
+ self.bbox = bbox
+ self.tiles = tiles
+ self.tile_size = tile_size
+ self.ppm = ppm
+ self.projection = projection
+ self.groups = groups
+ self.map_data = map_data
+ assert np.all(tiles[0, 0].bbox.min_ == self.origin)
+ for tile in tiles.values():
+ assert bbox.contains(tile.bbox)
+
+ @classmethod
+ def from_bbox(
+ cls,
+ projection: Projection,
+ bbox: BoundaryBox,
+ ppm: int,
+ path: Optional[Path] = None,
+ tile_size: int = 128,
+ ):
+ bbox_osm = projection.unproject(bbox)
+ if path is not None and path.is_file():
+ print(OSMData.from_file)
+ osm = OSMData.from_file(path)
+ if osm.box is not None:
+ assert osm.box.contains(bbox_osm)
+ else:
+ osm = OSMData.from_dict(get_osm(bbox_osm, path))
+
+ osm.add_xy_to_nodes(projection)
+ map_data = MapData.from_osm(osm)
+ map_index = MapIndex(map_data)
+
+ bounds_x, bounds_y = [
+ np.r_[np.arange(min_, max_, tile_size), max_]
+ for min_, max_ in zip(bbox.min_, bbox.max_)
+ ]
+ bbox_tiles = {}
+ for i, xmin in enumerate(bounds_x[:-1]):
+ for j, ymin in enumerate(bounds_y[:-1]):
+ bbox_tiles[i, j] = BoundaryBox(
+ [xmin, ymin], [bounds_x[i + 1], bounds_y[j + 1]]
+ )
+
+ tiles = {}
+ for ij, bbox_tile in bbox_tiles.items():
+ canvas = Canvas(bbox_tile, ppm)
+ nodes, lines, areas = map_index.query(bbox_tile)
+ masks = render_raster_masks(nodes, lines, areas, canvas)
+ canvas.raster = render_raster_map(masks)
+ tiles[ij] = canvas
+
+ groups = {k: v for k, v in vars(Groups).items() if not k.startswith("__")}
+
+ return cls(tiles, bbox, tile_size, ppm, projection, groups, map_data)
+
+ def query(self, bbox: BoundaryBox) -> Canvas:
+ bbox = round_bbox(bbox, self.bbox.min_, self.ppm)
+ canvas = Canvas(bbox, self.ppm)
+ raster = np.zeros((3, canvas.h, canvas.w), np.uint8)
+
+ bbox_all = bbox & self.bbox
+ ij_min = np.floor((bbox_all.min_ - self.origin) / self.tile_size).astype(int)
+ ij_max = np.ceil((bbox_all.max_ - self.origin) / self.tile_size).astype(int) - 1
+ for i in range(ij_min[0], ij_max[0] + 1):
+ for j in range(ij_min[1], ij_max[1] + 1):
+ tile = self.tiles[i, j]
+ bbox_select = tile.bbox & bbox
+ slice_query = bbox_to_slice(bbox_select, canvas)
+ slice_tile = bbox_to_slice(bbox_select, tile)
+ raster[(slice(None),) + slice_query] = tile.raster[
+ (slice(None),) + slice_tile
+ ]
+ canvas.raster = raster
+ return canvas
+
+ def save(self, path: Path):
+ dump = {
+ "bbox": self.bbox.format(),
+ "tile_size": self.tile_size,
+ "ppm": self.ppm,
+ "groups": self.groups,
+ "tiles_bbox": {},
+ "tiles_raster": {},
+ }
+ if self.projection is not None:
+ dump["ref_latlonalt"] = self.projection.latlonalt
+ for ij, canvas in self.tiles.items():
+ dump["tiles_bbox"][ij] = canvas.bbox.format()
+ raster_bytes = io.BytesIO()
+ raster = Image.fromarray(canvas.raster.transpose(1, 2, 0).astype(np.uint8))
+ raster.save(raster_bytes, format="PNG")
+ dump["tiles_raster"][ij] = raster_bytes
+ with open(path, "wb") as fp:
+ pickle.dump(dump, fp)
+
+ @classmethod
+ def load(cls, path: Path):
+ with path.open("rb") as fp:
+ dump = pickle.load(fp)
+ tiles = {}
+ for ij, bbox in dump["tiles_bbox"].items():
+ tiles[ij] = Canvas(BoundaryBox.from_string(bbox), dump["ppm"])
+ raster = np.asarray(Image.open(dump["tiles_raster"][ij]))
+ tiles[ij].raster = raster.transpose(2, 0, 1).copy()
+ projection = Projection(*dump["ref_latlonalt"])
+ return cls(
+ tiles,
+ BoundaryBox.from_string(dump["bbox"]),
+ dump["tile_size"],
+ dump["ppm"],
+ projection,
+ dump["groups"],
+ )
diff --git a/osm/viz.py b/osm/viz.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd99e3eda0049c2aae35d397018db73b2eb661ae
--- /dev/null
+++ b/osm/viz.py
@@ -0,0 +1,159 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+import matplotlib as mpl
+import matplotlib.pyplot as plt
+import numpy as np
+import plotly.graph_objects as go
+import PIL.Image
+
+from utils.viz_2d import add_text
+from .parser import Groups
+
+
+class GeoPlotter:
+ def __init__(self, zoom=12, **kwargs):
+ self.fig = go.Figure()
+ self.fig.update_layout(
+ mapbox_style="open-street-map",
+ autosize=True,
+ mapbox_zoom=zoom,
+ margin={"r": 0, "t": 0, "l": 0, "b": 0},
+ showlegend=True,
+ **kwargs,
+ )
+
+ def points(self, latlons, color, text=None, name=None, size=5, **kwargs):
+ latlons = np.asarray(latlons)
+ self.fig.add_trace(
+ go.Scattermapbox(
+ lat=latlons[..., 0],
+ lon=latlons[..., 1],
+ mode="markers",
+ text=text,
+ marker_color=color,
+ marker_size=size,
+ name=name,
+ **kwargs,
+ )
+ )
+ center = latlons.reshape(-1, 2).mean(0)
+ self.fig.update_layout(
+ mapbox_center=dict(zip(("lat", "lon"), center)),
+ )
+
+ def bbox(self, bbox, color, name=None, **kwargs):
+ corners = np.stack(
+ [bbox.min_, bbox.left_top, bbox.max_, bbox.right_bottom, bbox.min_]
+ )
+ self.fig.add_trace(
+ go.Scattermapbox(
+ lat=corners[:, 0],
+ lon=corners[:, 1],
+ mode="lines",
+ marker_color=color,
+ name=name,
+ **kwargs,
+ )
+ )
+ self.fig.update_layout(
+ mapbox_center=dict(zip(("lat", "lon"), bbox.center)),
+ )
+
+ def raster(self, raster, bbox, below="traces", **kwargs):
+ if not np.issubdtype(raster.dtype, np.integer):
+ raster = (raster * 255).astype(np.uint8)
+ raster = PIL.Image.fromarray(raster)
+ corners = np.stack(
+ [
+ bbox.min_,
+ bbox.left_top,
+ bbox.max_,
+ bbox.right_bottom,
+ ]
+ )[::-1, ::-1]
+ layers = [*self.fig.layout.mapbox.layers]
+ layers.append(
+ dict(
+ sourcetype="image",
+ source=raster,
+ coordinates=corners,
+ below=below,
+ **kwargs,
+ )
+ )
+ self.fig.layout.mapbox.layers = layers
+
+
+map_colors = {
+ "building": (84, 155, 255),
+ "parking": (255, 229, 145),
+ "playground": (150, 133, 125),
+ "grass": (188, 255, 143),
+ "park": (0, 158, 16),
+ "forest": (0, 92, 9),
+ "water": (184, 213, 255),
+ "fence": (238, 0, 255),
+ "wall": (0, 0, 0),
+ "hedge": (107, 68, 48),
+ "kerb": (255, 234, 0),
+ "building_outline": (0, 0, 255),
+ "cycleway": (0, 251, 255),
+ "path": (8, 237, 0),
+ "road": (255, 0, 0),
+ "tree_row": (0, 92, 9),
+ "busway": (255, 128, 0),
+ "void": [int(255 * 0.9)] * 3,
+}
+
+
+class Colormap:
+ colors_areas = np.stack([map_colors[k] for k in ["void"] + Groups.areas])
+ colors_ways = np.stack([map_colors[k] for k in ["void"] + Groups.ways])
+
+ @classmethod
+ def apply(cls, rasters):
+ return (
+ np.where(
+ rasters[1, ..., None] > 0,
+ cls.colors_ways[rasters[1]],
+ cls.colors_areas[rasters[0]],
+ )
+ / 255.0
+ )
+
+ @classmethod
+ def add_colorbar(cls):
+ ax2 = plt.gcf().add_axes([1, 0.1, 0.02, 0.8])
+ color_list = np.r_[cls.colors_areas[1:], cls.colors_ways[1:]] / 255.0
+ cmap = mpl.colors.ListedColormap(color_list[::-1])
+ ticks = np.linspace(0, 1, len(color_list), endpoint=False)
+ ticks += 1 / len(color_list) / 2
+ cb = mpl.colorbar.ColorbarBase(
+ ax2,
+ cmap=cmap,
+ orientation="vertical",
+ ticks=ticks,
+ )
+ cb.set_ticklabels((Groups.areas + Groups.ways)[::-1])
+ ax2.tick_params(labelsize=15)
+
+
+def plot_nodes(idx, raster, fontsize=8, size=15):
+ ax = plt.gcf().axes[idx]
+ ax.autoscale(enable=False)
+ nodes_xy = np.stack(np.where(raster > 0)[::-1], -1)
+ nodes_val = raster[tuple(nodes_xy.T[::-1])] - 1
+ ax.scatter(*nodes_xy.T, c="k", s=size)
+ for xy, val in zip(nodes_xy, nodes_val):
+ group = Groups.nodes[val]
+ add_text(
+ idx,
+ group,
+ xy + 2,
+ lcolor=None,
+ fs=fontsize,
+ color="k",
+ normalized=False,
+ ha="center",
+ )
+ plt.show()
diff --git a/train.py b/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..b95c8795faf68414f860ef87dc0f6acb6cadebf2
--- /dev/null
+++ b/train.py
@@ -0,0 +1,217 @@
+import os.path as osp
+import warnings
+warnings.filterwarnings('ignore')
+from typing import Optional
+from pathlib import Path
+from models.maplocnet import MapLocNet
+import hydra
+import pytorch_lightning as pl
+import torch
+from omegaconf import DictConfig, OmegaConf
+from pytorch_lightning.utilities import rank_zero_only
+from module import GenericModule
+from logger import logger, pl_logger, EXPERIMENTS_PATH
+from module import GenericModule
+from dataset import UavMapDatasetModule
+from pytorch_lightning.callbacks.early_stopping import EarlyStopping
+# print(osp.join(osp.dirname(__file__), "conf"))
+
+
+class CleanProgressBar(pl.callbacks.TQDMProgressBar):
+ def get_metrics(self, trainer, model):
+ items = super().get_metrics(trainer, model)
+ items.pop("v_num", None) # don't show the version number
+ items.pop("loss", None)
+ return items
+
+
+class SeedingCallback(pl.callbacks.Callback):
+ def on_epoch_start_(self, trainer, module):
+ seed = module.cfg.experiment.seed
+ is_overfit = module.cfg.training.trainer.get("overfit_batches", 0) > 0
+ if trainer.training and not is_overfit:
+ seed = seed + trainer.current_epoch
+
+ # Temporarily disable the logging (does not seem to work?)
+ pl_logger.disabled = True
+ try:
+ pl.seed_everything(seed, workers=True)
+ finally:
+ pl_logger.disabled = False
+
+ def on_train_epoch_start(self, *args, **kwargs):
+ self.on_epoch_start_(*args, **kwargs)
+
+ def on_validation_epoch_start(self, *args, **kwargs):
+ self.on_epoch_start_(*args, **kwargs)
+
+ def on_test_epoch_start(self, *args, **kwargs):
+ self.on_epoch_start_(*args, **kwargs)
+
+
+class ConsoleLogger(pl.callbacks.Callback):
+ @rank_zero_only
+ def on_train_epoch_start(self, trainer, module):
+ logger.info(
+ "New training epoch %d for experiment '%s'.",
+ module.current_epoch,
+ module.cfg.experiment.name,
+ )
+
+ # @rank_zero_only
+ # def on_validation_epoch_end(self, trainer, module):
+ # results = {
+ # **dict(module.metrics_val.items()),
+ # **dict(module.losses_val.items()),
+ # }
+ # results = [f"{k} {v.compute():.3E}" for k, v in results.items()]
+ # logger.info(f'[Validation] {{{", ".join(results)}}}')
+
+
+def find_last_checkpoint_path(experiment_dir):
+ cls = pl.callbacks.ModelCheckpoint
+ path = osp.join(experiment_dir, cls.CHECKPOINT_NAME_LAST + cls.FILE_EXTENSION)
+ if osp.exists(path):
+ return path
+ else:
+ return None
+
+
+def prepare_experiment_dir(experiment_dir, cfg, rank):
+ config_path = osp.join(experiment_dir, "config.yaml")
+ last_checkpoint_path = find_last_checkpoint_path(experiment_dir)
+ if last_checkpoint_path is not None:
+ if rank == 0:
+ logger.info(
+ "Resuming the training from checkpoint %s", last_checkpoint_path
+ )
+ if osp.exists(config_path):
+ with open(config_path, "r") as fp:
+ cfg_prev = OmegaConf.create(fp.read())
+ compare_keys = ["experiment", "data", "model", "training"]
+ if OmegaConf.masked_copy(cfg, compare_keys) != OmegaConf.masked_copy(
+ cfg_prev, compare_keys
+ ):
+ raise ValueError(
+ "Attempting to resume training with a different config: "
+ f"{OmegaConf.masked_copy(cfg, compare_keys)} vs "
+ f"{OmegaConf.masked_copy(cfg_prev, compare_keys)}"
+ )
+ if rank == 0:
+ Path(experiment_dir).mkdir(exist_ok=True, parents=True)
+ with open(config_path, "w") as fp:
+ OmegaConf.save(cfg, fp)
+ return last_checkpoint_path
+
+
+def train(cfg: DictConfig) -> None:
+ torch.set_float32_matmul_precision("medium")
+ OmegaConf.resolve(cfg)
+ rank = rank_zero_only.rank
+
+ if rank == 0:
+ logger.info("Starting training with config:\n%s", OmegaConf.to_yaml(cfg))
+ if cfg.experiment.gpus in (None, 0):
+ logger.warning("Will train on CPU...")
+ cfg.experiment.gpus = 0
+ elif not torch.cuda.is_available():
+ raise ValueError("Requested GPU but no NVIDIA drivers found.")
+ pl.seed_everything(cfg.experiment.seed, workers=True)
+
+ init_checkpoint_path = cfg.training.get("finetune_from_checkpoint")
+ if init_checkpoint_path is not None:
+ logger.info("Initializing the model from checkpoint %s.", init_checkpoint_path)
+ model = GenericModule.load_from_checkpoint(
+ init_checkpoint_path, strict=True, find_best=False, cfg=cfg
+ )
+ else:
+ model = GenericModule(cfg)
+ if rank == 0:
+ logger.info("Network:\n%s", model.model)
+
+ experiment_dir = osp.join(EXPERIMENTS_PATH, cfg.experiment.name)
+ last_checkpoint_path = prepare_experiment_dir(experiment_dir, cfg, rank)
+ checkpointing_epoch = pl.callbacks.ModelCheckpoint(
+ dirpath=experiment_dir,
+ filename="checkpoint-epoch-{epoch:02d}-loss-{loss/total/val:02f}",
+ auto_insert_metric_name=False,
+ save_last=True,
+ every_n_epochs=1,
+ save_on_train_epoch_end=True,
+ verbose=True,
+ **cfg.training.checkpointing,
+ )
+ checkpointing_step = pl.callbacks.ModelCheckpoint(
+ dirpath=experiment_dir,
+ filename="checkpoint-step-{step}-{loss/total/val:02f}",
+ auto_insert_metric_name=False,
+ save_last=True,
+ every_n_train_steps=1000,
+ verbose=True,
+ **cfg.training.checkpointing,
+ )
+ checkpointing_step.CHECKPOINT_NAME_LAST = "last-step-checkpointing"
+
+ # 创建 EarlyStopping 回调
+ early_stopping_callback = EarlyStopping(monitor=cfg.training.checkpointing.monitor, patience=5)
+
+ strategy = None
+ if cfg.experiment.gpus > 1:
+ strategy = pl.strategies.DDPStrategy(find_unused_parameters=False)
+ for split in ["train", "val"]:
+ cfg.data[split].batch_size = (
+ cfg.data[split].batch_size // cfg.experiment.gpus
+ )
+ cfg.data[split].num_workers = int(
+ (cfg.data[split].num_workers + cfg.experiment.gpus - 1)
+ / cfg.experiment.gpus
+ )
+
+ # data = data_modules[cfg.data.get("name", "mapillary")](cfg.data)
+
+ datamodule =UavMapDatasetModule(cfg.data)
+
+ tb_args = {"name": cfg.experiment.name, "version": ""}
+ tb = pl.loggers.TensorBoardLogger(EXPERIMENTS_PATH, **tb_args)
+
+ callbacks = [
+ checkpointing_epoch,
+ checkpointing_step,
+ # early_stopping_callback,
+ pl.callbacks.LearningRateMonitor(),
+ SeedingCallback(),
+ CleanProgressBar(),
+ ConsoleLogger(),
+ ]
+ if cfg.experiment.gpus > 0:
+ callbacks.append(pl.callbacks.DeviceStatsMonitor())
+
+ trainer = pl.Trainer(
+ default_root_dir=experiment_dir,
+ detect_anomaly=False,
+ # strategy=ddp_find_unused_parameters_true,
+ enable_model_summary=True,
+ sync_batchnorm=True,
+ enable_checkpointing=True,
+ logger=tb,
+ callbacks=callbacks,
+ strategy=strategy,
+ check_val_every_n_epoch=1,
+ accelerator="gpu",
+ num_nodes=1,
+ **cfg.training.trainer,
+ )
+ trainer.fit(model=model, datamodule=datamodule, ckpt_path=last_checkpoint_path)
+
+
+@hydra.main(
+ config_path=osp.join(osp.dirname(__file__), "conf"), config_name="maplocnet.yaml"
+)
+def main(cfg: DictConfig) -> None:
+ OmegaConf.save(config=cfg, f='maplocnet.yaml')
+ train(cfg)
+
+
+if __name__ == "__main__":
+ main()
+
diff --git a/train.sh b/train.sh
new file mode 100644
index 0000000000000000000000000000000000000000..3683adf5eb906698baec50d39e1e8d86f0d84bf7
--- /dev/null
+++ b/train.sh
@@ -0,0 +1 @@
+nohup python train.py > logs/train0907.log 2>&1 &
\ No newline at end of file
diff --git a/utils/exif.py b/utils/exif.py
new file mode 100644
index 0000000000000000000000000000000000000000..c272019b5673dc1e3aab8a3e0e21b630cc629154
--- /dev/null
+++ b/utils/exif.py
@@ -0,0 +1,356 @@
+"""Copied from opensfm.exif to minimize hard dependencies."""
+from pathlib import Path
+import json
+import datetime
+import logging
+from codecs import encode, decode
+from typing import Any, Dict, Optional, Tuple
+
+import exifread
+
+logger: logging.Logger = logging.getLogger(__name__)
+
+inch_in_mm = 25.4
+cm_in_mm = 10
+um_in_mm = 0.001
+default_projection = "perspective"
+maximum_altitude = 1e4
+
+
+def sensor_data():
+ with (Path(__file__).parent / "sensor_data.json").open() as fid:
+ data = json.load(fid)
+ return {k.lower(): v for k, v in data.items()}
+
+
+def eval_frac(value) -> Optional[float]:
+ try:
+ return float(value.num) / float(value.den)
+ except ZeroDivisionError:
+ return None
+
+
+def gps_to_decimal(values, reference) -> Optional[float]:
+ sign = 1 if reference in "NE" else -1
+ degrees = eval_frac(values[0])
+ minutes = eval_frac(values[1])
+ seconds = eval_frac(values[2])
+ if degrees is not None and minutes is not None and seconds is not None:
+ return sign * (degrees + minutes / 60 + seconds / 3600)
+ return None
+
+
+def get_tag_as_float(tags, key, index: int = 0) -> Optional[float]:
+ if key in tags:
+ val = tags[key].values[index]
+ if isinstance(val, exifread.utils.Ratio):
+ ret_val = eval_frac(val)
+ if ret_val is None:
+ logger.error(
+ 'The rational "{2}" of tag "{0:s}" at index {1:d} c'
+ "aused a division by zero error".format(key, index, val)
+ )
+ return ret_val
+ else:
+ return float(val)
+ else:
+ return None
+
+
+def compute_focal(
+ focal_35: Optional[float], focal: Optional[float], sensor_width, sensor_string
+) -> Tuple[float, float]:
+ if focal_35 is not None and focal_35 > 0:
+ focal_ratio = focal_35 / 36.0 # 35mm film produces 36x24mm pictures.
+ else:
+ if not sensor_width:
+ sensor_width = sensor_data().get(sensor_string, None)
+ if sensor_width and focal:
+ focal_ratio = focal / sensor_width
+ focal_35 = 36.0 * focal_ratio
+ else:
+ focal_35 = 0.0
+ focal_ratio = 0.0
+ return focal_35, focal_ratio
+
+
+def sensor_string(make: str, model: str) -> str:
+ if make != "unknown":
+ # remove duplicate 'make' information in 'model'
+ model = model.replace(make, "")
+ return (make.strip() + " " + model.strip()).strip().lower()
+
+
+def unescape_string(s) -> str:
+ return decode(encode(s, "latin-1", "backslashreplace"), "unicode-escape")
+
+
+class EXIF:
+ def __init__(
+ self, fileobj, image_size_loader, use_exif_size=True, name=None
+ ) -> None:
+ self.image_size_loader = image_size_loader
+ self.use_exif_size = use_exif_size
+ self.fileobj = fileobj
+ self.tags = exifread.process_file(fileobj, details=False)
+ fileobj.seek(0)
+ self.fileobj_name = self.fileobj.name if name is None else name
+
+ def extract_image_size(self) -> Tuple[int, int]:
+ if (
+ self.use_exif_size
+ and "EXIF ExifImageWidth" in self.tags
+ and "EXIF ExifImageLength" in self.tags
+ ):
+ width, height = (
+ int(self.tags["EXIF ExifImageWidth"].values[0]),
+ int(self.tags["EXIF ExifImageLength"].values[0]),
+ )
+ elif (
+ self.use_exif_size
+ and "Image ImageWidth" in self.tags
+ and "Image ImageLength" in self.tags
+ ):
+ width, height = (
+ int(self.tags["Image ImageWidth"].values[0]),
+ int(self.tags["Image ImageLength"].values[0]),
+ )
+ else:
+ height, width = self.image_size_loader()
+ return width, height
+
+ def _decode_make_model(self, value) -> str:
+ """Python 2/3 compatible decoding of make/model field."""
+ if hasattr(value, "decode"):
+ try:
+ return value.decode("utf-8")
+ except UnicodeDecodeError:
+ return "unknown"
+ else:
+ return value
+
+ def extract_make(self) -> str:
+ # Camera make and model
+ if "EXIF LensMake" in self.tags:
+ make = self.tags["EXIF LensMake"].values
+ elif "Image Make" in self.tags:
+ make = self.tags["Image Make"].values
+ else:
+ make = "unknown"
+ return self._decode_make_model(make)
+
+ def extract_model(self) -> str:
+ if "EXIF LensModel" in self.tags:
+ model = self.tags["EXIF LensModel"].values
+ elif "Image Model" in self.tags:
+ model = self.tags["Image Model"].values
+ else:
+ model = "unknown"
+ return self._decode_make_model(model)
+
+ def extract_focal(self) -> Tuple[float, float]:
+ make, model = self.extract_make(), self.extract_model()
+ focal_35, focal_ratio = compute_focal(
+ get_tag_as_float(self.tags, "EXIF FocalLengthIn35mmFilm"),
+ get_tag_as_float(self.tags, "EXIF FocalLength"),
+ self.extract_sensor_width(),
+ sensor_string(make, model),
+ )
+ return focal_35, focal_ratio
+
+ def extract_sensor_width(self) -> Optional[float]:
+ """Compute sensor with from width and resolution."""
+ if (
+ "EXIF FocalPlaneResolutionUnit" not in self.tags
+ or "EXIF FocalPlaneXResolution" not in self.tags
+ ):
+ return None
+ resolution_unit = self.tags["EXIF FocalPlaneResolutionUnit"].values[0]
+ mm_per_unit = self.get_mm_per_unit(resolution_unit)
+ if not mm_per_unit:
+ return None
+ pixels_per_unit = get_tag_as_float(self.tags, "EXIF FocalPlaneXResolution")
+ if pixels_per_unit is None:
+ return None
+ if pixels_per_unit <= 0.0:
+ pixels_per_unit = get_tag_as_float(self.tags, "EXIF FocalPlaneYResolution")
+ if pixels_per_unit is None or pixels_per_unit <= 0.0:
+ return None
+ units_per_pixel = 1 / pixels_per_unit
+ width_in_pixels = self.extract_image_size()[0]
+ return width_in_pixels * units_per_pixel * mm_per_unit
+
+ def get_mm_per_unit(self, resolution_unit) -> Optional[float]:
+ """Length of a resolution unit in millimeters.
+
+ Uses the values from the EXIF specs in
+ https://www.sno.phy.queensu.ca/~phil/exiftool/TagNames/EXIF.html
+
+ Args:
+ resolution_unit: the resolution unit value given in the EXIF
+ """
+ if resolution_unit == 2: # inch
+ return inch_in_mm
+ elif resolution_unit == 3: # cm
+ return cm_in_mm
+ elif resolution_unit == 4: # mm
+ return 1
+ elif resolution_unit == 5: # um
+ return um_in_mm
+ else:
+ logger.warning(
+ "Unknown EXIF resolution unit value: {}".format(resolution_unit)
+ )
+ return None
+
+ def extract_orientation(self) -> int:
+ orientation = 1
+ if "Image Orientation" in self.tags:
+ value = self.tags.get("Image Orientation").values[0]
+ if type(value) == int and value != 0:
+ orientation = value
+ return orientation
+
+ def extract_ref_lon_lat(self) -> Tuple[str, str]:
+ if "GPS GPSLatitudeRef" in self.tags:
+ reflat = self.tags["GPS GPSLatitudeRef"].values
+ else:
+ reflat = "N"
+ if "GPS GPSLongitudeRef" in self.tags:
+ reflon = self.tags["GPS GPSLongitudeRef"].values
+ else:
+ reflon = "E"
+ return reflon, reflat
+
+ def extract_lon_lat(self) -> Tuple[Optional[float], Optional[float]]:
+ if "GPS GPSLatitude" in self.tags:
+ reflon, reflat = self.extract_ref_lon_lat()
+ lat = gps_to_decimal(self.tags["GPS GPSLatitude"].values, reflat)
+ lon = gps_to_decimal(self.tags["GPS GPSLongitude"].values, reflon)
+ else:
+ lon, lat = None, None
+ return lon, lat
+
+ def extract_altitude(self) -> Optional[float]:
+ if "GPS GPSAltitude" in self.tags:
+ alt_value = self.tags["GPS GPSAltitude"].values[0]
+ if isinstance(alt_value, exifread.utils.Ratio):
+ altitude = eval_frac(alt_value)
+ elif isinstance(alt_value, int):
+ altitude = float(alt_value)
+ else:
+ altitude = None
+
+ # Check if GPSAltitudeRef is equal to 1, which means GPSAltitude should be negative, reference: http://www.exif.org/Exif2-2.PDF#page=53
+ if (
+ "GPS GPSAltitudeRef" in self.tags
+ and self.tags["GPS GPSAltitudeRef"].values[0] == 1
+ and altitude is not None
+ ):
+ altitude = -altitude
+ else:
+ altitude = None
+ return altitude
+
+ def extract_dop(self) -> Optional[float]:
+ if "GPS GPSDOP" in self.tags:
+ return eval_frac(self.tags["GPS GPSDOP"].values[0])
+ return None
+
+ def extract_geo(self) -> Dict[str, Any]:
+ altitude = self.extract_altitude()
+ dop = self.extract_dop()
+ lon, lat = self.extract_lon_lat()
+ d = {}
+
+ if lon is not None and lat is not None:
+ d["latitude"] = lat
+ d["longitude"] = lon
+ if altitude is not None:
+ d["altitude"] = min([maximum_altitude, altitude])
+ if dop is not None:
+ d["dop"] = dop
+ return d
+
+ def extract_capture_time(self) -> float:
+ if (
+ "GPS GPSDate" in self.tags
+ and "GPS GPSTimeStamp" in self.tags # Actually GPSDateStamp
+ ):
+ try:
+ hours_f = get_tag_as_float(self.tags, "GPS GPSTimeStamp", 0)
+ minutes_f = get_tag_as_float(self.tags, "GPS GPSTimeStamp", 1)
+ if hours_f is None or minutes_f is None:
+ raise TypeError
+ hours = int(hours_f)
+ minutes = int(minutes_f)
+ seconds = get_tag_as_float(self.tags, "GPS GPSTimeStamp", 2)
+ gps_timestamp_string = "{0:s} {1:02d}:{2:02d}:{3:02f}".format(
+ self.tags["GPS GPSDate"].values, hours, minutes, seconds
+ )
+ return (
+ datetime.datetime.strptime(
+ gps_timestamp_string, "%Y:%m:%d %H:%M:%S.%f"
+ )
+ - datetime.datetime(1970, 1, 1)
+ ).total_seconds()
+ except (TypeError, ValueError):
+ logger.info(
+ 'The GPS time stamp in image file "{0:s}" is invalid. '
+ "Falling back to DateTime*".format(self.fileobj_name)
+ )
+
+ time_strings = [
+ ("EXIF DateTimeOriginal", "EXIF SubSecTimeOriginal", "EXIF Tag 0x9011"),
+ ("EXIF DateTimeDigitized", "EXIF SubSecTimeDigitized", "EXIF Tag 0x9012"),
+ ("Image DateTime", "Image SubSecTime", "Image Tag 0x9010"),
+ ]
+ for datetime_tag, subsec_tag, offset_tag in time_strings:
+ if datetime_tag in self.tags:
+ date_time = self.tags[datetime_tag].values
+ if subsec_tag in self.tags:
+ subsec_time = self.tags[subsec_tag].values
+ else:
+ subsec_time = "0"
+ try:
+ s = "{0:s}.{1:s}".format(date_time, subsec_time)
+ d = datetime.datetime.strptime(s, "%Y:%m:%d %H:%M:%S.%f")
+ except ValueError:
+ logger.debug(
+ 'The "{1:s}" time stamp or "{2:s}" tag is invalid in '
+ 'image file "{0:s}"'.format(
+ self.fileobj_name, datetime_tag, subsec_tag
+ )
+ )
+ continue
+ # Test for OffsetTimeOriginal | OffsetTimeDigitized | OffsetTime
+ if offset_tag in self.tags:
+ offset_time = self.tags[offset_tag].values
+ try:
+ d += datetime.timedelta(
+ hours=-int(offset_time[0:3]), minutes=int(offset_time[4:6])
+ )
+ except (TypeError, ValueError):
+ logger.debug(
+ 'The "{0:s}" time zone offset in image file "{1:s}"'
+ " is invalid".format(offset_tag, self.fileobj_name)
+ )
+ logger.debug(
+ 'Naively assuming UTC on "{0:s}" in image file '
+ '"{1:s}"'.format(datetime_tag, self.fileobj_name)
+ )
+ else:
+ logger.debug(
+ "No GPS time stamp and no time zone offset in image "
+ 'file "{0:s}"'.format(self.fileobj_name)
+ )
+ logger.debug(
+ 'Naively assuming UTC on "{0:s}" in image file "{1:s}"'.format(
+ datetime_tag, self.fileobj_name
+ )
+ )
+ return (d - datetime.datetime(1970, 1, 1)).total_seconds()
+ logger.info(
+ 'Image file "{0:s}" has no valid time stamp'.format(self.fileobj_name)
+ )
+ return 0.0
diff --git a/utils/geo.py b/utils/geo.py
new file mode 100644
index 0000000000000000000000000000000000000000..7f97e501f81092b57fc3ad043aa520779b44faeb
--- /dev/null
+++ b/utils/geo.py
@@ -0,0 +1,130 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+from typing import Union
+
+import numpy as np
+import torch
+
+from .geo_opensfm import TopocentricConverter
+
+
+class BoundaryBox:
+ def __init__(self, min_: np.ndarray, max_: np.ndarray):
+ self.min_ = np.asarray(min_)
+ self.max_ = np.asarray(max_)
+ assert np.all(self.min_ <= self.max_)
+
+ @classmethod
+ def from_string(cls, string: str):
+ return cls(*np.split(np.array(string.split(","), float), 2))
+
+ @property
+ def left_top(self):
+ return np.stack([self.min_[..., 0], self.max_[..., 1]], -1)
+
+ @property
+ def right_bottom(self) -> (np.ndarray, np.ndarray):
+ return np.stack([self.max_[..., 0], self.min_[..., 1]], -1)
+
+ @property
+ def center(self) -> np.ndarray:
+ return (self.min_ + self.max_) / 2
+
+ @property
+ def size(self) -> np.ndarray:
+ return self.max_ - self.min_
+
+ def translate(self, t: float):
+ return self.__class__(self.min_ + t, self.max_ + t)
+
+ def contains(self, xy: Union[np.ndarray, "BoundaryBox"]):
+ if isinstance(xy, self.__class__):
+ return self.contains(xy.min_) and self.contains(xy.max_)
+ return np.all((xy >= self.min_) & (xy <= self.max_), -1)
+
+ def normalize(self, xy):
+ min_, max_ = self.min_, self.max_
+ if isinstance(xy, torch.Tensor):
+ min_ = torch.from_numpy(min_).to(xy)
+ max_ = torch.from_numpy(max_).to(xy)
+ return (xy - min_) / (max_ - min_)
+
+ def unnormalize(self, xy):
+ min_, max_ = self.min_, self.max_
+ if isinstance(xy, torch.Tensor):
+ min_ = torch.from_numpy(min_).to(xy)
+ max_ = torch.from_numpy(max_).to(xy)
+ return xy * (max_ - min_) + min_
+
+ def format(self) -> str:
+ return ",".join(np.r_[self.min_, self.max_].astype(str))
+
+ def __add__(self, x):
+ if isinstance(x, (int, float)):
+ return self.__class__(self.min_ - x, self.max_ + x)
+ else:
+ raise TypeError(f"Cannot add {self.__class__.__name__} to {type(x)}.")
+
+ def __and__(self, other):
+ return self.__class__(
+ np.maximum(self.min_, other.min_), np.minimum(self.max_, other.max_)
+ )
+
+ def __repr__(self):
+ return self.format()
+
+
+class Projection:
+ def __init__(self, lat, lon, alt=0, max_extent=25e3):
+ # The approximation error is |L - radius * tan(L / radius)|
+ # and is around 13cm for L=25km.
+ self.latlonalt = (lat, lon, alt)
+ self.converter = TopocentricConverter(lat, lon, alt)
+ min_ = self.converter.to_lla(*(-max_extent,) * 2, 0)[:2]
+ max_ = self.converter.to_lla(*(max_extent,) * 2, 0)[:2]
+ self.bounds = BoundaryBox(min_, max_)
+
+ @classmethod
+ def from_points(cls, all_latlon):
+ assert all_latlon.shape[-1] == 2
+ all_latlon = all_latlon.reshape(-1, 2)
+ latlon_mid = (all_latlon.min(0) + all_latlon.max(0)) / 2
+ return cls(*latlon_mid)
+
+ def check_bbox(self, bbox: BoundaryBox):
+ if self.bounds is not None and not self.bounds.contains(bbox):
+ raise ValueError(
+ f"Bbox {bbox.format()} is not contained in "
+ f"projection with bounds {self.bounds.format()}."
+ )
+
+ def project(self, geo, return_z=False):
+ if isinstance(geo, BoundaryBox):
+ return BoundaryBox(*self.project(np.stack([geo.min_, geo.max_])))
+ geo = np.asarray(geo)
+ assert geo.shape[-1] in (2, 3)
+ if self.bounds is not None:
+ if not np.all(self.bounds.contains(geo[..., :2])):
+ raise ValueError(
+ f"Points {geo} are out of the valid bounds "
+ f"{self.bounds.format()}."
+ )
+ lat, lon = geo[..., 0], geo[..., 1]
+ if geo.shape[-1] == 3:
+ alt = geo[..., -1]
+ else:
+ alt = np.zeros_like(lat)
+ x, y, z = self.converter.to_topocentric(lat, lon, alt)
+ return np.stack([x, y] + ([z] if return_z else []), -1)
+
+ def unproject(self, xy, return_z=False):
+ if isinstance(xy, BoundaryBox):
+ return BoundaryBox(*self.unproject(np.stack([xy.min_, xy.max_])))
+ xy = np.asarray(xy)
+ x, y = xy[..., 0], xy[..., 1]
+ if xy.shape[-1] == 3:
+ z = xy[..., -1]
+ else:
+ z = np.zeros_like(x)
+ lat, lon, alt = self.converter.to_lla(x, y, z)
+ return np.stack([lat, lon] + ([alt] if return_z else []), -1)
diff --git a/utils/geo_opensfm.py b/utils/geo_opensfm.py
new file mode 100644
index 0000000000000000000000000000000000000000..d42145236dd4c65a8764cbe37fa25c527814017e
--- /dev/null
+++ b/utils/geo_opensfm.py
@@ -0,0 +1,180 @@
+"""Copied from opensfm.geo to minimize hard dependencies."""
+import numpy as np
+from numpy import ndarray
+from typing import Tuple
+
+WGS84_a = 6378137.0
+WGS84_b = 6356752.314245
+
+
+def ecef_from_lla(lat, lon, alt: float) -> Tuple[float, ...]:
+ """
+ Compute ECEF XYZ from latitude, longitude and altitude.
+
+ All using the WGS84 model.
+ Altitude is the distance to the WGS84 ellipsoid.
+ Check results here http://www.oc.nps.edu/oc2902w/coord/llhxyz.htm
+
+ >>> lat, lon, alt = 10, 20, 30
+ >>> x, y, z = ecef_from_lla(lat, lon, alt)
+ >>> np.allclose(lla_from_ecef(x,y,z), [lat, lon, alt])
+ True
+ """
+ a2 = WGS84_a**2
+ b2 = WGS84_b**2
+ lat = np.radians(lat)
+ lon = np.radians(lon)
+ L = 1.0 / np.sqrt(a2 * np.cos(lat) ** 2 + b2 * np.sin(lat) ** 2)
+ x = (a2 * L + alt) * np.cos(lat) * np.cos(lon)
+ y = (a2 * L + alt) * np.cos(lat) * np.sin(lon)
+ z = (b2 * L + alt) * np.sin(lat)
+ return x, y, z
+
+
+def lla_from_ecef(x, y, z):
+ """
+ Compute latitude, longitude and altitude from ECEF XYZ.
+
+ All using the WGS84 model.
+ Altitude is the distance to the WGS84 ellipsoid.
+ """
+ a = WGS84_a
+ b = WGS84_b
+ ea = np.sqrt((a**2 - b**2) / a**2)
+ eb = np.sqrt((a**2 - b**2) / b**2)
+ p = np.sqrt(x**2 + y**2)
+ theta = np.arctan2(z * a, p * b)
+ lon = np.arctan2(y, x)
+ lat = np.arctan2(
+ z + eb**2 * b * np.sin(theta) ** 3, p - ea**2 * a * np.cos(theta) ** 3
+ )
+ N = a / np.sqrt(1 - ea**2 * np.sin(lat) ** 2)
+ alt = p / np.cos(lat) - N
+ return np.degrees(lat), np.degrees(lon), alt
+
+
+def ecef_from_topocentric_transform(lat, lon, alt: float) -> ndarray:
+ """
+ Transformation from a topocentric frame at reference position to ECEF.
+
+ The topocentric reference frame is a metric one with the origin
+ at the given (lat, lon, alt) position, with the X axis heading east,
+ the Y axis heading north and the Z axis vertical to the ellipsoid.
+ >>> a = ecef_from_topocentric_transform(30, 20, 10)
+ >>> b = ecef_from_topocentric_transform_finite_diff(30, 20, 10)
+ >>> np.allclose(a, b)
+ True
+ """
+ x, y, z = ecef_from_lla(lat, lon, alt)
+ sa = np.sin(np.radians(lat))
+ ca = np.cos(np.radians(lat))
+ so = np.sin(np.radians(lon))
+ co = np.cos(np.radians(lon))
+ return np.array(
+ [
+ [-so, -sa * co, ca * co, x],
+ [co, -sa * so, ca * so, y],
+ [0, ca, sa, z],
+ [0, 0, 0, 1],
+ ]
+ )
+
+
+def ecef_from_topocentric_transform_finite_diff(lat, lon, alt: float) -> ndarray:
+ """
+ Transformation from a topocentric frame at reference position to ECEF.
+
+ The topocentric reference frame is a metric one with the origin
+ at the given (lat, lon, alt) position, with the X axis heading east,
+ the Y axis heading north and the Z axis vertical to the ellipsoid.
+ """
+ eps = 1e-2
+ x, y, z = ecef_from_lla(lat, lon, alt)
+ v1 = (
+ (
+ np.array(ecef_from_lla(lat, lon + eps, alt))
+ - np.array(ecef_from_lla(lat, lon - eps, alt))
+ )
+ / 2
+ / eps
+ )
+ v2 = (
+ (
+ np.array(ecef_from_lla(lat + eps, lon, alt))
+ - np.array(ecef_from_lla(lat - eps, lon, alt))
+ )
+ / 2
+ / eps
+ )
+ v3 = (
+ (
+ np.array(ecef_from_lla(lat, lon, alt + eps))
+ - np.array(ecef_from_lla(lat, lon, alt - eps))
+ )
+ / 2
+ / eps
+ )
+ v1 /= np.linalg.norm(v1)
+ v2 /= np.linalg.norm(v2)
+ v3 /= np.linalg.norm(v3)
+ return np.array(
+ [
+ [v1[0], v2[0], v3[0], x],
+ [v1[1], v2[1], v3[1], y],
+ [v1[2], v2[2], v3[2], z],
+ [0, 0, 0, 1],
+ ]
+ )
+
+
+def topocentric_from_lla(lat, lon, alt: float, reflat, reflon, refalt: float):
+ """
+ Transform from lat, lon, alt to topocentric XYZ.
+
+ >>> lat, lon, alt = -10, 20, 100
+ >>> np.allclose(topocentric_from_lla(lat, lon, alt, lat, lon, alt),
+ ... [0,0,0])
+ True
+ >>> x, y, z = topocentric_from_lla(lat, lon, alt, 0, 0, 0)
+ >>> np.allclose(lla_from_topocentric(x, y, z, 0, 0, 0),
+ ... [lat, lon, alt])
+ True
+ """
+ T = np.linalg.inv(ecef_from_topocentric_transform(reflat, reflon, refalt))
+ x, y, z = ecef_from_lla(lat, lon, alt)
+ tx = T[0, 0] * x + T[0, 1] * y + T[0, 2] * z + T[0, 3]
+ ty = T[1, 0] * x + T[1, 1] * y + T[1, 2] * z + T[1, 3]
+ tz = T[2, 0] * x + T[2, 1] * y + T[2, 2] * z + T[2, 3]
+ return tx, ty, tz
+
+
+def lla_from_topocentric(x, y, z, reflat, reflon, refalt: float):
+ """
+ Transform from topocentric XYZ to lat, lon, alt.
+ """
+ T = ecef_from_topocentric_transform(reflat, reflon, refalt)
+ ex = T[0, 0] * x + T[0, 1] * y + T[0, 2] * z + T[0, 3]
+ ey = T[1, 0] * x + T[1, 1] * y + T[1, 2] * z + T[1, 3]
+ ez = T[2, 0] * x + T[2, 1] * y + T[2, 2] * z + T[2, 3]
+ return lla_from_ecef(ex, ey, ez)
+
+
+class TopocentricConverter(object):
+ """Convert to and from a topocentric reference frame."""
+
+ def __init__(self, reflat, reflon, refalt):
+ """Init the converter given the reference origin."""
+ self.lat = reflat
+ self.lon = reflon
+ self.alt = refalt
+
+ def to_topocentric(self, lat, lon, alt):
+ """Convert lat, lon, alt to topocentric x, y, z."""
+ return topocentric_from_lla(lat, lon, alt, self.lat, self.lon, self.alt)
+
+ def to_lla(self, x, y, z):
+ """Convert topocentric x, y, z to lat, lon, alt."""
+ return lla_from_topocentric(x, y, z, self.lat, self.lon, self.alt)
+
+ def __eq__(self, o):
+ return np.allclose([self.lat, self.lon, self.alt], (o.lat, o.lon, o.alt))
diff --git a/utils/geometry.py b/utils/geometry.py
new file mode 100644
index 0000000000000000000000000000000000000000..9bcbcba7c41e689e9dd9e35fe33e7787fdd13b03
--- /dev/null
+++ b/utils/geometry.py
@@ -0,0 +1,68 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+import numpy as np
+import torch
+
+
+def from_homogeneous(points, eps: float = 1e-8):
+ """Remove the homogeneous dimension of N-dimensional points.
+ Args:
+ points: torch.Tensor or numpy.ndarray with size (..., N+1).
+ Returns:
+ A torch.Tensor or numpy ndarray with size (..., N).
+ """
+ return points[..., :-1] / (points[..., -1:] + eps)
+
+
+def to_homogeneous(points):
+ """Convert N-dimensional points to homogeneous coordinates.
+ Args:
+ points: torch.Tensor or numpy.ndarray with size (..., N).
+ Returns:
+ A torch.Tensor or numpy.ndarray with size (..., N+1).
+ """
+ if isinstance(points, torch.Tensor):
+ pad = points.new_ones(points.shape[:-1] + (1,))
+ return torch.cat([points, pad], dim=-1)
+ elif isinstance(points, np.ndarray):
+ pad = np.ones((points.shape[:-1] + (1,)), dtype=points.dtype)
+ return np.concatenate([points, pad], axis=-1)
+ else:
+ raise ValueError
+
+
+@torch.jit.script
+def undistort_points(pts, dist):
+ dist = dist.unsqueeze(-2) # add point dimension
+ ndist = dist.shape[-1]
+ undist = pts
+ valid = torch.ones(pts.shape[:-1], device=pts.device, dtype=torch.bool)
+ if ndist > 0:
+ k1, k2 = dist[..., :2].split(1, -1)
+ r2 = torch.sum(pts**2, -1, keepdim=True)
+ radial = k1 * r2 + k2 * r2**2
+ undist = undist + pts * radial
+
+ # The distortion model is supposedly only valid within the image
+ # boundaries. Because of the negative radial distortion, points that
+ # are far outside of the boundaries might actually be mapped back
+ # within the image. To account for this, we discard points that are
+ # beyond the inflection point of the distortion model,
+ # e.g. such that d(r + k_1 r^3 + k2 r^5)/dr = 0
+ limited = ((k2 > 0) & ((9 * k1**2 - 20 * k2) > 0)) | ((k2 <= 0) & (k1 > 0))
+ limit = torch.abs(
+ torch.where(
+ k2 > 0,
+ (torch.sqrt(9 * k1**2 - 20 * k2) - 3 * k1) / (10 * k2),
+ 1 / (3 * k1),
+ )
+ )
+ valid = valid & torch.squeeze(~limited | (r2 < limit), -1)
+
+ if ndist > 2:
+ p12 = dist[..., 2:]
+ p21 = p12.flip(-1)
+ uv = torch.prod(pts, -1, keepdim=True)
+ undist = undist + 2 * p12 * uv + p21 * (r2 + 2 * pts**2)
+
+ return undist, valid
diff --git a/utils/io.py b/utils/io.py
new file mode 100644
index 0000000000000000000000000000000000000000..e092ef8695a41f2fae6da4e03d7cad74eab5cadf
--- /dev/null
+++ b/utils/io.py
@@ -0,0 +1,61 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+import json
+import requests
+import shutil
+from pathlib import Path
+
+import cv2
+import numpy as np
+import torch
+from tqdm.auto import tqdm
+
+import logger
+
+DATA_URL = "https://cvg-data.inf.ethz.ch/OrienterNet_CVPR2023"
+
+
+def read_image(path, grayscale=False):
+ if grayscale:
+ mode = cv2.IMREAD_GRAYSCALE
+ else:
+ mode = cv2.IMREAD_COLOR
+ image = cv2.imread(str(path), mode)
+ if image is None:
+ raise ValueError(f"Cannot read image {path}.")
+ if not grayscale and len(image.shape) == 3:
+ image = np.ascontiguousarray(image[:, :, ::-1]) # BGR to RGB
+ return image
+
+
+def write_torch_image(path, image):
+ image_cv2 = np.round(image.clip(0, 1) * 255).astype(int)[..., ::-1]
+ cv2.imwrite(str(path), image_cv2)
+
+
+class JSONEncoder(json.JSONEncoder):
+ def default(self, obj):
+ if isinstance(obj, (np.ndarray, torch.Tensor)):
+ return obj.tolist()
+ elif isinstance(obj, np.generic):
+ return obj.item()
+ return json.JSONEncoder.default(self, obj)
+
+
+def write_json(path, data):
+ with open(path, "w") as f:
+ json.dump(data, f, cls=JSONEncoder)
+
+
+def download_file(url, path):
+ path = Path(path)
+ if path.is_dir():
+ path = path / Path(url).name
+ path.parent.mkdir(exist_ok=True, parents=True)
+ logger.info("Downloading %s to %s.", url, path)
+ with requests.get(url, stream=True) as r:
+ total_length = int(r.headers.get("Content-Length"))
+ with tqdm.wrapattr(r.raw, "read", total=total_length, desc="") as raw:
+ with open(path, "wb") as output:
+ shutil.copyfileobj(raw, output)
+ return path
diff --git a/utils/tools.py b/utils/tools.py
new file mode 100644
index 0000000000000000000000000000000000000000..adbc3eae97c599d9ce606426ae5628484b9e2499
--- /dev/null
+++ b/utils/tools.py
@@ -0,0 +1,17 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+import time
+
+
+class Timer:
+ def __init__(self, name=None):
+ self.name = name
+
+ def __enter__(self):
+ self.tstart = time.time()
+ return self
+
+ def __exit__(self, type, value, traceback):
+ self.duration = time.time() - self.tstart
+ if self.name is not None:
+ print("[%s] Elapsed: %s" % (self.name, self.duration))
diff --git a/utils/viz_2d.py b/utils/viz_2d.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac93d002e2666559f3d77dd54c4f52d465e96869
--- /dev/null
+++ b/utils/viz_2d.py
@@ -0,0 +1,195 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+# Adapted from Hierarchical-Localization, Paul-Edouard Sarlin, ETH Zurich
+# https://github.com/cvg/Hierarchical-Localization/blob/master/hloc/utils/viz.py
+# Released under the Apache License 2.0
+
+import matplotlib
+import matplotlib.patheffects as path_effects
+import matplotlib.pyplot as plt
+import numpy as np
+
+
+def plot_images(imgs, titles=None, cmaps="gray", dpi=100, pad=0.5, adaptive=True):
+ """Plot a set of images horizontally.
+ Args:
+ imgs: a list of NumPy or PyTorch images, RGB (H, W, 3) or mono (H, W).
+ titles: a list of strings, as titles for each image.
+ cmaps: colormaps for monochrome images.
+ adaptive: whether the figure size should fit the image aspect ratios.
+ """
+ n = len(imgs)
+ if not isinstance(cmaps, (list, tuple)):
+ cmaps = [cmaps] * n
+
+ if adaptive:
+ ratios = [i.shape[1] / i.shape[0] for i in imgs] # W / H
+ else:
+ ratios = [4 / 3] * n
+ figsize = [sum(ratios) * 4.5, 4.5]
+ fig, ax = plt.subplots(
+ 1, n, figsize=figsize, dpi=dpi, gridspec_kw={"width_ratios": ratios}
+ )
+ if n == 1:
+ ax = [ax]
+ for i in range(n):
+ ax[i].imshow(imgs[i], cmap=plt.get_cmap(cmaps[i]))
+ ax[i].get_yaxis().set_ticks([])
+ ax[i].get_xaxis().set_ticks([])
+ ax[i].set_axis_off()
+ for spine in ax[i].spines.values(): # remove frame
+ spine.set_visible(False)
+ if titles:
+ ax[i].set_title(titles[i])
+ fig.tight_layout(pad=pad)
+ return fig
+
+
+def plot_keypoints(kpts, colors="lime", ps=4):
+ """Plot keypoints for existing images.
+ Args:
+ kpts: list of ndarrays of size (N, 2).
+ colors: string, or list of list of tuples (one for each keypoints).
+ ps: size of the keypoints as float.
+ """
+ if not isinstance(colors, list):
+ colors = [colors] * len(kpts)
+ axes = plt.gcf().axes
+ for a, k, c in zip(axes, kpts, colors):
+ a.scatter(k[:, 0], k[:, 1], c=c, s=ps, linewidths=0)
+
+
+def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, indices=(0, 1), a=1.0):
+ """Plot matches for a pair of existing images.
+ Args:
+ kpts0, kpts1: corresponding keypoints of size (N, 2).
+ color: color of each match, string or RGB tuple. Random if not given.
+ lw: width of the lines.
+ ps: size of the end points (no endpoint if ps=0)
+ indices: indices of the images to draw the matches on.
+ a: alpha opacity of the match lines.
+ """
+ fig = plt.gcf()
+ ax = fig.axes
+ assert len(ax) > max(indices)
+ ax0, ax1 = ax[indices[0]], ax[indices[1]]
+ fig.canvas.draw()
+
+ assert len(kpts0) == len(kpts1)
+ if color is None:
+ color = matplotlib.cm.hsv(np.random.rand(len(kpts0))).tolist()
+ elif len(color) > 0 and not isinstance(color[0], (tuple, list)):
+ color = [color] * len(kpts0)
+
+ if lw > 0:
+ # transform the points into the figure coordinate system
+ transFigure = fig.transFigure.inverted()
+ fkpts0 = transFigure.transform(ax0.transData.transform(kpts0))
+ fkpts1 = transFigure.transform(ax1.transData.transform(kpts1))
+ fig.lines += [
+ matplotlib.lines.Line2D(
+ (fkpts0[i, 0], fkpts1[i, 0]),
+ (fkpts0[i, 1], fkpts1[i, 1]),
+ zorder=1,
+ transform=fig.transFigure,
+ c=color[i],
+ linewidth=lw,
+ alpha=a,
+ )
+ for i in range(len(kpts0))
+ ]
+
+ # freeze the axes to prevent the transform to change
+ ax0.autoscale(enable=False)
+ ax1.autoscale(enable=False)
+
+ if ps > 0:
+ ax0.scatter(kpts0[:, 0], kpts0[:, 1], c=color, s=ps)
+ ax1.scatter(kpts1[:, 0], kpts1[:, 1], c=color, s=ps)
+
+
+def add_text(
+ idx,
+ text,
+ pos=(0.01, 0.99),
+ fs=15,
+ color="w",
+ lcolor="k",
+ lwidth=2,
+ ha="left",
+ va="top",
+ normalized=True,
+ zorder=3,
+):
+ ax = plt.gcf().axes[idx]
+ tfm = ax.transAxes if normalized else ax.transData
+ t = ax.text(
+ *pos,
+ text,
+ fontsize=fs,
+ ha=ha,
+ va=va,
+ color=color,
+ transform=tfm,
+ clip_on=True,
+ zorder=zorder,
+ )
+ if lcolor is not None:
+ t.set_path_effects(
+ [
+ path_effects.Stroke(linewidth=lwidth, foreground=lcolor),
+ path_effects.Normal(),
+ ]
+ )
+
+
+def save_plot(path, **kw):
+ """Save the current figure without any white margin."""
+ plt.savefig(path, bbox_inches="tight", pad_inches=0, **kw)
+
+
+def features_to_RGB(*Fs, masks=None, skip=1):
+ """Project a list of d-dimensional feature maps to RGB colors using PCA."""
+ from sklearn.decomposition import PCA
+
+ def normalize(x):
+ return x / np.linalg.norm(x, axis=-1, keepdims=True)
+
+ if masks is not None:
+ assert len(Fs) == len(masks)
+
+ flatten = []
+ for i, F in enumerate(Fs):
+ c, h, w = F.shape
+ F = np.rollaxis(F, 0, 3)
+ F_flat = F.reshape(-1, c)
+ if masks is not None and masks[i] is not None:
+ mask = masks[i]
+ assert mask.shape == F.shape[:2]
+ F_flat = F_flat[mask.reshape(-1)]
+ flatten.append(F_flat)
+ flatten = np.concatenate(flatten, axis=0)
+ flatten = normalize(flatten)
+
+ pca = PCA(n_components=3)
+ if skip > 1:
+ pca.fit(flatten[::skip])
+ flatten = pca.transform(flatten)
+ else:
+ flatten = pca.fit_transform(flatten)
+ flatten = (normalize(flatten) + 1) / 2
+
+ Fs_rgb = []
+ for i, F in enumerate(Fs):
+ h, w = F.shape[-2:]
+ if masks is None or masks[i] is None:
+ F_rgb, flatten = np.split(flatten, [h * w], axis=0)
+ F_rgb = F_rgb.reshape((h, w, 3))
+ else:
+ F_rgb = np.zeros((h, w, 3))
+ indices = np.where(masks[i])
+ F_rgb[indices], flatten = np.split(flatten, [len(indices[0])], axis=0)
+ F_rgb = np.concatenate([F_rgb, masks[i][..., None]], axis=-1)
+ Fs_rgb.append(F_rgb)
+ assert flatten.shape[0] == 0, flatten.shape
+ return Fs_rgb
diff --git a/utils/viz_localization.py b/utils/viz_localization.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e15da2329ae2956c4eb663dfcd3a0cf841f2b84
--- /dev/null
+++ b/utils/viz_localization.py
@@ -0,0 +1,160 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+import copy
+
+import matplotlib as mpl
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+
+
+def likelihood_overlay(
+ prob, map_viz=None, p_rgb=0.2, p_alpha=1 / 15, thresh=None, cmap="jet"
+):
+ prob = prob / prob.max()
+ cmap = plt.get_cmap(cmap)
+ rgb = cmap(prob**p_rgb)
+ alpha = prob[..., None] ** p_alpha
+ if thresh is not None:
+ alpha[prob <= thresh] = 0
+ if map_viz is not None:
+ faded = map_viz + (1 - map_viz) * 0.5
+ rgb = rgb[..., :3] * alpha + faded * (1 - alpha)
+ rgb = np.clip(rgb, 0, 1)
+ else:
+ rgb[..., -1] = alpha.squeeze(-1)
+ return rgb
+
+
+def heatmap2rgb(scores, mask=None, clip_min=0.05, alpha=0.8, cmap="jet"):
+ min_, max_ = np.quantile(scores, [clip_min, 1])
+ scores = scores.clip(min=min_)
+ rgb = plt.get_cmap(cmap)((scores - min_) / (max_ - min_))
+ if mask is not None:
+ if alpha == 0:
+ rgb[mask] = np.nan
+ else:
+ rgb[..., -1] = 1 - (1 - 1.0 * mask) * (1 - alpha)
+ return rgb
+
+
+def plot_pose(axs, xy, yaw=None, s=1 / 35, c="r", a=1, w=0.015, dot=True, zorder=10):
+ if yaw is not None:
+ yaw = np.deg2rad(yaw)
+ uv = np.array([np.sin(yaw), -np.cos(yaw)])
+ xy = np.array(xy) + 0.5
+ if not isinstance(axs, list):
+ axs = [axs]
+ for ax in axs:
+ if isinstance(ax, int):
+ ax = plt.gcf().axes[ax]
+ if dot:
+ ax.scatter(*xy, c=c, s=70, zorder=zorder, linewidths=0, alpha=a)
+ if yaw is not None:
+ ax.quiver(
+ *xy,
+ *uv,
+ scale=s,
+ scale_units="xy",
+ angles="xy",
+ color=c,
+ zorder=zorder,
+ alpha=a,
+ width=w,
+ )
+
+
+def plot_dense_rotations(
+ ax, prob, thresh=0.01, skip=10, s=1 / 15, k=3, c="k", w=None, **kwargs
+):
+ t = torch.argmax(prob, -1)
+ yaws = t.numpy() / prob.shape[-1] * 360
+ prob = prob.max(-1).values / prob.max()
+ mask = prob > thresh
+ masked = prob.masked_fill(~mask, 0)
+ max_ = torch.nn.functional.max_pool2d(
+ masked.float()[None, None], k, stride=1, padding=k // 2
+ )
+ mask = (max_[0, 0] == masked.float()) & mask
+ indices = np.where(mask.numpy() > 0)
+ plot_pose(
+ ax,
+ indices[::-1],
+ yaws[indices],
+ s=s,
+ c=c,
+ dot=False,
+ zorder=0.1,
+ w=w,
+ **kwargs,
+ )
+
+
+def copy_image(im, ax):
+ prop = im.properties()
+ prop.pop("children")
+ prop.pop("size")
+ prop.pop("tightbbox")
+ prop.pop("transformed_clip_path_and_affine")
+ prop.pop("window_extent")
+ prop.pop("figure")
+ prop.pop("transform")
+ return ax.imshow(im.get_array(), **prop)
+
+
+def add_circle_inset(
+ ax,
+ center,
+ corner=None,
+ radius_px=10,
+ inset_size=0.4,
+ inset_offset=0.005,
+ color="red",
+):
+ data_t_axes = ax.transAxes + ax.transData.inverted()
+ if corner is None:
+ center_axes = np.array(data_t_axes.inverted().transform(center))
+ corner = 1 - np.round(center_axes).astype(int)
+ corner = np.array(corner)
+ bottom_left = corner * (1 - inset_size - inset_offset) + (1 - corner) * inset_offset
+ axins = ax.inset_axes([*bottom_left, inset_size, inset_size])
+ if ax.yaxis_inverted():
+ axins.invert_yaxis()
+ axins.set_axis_off()
+
+ c = mpl.patches.Circle(center, radius_px, fill=False, color=color)
+ c1 = mpl.patches.Circle(center, radius_px, fill=False, color=color)
+ # ax.add_patch(c)
+ ax.add_patch(c1)
+ # ax.add_patch(c.frozen())
+ axins.add_patch(c)
+
+ radius_inset = radius_px + 1
+ axins.set_xlim([center[0] - radius_inset, center[0] + radius_inset])
+ ylim = center[1] - radius_inset, center[1] + radius_inset
+ if axins.yaxis_inverted():
+ ylim = ylim[::-1]
+ axins.set_ylim(ylim)
+
+ for im in ax.images:
+ im2 = copy_image(im, axins)
+ im2.set_clip_path(c)
+ return axins
+
+
+def plot_bev(bev, uv, yaw, ax=None, zorder=10, **kwargs):
+ if ax is None:
+ ax = plt.gca()
+ h, w = bev.shape[:2]
+ tfm = mpl.transforms.Affine2D().translate(-w / 2, -h)
+ tfm = tfm.rotate_deg(yaw).translate(*uv + 0.5)
+ tfm += plt.gca().transData
+ ax.imshow(bev, transform=tfm, zorder=zorder, **kwargs)
+ ax.plot(
+ [0, w - 1, w / 2, 0],
+ [0, 0, h - 0.5, 0],
+ transform=tfm,
+ c="k",
+ lw=1,
+ zorder=zorder + 1,
+ )
diff --git a/utils/wrappers.py b/utils/wrappers.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d8b35a35fdc83569e943735e7433ed72da1343e
--- /dev/null
+++ b/utils/wrappers.py
@@ -0,0 +1,342 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+
+# Adapted from PixLoc, Paul-Edouard Sarlin, ETH Zurich
+# https://github.com/cvg/pixloc
+# Released under the Apache License 2.0
+
+"""
+Convenience classes for an SE3 pose and a pinhole Camera with lens distortion.
+Based on PyTorch tensors: differentiable, batched, with GPU support.
+"""
+
+import functools
+import inspect
+import math
+from typing import Dict, List, NamedTuple, Tuple, Union
+
+import numpy as np
+import torch
+
+from .geometry import undistort_points
+
+
+def autocast(func):
+ """Cast the inputs of a TensorWrapper method to PyTorch tensors
+ if they are numpy arrays. Use the device and dtype of the wrapper.
+ """
+
+ @functools.wraps(func)
+ def wrap(self, *args):
+ device = torch.device("cpu")
+ dtype = None
+ if isinstance(self, TensorWrapper):
+ if self._data is not None:
+ device = self.device
+ dtype = self.dtype
+ elif not inspect.isclass(self) or not issubclass(self, TensorWrapper):
+ raise ValueError(self)
+
+ cast_args = []
+ for arg in args:
+ if isinstance(arg, np.ndarray):
+ arg = torch.from_numpy(arg)
+ arg = arg.to(device=device, dtype=dtype)
+ cast_args.append(arg)
+ return func(self, *cast_args)
+
+ return wrap
+
+
+class TensorWrapper:
+ _data = None
+
+ @autocast
+ def __init__(self, data: torch.Tensor):
+ self._data = data
+
+ @property
+ def shape(self):
+ return self._data.shape[:-1]
+
+ @property
+ def device(self):
+ return self._data.device
+
+ @property
+ def dtype(self):
+ return self._data.dtype
+
+ def __getitem__(self, index):
+ return self.__class__(self._data[index])
+
+ def __setitem__(self, index, item):
+ self._data[index] = item.data
+
+ def to(self, *args, **kwargs):
+ return self.__class__(self._data.to(*args, **kwargs))
+
+ def cpu(self):
+ return self.__class__(self._data.cpu())
+
+ def cuda(self):
+ return self.__class__(self._data.cuda())
+
+ def pin_memory(self):
+ return self.__class__(self._data.pin_memory())
+
+ def float(self):
+ return self.__class__(self._data.float())
+
+ def double(self):
+ return self.__class__(self._data.double())
+
+ def detach(self):
+ return self.__class__(self._data.detach())
+
+ @classmethod
+ def stack(cls, objects: List, dim=0, *, out=None):
+ data = torch.stack([obj._data for obj in objects], dim=dim, out=out)
+ return cls(data)
+
+ @classmethod
+ def __torch_function__(cls, func, types, args=(), kwargs=None):
+ if kwargs is None:
+ kwargs = {}
+ if func is torch.stack:
+ return cls.stack(*args, **kwargs)
+ else:
+ return NotImplemented
+
+
+class Pose(TensorWrapper):
+ def __init__(self, data: torch.Tensor):
+ assert data.shape[-1] == 12
+ super().__init__(data)
+
+ @classmethod
+ @autocast
+ def from_Rt(cls, R: torch.Tensor, t: torch.Tensor):
+ """Pose from a rotation matrix and translation vector.
+ Accepts numpy arrays or PyTorch tensors.
+
+ Args:
+ R: rotation matrix with shape (..., 3, 3).
+ t: translation vector with shape (..., 3).
+ """
+ assert R.shape[-2:] == (3, 3)
+ assert t.shape[-1] == 3
+ assert R.shape[:-2] == t.shape[:-1]
+ data = torch.cat([R.flatten(start_dim=-2), t], -1)
+ return cls(data)
+
+ @classmethod
+ def from_4x4mat(cls, T: torch.Tensor):
+ """Pose from an SE(3) transformation matrix.
+ Args:
+ T: transformation matrix with shape (..., 4, 4).
+ """
+ assert T.shape[-2:] == (4, 4)
+ R, t = T[..., :3, :3], T[..., :3, 3]
+ return cls.from_Rt(R, t)
+
+ @classmethod
+ def from_colmap(cls, image: NamedTuple):
+ """Pose from a COLMAP Image."""
+ return cls.from_Rt(image.qvec2rotmat(), image.tvec)
+
+ @property
+ def R(self) -> torch.Tensor:
+ """Underlying rotation matrix with shape (..., 3, 3)."""
+ rvec = self._data[..., :9]
+ return rvec.reshape(rvec.shape[:-1] + (3, 3))
+
+ @property
+ def t(self) -> torch.Tensor:
+ """Underlying translation vector with shape (..., 3)."""
+ return self._data[..., -3:]
+
+ def inv(self) -> "Pose":
+ """Invert an SE(3) pose."""
+ R = self.R.transpose(-1, -2)
+ t = -(R @ self.t.unsqueeze(-1)).squeeze(-1)
+ return self.__class__.from_Rt(R, t)
+
+ def compose(self, other: "Pose") -> "Pose":
+ """Chain two SE(3) poses: T_B2C.compose(T_A2B) -> T_A2C."""
+ R = self.R @ other.R
+ t = self.t + (self.R @ other.t.unsqueeze(-1)).squeeze(-1)
+ return self.__class__.from_Rt(R, t)
+
+ @autocast
+ def transform(self, p3d: torch.Tensor) -> torch.Tensor:
+ """Transform a set of 3D points.
+ Args:
+ p3d: 3D points, numpy array or PyTorch tensor with shape (..., 3).
+ """
+ assert p3d.shape[-1] == 3
+ # assert p3d.shape[:-2] == self.shape # allow broadcasting
+ return p3d @ self.R.transpose(-1, -2) + self.t.unsqueeze(-2)
+
+ def __matmul__(
+ self, other: Union["Pose", torch.Tensor]
+ ) -> Union["Pose", torch.Tensor]:
+ """Transform a set of 3D points: T_A2B * p3D_A -> p3D_B.
+ or chain two SE(3) poses: T_B2C @ T_A2B -> T_A2C."""
+ if isinstance(other, self.__class__):
+ return self.compose(other)
+ else:
+ return self.transform(other)
+
+ def numpy(self) -> Tuple[np.ndarray]:
+ return self.R.numpy(), self.t.numpy()
+
+ def magnitude(self) -> Tuple[torch.Tensor]:
+ """Magnitude of the SE(3) transformation.
+ Returns:
+ dr: rotation anngle in degrees.
+ dt: translation distance in meters.
+ """
+ trace = torch.diagonal(self.R, dim1=-1, dim2=-2).sum(-1)
+ cos = torch.clamp((trace - 1) / 2, -1, 1)
+ dr = torch.acos(cos).abs() / math.pi * 180
+ dt = torch.norm(self.t, dim=-1)
+ return dr, dt
+
+ def __repr__(self):
+ return f"Pose: {self.shape} {self.dtype} {self.device}"
+
+
+class Camera(TensorWrapper):
+ eps = 1e-4
+
+ def __init__(self, data: torch.Tensor):
+ assert data.shape[-1] in {6, 8, 10}
+ super().__init__(data)
+
+ @classmethod
+ def from_dict(cls, camera: Union[Dict, NamedTuple]):
+ """Camera from a COLMAP Camera tuple or dictionary.
+ We assume that the origin (0, 0) is the center of the top-left pixel.
+ This is different from COLMAP.
+ """
+ if isinstance(camera, tuple):
+ camera = camera._asdict()
+
+ model = camera["model"]
+ params = camera["params"]
+
+ if model in ["OPENCV", "PINHOLE"]:
+ (fx, fy, cx, cy), params = np.split(params, [4])
+ elif model in ["SIMPLE_PINHOLE", "SIMPLE_RADIAL", "RADIAL"]:
+ (f, cx, cy), params = np.split(params, [3])
+ fx = fy = f
+ if model == "SIMPLE_RADIAL":
+ params = np.r_[params, 0.0]
+ else:
+ raise NotImplementedError(model)
+
+ data = np.r_[
+ camera["width"], camera["height"], fx, fy, cx - 0.5, cy - 0.5, params
+ ]
+ return cls(data)
+
+ @property
+ def size(self) -> torch.Tensor:
+ """Size (width height) of the images, with shape (..., 2)."""
+ return self._data[..., :2]
+
+ @property
+ def f(self) -> torch.Tensor:
+ """Focal lengths (fx, fy) with shape (..., 2)."""
+ return self._data[..., 2:4]
+
+ @property
+ def c(self) -> torch.Tensor:
+ """Principal points (cx, cy) with shape (..., 2)."""
+ return self._data[..., 4:6]
+
+ @property
+ def dist(self) -> torch.Tensor:
+ """Distortion parameters, with shape (..., {0, 2, 4})."""
+ return self._data[..., 6:]
+
+ def scale(self, scales: Union[float, int, Tuple[Union[float, int]]]):
+ """Update the camera parameters after resizing an image."""
+ if isinstance(scales, (int, float)):
+ scales = (scales, scales)
+ s = self._data.new_tensor(scales)
+ data = torch.cat(
+ [self.size * s, self.f * s, (self.c + 0.5) * s - 0.5, self.dist], -1
+ )
+ return self.__class__(data)
+
+ def crop(self, left_top: Tuple[float], size: Tuple[int]):
+ """Update the camera parameters after cropping an image."""
+ left_top = self._data.new_tensor(left_top)
+ size = self._data.new_tensor(size)
+ data = torch.cat([size, self.f, self.c - left_top, self.dist], -1)
+ return self.__class__(data)
+
+ @autocast
+ def in_image(self, p2d: torch.Tensor):
+ """Check if 2D points are within the image boundaries."""
+ assert p2d.shape[-1] == 2
+ # assert p2d.shape[:-2] == self.shape # allow broadcasting
+ size = self.size.unsqueeze(-2)
+ valid = torch.all((p2d >= 0) & (p2d <= (size - 1)), -1)
+ return valid
+
+ @autocast
+ def project(self, p3d: torch.Tensor) -> Tuple[torch.Tensor]:
+ """Project 3D points into the camera plane and check for visibility."""
+ z = p3d[..., -1]
+ valid = z > self.eps
+ z = z.clamp(min=self.eps)
+ p2d = p3d[..., :-1] / z.unsqueeze(-1)
+ return p2d, valid
+
+ def J_project(self, p3d: torch.Tensor):
+ x, y, z = p3d[..., 0], p3d[..., 1], p3d[..., 2]
+ zero = torch.zeros_like(z)
+ J = torch.stack([1 / z, zero, -x / z**2, zero, 1 / z, -y / z**2], dim=-1)
+ J = J.reshape(p3d.shape[:-1] + (2, 3))
+ return J # N x 2 x 3
+
+ @autocast
+ def undistort(self, pts: torch.Tensor) -> Tuple[torch.Tensor]:
+ """Undistort normalized 2D coordinates
+ and check for validity of the distortion model.
+ """
+ assert pts.shape[-1] == 2
+ # assert pts.shape[:-2] == self.shape # allow broadcasting
+ return undistort_points(pts, self.dist)
+
+ @autocast
+ def denormalize(self, p2d: torch.Tensor) -> torch.Tensor:
+ """Convert normalized 2D coordinates into pixel coordinates."""
+ return p2d * self.f.unsqueeze(-2) + self.c.unsqueeze(-2)
+
+ @autocast
+ def normalize(self, p2d: torch.Tensor) -> torch.Tensor:
+ """Convert pixel coordinates into normalized 2D coordinates."""
+ return (p2d - self.c.unsqueeze(-2)) / self.f.unsqueeze(-2)
+
+ def J_denormalize(self):
+ return torch.diag_embed(self.f).unsqueeze(-3) # 1 x 2 x 2
+
+ @autocast
+ def world2image(self, p3d: torch.Tensor) -> Tuple[torch.Tensor]:
+ """Transform 3D points into 2D pixel coordinates."""
+ p2d, visible = self.project(p3d)
+ p2d, mask = self.undistort(p2d)
+ p2d = self.denormalize(p2d)
+ valid = visible & mask & self.in_image(p2d)
+ return p2d, valid
+
+ def J_world2image(self, p3d: torch.Tensor):
+ p2d_dist, valid = self.project(p3d)
+ J = self.J_denormalize() @ self.J_undistort(p2d_dist) @ self.J_project(p3d)
+ return J, valid
+
+ def __repr__(self):
+ return f"Camera {self.shape} {self.dtype} {self.device}"