Spaces:
Runtime error
Runtime error
File size: 12,388 Bytes
e202b16 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 |
#################################################################################################
#
# Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################
"""
Definition of CuTe Layouts and functions to manipulate them
"""
from itertools import chain
from typing import Union
from .int_tuple import *
class LayoutBase:
pass
def is_layout(x):
return isinstance(x, LayoutBase)
class Layout(LayoutBase):
def __init__(self, _shape, _stride=None):
self.shape = _shape
if _stride is None:
self.stride = prefix_product(self.shape)
else:
self.stride = _stride
# operator ==
def __eq__(self, other):
return self.shape == other.shape and self.stride == other.stride
# operator len(L) (len [rank] like tuples)
def __len__(self):
if is_tuple(self.shape):
return len(self.shape)
else:
return 1
# operator () (map coord to idx)
def __call__(self, *args):
"""
Map a logical coordinate to a linear index (Coord has no Underscore slice operators)
OR
Slice the layout and return the sublayout (Coord has an Underscore slice op)
Follow the same behavior of `Layout::operator(Coord const&)` in cute C++
"""
if has_none(args):
if len(args) == 1:
return Layout(slice_(args[0], self.shape), slice_(args[0], self.stride))
else:
return Layout(slice_(args, self.shape), slice_(args, self.stride))
else:
if len(args) == 1:
return crd2idx(args[0], self.shape, self.stride)
else:
return crd2idx(args, self.shape, self.stride)
# operator [] (get-i like tuples)
def __getitem__(self, i):
if is_tuple(self.shape):
return Layout(self.shape[i], self.stride[i])
else:
assert i == 0
return Layout(self.shape, self.stride)
# size(layout) Size of the domain
def size(self):
return product(self.shape)
# cosize(layout) Size of the codomain
def cosize(self):
return self(self.size() - 1) + 1
# print and str
def __str__(self):
return f"{self.shape}:{self.stride}"
# error msgs and representation
def __repr__(self):
return f"Layout({self.shape},{self.stride})"
# Make Layout from a list of layouts (each layout it's own mode in the result)
def make_layout(*layouts):
if len(layouts) == 1 and not is_layout(layouts[0]):
layouts = layouts[0]
shape, stride = zip(*((a.shape,a.stride) for a in layouts))
return Layout(shape, stride)
# Size of the domain
def size(layout):
if is_layout(layout):
return layout.size()
return product(layout)
# Size of the codomain
def cosize(layout):
return layout.cosize()
# Layout coalesce -- flatten and combine as many modes as possible while preserving the int-to-int function
def coalesce(layout, profile=None):
if is_tuple(profile):
assert len(layout) >= len(profile)
return make_layout(chain((coalesce(layout[i], profile[i]) for i in range( 0,len(profile))),
(layout[i] for i in range(len(profile),len(layout)))))
result_shape = [1]
result_stride = [0]
for (shape,stride) in zip(flatten(layout.shape),flatten(layout.stride)):
# skip their shape-1s
if shape == 1:
continue
# replace our shape-1 with anything
elif result_shape[-1] == 1:
result_shape[-1] = shape
result_stride[-1] = stride
# merge modes if the shape*stride match
elif result_shape[-1] * result_stride[-1] == stride:
result_shape[-1] = result_shape[-1] * shape
# append a new mode
else:
result_shape.append(shape)
result_stride.append(stride)
if len(result_shape) == 1:
return Layout(result_shape[0], result_stride[0])
else:
return Layout(tuple(result_shape), tuple(result_stride))
# Layout filter -- replace all stride-0 modes with size-1 and then coalesce to remove them
def filter(layout, profile=None):
if is_tuple(profile):
assert len(layout) >= len(profile)
return make_layout(chain((filter(layout[i], profile[i]) for i in range( 0,len(profile))),
(layout[i] for i in range(len(profile),len(layout)))))
result_shape = []
result_stride = []
for (shape,stride) in zip(flatten(layout.shape),flatten(layout.stride)):
# skip their shape-1s and stride-0s
if not (shape == 1 or stride == 0):
result_shape.append(shape)
result_stride.append(stride)
if len(result_shape) == 0:
return Layout(1,0)
else:
return coalesce(Layout(tuple(result_shape), tuple(result_stride)))
# Layout composition
# Use tuples-of-layouts to perform this operation by-mode and None as no-op
def composition(layoutA, layoutB):
if layoutB is None:
return layoutA
elif is_int(layoutB):
return composition(layoutA, Layout(layoutB))
elif is_tuple(layoutB):
assert len(layoutA) >= len(layoutB)
return make_layout(chain((composition(layoutA[i], layoutB[i]) for i in range( 0,len(layoutB))),
(layoutA[i] for i in range(len(layoutB),len(layoutA)))))
elif is_tuple(layoutB.shape):
return make_layout(composition(layoutA, layoutB_i) for layoutB_i in layoutB)
if layoutB.stride == 0:
return Layout(layoutB.shape, 0)
else:
result_shape = []
result_stride = []
rest_shape = layoutB.shape
rest_stride = layoutB.stride
for (s, d) in zip(flatten(layoutA.shape)[:-1], flatten(layoutA.stride)[:-1]):
s1 = shape_div(s, rest_stride)
result_shape.append(min(s1,rest_shape))
result_stride.append(rest_stride * d)
rest_shape = shape_div(rest_shape, abs(s1))
rest_stride = shape_div(rest_stride, s)
result_shape.append(rest_shape)
result_stride.append(rest_stride * flatten(layoutA.stride)[-1])
return coalesce(Layout(tuple(result_shape), tuple(result_stride)))
# Layout complement
def complement(layout, max_idx=1):
if is_int(layout):
return complement(Layout(layout))
result_shape = []
result_stride = []
current_idx = 1
sorted_DS = sorted(zip(flatten(layout.stride), flatten(layout.shape)))
for (stride, shape) in sorted_DS:
if stride == 0 or shape == 1:
continue
in_bound = current_idx <= shape * stride
# To support symbolic value which can't be evaluated now
assert (type(in_bound) is not bool) or in_bound
result_shape.append(stride // current_idx)
result_stride.append(current_idx)
current_idx = shape * stride
result_shape.append((max_idx + current_idx - 1) // current_idx) # ceil_div
result_stride.append(current_idx)
return coalesce(Layout(tuple(result_shape), tuple(result_stride)))
# Layout right inverse
def right_inverse(layout):
if layout is None:
return None
elif is_int(layout):
return Layout(layout)
result_shape = []
result_stride = []
current_idx = 1
flat_shape = flatten(layout.shape)
flat_stride = flatten(layout.stride)
sorted_DSA = sorted(zip(flat_stride, flat_shape, prefix_product(flat_shape)))
for (stride,shape,rstride) in sorted_DSA:
if shape == 1:
continue
if current_idx != stride:
break
result_shape.append(shape)
result_stride.append(rstride)
current_idx = shape * stride
return coalesce(Layout(tuple(result_shape), tuple(result_stride)))
# Layout left inverse
def left_inverse(layout):
if layout is None:
return None
elif is_int(layout):
return Layout(layout)
return right_inverse(make_layout(layout, complement(layout)))
# Split a layout by the composition of B and the "rest"
# Use tuples-of-layouts to perform this operation by-mode and None as no-op
def logical_divide(layoutA, layoutB):
if layoutB is None:
return layoutA
elif is_int(layoutB):
return logical_divide(layoutA, Layout(layoutB))
elif is_tuple(layoutB):
assert len(layoutA) >= len(layoutB)
return make_layout(chain((logical_divide(layoutA[i], layoutB[i]) for i in range( 0,len(layoutB))),
(layoutA[i] for i in range(len(layoutB),len(layoutA)))))
return composition(layoutA, make_layout(layoutB, complement(layoutB, size(layoutA))))
# Reproduce a layoutA over a layoutB
# Use tuples-of-layouts to perform this operation by-mode and None as no-op
def logical_product(layoutA, layoutB):
if layoutB is None:
return layoutA
elif is_int(layoutB):
return logical_divide(layoutA, Layout(layoutB))
elif is_tuple(layoutB):
assert len(layoutA) >= len(layoutB)
return make_layout(chain((logical_product(layoutA[i], layoutB[i]) for i in range( 0,len(layoutB))),
(layoutA[i] for i in range(len(layoutB),len(layoutA)))))
return make_layout(layoutA, composition(complement(layoutA, size(layoutA)*cosize(layoutB)), layoutB));
# Gather the modes from a hierarchical logical_divide or logical_product
def hier_unzip(splitter, layoutA, layoutB):
if layoutB is None:
return make_layout(Layout(1,0), layoutA)
elif is_tuple(layoutB):
assert len(layoutA) >= len(layoutB)
# A layout with shape ((A,a),(B,b),(C,c))
split = make_layout(hier_unzip(splitter, layoutA[i], layoutB[i]) for i in range(0,len(layoutB)))
# Gather to shape ((A,B,C,...),(a,b,c,...,y,z))
return make_layout(make_layout( split[i][0] for i in range( 0,len(layoutB))),
make_layout(chain((split[i][1] for i in range( 0,len(layoutB))),
(layoutA[i] for i in range(len(layoutB),len(layoutA))))))
# splitter must return a rank-2 layout
return splitter(layoutA, layoutB)
# Apply logical divide hierarchically and gather the split modes into two modes
def zipped_divide(layoutA, layoutB):
return hier_unzip(logical_divide, layoutA, layoutB)
# Perform logical divide hierarchically and gather tiles (B-layouts) into a new mode
def tiled_divide(layoutA, layoutB):
result = zipped_divide(layoutA, layoutB)
return make_layout([result[0]] + [result[1][i] for i in range(len(result[1]))])
# Apply logical product hierarchically and gather the split modes into two modes
def zipped_product(layoutA, layoutB):
return hier_unzip(logical_product, layoutA, layoutB)
# Perform logical product hierarchically and gather tiles (B-layouts) into a new mode
def tiled_product(layoutA, layoutB):
result = zipped_product(layoutA, layoutB)
return make_layout([result[0]] + [result[1][i] for i in range(len(result[1]))])
def slice_and_offset(crd: tuple,
layout: Layout):
return (Layout(slice_(crd, layout.shape), slice_(crd, layout.stride)),
crd2idx(crd, layout.shape, layout.stride))
|