29 #ifndef _HIP_BFLOAT16_H_ 30 #define _HIP_BFLOAT16_H_ 32 #if __cplusplus < 201103L || (!defined(__HCC__) && !defined(__HIPCC__)) 44 #else // __cplusplus < 201103L || (!defined(__HCC__) && !defined(__HIPCC__)) 49 #include <hip/hip_runtime.h> 51 #include <type_traits> 66 : data(float_to_bfloat16(f))
70 explicit __host__ __device__ hip_bfloat16(
float f, truncate_t)
71 : data(truncate_float_to_bfloat16(f))
76 __host__ __device__
operator float()
const 82 } u = {uint32_t(data) << 16};
86 static __host__ __device__ hip_bfloat16 round_to_bfloat16(
float f)
89 output.data = float_to_bfloat16(f);
93 static __host__ __device__ hip_bfloat16 round_to_bfloat16(
float f, truncate_t)
96 output.data = truncate_float_to_bfloat16(f);
101 static __host__ __device__ uint16_t float_to_bfloat16(
float f)
108 if(~u.int32 & 0x7f800000)
126 u.int32 += 0x7fff + ((u.int32 >> 16) & 1);
128 else if(u.int32 & 0xffff)
140 return uint16_t(u.int32 >> 16);
144 static __host__ __device__ uint16_t truncate_float_to_bfloat16(
float f)
151 return uint16_t(u.int32 >> 16) | (!(~u.int32 & 0x7f800000) && (u.int32 & 0xffff));
158 } hip_bfloat16_public;
160 static_assert(std::is_standard_layout<hip_bfloat16>{},
161 "hip_bfloat16 is not a standard layout type, and thus is " 162 "incompatible with C.");
164 static_assert(std::is_trivial<hip_bfloat16>{},
165 "hip_bfloat16 is not a trivial type, and thus is " 166 "incompatible with C.");
168 static_assert(
sizeof(
hip_bfloat16) ==
sizeof(hip_bfloat16_public)
169 && offsetof(
hip_bfloat16, data) == offsetof(hip_bfloat16_public, data),
170 "internal hip_bfloat16 does not match public hip_bfloat16");
172 inline std::ostream& operator<<(std::ostream& os,
const hip_bfloat16& bf16)
174 return os << float(bf16);
203 return float(a) < float(b);
207 return float(a) == float(b);
266 return !(~a.data & 0x7f80) && !(a.data & 0x7f);
270 return !(~a.data & 0x7f80) && +(a.data & 0x7f);
274 return !(a.data & 0x7fff);
278 #endif // __cplusplus < 201103L || (!defined(__HCC__) && !defined(__HIPCC__)) 280 #endif // _HIP_BFLOAT16_H_ Struct to represent a 16 bit brain floating point number.
Definition: hip_bfloat16.h:39
#define __host__
Definition: host_defines.h:41