Spaces:
Sleeping
Sleeping
import time | |
import torch | |
from typing import Tuple | |
from hpc_rll.origin.scatter_connection import ScatterConnection | |
from hpc_rll.torch_utils.network.scatter_connection import ScatterConnection as HPCScatterConnection | |
from testbase import mean_relative_error, times | |
assert torch.cuda.is_available() | |
use_cuda = True | |
B = 256 | |
M = 256 | |
N = 256 | |
H = 16 | |
W = 16 | |
# Note: origin gpu version of cover mode is not determinate, thus validation test use origin cpu version instead | |
def scatter_val(): | |
for scatter_type in ['add', 'cover']: | |
ori_input = torch.randn(B, M, N) | |
h = torch.randint( | |
low=0, high=H, size=( | |
B, | |
M, | |
) | |
).unsqueeze(dim=2) | |
w = torch.randint( | |
low=0, high=W, size=( | |
B, | |
M, | |
) | |
).unsqueeze(dim=2) | |
ori_location = torch.cat([h, w], dim=2) | |
ori_scatter = ScatterConnection(scatter_type) | |
hpc_input = ori_input.clone().detach() | |
hpc_location = ori_location.clone().detach() | |
hpc_scatter = HPCScatterConnection(B, M, N, H, W, scatter_type) | |
if use_cuda: | |
#ori_input = ori_input.cuda() | |
#ori_location = ori_location.cuda() | |
#ori_scatter = ori_scatter.cuda() | |
hpc_input = hpc_input.cuda() | |
hpc_location = hpc_location.cuda() | |
hpc_scatter = hpc_scatter.cuda() | |
ori_input.requires_grad_(True) | |
ori_output = ori_scatter(ori_input, (H, W), ori_location) | |
ori_loss = ori_output * ori_output | |
ori_loss = ori_loss.mean() | |
ori_loss.backward() | |
if use_cuda: | |
torch.cuda.synchronize() | |
hpc_input.requires_grad_(True) | |
hpc_output = hpc_scatter(hpc_input, hpc_location) | |
hpc_loss = hpc_output * hpc_output | |
hpc_loss = hpc_loss.mean() | |
hpc_loss.backward() | |
if use_cuda: | |
torch.cuda.synchronize() | |
mre = mean_relative_error( | |
torch.flatten(ori_loss).cpu().detach().numpy(), | |
torch.flatten(hpc_loss).cpu().detach().numpy() | |
) | |
print("scatter type {} fp mean_relative_error: {}".format(scatter_type, str(mre))) | |
mre = mean_relative_error( | |
torch.flatten(ori_input.grad).cpu().detach().numpy(), | |
torch.flatten(hpc_input.grad).cpu().detach().numpy() | |
) | |
print("scatter type {} bp mean_relative_error: {}".format(scatter_type, str(mre))) | |
# Note: performance test use origin gpu version | |
def scatter_perf(): | |
for scatter_type in ['add', 'cover']: | |
ori_input = torch.randn(B, M, N) | |
h = torch.randint( | |
low=0, high=H, size=( | |
B, | |
M, | |
) | |
).unsqueeze(dim=2) | |
w = torch.randint( | |
low=0, high=W, size=( | |
B, | |
M, | |
) | |
).unsqueeze(dim=2) | |
ori_location = torch.cat([h, w], dim=2) | |
ori_scatter = ScatterConnection(scatter_type) | |
hpc_input = ori_input.clone().detach() | |
hpc_location = ori_location.clone().detach() | |
hpc_scatter = HPCScatterConnection(B, M, N, H, W, scatter_type) | |
if use_cuda: | |
ori_input = ori_input.cuda() | |
ori_location = ori_location.cuda() | |
ori_scatter = ori_scatter.cuda() | |
hpc_input = hpc_input.cuda() | |
hpc_location = hpc_location.cuda() | |
hpc_scatter = hpc_scatter.cuda() | |
for i in range(times): | |
t = time.time() | |
ori_input.requires_grad_(True) | |
ori_output = ori_scatter(ori_input, (H, W), ori_location) | |
ori_loss = ori_output * ori_output | |
ori_loss = ori_loss.mean() | |
ori_loss.backward() | |
if use_cuda: | |
torch.cuda.synchronize() | |
print('epoch: {}, original scatter type {} cost time: {}'.format(i, scatter_type, time.time() - t)) | |
for i in range(times): | |
t = time.time() | |
hpc_input.requires_grad_(True) | |
hpc_output = hpc_scatter(hpc_input, hpc_location) | |
hpc_loss = hpc_output * hpc_output | |
hpc_loss = hpc_loss.mean() | |
hpc_loss.backward() | |
if use_cuda: | |
torch.cuda.synchronize() | |
print('epoch: {}, hpc scatter type {} cost time: {}'.format(i, scatter_type, time.time() - t)) | |
if __name__ == '__main__': | |
print("target problem: B = {}, M = {}, N = {}, H = {}, W = {}".format(B, M, N, H, W)) | |
print("================run scatter validation test================") | |
scatter_val() | |
print("================run scatter performance test================") | |
scatter_perf() | |