File size: 795 Bytes
abd2a81
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import numpy as np


def get_seg_display(seg):
    seg_display = np.zeros([seg.shape[0], seg.shape[1], 4], dtype=np.float)
    if len(seg.shape) == 2:
        seg_display[..., 0] = seg
        seg_display[..., 3] = seg
    else:
        for i in range(seg.shape[-1]):
            seg_display[..., i] = seg[..., i]
        seg_display[..., 3] = np.clip(np.sum(seg, axis=-1), 0, 1)
    return seg_display


def batch_to_cuda(batch):
    # Send data to computing device:
    for key, item in batch.items():
        if hasattr(item, "cuda"):
            batch[key] = item.cuda(non_blocking=True)
    return batch


def batch_to_cpu(batch):
    # Send data to computing device:
    for key, item in batch.items():
        if hasattr(item, "cpu"):
            batch[key] = item.cpu()
    return batch