rnwang's picture
infer demo
c8bce00
raw
history blame
5.53 kB
import gradio as gr
import numpy as np
import time
from data import write_image_tensor, PatchDataModule, prepare_data, image2tensor, tensor2image
import torch
from tqdm import tqdm
from bigdl.nano.pytorch.trainer import Trainer
from torch.utils.data import DataLoader
from pathlib import Path
from torch.utils.data import Dataset
import datetime
device = 'cpu'
dtype = torch.float32
generator = torch.load("models/generator.pt")
generator.eval()
generator.to(device, dtype)
params = {'batch_size': 1,
'num_workers': 0}
class ImageDataset(Dataset):
def __init__(self, img):
self.imgs = [image2tensor(img)]
def __getitem__(self, idx: int) -> dict:
return self.imgs[idx]
def __len__(self) -> int:
return len(self.imgs)
# quantize model
data_path = Path('data/webcam')
train_image_dd = prepare_data(data_path)
dm = PatchDataModule(train_image_dd, patch_size=2**6,
batch_size=2**3, patch_num=2**6)
train_loader = dm.train_dataloader()
train_loader_iter = iter(train_loader)
quantized_model = Trainer.quantize(generator, accelerator=None,
calib_dataloader=train_loader)
def original_transfer(input_img):
w, h, _ = input_img.shape
print(datetime.datetime.now())
print("input size: ", w, h)
# resize too large image
if w > 3000 or h > 3000:
ratio = min(3000 / w, 3000 / h)
w = int(w * ratio)
h = int(h * ratio)
if w % 4 != 0 or h % 4 != 0:
NW = int((w // 4) * 4)
NH = int((h // 4) * 4)
input_img = np.resize(input_img,(NW,NH,3))
st = time.perf_counter()
dataset = ImageDataset(input_img)
loader = DataLoader(dataset, **params)
with torch.no_grad():
for inputs in tqdm(loader):
inputs = inputs.to(device, dtype)
st = time.perf_counter()
outputs = generator(inputs)
ori_time = time.perf_counter() - st
ori_time = "{:.3f}s".format(ori_time)
ori_image = np.array(tensor2image(outputs[0]))
del inputs
del outputs
return ori_image, ori_time
def nano_transfer(input_img):
w, h, _ = input_img.shape
print(datetime.datetime.now())
print("input size: ", w, h)
# resize too large image
if w > 3000 or h > 3000:
ratio = min(3000 / w, 3000 / h)
w = int(w * ratio)
h = int(h * ratio)
if w % 4 != 0 or h % 4 != 0:
NW = int((w // 4) * 4)
NH = int((h // 4) * 4)
input_img = np.resize(input_img,(NW,NH,3))
st = time.perf_counter()
dataset = ImageDataset(input_img)
loader = DataLoader(dataset, **params)
with torch.no_grad():
for inputs in tqdm(loader):
inputs = inputs.to(device, dtype)
st = time.perf_counter()
outputs = quantized_model(inputs)
nano_time = time.perf_counter() - st
nano_time = "{:.3f}s".format(nano_time)
nano_image = np.array(tensor2image(outputs[0]))
del inputs
del outputs
return nano_image, nano_time
def clear():
return None, None, None, None
demo = gr.Blocks()
with demo:
gr.Markdown("<h1><center>BigDL-Nano inference demo</center></h1>")
with gr.Row().style(equal_height=False):
with gr.Column():
gr.Markdown('''
<h2>Overview</h2>
BigDL-Nano is a library in [BigDL 2.0](https://github.com/intel-analytics/BigDL) that allows the users to transparently accelerate their deep learning pipelines (including data processing, training and inference) by automatically integrating optimized libraries, best-known configurations, and software optimizations. </p>
The video on the right shows how the user can easily enable quantization using BigDL-Nano (with just a couple of lines of code); you may refer to our [CVPR 2022 demo paper](https://arxiv.org/abs/2204.01715) for more details.
''')
with gr.Column():
gr.Video(value="nano_quantize_api.mp4")
gr.Markdown('''
<h2>Demo</h2>
This section uses an image stylization example to demostrate the speedup of the above code when using quantization in BigDL-Nano (about 2~3x inference time speedup). The demo is adapted from the original [FSPBT-Image-Translation code](https://github.com/rnwzd/FSPBT-Image-Translation/blob/master/eval.py).
''')
with gr.Row().style(equal_height=False):
input_img = gr.Image(label="input image", value="Marvelous_Maisel.jpg", source="upload")
with gr.Column():
ori_but = gr.Button("Standard PyTorch Lightning")
nano_but = gr.Button("BigDL-Nano")
clear_but = gr.Button("Clear Output")
with gr.Row().style(equal_height=False):
with gr.Column():
ori_time = gr.Text(label="Standard PyTorch Lightning latency")
ori_image = gr.Image(label="Standard PyTorch Lightning output image")
with gr.Column():
nano_time = gr.Text(label="BigDL-Nano latency")
nano_image = gr.Image(label="BigDL-Nano output image")
ori_but.click(original_transfer, inputs=input_img, outputs=[ori_image, ori_time])
nano_but.click(nano_transfer, inputs=input_img, outputs=[nano_image, nano_time])
clear_but.click(clear, inputs=None, outputs=[ori_image, ori_time, nano_image, nano_time])
demo.launch(share=True, enable_queue=True)