File size: 2,417 Bytes
ad664d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# By Forge


import torch


def native_unpack_4x4bits_in_1x16bits_to_4x8bits_in_1x32bits(x):
    x = x.view(torch.uint8).view(x.size(0), -1)
    unpacked = torch.stack([x & 15, x >> 4], dim=-1)
    reshaped = unpacked.view(x.size(0), -1)
    reshaped = reshaped.view(torch.int8) - 8
    return reshaped.view(torch.int32)


def native_unpack_4x4bits_in_1x16bits_to_4x8bits_in_1x32bits_u(x):
    x = x.view(torch.uint8).view(x.size(0), -1)
    unpacked = torch.stack([x & 15, x >> 4], dim=-1)
    reshaped = unpacked.view(x.size(0), -1)
    return reshaped.view(torch.int32)


disable_all_optimizations = False

if not hasattr(torch, 'uint16'):
    disable_all_optimizations = True

if disable_all_optimizations:
    print('You are using PyTorch below version 2.3. Some optimizations will be disabled.')

if not disable_all_optimizations:
    native_4bits_lookup_table = native_unpack_4x4bits_in_1x16bits_to_4x8bits_in_1x32bits(torch.arange(start=0, end=256*256, dtype=torch.long).to(torch.uint16))[:, 0]
    native_4bits_lookup_table_u = native_unpack_4x4bits_in_1x16bits_to_4x8bits_in_1x32bits_u(torch.arange(start=0, end=256*256, dtype=torch.long).to(torch.uint16))[:, 0]


def quick_unpack_4bits(x):
    if disable_all_optimizations:
        return torch.stack([x & 15, x >> 4], dim=-1).view(x.size(0), -1).view(torch.int8) - 8

    global native_4bits_lookup_table

    s0 = x.size(0)
    x = x.view(torch.uint16)

    if native_4bits_lookup_table.device != x.device:
        native_4bits_lookup_table = native_4bits_lookup_table.to(device=x.device)

    y = torch.index_select(input=native_4bits_lookup_table, dim=0, index=x.to(dtype=torch.int32).flatten())
    y = y.view(torch.int8)
    y = y.view(s0, -1)

    return y


def quick_unpack_4bits_u(x):
    if disable_all_optimizations:
        return torch.stack([x & 15, x >> 4], dim=-1).view(x.size(0), -1)

    global native_4bits_lookup_table_u

    s0 = x.size(0)
    x = x.view(torch.uint16)

    if native_4bits_lookup_table_u.device != x.device:
        native_4bits_lookup_table_u = native_4bits_lookup_table_u.to(device=x.device)

    y = torch.index_select(input=native_4bits_lookup_table_u, dim=0, index=x.to(dtype=torch.int32).flatten())
    y = y.view(torch.uint8)
    y = y.view(s0, -1)

    return y


def change_4bits_order(x):
    y = torch.stack([x & 15, x >> 4], dim=-2).view(x.size(0), -1)
    z = y[:, ::2] | (y[:, 1::2] << 4)
    return z