henry000 commited on
Commit
7a28749
Β·
1 Parent(s): 200b5c1

πŸ“ [Update] README and the TRT ipynb tutorial

Browse files
Files changed (2) hide show
  1. README.md +1 -1
  2. examples/notebook_TensorRT.ipynb +130 -0
README.md CHANGED
@@ -47,7 +47,7 @@ pip install -r requirements.txt
47
  | ------------------ | :---------: | :-------: | :-------: |
48
  | PyTorch | v1.12 | v2.3+ | v1.12 |
49
  | ONNX | βœ… | βœ… | - |
50
- | TensorRT | πŸ§ͺ | πŸ§ͺ | - |
51
  | OpenVINO | - | πŸ§ͺ | ❔ |
52
 
53
  </td></tr> </table>
 
47
  | ------------------ | :---------: | :-------: | :-------: |
48
  | PyTorch | v1.12 | v2.3+ | v1.12 |
49
  | ONNX | βœ… | βœ… | - |
50
+ | TensorRT | βœ… | - | - |
51
  | OpenVINO | - | πŸ§ͺ | ❔ |
52
 
53
  </td></tr> </table>
examples/notebook_TensorRT.ipynb ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import os\n",
10
+ "import sys\n",
11
+ "from pathlib import Path\n",
12
+ "\n",
13
+ "import torch\n",
14
+ "from PIL import Image \n",
15
+ "from loguru import logger\n",
16
+ "from omegaconf import OmegaConf\n",
17
+ "\n",
18
+ "project_root = Path().resolve().parent\n",
19
+ "sys.path.append(str(project_root))\n",
20
+ "\n",
21
+ "from yolo import AugmentationComposer, bbox_nms, create_model, custom_logger, draw_bboxes, Vec2Box\n",
22
+ "from yolo.config.config import NMSConfig"
23
+ ]
24
+ },
25
+ {
26
+ "cell_type": "code",
27
+ "execution_count": null,
28
+ "metadata": {},
29
+ "outputs": [],
30
+ "source": [
31
+ "MODEL = \"v9-c\"\n",
32
+ "DEVICE = \"cuda:0\"\n",
33
+ "\n",
34
+ "WEIGHT_PATH = f\"../weights/{MODEL}.pt\" \n",
35
+ "TRT_WEIGHT_PATH = f\"../weights/{MODEL}.trt\"\n",
36
+ "MODEL_CONFIG = f\"../yolo/config/model/{MODEL}.yaml\"\n",
37
+ "\n",
38
+ "IMAGE_PATH = \"../demo/images/inference/image.png\"\n",
39
+ "IMAGE_SIZE = (640, 640)\n",
40
+ "\n",
41
+ "custom_logger()\n",
42
+ "device = torch.device(DEVICE)\n",
43
+ "image = Image.open(IMAGE_PATH)"
44
+ ]
45
+ },
46
+ {
47
+ "cell_type": "code",
48
+ "execution_count": null,
49
+ "metadata": {},
50
+ "outputs": [],
51
+ "source": [
52
+ "if os.path.exists(TRT_WEIGHT_PATH):\n",
53
+ " from torch2trt import TRTModule\n",
54
+ "\n",
55
+ " model_trt = TRTModule()\n",
56
+ " model_trt.load_state_dict(torch.load(TRT_WEIGHT_PATH))\n",
57
+ "else:\n",
58
+ " from torch2trt import torch2trt\n",
59
+ "\n",
60
+ " with open(MODEL_CONFIG) as stream:\n",
61
+ " cfg_model = OmegaConf.load(stream)\n",
62
+ "\n",
63
+ " model = create_model(cfg_model, weight_path=WEIGHT_PATH)\n",
64
+ " model = model.to(device).eval()\n",
65
+ "\n",
66
+ " dummy_input = torch.ones((1, 3, 640, 640)).to(device)\n",
67
+ " logger.info(f\"♻️ Creating TensorRT model\")\n",
68
+ " model_trt = torch2trt(model, [dummy_input])\n",
69
+ " torch.save(model_trt.state_dict(), TRT_WEIGHT_PATH)\n",
70
+ " logger.info(f\"πŸ“₯ TensorRT model saved to oonx.pt\")\n",
71
+ "\n",
72
+ "transform = AugmentationComposer([], IMAGE_SIZE)\n",
73
+ "vec2box = Vec2Box(model_trt, IMAGE_SIZE, device)\n"
74
+ ]
75
+ },
76
+ {
77
+ "cell_type": "code",
78
+ "execution_count": null,
79
+ "metadata": {},
80
+ "outputs": [],
81
+ "source": [
82
+ "image, bbox = transform(image, torch.zeros(0, 5))\n",
83
+ "image = image.to(device)[None]"
84
+ ]
85
+ },
86
+ {
87
+ "cell_type": "code",
88
+ "execution_count": null,
89
+ "metadata": {},
90
+ "outputs": [],
91
+ "source": [
92
+ "with torch.no_grad():\n",
93
+ " predict = model_trt(image)\n",
94
+ " predict = vec2box(predict[\"Main\"])\n",
95
+ "predict_box = bbox_nms(predict[0], predict[2], NMSConfig(0.5, 0.5))\n",
96
+ "draw_bboxes(image, predict_box)"
97
+ ]
98
+ },
99
+ {
100
+ "cell_type": "markdown",
101
+ "metadata": {},
102
+ "source": [
103
+ "Sample Output:\n",
104
+ "\n",
105
+ "![image](../demo/images/output/visualize.png)"
106
+ ]
107
+ }
108
+ ],
109
+ "metadata": {
110
+ "kernelspec": {
111
+ "display_name": "yolomit",
112
+ "language": "python",
113
+ "name": "python3"
114
+ },
115
+ "language_info": {
116
+ "codemirror_mode": {
117
+ "name": "ipython",
118
+ "version": 3
119
+ },
120
+ "file_extension": ".py",
121
+ "mimetype": "text/x-python",
122
+ "name": "python",
123
+ "nbconvert_exporter": "python",
124
+ "pygments_lexer": "ipython3",
125
+ "version": "3.1.undefined"
126
+ }
127
+ },
128
+ "nbformat": 4,
129
+ "nbformat_minor": 2
130
+ }