|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#pragma once
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#include <stdint.h>
|
|
#include <cuda_fp16.h>
|
|
#include <iosfwd>
|
|
|
|
#include <cub/util_type.cuh>
|
|
|
|
#ifdef __GNUC__ |
|
|
|
#pragma GCC diagnostic push |
|
#pragma GCC diagnostic ignored "-Wstrict-aliasing" |
|
#endif |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
struct half_t
|
|
{
|
|
uint16_t __x;
|
|
|
|
|
|
__host__ __device__ __forceinline__
|
|
half_t(const __half &other)
|
|
{
|
|
__x = reinterpret_cast<const uint16_t&>(other);
|
|
}
|
|
|
|
|
|
__host__ __device__ __forceinline__
|
|
half_t(int a)
|
|
{
|
|
*this = half_t(float(a));
|
|
}
|
|
|
|
|
|
__host__ __device__ __forceinline__
|
|
half_t() : __x(0)
|
|
{}
|
|
|
|
|
|
__host__ __device__ __forceinline__
|
|
half_t(float a)
|
|
{
|
|
|
|
uint32_t ia = *reinterpret_cast<uint32_t*>(&a);
|
|
uint16_t ir;
|
|
|
|
ir = (ia >> 16) & 0x8000;
|
|
|
|
if ((ia & 0x7f800000) == 0x7f800000)
|
|
{
|
|
if ((ia & 0x7fffffff) == 0x7f800000)
|
|
{
|
|
ir |= 0x7c00;
|
|
}
|
|
else
|
|
{
|
|
ir = 0x7fff;
|
|
}
|
|
}
|
|
else if ((ia & 0x7f800000) >= 0x33000000)
|
|
{
|
|
int32_t shift = (int32_t) ((ia >> 23) & 0xff) - 127;
|
|
if (shift > 15)
|
|
{
|
|
ir |= 0x7c00;
|
|
}
|
|
else
|
|
{
|
|
ia = (ia & 0x007fffff) | 0x00800000;
|
|
if (shift < -14)
|
|
{
|
|
ir |= ia >> (-1 - shift);
|
|
ia = ia << (32 - (-1 - shift));
|
|
}
|
|
else
|
|
{
|
|
ir |= ia >> (24 - 11);
|
|
ia = ia << (32 - (24 - 11));
|
|
ir = ir + ((14 + shift) << 10);
|
|
}
|
|
|
|
if ((ia > 0x80000000) || ((ia == 0x80000000) && (ir & 1)))
|
|
{
|
|
ir++;
|
|
}
|
|
}
|
|
}
|
|
|
|
this->__x = ir;
|
|
}
|
|
|
|
|
|
__host__ __device__ __forceinline__
|
|
operator __half() const
|
|
{
|
|
return reinterpret_cast<const __half&>(__x);
|
|
}
|
|
|
|
|
|
__host__ __device__ __forceinline__
|
|
operator float() const
|
|
{
|
|
|
|
|
|
int sign = ((this->__x >> 15) & 1);
|
|
int exp = ((this->__x >> 10) & 0x1f);
|
|
int mantissa = (this->__x & 0x3ff);
|
|
uint32_t f = 0;
|
|
|
|
if (exp > 0 && exp < 31)
|
|
{
|
|
|
|
exp += 112;
|
|
f = (sign << 31) | (exp << 23) | (mantissa << 13);
|
|
}
|
|
else if (exp == 0)
|
|
{
|
|
if (mantissa)
|
|
{
|
|
|
|
exp += 113;
|
|
while ((mantissa & (1 << 10)) == 0)
|
|
{
|
|
mantissa <<= 1;
|
|
exp--;
|
|
}
|
|
mantissa &= 0x3ff;
|
|
f = (sign << 31) | (exp << 23) | (mantissa << 13);
|
|
}
|
|
else if (sign)
|
|
{
|
|
f = 0x80000000;
|
|
}
|
|
else
|
|
{
|
|
f = 0x0;
|
|
}
|
|
}
|
|
else if (exp == 31)
|
|
{
|
|
if (mantissa)
|
|
{
|
|
f = 0x7fffffff;
|
|
}
|
|
else
|
|
{
|
|
f = (0xff << 23) | (sign << 31);
|
|
}
|
|
}
|
|
return *reinterpret_cast<float const *>(&f);
|
|
}
|
|
|
|
|
|
|
|
__host__ __device__ __forceinline__
|
|
uint16_t raw()
|
|
{
|
|
return this->__x;
|
|
}
|
|
|
|
|
|
__host__ __device__ __forceinline__
|
|
bool operator ==(const half_t &other)
|
|
{
|
|
return (this->__x == other.__x);
|
|
}
|
|
|
|
|
|
__host__ __device__ __forceinline__
|
|
bool operator !=(const half_t &other)
|
|
{
|
|
return (this->__x != other.__x);
|
|
}
|
|
|
|
|
|
__host__ __device__ __forceinline__
|
|
half_t& operator +=(const half_t &rhs)
|
|
{
|
|
*this = half_t(float(*this) + float(rhs));
|
|
return *this;
|
|
}
|
|
|
|
|
|
__host__ __device__ __forceinline__
|
|
half_t operator*(const half_t &other)
|
|
{
|
|
return half_t(float(*this) * float(other));
|
|
}
|
|
|
|
|
|
__host__ __device__ __forceinline__
|
|
half_t operator+(const half_t &other)
|
|
{
|
|
return half_t(float(*this) + float(other));
|
|
}
|
|
|
|
|
|
__host__ __device__ __forceinline__
|
|
bool operator<(const half_t &other) const
|
|
{
|
|
return float(*this) < float(other);
|
|
}
|
|
|
|
|
|
__host__ __device__ __forceinline__
|
|
bool operator<=(const half_t &other) const
|
|
{
|
|
return float(*this) <= float(other);
|
|
}
|
|
|
|
|
|
__host__ __device__ __forceinline__
|
|
bool operator>(const half_t &other) const
|
|
{
|
|
return float(*this) > float(other);
|
|
}
|
|
|
|
|
|
__host__ __device__ __forceinline__
|
|
bool operator>=(const half_t &other) const
|
|
{
|
|
return float(*this) >= float(other);
|
|
}
|
|
|
|
|
|
__host__ __device__ __forceinline__
|
|
static half_t max() {
|
|
uint16_t max_word = 0x7BFF;
|
|
return reinterpret_cast<half_t&>(max_word);
|
|
}
|
|
|
|
|
|
__host__ __device__ __forceinline__
|
|
static half_t lowest() {
|
|
uint16_t lowest_word = 0xFBFF;
|
|
return reinterpret_cast<half_t&>(lowest_word);
|
|
}
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
std::ostream& operator<<(std::ostream &out, const half_t &x)
|
|
{
|
|
out << (float)x;
|
|
return out;
|
|
}
|
|
|
|
|
|
|
|
std::ostream& operator<<(std::ostream &out, const __half &x)
|
|
{
|
|
return out << half_t(x);
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <>
|
|
struct cub::FpLimits<half_t>
|
|
{
|
|
static __host__ __device__ __forceinline__ half_t Max() { return half_t::max(); }
|
|
|
|
static __host__ __device__ __forceinline__ half_t Lowest() { return half_t::lowest(); }
|
|
};
|
|
|
|
template <> struct cub::NumericTraits<half_t> : cub::BaseTraits<FLOATING_POINT, true, false, unsigned short, half_t> {};
|
|
|
|
|
|
#ifdef __GNUC__ |
|
#pragma GCC diagnostic pop |
|
#endif |
|
|