HIP: Heterogenous-computing Interface for Portability
hip_fp16_gcc.h
1 #pragma once
2 
3 #if defined(__cplusplus)
4  #include <cstring>
5 #endif
6 
7 struct __half_raw {
8  unsigned short x;
9 };
10 
11 struct __half2_raw {
12  unsigned short x;
13  unsigned short y;
14 };
15 
16 #if defined(__cplusplus)
17  struct __half;
18 
19  __half __float2half(float);
20  float __half2float(__half);
21 
22  // BEGIN STRUCT __HALF
23  struct __half {
24  protected:
25  unsigned short __x;
26  public:
27  // CREATORS
28  __half() = default;
29  __half(const __half_raw& x) : __x{x.x} {}
30  #if !defined(__HIP_NO_HALF_CONVERSIONS__)
31  __half(float x) : __x{__float2half(x).__x} {}
32  __half(double x) : __x{__float2half(x).__x} {}
33  #endif
34  __half(const __half&) = default;
35  __half(__half&&) = default;
36  ~__half() = default;
37 
38  // MANIPULATORS
39  __half& operator=(const __half&) = default;
40  __half& operator=(__half&&) = default;
41  __half& operator=(const __half_raw& x) { __x = x.x; return *this; }
42  #if !defined(__HIP_NO_HALF_CONVERSIONS__)
43  __half& operator=(float x)
44  {
45  __x = __float2half(x).__x;
46  return *this;
47  }
48  __half& operator=(double x)
49  {
50  return *this = static_cast<float>(x);
51  }
52  #endif
53 
54  // ACCESSORS
55  operator float() const { return __half2float(*this); }
56  operator __half_raw() const { return __half_raw{__x}; }
57  };
58  // END STRUCT __HALF
59 
60  // BEGIN STRUCT __HALF2
61  struct __half2 {
62  protected:
63  __half x;
64  __half y;
65  public:
66  // CREATORS
67  __half2() = default;
68  __half2(const __half2_raw& ix)
69  :
70  x{reinterpret_cast<const __half&>(ix.x)},
71  y{reinterpret_cast<const __half&>(ix.y)}
72  {}
73  __half2(const __half& ix, const __half& iy) : x{ix}, y{iy} {}
74  __half2(const __half2&) = default;
75  __half2(__half2&&) = default;
76  ~__half2() = default;
77 
78  // MANIPULATORS
79  __half2& operator=(const __half2&) = default;
80  __half2& operator=(__half2&&) = default;
81  __half2& operator=(const __half2_raw& ix)
82  {
83  x = reinterpret_cast<const __half_raw&>(ix.x);
84  y = reinterpret_cast<const __half_raw&>(ix.y);
85  return *this;
86  }
87 
88  // ACCESSORS
89  operator __half2_raw() const
90  {
91  return __half2_raw{
92  reinterpret_cast<const unsigned short&>(x),
93  reinterpret_cast<const unsigned short&>(y)};
94  }
95  };
96  // END STRUCT __HALF2
97 
98  namespace
99  {
100  inline
101  unsigned short __internal_float2half(
102  float flt, unsigned int& sgn, unsigned int& rem)
103  {
104  unsigned int x{};
105  std::memcpy(&x, &flt, sizeof(flt));
106 
107  unsigned int u = (x & 0x7fffffffU);
108  sgn = ((x >> 16) & 0x8000U);
109 
110  // NaN/+Inf/-Inf
111  if (u >= 0x7f800000U) {
112  rem = 0;
113  return static_cast<unsigned short>(
114  (u == 0x7f800000U) ? (sgn | 0x7c00U) : 0x7fffU);
115  }
116  // Overflows
117  if (u > 0x477fefffU) {
118  rem = 0x80000000U;
119  return static_cast<unsigned short>(sgn | 0x7bffU);
120  }
121  // Normal numbers
122  if (u >= 0x38800000U) {
123  rem = u << 19;
124  u -= 0x38000000U;
125  return static_cast<unsigned short>(sgn | (u >> 13));
126  }
127  // +0/-0
128  if (u < 0x33000001U) {
129  rem = u;
130  return static_cast<unsigned short>(sgn);
131  }
132  // Denormal numbers
133  unsigned int exponent = u >> 23;
134  unsigned int mantissa = (u & 0x7fffffU);
135  unsigned int shift = 0x7eU - exponent;
136  mantissa |= 0x800000U;
137  rem = mantissa << (32 - shift);
138  return static_cast<unsigned short>(sgn | (mantissa >> shift));
139  }
140 
141  inline
142  __half __float2half(float x)
143  {
144  __half_raw r;
145  unsigned int sgn{};
146  unsigned int rem{};
147  r.x = __internal_float2half(x, sgn, rem);
148  if (rem > 0x80000000U || (rem == 0x80000000U && (r.x & 0x1))) ++r.x;
149 
150  return r;
151  }
152 
153  inline
154  __half __float2half_rn(float x) { return __float2half(x); }
155 
156  inline
157  __half __float2half_rz(float x)
158  {
159  __half_raw r;
160  unsigned int sgn{};
161  unsigned int rem{};
162  r.x = __internal_float2half(x, sgn, rem);
163 
164  return r;
165  }
166 
167  inline
168  __half __float2half_rd(float x)
169  {
170  __half_raw r;
171  unsigned int sgn{};
172  unsigned int rem{};
173  r.x = __internal_float2half(x, sgn, rem);
174  if (rem && sgn) ++r.x;
175 
176  return r;
177  }
178 
179  inline
180  __half __float2half_ru(float x)
181  {
182  __half_raw r;
183  unsigned int sgn{};
184  unsigned int rem{};
185  r.x = __internal_float2half(x, sgn, rem);
186  if (rem && !sgn) ++r.x;
187 
188  return r;
189  }
190 
191  inline
192  __half2 __float2half2_rn(float x)
193  {
194  return __half2{__float2half_rn(x), __float2half_rn(x)};
195  }
196 
197  inline
198  __half2 __floats2half2_rn(float x, float y)
199  {
200  return __half2{__float2half_rn(x), __float2half_rn(y)};
201  }
202 
203  inline
204  float __internal_half2float(unsigned short x)
205  {
206  unsigned int sign = ((x >> 15) & 1);
207  unsigned int exponent = ((x >> 10) & 0x1f);
208  unsigned int mantissa = ((x & 0x3ff) << 13);
209 
210  if (exponent == 0x1fU) { /* NaN or Inf */
211  mantissa = (mantissa ? (sign = 0, 0x7fffffU) : 0);
212  exponent = 0xffU;
213  } else if (!exponent) { /* Denorm or Zero */
214  if (mantissa) {
215  unsigned int msb;
216  exponent = 0x71U;
217  do {
218  msb = (mantissa & 0x400000U);
219  mantissa <<= 1; /* normalize */
220  --exponent;
221  } while (!msb);
222  mantissa &= 0x7fffffU; /* 1.mantissa is implicit */
223  }
224  } else {
225  exponent += 0x70U;
226  }
227  unsigned int u = ((sign << 31) | (exponent << 23) | mantissa);
228  float f;
229  memcpy(&f, &u, sizeof(u));
230 
231  return f;
232  }
233 
234  inline
235  float __half2float(__half x)
236  {
237  return __internal_half2float(static_cast<__half_raw>(x).x);
238  }
239 
240  inline
241  float __low2float(__half2 x)
242  {
243  return __internal_half2float(static_cast<__half2_raw>(x).x);
244  }
245 
246  inline
247  float __high2float(__half2 x)
248  {
249  return __internal_half2float(static_cast<__half2_raw>(x).y);
250  }
251  } // Anonymous namespace.
252 
253  #if !defined(HIP_NO_HALF)
254  using half = __half;
255  using half2 = __half2;
256  #endif
257 #endif // defined(__cplusplus)
Definition: hip_fp16_gcc.h:11
Definition: hip_fp16_gcc.h:7