Spaces:
Runtime error
Runtime error
File size: 12,388 Bytes
e202b16 |
|
#################################################################################################
#
# Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Definition of CuTe Layouts and functions to manipulate them
"""
from itertools import chain
from typing import Union
from .int_tuple import *
class LayoutBase:
pass
def is_layout(x):
return isinstance(x, LayoutBase)
class Layout(LayoutBase):
def __init__(self, _shape, _stride=None):
self.shape = _shape
if _stride is None:
self.stride = prefix_product(self.shape)
else:
self.stride = _stride
# operator ==
def __eq__(self, other):
return self.shape == other.shape and self.stride == other.stride
# operator len(L) (len [rank] like tuples)
def __len__(self):
if is_tuple(self.shape):
return len(self.shape)
else:
return 1
# operator () (map coord to idx)
def __call__(self, *args):
"""
Map a logical coordinate to a linear index (Coord has no Underscore slice operators)
OR
Slice the layout and return the sublayout (Coord has an Underscore slice op)
Follow the same behavior of `Layout::operator(Coord const&)` in cute C++
"""
if has_none(args):
if len(args) == 1:
return Layout(slice_(args[0], self.shape), slice_(args[0], self.stride))
else:
return Layout(slice_(args, self.shape), slice_(args, self.stride))
else:
if len(args) == 1:
return crd2idx(args[0], self.shape, self.stride)
else:
return crd2idx(args, self.shape, self.stride)
# operator [] (get-i like tuples)
def __getitem__(self, i):
if is_tuple(self.shape):
return Layout(self.shape[i], self.stride[i])
else:
assert i == 0
return Layout(self.shape, self.stride)
# size(layout) Size of the domain
def size(self):
return product(self.shape)
# cosize(layout) Size of the codomain
def cosize(self):
return self(self.size() - 1) + 1
# print and str
def __str__(self):
return f"{self.shape}:{self.stride}"
# error msgs and representation
def __repr__(self):
return f"Layout({self.shape},{self.stride})"
# Make Layout from a list of layouts (each layout it's own mode in the result)
def make_layout(*layouts):
if len(layouts) == 1 and not is_layout(layouts[0]):
layouts = layouts[0]
shape, stride = zip(*((a.shape,a.stride) for a in layouts))
return Layout(shape, stride)
# Size of the domain
def size(layout):
if is_layout(layout):
return layout.size()
return product(layout)
# Size of the codomain
def cosize(layout):
return layout.cosize()
# Layout coalesce -- flatten and combine as many modes as possible while preserving the int-to-int function
def coalesce(layout, profile=None):
if is_tuple(profile):
assert len(layout) >= len(profile)
return make_layout(chain((coalesce(layout[i], profile[i]) for i in range( 0,len(profile))),
(layout[i] for i in range(len(profile),len(layout)))))
result_shape = [1]
result_stride = [0]
for (shape,stride) in zip(flatten(layout.shape),flatten(layout.stride)):
# skip their shape-1s
if shape == 1:
continue
# replace our shape-1 with anything
elif result_shape[-1] == 1:
result_shape[-1] = shape
result_stride[-1] = stride
# merge modes if the shape*stride match
elif result_shape[-1] * result_stride[-1] == stride:
result_shape[-1] = result_shape[-1] * shape
# append a new mode
else:
result_shape.append(shape)
result_stride.append(stride)
if len(result_shape) == 1:
return Layout(result_shape[0], result_stride[0])
else:
return Layout(tuple(result_shape), tuple(result_stride))
# Layout filter -- replace all stride-0 modes with size-1 and then coalesce to remove them
def filter(layout, profile=None):
if is_tuple(profile):
assert len(layout) >= len(profile)
return make_layout(chain((filter(layout[i], profile[i]) for i in range( 0,len(profile))),
(layout[i] for i in range(len(profile),len(layout)))))
result_shape = []
result_stride = []
for (shape,stride) in zip(flatten(layout.shape),flatten(layout.stride)):
# skip their shape-1s and stride-0s
if not (shape == 1 or stride == 0):
result_shape.append(shape)
result_stride.append(stride)
if len(result_shape) == 0:
return Layout(1,0)
else:
return coalesce(Layout(tuple(result_shape), tuple(result_stride)))
# Layout composition
# Use tuples-of-layouts to perform this operation by-mode and None as no-op
def composition(layoutA, layoutB):
if layoutB is None:
return layoutA
elif is_int(layoutB):
return composition(layoutA, Layout(layoutB))
elif is_tuple(layoutB):
assert len(layoutA) >= len(layoutB)
return make_layout(chain((composition(layoutA[i], layoutB[i]) for i in range( 0,len(layoutB))),
(layoutA[i] for i in range(len(layoutB),len(layoutA)))))
elif is_tuple(layoutB.shape):
return make_layout(composition(layoutA, layoutB_i) for layoutB_i in layoutB)
if layoutB.stride == 0:
return Layout(layoutB.shape, 0)
else:
result_shape = []
result_stride = []
rest_shape = layoutB.shape
rest_stride = layoutB.stride
for (s, d) in zip(flatten(layoutA.shape)[:-1], flatten(layoutA.stride)[:-1]):
s1 = shape_div(s, rest_stride)
result_shape.append(min(s1,rest_shape))
result_stride.append(rest_stride * d)
rest_shape = shape_div(rest_shape, abs(s1))
rest_stride = shape_div(rest_stride, s)
result_shape.append(rest_shape)
result_stride.append(rest_stride * flatten(layoutA.stride)[-1])
return coalesce(Layout(tuple(result_shape), tuple(result_stride)))
# Layout complement
def complement(layout, max_idx=1):
if is_int(layout):
return complement(Layout(layout))
result_shape = []
result_stride = []
current_idx = 1
sorted_DS = sorted(zip(flatten(layout.stride), flatten(layout.shape)))
for (stride, shape) in sorted_DS:
if stride == 0 or shape == 1:
continue
in_bound = current_idx <= shape * stride
# To support symbolic value which can't be evaluated now
assert (type(in_bound) is not bool) or in_bound
result_shape.append(stride // current_idx)
result_stride.append(current_idx)
current_idx = shape * stride
result_shape.append((max_idx + current_idx - 1) // current_idx) # ceil_div
result_stride.append(current_idx)
return coalesce(Layout(tuple(result_shape), tuple(result_stride)))
# Layout right inverse
def right_inverse(layout):
if layout is None:
return None
elif is_int(layout):
return Layout(layout)
result_shape = []
result_stride = []
current_idx = 1
flat_shape = flatten(layout.shape)
flat_stride = flatten(layout.stride)
sorted_DSA = sorted(zip(flat_stride, flat_shape, prefix_product(flat_shape)))
for (stride,shape,rstride) in sorted_DSA:
if shape == 1:
continue
if current_idx != stride:
break
result_shape.append(shape)
result_stride.append(rstride)
current_idx = shape * stride
return coalesce(Layout(tuple(result_shape), tuple(result_stride)))
# Layout left inverse
def left_inverse(layout):
if layout is None:
return None
elif is_int(layout):
return Layout(layout)
return right_inverse(make_layout(layout, complement(layout)))
# Split a layout by the composition of B and the "rest"
# Use tuples-of-layouts to perform this operation by-mode and None as no-op
def logical_divide(layoutA, layoutB):
if layoutB is None:
return layoutA
elif is_int(layoutB):
return logical_divide(layoutA, Layout(layoutB))
elif is_tuple(layoutB):
assert len(layoutA) >= len(layoutB)
return make_layout(chain((logical_divide(layoutA[i], layoutB[i]) for i in range( 0,len(layoutB))),
(layoutA[i] for i in range(len(layoutB),len(layoutA)))))
return composition(layoutA, make_layout(layoutB, complement(layoutB, size(layoutA))))
# Reproduce a layoutA over a layoutB
# Use tuples-of-layouts to perform this operation by-mode and None as no-op
def logical_product(layoutA, layoutB):
if layoutB is None:
return layoutA
elif is_int(layoutB):
return logical_divide(layoutA, Layout(layoutB))
elif is_tuple(layoutB):
assert len(layoutA) >= len(layoutB)
return make_layout(chain((logical_product(layoutA[i], layoutB[i]) for i in range( 0,len(layoutB))),
(layoutA[i] for i in range(len(layoutB),len(layoutA)))))
return make_layout(layoutA, composition(complement(layoutA, size(layoutA)*cosize(layoutB)), layoutB));
# Gather the modes from a hierarchical logical_divide or logical_product
def hier_unzip(splitter, layoutA, layoutB):
if layoutB is None:
return make_layout(Layout(1,0), layoutA)
elif is_tuple(layoutB):
assert len(layoutA) >= len(layoutB)
# A layout with shape ((A,a),(B,b),(C,c))
split = make_layout(hier_unzip(splitter, layoutA[i], layoutB[i]) for i in range(0,len(layoutB)))
# Gather to shape ((A,B,C,...),(a,b,c,...,y,z))
return make_layout(make_layout( split[i][0] for i in range( 0,len(layoutB))),
make_layout(chain((split[i][1] for i in range( 0,len(layoutB))),
(layoutA[i] for i in range(len(layoutB),len(layoutA))))))
# splitter must return a rank-2 layout
return splitter(layoutA, layoutB)
# Apply logical divide hierarchically and gather the split modes into two modes
def zipped_divide(layoutA, layoutB):
return hier_unzip(logical_divide, layoutA, layoutB)
# Perform logical divide hierarchically and gather tiles (B-layouts) into a new mode
def tiled_divide(layoutA, layoutB):
result = zipped_divide(layoutA, layoutB)
return make_layout([result[0]] + [result[1][i] for i in range(len(result[1]))])
# Apply logical product hierarchically and gather the split modes into two modes
def zipped_product(layoutA, layoutB):
return hier_unzip(logical_product, layoutA, layoutB)
# Perform logical product hierarchically and gather tiles (B-layouts) into a new mode
def tiled_product(layoutA, layoutB):
result = zipped_product(layoutA, layoutB)
return make_layout([result[0]] + [result[1][i] for i in range(len(result[1]))])
def slice_and_offset(crd: tuple,
layout: Layout):
return (Layout(slice_(crd, layout.shape), slice_(crd, layout.stride)),
crd2idx(crd, layout.shape, layout.stride))
|