HIP: Heterogenous-computing Interface for Portability
hip_fp16_gcc.h
1 #pragma once
2 
3 #if defined(__cplusplus)
4  #include <cstring>
5 #endif
6 
7 struct __half_raw {
8  unsigned short x;
9 };
10 
11 struct __half2_raw {
12  unsigned short x;
13  unsigned short y;
14 };
15 
16 #if defined(__cplusplus)
17  struct __half;
18 
19  __half __float2half(float);
20  float __half2float(__half);
21 
22  // BEGIN STRUCT __HALF
23  struct __half {
24  protected:
25  unsigned short __x;
26  public:
27  // CREATORS
28  __half() = default;
29  __half(const __half_raw& x) : __x{x.x} {}
30  #if !defined(__HIP_NO_HALF_CONVERSIONS__)
31  __half(float x) : __x{__float2half(x).__x} {}
32  __half(double x) : __x{__float2half(x).__x} {}
33  #endif
34  __half(const __half&) = default;
35  __half(__half&&) = default;
36  ~__half() = default;
37 
38  // MANIPULATORS
39  __half& operator=(const __half&) = default;
40  __half& operator=(__half&&) = default;
41  __half& operator=(const __half_raw& x) { __x = x.x; return *this; }
42  #if !defined(__HIP_NO_HALF_CONVERSIONS__)
43  __half& operator=(float x)
44  {
45  __x = __float2half(x).__x;
46  return *this;
47  }
48  __half& operator=(double x)
49  {
50  return *this = static_cast<float>(x);
51  }
52  #endif
53 
54  // ACCESSORS
55  operator float() const { return __half2float(*this); }
56  operator __half_raw() const { return __half_raw{__x}; }
57  };
58  // END STRUCT __HALF
59 
60  // BEGIN STRUCT __HALF2
61  struct __half2 {
62  public:
63  __half x;
64  __half y;
65 
66  // CREATORS
67  __half2() = default;
68  __half2(const __half2_raw& ix)
69  :
70  x{reinterpret_cast<const __half&>(ix.x)},
71  y{reinterpret_cast<const __half&>(ix.y)}
72  {}
73  __half2(const __half& ix, const __half& iy) : x{ix}, y{iy} {}
74  __half2(const __half2&) = default;
75  __half2(__half2&&) = default;
76  ~__half2() = default;
77 
78  // MANIPULATORS
79  __half2& operator=(const __half2&) = default;
80  __half2& operator=(__half2&&) = default;
81  __half2& operator=(const __half2_raw& ix)
82  {
83  x = reinterpret_cast<const __half_raw&>(ix.x);
84  y = reinterpret_cast<const __half_raw&>(ix.y);
85  return *this;
86  }
87 
88  // ACCESSORS
89  operator __half2_raw() const
90  {
91  return __half2_raw{
92  reinterpret_cast<const unsigned short&>(x),
93  reinterpret_cast<const unsigned short&>(y)};
94  }
95  };
96  // END STRUCT __HALF2
97 
98  inline
99  unsigned short __internal_float2half(
100  float flt, unsigned int& sgn, unsigned int& rem)
101  {
102  unsigned int x{};
103  std::memcpy(&x, &flt, sizeof(flt));
104 
105  unsigned int u = (x & 0x7fffffffU);
106  sgn = ((x >> 16) & 0x8000U);
107 
108  // NaN/+Inf/-Inf
109  if (u >= 0x7f800000U) {
110  rem = 0;
111  return static_cast<unsigned short>(
112  (u == 0x7f800000U) ? (sgn | 0x7c00U) : 0x7fffU);
113  }
114  // Overflows
115  if (u > 0x477fefffU) {
116  rem = 0x80000000U;
117  return static_cast<unsigned short>(sgn | 0x7bffU);
118  }
119  // Normal numbers
120  if (u >= 0x38800000U) {
121  rem = u << 19;
122  u -= 0x38000000U;
123  return static_cast<unsigned short>(sgn | (u >> 13));
124  }
125  // +0/-0
126  if (u < 0x33000001U) {
127  rem = u;
128  return static_cast<unsigned short>(sgn);
129  }
130  // Denormal numbers
131  unsigned int exponent = u >> 23;
132  unsigned int mantissa = (u & 0x7fffffU);
133  unsigned int shift = 0x7eU - exponent;
134  mantissa |= 0x800000U;
135  rem = mantissa << (32 - shift);
136  return static_cast<unsigned short>(sgn | (mantissa >> shift));
137  }
138 
139  inline
140  __half __float2half(float x)
141  {
142  __half_raw r;
143  unsigned int sgn{};
144  unsigned int rem{};
145  r.x = __internal_float2half(x, sgn, rem);
146  if (rem > 0x80000000U || (rem == 0x80000000U && (r.x & 0x1))) ++r.x;
147 
148  return r;
149  }
150 
151  inline
152  __half __float2half_rn(float x) { return __float2half(x); }
153 
154  inline
155  __half __float2half_rz(float x)
156  {
157  __half_raw r;
158  unsigned int sgn{};
159  unsigned int rem{};
160  r.x = __internal_float2half(x, sgn, rem);
161 
162  return r;
163  }
164 
165  inline
166  __half __float2half_rd(float x)
167  {
168  __half_raw r;
169  unsigned int sgn{};
170  unsigned int rem{};
171  r.x = __internal_float2half(x, sgn, rem);
172  if (rem && sgn) ++r.x;
173 
174  return r;
175  }
176 
177  inline
178  __half __float2half_ru(float x)
179  {
180  __half_raw r;
181  unsigned int sgn{};
182  unsigned int rem{};
183  r.x = __internal_float2half(x, sgn, rem);
184  if (rem && !sgn) ++r.x;
185 
186  return r;
187  }
188 
189  inline
190  __half2 __float2half2_rn(float x)
191  {
192  return __half2{__float2half_rn(x), __float2half_rn(x)};
193  }
194 
195  inline
196  __half2 __floats2half2_rn(float x, float y)
197  {
198  return __half2{__float2half_rn(x), __float2half_rn(y)};
199  }
200 
201  inline
202  float __internal_half2float(unsigned short x)
203  {
204  unsigned int sign = ((x >> 15) & 1);
205  unsigned int exponent = ((x >> 10) & 0x1f);
206  unsigned int mantissa = ((x & 0x3ff) << 13);
207 
208  if (exponent == 0x1fU) { /* NaN or Inf */
209  mantissa = (mantissa ? (sign = 0, 0x7fffffU) : 0);
210  exponent = 0xffU;
211  } else if (!exponent) { /* Denorm or Zero */
212  if (mantissa) {
213  unsigned int msb;
214  exponent = 0x71U;
215  do {
216  msb = (mantissa & 0x400000U);
217  mantissa <<= 1; /* normalize */
218  --exponent;
219  } while (!msb);
220  mantissa &= 0x7fffffU; /* 1.mantissa is implicit */
221  }
222  } else {
223  exponent += 0x70U;
224  }
225  unsigned int u = ((sign << 31) | (exponent << 23) | mantissa);
226  float f;
227  memcpy(&f, &u, sizeof(u));
228 
229  return f;
230  }
231 
232  inline
233  float __half2float(__half x)
234  {
235  return __internal_half2float(static_cast<__half_raw>(x).x);
236  }
237 
238  inline
239  float __low2float(__half2 x)
240  {
241  return __internal_half2float(static_cast<__half2_raw>(x).x);
242  }
243 
244  inline
245  float __high2float(__half2 x)
246  {
247  return __internal_half2float(static_cast<__half2_raw>(x).y);
248  }
249 
250  #if !defined(HIP_NO_HALF)
251  using half = __half;
252  using half2 = __half2;
253  #endif
254 #endif // defined(__cplusplus)
__half2_raw
Definition: hip_fp16_gcc.h:11
__half_raw
Definition: hip_fp16_gcc.h:7