24 #ifndef HIP_INCLUDE_HIP_AMD_DETAIL_HIP_FP16_H
25 #define HIP_INCLUDE_HIP_AMD_DETAIL_HIP_FP16_H
27 #include <hip/amd_detail/hip_common.h>
29 #if defined(__HIPCC_RTC__)
30 #define __HOST_DEVICE__ __device__
32 #define __HOST_DEVICE__ __host__ __device__
34 #if defined(__cplusplus)
36 #include <type_traits>
39 #endif // !defined(__HIPCC_RTC__)
41 #if __HIP_CLANG_ONLY__
42 typedef _Float16 _Float16_2 __attribute__((ext_vector_type(2)));
46 static_assert(
sizeof(_Float16) ==
sizeof(
unsigned short),
"");
55 static_assert(
sizeof(_Float16_2) ==
sizeof(
unsigned short[2]),
"");
65 #if defined(__cplusplus)
66 #include "hip_fp16_math_fwd.h"
67 #include "hip_vector_types.h"
72 template<>
struct is_floating_point<_Float16> : std::true_type {};
75 template<
bool cond,
typename T =
void>
76 using Enable_if_t =
typename std::enable_if<cond, T>::type;
82 static_assert(
sizeof(_Float16) ==
sizeof(
unsigned short),
"");
93 #if !defined(__HIP_NO_HALF_CONVERSIONS__)
95 __half(decltype(data) x) : data{x} {}
98 Enable_if_t<std::is_floating_point<T>{}>* =
nullptr>
100 __half(T x) : data{
static_cast<_Float16
>(x)} {}
103 __half(
const __half&) =
default;
105 __half(__half&&) =
default;
110 #if !defined(__HIP_NO_HALF_CONVERSIONS__)
112 typename T, Enable_if_t<std::is_integral<T>{}>* =
nullptr>
114 __half(T x) : data{
static_cast<_Float16
>(x)} {}
119 __half& operator=(
const __half&) =
default;
121 __half& operator=(__half&&) =
default;
129 volatile __half& operator=(
const __half_raw& x)
volatile
134 volatile __half& operator=(
const volatile __half_raw& x)
volatile
144 volatile __half& operator=(
__half_raw&& x)
volatile
149 volatile __half& operator=(
volatile __half_raw&& x)
volatile
154 #if !defined(__HIP_NO_HALF_CONVERSIONS__)
157 Enable_if_t<std::is_floating_point<T>{}>* =
nullptr>
159 __half& operator=(T x)
161 data =
static_cast<_Float16
>(x);
167 #if !defined(__HIP_NO_HALF_CONVERSIONS__)
169 typename T, Enable_if_t<std::is_integral<T>{}>* =
nullptr>
171 __half& operator=(T x)
173 data =
static_cast<_Float16
>(x);
178 #if !defined(__HIP_NO_HALF_OPERATORS__)
180 __half& operator+=(
const __half& x)
186 __half& operator-=(
const __half& x)
192 __half& operator*=(
const __half& x)
198 __half& operator/=(
const __half& x)
204 __half& operator++() { ++data;
return *
this; }
206 __half operator++(
int)
213 __half& operator--() { --data;
return *
this; }
215 __half operator--(
int)
224 #if !defined(__HIP_NO_HALF_CONVERSIONS__)
227 Enable_if_t<std::is_floating_point<T>{}>* =
nullptr>
229 operator T()
const {
return data; }
239 #if !defined(__HIP_NO_HALF_CONVERSIONS__)
241 typename T, Enable_if_t<std::is_integral<T>{}>* =
nullptr>
243 operator T()
const {
return data; }
246 #if !defined(__HIP_NO_HALF_OPERATORS__)
248 __half operator+()
const {
return *
this; }
250 __half operator-()
const
253 tmp.data = -tmp.data;
259 #if !defined(__HIP_NO_HALF_OPERATORS__)
263 __half operator+(
const __half& x,
const __half& y)
265 return __half{x} += y;
270 __half operator-(
const __half& x,
const __half& y)
272 return __half{x} -= y;
277 __half operator*(
const __half& x,
const __half& y)
279 return __half{x} *= y;
284 __half operator/(
const __half& x,
const __half& y)
286 return __half{x} /= y;
291 bool operator==(
const __half& x,
const __half& y)
293 return x.data == y.data;
298 bool operator!=(
const __half& x,
const __half& y)
305 bool operator<(
const __half& x,
const __half& y)
307 return x.data < y.data;
312 bool operator>(
const __half& x,
const __half& y)
314 return y.data < x.data;
319 bool operator<=(
const __half& x,
const __half& y)
326 bool operator>=(
const __half& x,
const __half& y)
330 #endif // !defined(__HIP_NO_HALF_OPERATORS__)
339 sizeof(_Float16_2) ==
sizeof(
unsigned short[2]),
"");
354 __half2(decltype(data) x) : data{x} {}
356 __half2(
const __half& x,
const __half& y)
363 __half2(
const __half2&) =
default;
365 __half2(__half2&&) =
default;
367 ~__half2() =
default;
371 __half2& operator=(
const __half2&) =
default;
373 __half2& operator=(__half2&&) =
default;
382 #if !defined(__HIP_NO_HALF_OPERATORS__)
384 __half2& operator+=(
const __half2& x)
390 __half2& operator-=(
const __half2& x)
396 __half2& operator*=(
const __half2& x)
402 __half2& operator/=(
const __half2& x)
408 __half2& operator++() {
return *
this += _Float16_2{1, 1}; }
410 __half2 operator++(
int)
417 __half2& operator--() {
return *
this -= _Float16_2{1, 1}; }
419 __half2 operator--(
int)
429 operator decltype(data)()
const {
return data; }
434 #if !defined(__HIP_NO_HALF_OPERATORS__)
436 __half2 operator+()
const {
return *
this; }
438 __half2 operator-()
const
441 tmp.data = -tmp.data;
447 #if !defined(__HIP_NO_HALF_OPERATORS__)
451 __half2 operator+(
const __half2& x,
const __half2& y)
453 return __half2{x} += y;
458 __half2 operator-(
const __half2& x,
const __half2& y)
460 return __half2{x} -= y;
465 __half2 operator*(
const __half2& x,
const __half2& y)
467 return __half2{x} *= y;
472 __half2 operator/(
const __half2& x,
const __half2& y)
474 return __half2{x} /= y;
479 bool operator==(
const __half2& x,
const __half2& y)
481 auto r = x.data == y.data;
482 return r.x != 0 && r.y != 0;
487 bool operator!=(
const __half2& x,
const __half2& y)
494 bool operator<(
const __half2& x,
const __half2& y)
496 auto r = x.data < y.data;
497 return r.x != 0 && r.y != 0;
502 bool operator>(
const __half2& x,
const __half2& y)
509 bool operator<=(
const __half2& x,
const __half2& y)
516 bool operator>=(
const __half2& x,
const __half2& y)
520 #endif // !defined(__HIP_NO_HALF_OPERATORS__)
528 __half2 make_half2(__half x, __half y)
530 return __half2{x, y};
535 __half __low2half(__half2 x)
542 __half __high2half(__half2 x)
549 __half2 __half2half2(__half x)
551 return __half2{x, x};
556 __half2 __halves2half2(__half x, __half y)
558 return __half2{x, y};
563 __half2 __low2half2(__half2 x)
573 __half2 __high2half2(__half2 x)
583 __half2 __lows2half2(__half2 x, __half2 y)
593 __half2 __highs2half2(__half2 x, __half2 y)
603 __half2 __lowhigh2highlow(__half2 x)
614 short __half_as_short(__half x)
621 unsigned short __half_as_ushort(__half x)
628 __half __short_as_half(
short x)
636 __half __ushort_as_half(
unsigned short x)
646 __half __float2half(
float x)
652 __half __float2half_rn(
float x)
658 __half __float2half_rz(
float x)
664 __half __float2half_rd(
float x)
670 __half __float2half_ru(
float x)
676 __half2 __float2half2_rn(
float x)
680 static_cast<_Float16
>(x),
static_cast<_Float16
>(x)}};
684 __half2 __floats2half2_rn(
float x,
float y)
687 static_cast<_Float16
>(x),
static_cast<_Float16
>(y)}};
691 __half2 __float22half2_rn(
float2 x)
693 return __floats2half2_rn(x.x, x.y);
699 float __half2float(__half x)
705 float __low2float(__half2 x)
711 float __high2float(__half2 x)
717 float2 __half22float2(__half2 x)
727 int __half2int_rn(__half x)
733 int __half2int_rz(__half x)
739 int __half2int_rd(__half x)
745 int __half2int_ru(__half x)
753 __half __int2half_rn(
int x)
759 __half __int2half_rz(
int x)
765 __half __int2half_rd(
int x)
771 __half __int2half_ru(
int x)
779 short __half2short_rn(__half x)
785 short __half2short_rz(__half x)
791 short __half2short_rd(__half x)
797 short __half2short_ru(__half x)
805 __half __short2half_rn(
short x)
811 __half __short2half_rz(
short x)
817 __half __short2half_rd(
short x)
823 __half __short2half_ru(
short x)
831 long long __half2ll_rn(__half x)
837 long long __half2ll_rz(__half x)
843 long long __half2ll_rd(__half x)
849 long long __half2ll_ru(__half x)
857 __half __ll2half_rn(
long long x)
863 __half __ll2half_rz(
long long x)
869 __half __ll2half_rd(
long long x)
875 __half __ll2half_ru(
long long x)
883 unsigned int __half2uint_rn(__half x)
889 unsigned int __half2uint_rz(__half x)
895 unsigned int __half2uint_rd(__half x)
901 unsigned int __half2uint_ru(__half x)
909 __half __uint2half_rn(
unsigned int x)
915 __half __uint2half_rz(
unsigned int x)
921 __half __uint2half_rd(
unsigned int x)
927 __half __uint2half_ru(
unsigned int x)
935 unsigned short __half2ushort_rn(__half x)
941 unsigned short __half2ushort_rz(__half x)
947 unsigned short __half2ushort_rd(__half x)
953 unsigned short __half2ushort_ru(__half x)
961 __half __ushort2half_rn(
unsigned short x)
967 __half __ushort2half_rz(
unsigned short x)
973 __half __ushort2half_rd(
unsigned short x)
979 __half __ushort2half_ru(
unsigned short x)
987 unsigned long long __half2ull_rn(__half x)
993 unsigned long long __half2ull_rz(__half x)
999 unsigned long long __half2ull_rd(__half x)
1005 unsigned long long __half2ull_ru(__half x)
1013 __half __ull2half_rn(
unsigned long long x)
1019 __half __ull2half_rz(
unsigned long long x)
1025 __half __ull2half_rd(
unsigned long long x)
1031 __half __ull2half_ru(
unsigned long long x)
1039 __half __ldg(
const __half* ptr) {
return *ptr; }
1042 __half __ldcg(
const __half* ptr) {
return *ptr; }
1045 __half __ldca(
const __half* ptr) {
return *ptr; }
1048 __half __ldcs(
const __half* ptr) {
return *ptr; }
1052 __half2 __ldg(
const __half2* ptr) {
return *ptr; }
1055 __half2 __ldcg(
const __half2* ptr) {
return *ptr; }
1058 __half2 __ldca(
const __half2* ptr) {
return *ptr; }
1061 __half2 __ldcs(
const __half2* ptr) {
return *ptr; }
1066 bool __heq(__half x, __half y)
1073 bool __hne(__half x, __half y)
1080 bool __hle(__half x, __half y)
1087 bool __hge(__half x, __half y)
1094 bool __hlt(__half x, __half y)
1101 bool __hgt(__half x, __half y)
1108 bool __hequ(__half x, __half y) {
return __heq(x, y); }
1111 bool __hneu(__half x, __half y) {
return __hne(x, y); }
1114 bool __hleu(__half x, __half y) {
return __hle(x, y); }
1117 bool __hgeu(__half x, __half y) {
return __hge(x, y); }
1120 bool __hltu(__half x, __half y) {
return __hlt(x, y); }
1123 bool __hgtu(__half x, __half y) {
return __hgt(x, y); }
1127 __half2 __heq2(__half2 x, __half2 y)
1131 return __builtin_convertvector(-r, _Float16_2);
1135 __half2 __hne2(__half2 x, __half2 y)
1139 return __builtin_convertvector(-r, _Float16_2);
1143 __half2 __hle2(__half2 x, __half2 y)
1147 return __builtin_convertvector(-r, _Float16_2);
1151 __half2 __hge2(__half2 x, __half2 y)
1155 return __builtin_convertvector(-r, _Float16_2);
1159 __half2 __hlt2(__half2 x, __half2 y)
1163 return __builtin_convertvector(-r, _Float16_2);
1167 __half2 __hgt2(__half2 x, __half2 y)
1171 return __builtin_convertvector(-r, _Float16_2);
1175 __half2 __hequ2(__half2 x, __half2 y) {
return __heq2(x, y); }
1178 __half2 __hneu2(__half2 x, __half2 y) {
return __hne2(x, y); }
1181 __half2 __hleu2(__half2 x, __half2 y) {
return __hle2(x, y); }
1184 __half2 __hgeu2(__half2 x, __half2 y) {
return __hge2(x, y); }
1187 __half2 __hltu2(__half2 x, __half2 y) {
return __hlt2(x, y); }
1190 __half2 __hgtu2(__half2 x, __half2 y) {
return __hgt2(x, y); }
1194 bool __hbeq2(__half2 x, __half2 y)
1197 return r.data.x != 0 && r.data.y != 0;
1201 bool __hbne2(__half2 x, __half2 y)
1204 return r.data.x != 0 && r.data.y != 0;
1208 bool __hble2(__half2 x, __half2 y)
1211 return r.data.x != 0 && r.data.y != 0;
1215 bool __hbge2(__half2 x, __half2 y)
1218 return r.data.x != 0 && r.data.y != 0;
1222 bool __hblt2(__half2 x, __half2 y)
1225 return r.data.x != 0 && r.data.y != 0;
1229 bool __hbgt2(__half2 x, __half2 y)
1232 return r.data.x != 0 && r.data.y != 0;
1236 bool __hbequ2(__half2 x, __half2 y) {
return __hbeq2(x, y); }
1239 bool __hbneu2(__half2 x, __half2 y) {
return __hbne2(x, y); }
1242 bool __hbleu2(__half2 x, __half2 y) {
return __hble2(x, y); }
1245 bool __hbgeu2(__half2 x, __half2 y) {
return __hbge2(x, y); }
1248 bool __hbltu2(__half2 x, __half2 y) {
return __hblt2(x, y); }
1251 bool __hbgtu2(__half2 x, __half2 y) {
return __hbgt2(x, y); }
1256 __half __clamp_01(__half x)
1267 __half __hadd(__half x, __half y)
1275 __half __habs(__half x)
1278 __ocml_fabs_f16(
static_cast<__half_raw>(x).data)};
1282 __half __hsub(__half x, __half y)
1290 __half __hmul(__half x, __half y)
1298 __half __hadd_sat(__half x, __half y)
1300 return __clamp_01(__hadd(x, y));
1304 __half __hsub_sat(__half x, __half y)
1306 return __clamp_01(__hsub(x, y));
1310 __half __hmul_sat(__half x, __half y)
1312 return __clamp_01(__hmul(x, y));
1316 __half __hfma(__half x, __half y, __half z)
1325 __half __hfma_sat(__half x, __half y, __half z)
1327 return __clamp_01(__hfma(x, y, z));
1331 __half __hdiv(__half x, __half y)
1340 __half2 __hadd2(__half2 x, __half2 y)
1348 __half2 __habs2(__half2 x)
1351 __ocml_fabs_2f16(
static_cast<__half2_raw>(x).data)};
1355 __half2 __hsub2(__half2 x, __half2 y)
1363 __half2 __hmul2(__half2 x, __half2 y)
1371 __half2 __hadd2_sat(__half2 x, __half2 y)
1380 __half2 __hsub2_sat(__half2 x, __half2 y)
1389 __half2 __hmul2_sat(__half2 x, __half2 y)
1398 __half2 __hfma2(__half2 x, __half2 y, __half2 z)
1404 __half2 __hfma2_sat(__half2 x, __half2 y, __half2 z)
1406 auto r =
static_cast<__half2_raw>(__hfma2(x, y, z));
1413 __half2 __h2div(__half2 x, __half2 y)
1421 #if __HIP_CLANG_ONLY__
1424 float amd_mixed_dot(__half2 a, __half2 b,
float c,
bool saturate) {
1425 return __ockl_fdot2(
static_cast<__half2_raw>(a).data,
1432 __half htrunc(__half x)
1435 __ocml_trunc_f16(
static_cast<__half_raw>(x).data)};
1439 __half hceil(__half x)
1442 __ocml_ceil_f16(
static_cast<__half_raw>(x).data)};
1446 __half hfloor(__half x)
1449 __ocml_floor_f16(
static_cast<__half_raw>(x).data)};
1453 __half hrint(__half x)
1456 __ocml_rint_f16(
static_cast<__half_raw>(x).data)};
1460 __half hsin(__half x)
1463 __ocml_sin_f16(
static_cast<__half_raw>(x).data)};
1467 __half hcos(__half x)
1470 __ocml_cos_f16(
static_cast<__half_raw>(x).data)};
1474 __half hexp(__half x)
1477 __ocml_exp_f16(
static_cast<__half_raw>(x).data)};
1481 __half hexp2(__half x)
1484 __ocml_exp2_f16(
static_cast<__half_raw>(x).data)};
1488 __half hexp10(__half x)
1491 __ocml_exp10_f16(
static_cast<__half_raw>(x).data)};
1495 __half hlog2(__half x)
1498 __ocml_log2_f16(
static_cast<__half_raw>(x).data)};
1502 __half hlog(__half x)
1505 __ocml_log_f16(
static_cast<__half_raw>(x).data)};
1509 __half hlog10(__half x)
1512 __ocml_log10_f16(
static_cast<__half_raw>(x).data)};
1516 __half hrcp(__half x)
1519 __llvm_amdgcn_rcp_f16(
static_cast<__half_raw>(x).data)};
1523 __half hrsqrt(__half x)
1526 __ocml_rsqrt_f16(
static_cast<__half_raw>(x).data)};
1530 __half hsqrt(__half x)
1533 __ocml_sqrt_f16(
static_cast<__half_raw>(x).data)};
1537 bool __hisinf(__half x)
1539 return __ocml_isinf_f16(
static_cast<__half_raw>(x).data);
1543 bool __hisnan(__half x)
1545 return __ocml_isnan_f16(
static_cast<__half_raw>(x).data);
1549 __half __hneg(__half x)
1556 __half2 h2trunc(__half2 x)
1562 __half2 h2ceil(__half2 x)
1568 __half2 h2floor(__half2 x)
1574 __half2 h2rint(__half2 x)
1580 __half2 h2sin(__half2 x)
1586 __half2 h2cos(__half2 x)
1592 __half2 h2exp(__half2 x)
1598 __half2 h2exp2(__half2 x)
1604 __half2 h2exp10(__half2 x)
1610 __half2 h2log2(__half2 x)
1616 __half2 h2log(__half2 x) {
return __ocml_log_2f16(x); }
1619 __half2 h2log10(__half2 x) {
return __ocml_log10_2f16(x); }
1622 __half2 h2rcp(__half2 x) {
return __llvm_amdgcn_rcp_2f16(x); }
1625 __half2 h2rsqrt(__half2 x) {
return __ocml_rsqrt_2f16(x); }
1628 __half2 h2sqrt(__half2 x) {
return __ocml_sqrt_2f16(x); }
1631 __half2 __hisinf2(__half2 x)
1633 auto r = __ocml_isinf_2f16(x);
1635 static_cast<_Float16
>(r.x),
static_cast<_Float16
>(r.y)}};
1639 __half2 __hisnan2(__half2 x)
1641 auto r = __ocml_isnan_2f16(x);
1643 static_cast<_Float16
>(r.x),
static_cast<_Float16
>(r.y)}};
1647 __half2 __hneg2(__half2 x)
1653 #if !defined(HIP_NO_HALF)
1654 using half = __half;
1655 using half2 = __half2;
1657 #endif // defined(__cplusplus)
1658 #elif defined(__GNUC__)
1659 #include "hip_fp16_gcc.h"
1660 #endif // !defined(__clang__) && defined(__GNUC__)
1662 #endif // HIP_INCLUDE_HIP_AMD_DETAIL_HIP_FP16_H