3 #if defined(__cplusplus)
16 #if defined(__cplusplus)
19 __half __float2half(
float);
20 float __half2float(__half);
30 #if !defined(__HIP_NO_HALF_CONVERSIONS__)
31 __half(
float x) : __x{__float2half(x).__x} {}
32 __half(
double x) : __x{__float2half(x).__x} {}
34 __half(
const __half&) =
default;
35 __half(__half&&) =
default;
39 __half& operator=(
const __half&) =
default;
40 __half& operator=(__half&&) =
default;
41 __half& operator=(
const __half_raw& x) { __x = x.x;
return *
this; }
42 #if !defined(__HIP_NO_HALF_CONVERSIONS__)
43 __half& operator=(
float x)
45 __x = __float2half(x).__x;
48 __half& operator=(
double x)
50 return *
this =
static_cast<float>(x);
55 operator float()
const {
return __half2float(*
this); }
70 x{
reinterpret_cast<const __half&
>(ix.x)},
71 y{
reinterpret_cast<const __half&
>(ix.y)}
73 __half2(
const __half& ix,
const __half& iy) : x{ix}, y{iy} {}
74 __half2(
const __half2&) =
default;
75 __half2(__half2&&) =
default;
79 __half2& operator=(
const __half2&) =
default;
80 __half2& operator=(__half2&&) =
default;
92 reinterpret_cast<const unsigned short&
>(x),
93 reinterpret_cast<const unsigned short&
>(y)};
101 unsigned short __internal_float2half(
102 float flt,
unsigned int& sgn,
unsigned int& rem)
105 std::memcpy(&x, &flt,
sizeof(flt));
107 unsigned int u = (x & 0x7fffffffU);
108 sgn = ((x >> 16) & 0x8000U);
111 if (u >= 0x7f800000U) {
113 return static_cast<unsigned short>(
114 (u == 0x7f800000U) ? (sgn | 0x7c00U) : 0x7fffU);
117 if (u > 0x477fefffU) {
119 return static_cast<unsigned short>(sgn | 0x7bffU);
122 if (u >= 0x38800000U) {
125 return static_cast<unsigned short>(sgn | (u >> 13));
128 if (u < 0x33000001U) {
130 return static_cast<unsigned short>(sgn);
133 unsigned int exponent = u >> 23;
134 unsigned int mantissa = (u & 0x7fffffU);
135 unsigned int shift = 0x7eU - exponent;
136 mantissa |= 0x800000U;
137 rem = mantissa << (32 - shift);
138 return static_cast<unsigned short>(sgn | (mantissa >> shift));
142 __half __float2half(
float x)
147 r.x = __internal_float2half(x, sgn, rem);
148 if (rem > 0x80000000U || (rem == 0x80000000U && (r.x & 0x1))) ++r.x;
154 __half __float2half_rn(
float x) {
return __float2half(x); }
157 __half __float2half_rz(
float x)
162 r.x = __internal_float2half(x, sgn, rem);
168 __half __float2half_rd(
float x)
173 r.x = __internal_float2half(x, sgn, rem);
174 if (rem && sgn) ++r.x;
180 __half __float2half_ru(
float x)
185 r.x = __internal_float2half(x, sgn, rem);
186 if (rem && !sgn) ++r.x;
192 __half2 __float2half2_rn(
float x)
194 return __half2{__float2half_rn(x), __float2half_rn(x)};
198 __half2 __floats2half2_rn(
float x,
float y)
200 return __half2{__float2half_rn(x), __float2half_rn(y)};
204 float __internal_half2float(
unsigned short x)
206 unsigned int sign = ((x >> 15) & 1);
207 unsigned int exponent = ((x >> 10) & 0x1f);
208 unsigned int mantissa = ((x & 0x3ff) << 13);
210 if (exponent == 0x1fU) {
211 mantissa = (mantissa ? (sign = 0, 0x7fffffU) : 0);
213 }
else if (!exponent) {
218 msb = (mantissa & 0x400000U);
222 mantissa &= 0x7fffffU;
227 unsigned int u = ((sign << 31) | (exponent << 23) | mantissa);
229 memcpy(&f, &u,
sizeof(u));
235 float __half2float(__half x)
237 return __internal_half2float(
static_cast<__half_raw>(x).x);
241 float __low2float(__half2 x)
243 return __internal_half2float(
static_cast<__half2_raw>(x).x);
247 float __high2float(__half2 x)
249 return __internal_half2float(
static_cast<__half2_raw>(x).y);
253 #if !defined(HIP_NO_HALF)
255 using half2 = __half2;
257 #endif // defined(__cplusplus)