24 #ifndef HIP_INCLUDE_HIP_AMD_DETAIL_HIP_FP16_H
25 #define HIP_INCLUDE_HIP_AMD_DETAIL_HIP_FP16_H
27 #include <hip/amd_detail/hip_common.h>
31 #if defined(__cplusplus)
33 #include <type_traits>
37 #if __HIP_CLANG_ONLY__
38 typedef _Float16 _Float16_2 __attribute__((ext_vector_type(2)));
42 static_assert(
sizeof(_Float16) ==
sizeof(
unsigned short),
"");
51 static_assert(
sizeof(_Float16_2) ==
sizeof(
unsigned short[2]),
"");
61 #if defined(__cplusplus)
62 #include "hip_fp16_math_fwd.h"
63 #include "hip_vector_types.h"
68 template<>
struct is_floating_point<_Float16> : std::true_type {};
71 template<
bool cond,
typename T =
void>
72 using Enable_if_t =
typename std::enable_if<cond, T>::type;
78 static_assert(
sizeof(_Float16) ==
sizeof(
unsigned short),
"");
89 #if !defined(__HIP_NO_HALF_CONVERSIONS__)
91 __half(decltype(data) x) : data{x} {}
94 Enable_if_t<std::is_floating_point<T>{}>* =
nullptr>
96 __half(T x) : data{
static_cast<_Float16
>(x)} {}
99 __half(
const __half&) =
default;
101 __half(__half&&) =
default;
106 #if !defined(__HIP_NO_HALF_CONVERSIONS__)
108 typename T, Enable_if_t<std::is_integral<T>{}>* =
nullptr>
110 __half(T x) : data{
static_cast<_Float16
>(x)} {}
115 __half& operator=(
const __half&) =
default;
117 __half& operator=(__half&&) =
default;
125 volatile __half& operator=(
const __half_raw& x)
volatile
130 volatile __half& operator=(
const volatile __half_raw& x)
volatile
140 volatile __half& operator=(
__half_raw&& x)
volatile
145 volatile __half& operator=(
volatile __half_raw&& x)
volatile
150 #if !defined(__HIP_NO_HALF_CONVERSIONS__)
153 Enable_if_t<std::is_floating_point<T>{}>* =
nullptr>
155 __half& operator=(T x)
157 data =
static_cast<_Float16
>(x);
163 #if !defined(__HIP_NO_HALF_CONVERSIONS__)
165 typename T, Enable_if_t<std::is_integral<T>{}>* =
nullptr>
167 __half& operator=(T x)
169 data =
static_cast<_Float16
>(x);
174 #if !defined(__HIP_NO_HALF_OPERATORS__)
176 __half& operator+=(
const __half& x)
182 __half& operator-=(
const __half& x)
188 __half& operator*=(
const __half& x)
194 __half& operator/=(
const __half& x)
200 __half& operator++() { ++data;
return *
this; }
202 __half operator++(
int)
209 __half& operator--() { --data;
return *
this; }
211 __half operator--(
int)
220 #if !defined(__HIP_NO_HALF_CONVERSIONS__)
223 Enable_if_t<std::is_floating_point<T>{}>* =
nullptr>
225 operator T()
const {
return data; }
235 #if !defined(__HIP_NO_HALF_CONVERSIONS__)
237 typename T, Enable_if_t<std::is_integral<T>{}>* =
nullptr>
239 operator T()
const {
return data; }
242 #if !defined(__HIP_NO_HALF_OPERATORS__)
244 __half operator+()
const {
return *
this; }
246 __half operator-()
const
249 tmp.data = -tmp.data;
255 #if !defined(__HIP_NO_HALF_OPERATORS__)
259 __half operator+(
const __half& x,
const __half& y)
261 return __half{x} += y;
266 __half operator-(
const __half& x,
const __half& y)
268 return __half{x} -= y;
273 __half operator*(
const __half& x,
const __half& y)
275 return __half{x} *= y;
280 __half operator/(
const __half& x,
const __half& y)
282 return __half{x} /= y;
287 bool operator==(
const __half& x,
const __half& y)
289 return x.data == y.data;
294 bool operator!=(
const __half& x,
const __half& y)
301 bool operator<(
const __half& x,
const __half& y)
303 return x.data < y.data;
308 bool operator>(
const __half& x,
const __half& y)
310 return y.data < x.data;
315 bool operator<=(
const __half& x,
const __half& y)
322 bool operator>=(
const __half& x,
const __half& y)
326 #endif // !defined(__HIP_NO_HALF_OPERATORS__)
335 sizeof(_Float16_2) ==
sizeof(
unsigned short[2]),
"");
350 __half2(decltype(data) x) : data{x} {}
352 __half2(
const __half& x,
const __half& y)
359 __half2(
const __half2&) =
default;
361 __half2(__half2&&) =
default;
363 ~__half2() =
default;
367 __half2& operator=(
const __half2&) =
default;
369 __half2& operator=(__half2&&) =
default;
378 #if !defined(__HIP_NO_HALF_OPERATORS__)
380 __half2& operator+=(
const __half2& x)
386 __half2& operator-=(
const __half2& x)
392 __half2& operator*=(
const __half2& x)
398 __half2& operator/=(
const __half2& x)
404 __half2& operator++() {
return *
this += _Float16_2{1, 1}; }
406 __half2 operator++(
int)
413 __half2& operator--() {
return *
this -= _Float16_2{1, 1}; }
415 __half2 operator--(
int)
425 operator decltype(data)()
const {
return data; }
430 #if !defined(__HIP_NO_HALF_OPERATORS__)
432 __half2 operator+()
const {
return *
this; }
434 __half2 operator-()
const
437 tmp.data = -tmp.data;
443 #if !defined(__HIP_NO_HALF_OPERATORS__)
447 __half2 operator+(
const __half2& x,
const __half2& y)
449 return __half2{x} += y;
454 __half2 operator-(
const __half2& x,
const __half2& y)
456 return __half2{x} -= y;
461 __half2 operator*(
const __half2& x,
const __half2& y)
463 return __half2{x} *= y;
468 __half2 operator/(
const __half2& x,
const __half2& y)
470 return __half2{x} /= y;
475 bool operator==(
const __half2& x,
const __half2& y)
477 auto r = x.data == y.data;
478 return r.x != 0 && r.y != 0;
483 bool operator!=(
const __half2& x,
const __half2& y)
490 bool operator<(
const __half2& x,
const __half2& y)
492 auto r = x.data < y.data;
493 return r.x != 0 && r.y != 0;
498 bool operator>(
const __half2& x,
const __half2& y)
505 bool operator<=(
const __half2& x,
const __half2& y)
512 bool operator>=(
const __half2& x,
const __half2& y)
516 #endif // !defined(__HIP_NO_HALF_OPERATORS__)
524 __half2 make_half2(__half x, __half y)
526 return __half2{x, y};
531 __half __low2half(__half2 x)
538 __half __high2half(__half2 x)
545 __half2 __half2half2(__half x)
547 return __half2{x, x};
552 __half2 __halves2half2(__half x, __half y)
554 return __half2{x, y};
559 __half2 __low2half2(__half2 x)
569 __half2 __high2half2(__half2 x)
579 __half2 __lows2half2(__half2 x, __half2 y)
589 __half2 __highs2half2(__half2 x, __half2 y)
599 __half2 __lowhigh2highlow(__half2 x)
610 short __half_as_short(__half x)
617 unsigned short __half_as_ushort(__half x)
624 __half __short_as_half(
short x)
632 __half __ushort_as_half(
unsigned short x)
642 __half __float2half(
float x)
648 __half __float2half_rn(
float x)
654 __half __float2half_rz(
float x)
660 __half __float2half_rd(
float x)
666 __half __float2half_ru(
float x)
672 __half2 __float2half2_rn(
float x)
676 static_cast<_Float16
>(x),
static_cast<_Float16
>(x)}};
680 __half2 __floats2half2_rn(
float x,
float y)
683 static_cast<_Float16
>(x),
static_cast<_Float16
>(y)}};
687 __half2 __float22half2_rn(
float2 x)
689 return __floats2half2_rn(x.x, x.y);
695 float __half2float(__half x)
701 float __low2float(__half2 x)
707 float __high2float(__half2 x)
713 float2 __half22float2(__half2 x)
723 int __half2int_rn(__half x)
729 int __half2int_rz(__half x)
735 int __half2int_rd(__half x)
741 int __half2int_ru(__half x)
749 __half __int2half_rn(
int x)
755 __half __int2half_rz(
int x)
761 __half __int2half_rd(
int x)
767 __half __int2half_ru(
int x)
775 short __half2short_rn(__half x)
781 short __half2short_rz(__half x)
787 short __half2short_rd(__half x)
793 short __half2short_ru(__half x)
801 __half __short2half_rn(
short x)
807 __half __short2half_rz(
short x)
813 __half __short2half_rd(
short x)
819 __half __short2half_ru(
short x)
827 long long __half2ll_rn(__half x)
833 long long __half2ll_rz(__half x)
839 long long __half2ll_rd(__half x)
845 long long __half2ll_ru(__half x)
853 __half __ll2half_rn(
long long x)
859 __half __ll2half_rz(
long long x)
865 __half __ll2half_rd(
long long x)
871 __half __ll2half_ru(
long long x)
879 unsigned int __half2uint_rn(__half x)
885 unsigned int __half2uint_rz(__half x)
891 unsigned int __half2uint_rd(__half x)
897 unsigned int __half2uint_ru(__half x)
905 __half __uint2half_rn(
unsigned int x)
911 __half __uint2half_rz(
unsigned int x)
917 __half __uint2half_rd(
unsigned int x)
923 __half __uint2half_ru(
unsigned int x)
931 unsigned short __half2ushort_rn(__half x)
937 unsigned short __half2ushort_rz(__half x)
943 unsigned short __half2ushort_rd(__half x)
949 unsigned short __half2ushort_ru(__half x)
957 __half __ushort2half_rn(
unsigned short x)
963 __half __ushort2half_rz(
unsigned short x)
969 __half __ushort2half_rd(
unsigned short x)
975 __half __ushort2half_ru(
unsigned short x)
983 unsigned long long __half2ull_rn(__half x)
989 unsigned long long __half2ull_rz(__half x)
995 unsigned long long __half2ull_rd(__half x)
1001 unsigned long long __half2ull_ru(__half x)
1009 __half __ull2half_rn(
unsigned long long x)
1015 __half __ull2half_rz(
unsigned long long x)
1021 __half __ull2half_rd(
unsigned long long x)
1027 __half __ull2half_ru(
unsigned long long x)
1035 __half __ldg(
const __half* ptr) {
return *ptr; }
1038 __half __ldcg(
const __half* ptr) {
return *ptr; }
1041 __half __ldca(
const __half* ptr) {
return *ptr; }
1044 __half __ldcs(
const __half* ptr) {
return *ptr; }
1048 __half2 __ldg(
const __half2* ptr) {
return *ptr; }
1051 __half2 __ldcg(
const __half2* ptr) {
return *ptr; }
1054 __half2 __ldca(
const __half2* ptr) {
return *ptr; }
1057 __half2 __ldcs(
const __half2* ptr) {
return *ptr; }
1062 bool __heq(__half x, __half y)
1069 bool __hne(__half x, __half y)
1076 bool __hle(__half x, __half y)
1083 bool __hge(__half x, __half y)
1090 bool __hlt(__half x, __half y)
1097 bool __hgt(__half x, __half y)
1104 bool __hequ(__half x, __half y) {
return __heq(x, y); }
1107 bool __hneu(__half x, __half y) {
return __hne(x, y); }
1110 bool __hleu(__half x, __half y) {
return __hle(x, y); }
1113 bool __hgeu(__half x, __half y) {
return __hge(x, y); }
1116 bool __hltu(__half x, __half y) {
return __hlt(x, y); }
1119 bool __hgtu(__half x, __half y) {
return __hgt(x, y); }
1123 __half2 __heq2(__half2 x, __half2 y)
1127 return __builtin_convertvector(-r, _Float16_2);
1131 __half2 __hne2(__half2 x, __half2 y)
1135 return __builtin_convertvector(-r, _Float16_2);
1139 __half2 __hle2(__half2 x, __half2 y)
1143 return __builtin_convertvector(-r, _Float16_2);
1147 __half2 __hge2(__half2 x, __half2 y)
1151 return __builtin_convertvector(-r, _Float16_2);
1155 __half2 __hlt2(__half2 x, __half2 y)
1159 return __builtin_convertvector(-r, _Float16_2);
1163 __half2 __hgt2(__half2 x, __half2 y)
1167 return __builtin_convertvector(-r, _Float16_2);
1171 __half2 __hequ2(__half2 x, __half2 y) {
return __heq2(x, y); }
1174 __half2 __hneu2(__half2 x, __half2 y) {
return __hne2(x, y); }
1177 __half2 __hleu2(__half2 x, __half2 y) {
return __hle2(x, y); }
1180 __half2 __hgeu2(__half2 x, __half2 y) {
return __hge2(x, y); }
1183 __half2 __hltu2(__half2 x, __half2 y) {
return __hlt2(x, y); }
1186 __half2 __hgtu2(__half2 x, __half2 y) {
return __hgt2(x, y); }
1190 bool __hbeq2(__half2 x, __half2 y)
1193 return r.data.x != 0 && r.data.y != 0;
1197 bool __hbne2(__half2 x, __half2 y)
1200 return r.data.x != 0 && r.data.y != 0;
1204 bool __hble2(__half2 x, __half2 y)
1207 return r.data.x != 0 && r.data.y != 0;
1211 bool __hbge2(__half2 x, __half2 y)
1214 return r.data.x != 0 && r.data.y != 0;
1218 bool __hblt2(__half2 x, __half2 y)
1221 return r.data.x != 0 && r.data.y != 0;
1225 bool __hbgt2(__half2 x, __half2 y)
1228 return r.data.x != 0 && r.data.y != 0;
1232 bool __hbequ2(__half2 x, __half2 y) {
return __hbeq2(x, y); }
1235 bool __hbneu2(__half2 x, __half2 y) {
return __hbne2(x, y); }
1238 bool __hbleu2(__half2 x, __half2 y) {
return __hble2(x, y); }
1241 bool __hbgeu2(__half2 x, __half2 y) {
return __hbge2(x, y); }
1244 bool __hbltu2(__half2 x, __half2 y) {
return __hblt2(x, y); }
1247 bool __hbgtu2(__half2 x, __half2 y) {
return __hbgt2(x, y); }
1252 __half __clamp_01(__half x)
1263 __half __hadd(__half x, __half y)
1271 __half __habs(__half x)
1274 __ocml_fabs_f16(
static_cast<__half_raw>(x).data)};
1278 __half __hsub(__half x, __half y)
1286 __half __hmul(__half x, __half y)
1294 __half __hadd_sat(__half x, __half y)
1296 return __clamp_01(__hadd(x, y));
1300 __half __hsub_sat(__half x, __half y)
1302 return __clamp_01(__hsub(x, y));
1306 __half __hmul_sat(__half x, __half y)
1308 return __clamp_01(__hmul(x, y));
1312 __half __hfma(__half x, __half y, __half z)
1321 __half __hfma_sat(__half x, __half y, __half z)
1323 return __clamp_01(__hfma(x, y, z));
1327 __half __hdiv(__half x, __half y)
1336 __half2 __hadd2(__half2 x, __half2 y)
1344 __half2 __habs2(__half2 x)
1347 __ocml_fabs_2f16(
static_cast<__half2_raw>(x).data)};
1351 __half2 __hsub2(__half2 x, __half2 y)
1359 __half2 __hmul2(__half2 x, __half2 y)
1367 __half2 __hadd2_sat(__half2 x, __half2 y)
1376 __half2 __hsub2_sat(__half2 x, __half2 y)
1385 __half2 __hmul2_sat(__half2 x, __half2 y)
1394 __half2 __hfma2(__half2 x, __half2 y, __half2 z)
1400 __half2 __hfma2_sat(__half2 x, __half2 y, __half2 z)
1402 auto r =
static_cast<__half2_raw>(__hfma2(x, y, z));
1409 __half2 __h2div(__half2 x, __half2 y)
1417 #if __HIP_CLANG_ONLY__
1420 float amd_mixed_dot(__half2 a, __half2 b,
float c,
bool saturate) {
1421 return __ockl_fdot2(
static_cast<__half2_raw>(a).data,
1428 __half htrunc(__half x)
1431 __ocml_trunc_f16(
static_cast<__half_raw>(x).data)};
1435 __half hceil(__half x)
1438 __ocml_ceil_f16(
static_cast<__half_raw>(x).data)};
1442 __half hfloor(__half x)
1445 __ocml_floor_f16(
static_cast<__half_raw>(x).data)};
1449 __half hrint(__half x)
1452 __ocml_rint_f16(
static_cast<__half_raw>(x).data)};
1456 __half hsin(__half x)
1459 __ocml_sin_f16(
static_cast<__half_raw>(x).data)};
1463 __half hcos(__half x)
1466 __ocml_cos_f16(
static_cast<__half_raw>(x).data)};
1470 __half hexp(__half x)
1473 __ocml_exp_f16(
static_cast<__half_raw>(x).data)};
1477 __half hexp2(__half x)
1480 __ocml_exp2_f16(
static_cast<__half_raw>(x).data)};
1484 __half hexp10(__half x)
1487 __ocml_exp10_f16(
static_cast<__half_raw>(x).data)};
1491 __half hlog2(__half x)
1494 __ocml_log2_f16(
static_cast<__half_raw>(x).data)};
1498 __half hlog(__half x)
1501 __ocml_log_f16(
static_cast<__half_raw>(x).data)};
1505 __half hlog10(__half x)
1508 __ocml_log10_f16(
static_cast<__half_raw>(x).data)};
1512 __half hrcp(__half x)
1515 __llvm_amdgcn_rcp_f16(
static_cast<__half_raw>(x).data)};
1519 __half hrsqrt(__half x)
1522 __ocml_rsqrt_f16(
static_cast<__half_raw>(x).data)};
1526 __half hsqrt(__half x)
1529 __ocml_sqrt_f16(
static_cast<__half_raw>(x).data)};
1533 bool __hisinf(__half x)
1535 return __ocml_isinf_f16(
static_cast<__half_raw>(x).data);
1539 bool __hisnan(__half x)
1541 return __ocml_isnan_f16(
static_cast<__half_raw>(x).data);
1545 __half __hneg(__half x)
1552 __half2 h2trunc(__half2 x)
1558 __half2 h2ceil(__half2 x)
1564 __half2 h2floor(__half2 x)
1570 __half2 h2rint(__half2 x)
1576 __half2 h2sin(__half2 x)
1582 __half2 h2cos(__half2 x)
1588 __half2 h2exp(__half2 x)
1594 __half2 h2exp2(__half2 x)
1600 __half2 h2exp10(__half2 x)
1606 __half2 h2log2(__half2 x)
1612 __half2 h2log(__half2 x) {
return __ocml_log_2f16(x); }
1615 __half2 h2log10(__half2 x) {
return __ocml_log10_2f16(x); }
1618 __half2 h2rcp(__half2 x) {
return __llvm_amdgcn_rcp_2f16(x); }
1621 __half2 h2rsqrt(__half2 x) {
return __ocml_rsqrt_2f16(x); }
1624 __half2 h2sqrt(__half2 x) {
return __ocml_sqrt_2f16(x); }
1627 __half2 __hisinf2(__half2 x)
1629 auto r = __ocml_isinf_2f16(x);
1631 static_cast<_Float16
>(r.x),
static_cast<_Float16
>(r.y)}};
1635 __half2 __hisnan2(__half2 x)
1637 auto r = __ocml_isnan_2f16(x);
1639 static_cast<_Float16
>(r.x),
static_cast<_Float16
>(r.y)}};
1643 __half2 __hneg2(__half2 x)
1649 #if !defined(HIP_NO_HALF)
1650 using half = __half;
1651 using half2 = __half2;
1653 #endif // defined(__cplusplus)
1654 #elif defined(__GNUC__)
1655 #include "hip_fp16_gcc.h"
1656 #endif // !defined(__clang__) && defined(__GNUC__)
1658 #endif // HIP_INCLUDE_HIP_AMD_DETAIL_HIP_FP16_H