32 #ifndef HIP_INCLUDE_HIP_AMD_DETAIL_HIP_COOPERATIVE_GROUPS_H
33 #define HIP_INCLUDE_HIP_AMD_DETAIL_HIP_COOPERATIVE_GROUPS_H
38 namespace cooperative_groups {
57 __CG_QUALIFIER__ thread_group(internal::group_type type, uint32_t size,
58 uint64_t mask = (uint64_t)0) {
69 friend __CG_QUALIFIER__ thread_group tiled_partition(
const thread_group& parent,
70 unsigned int tile_size);
71 friend class thread_block;
77 __CG_QUALIFIER__ uint32_t size()
const {
return _size; }
78 __CG_QUALIFIER__
unsigned int cg_type()
const {
return _type; }
80 __CG_QUALIFIER__ uint32_t thread_rank()
const;
82 __CG_QUALIFIER__
bool is_valid()
const;
84 __CG_QUALIFIER__
void sync()
const;
93 class multi_grid_group :
public thread_group {
96 friend __CG_QUALIFIER__ multi_grid_group this_multi_grid();
100 explicit __CG_QUALIFIER__ multi_grid_group(uint32_t size)
101 : thread_group(internal::cg_multi_grid, size) {}
106 __CG_QUALIFIER__ uint32_t num_grids() {
return internal::multi_grid::num_grids(); }
109 __CG_QUALIFIER__ uint32_t grid_rank() {
return internal::multi_grid::grid_rank(); }
110 __CG_QUALIFIER__ uint32_t thread_rank()
const {
return internal::multi_grid::thread_rank(); }
111 __CG_QUALIFIER__
bool is_valid()
const {
return internal::multi_grid::is_valid(); }
112 __CG_QUALIFIER__
void sync()
const { internal::multi_grid::sync(); }
122 __CG_QUALIFIER__ multi_grid_group this_multi_grid() {
123 return multi_grid_group(internal::multi_grid::size());
132 class grid_group :
public thread_group {
135 friend __CG_QUALIFIER__ grid_group this_grid();
139 explicit __CG_QUALIFIER__ grid_group(uint32_t size) : thread_group(internal::cg_grid, size) {}
142 __CG_QUALIFIER__ uint32_t thread_rank()
const {
return internal::grid::thread_rank(); }
143 __CG_QUALIFIER__
bool is_valid()
const {
return internal::grid::is_valid(); }
144 __CG_QUALIFIER__
void sync()
const { internal::grid::sync(); }
154 __CG_QUALIFIER__ grid_group this_grid() {
return grid_group(internal::grid::size()); }
163 class thread_block :
public thread_group {
166 friend __CG_QUALIFIER__ thread_block this_thread_block();
167 friend __CG_QUALIFIER__ thread_group tiled_partition(
const thread_group& parent,
168 unsigned int tile_size);
169 friend __CG_QUALIFIER__ thread_group tiled_partition(
const thread_block& parent,
170 unsigned int tile_size);
174 explicit __CG_QUALIFIER__ thread_block(uint32_t size)
175 : thread_group(internal::cg_workgroup, size) {}
177 __CG_QUALIFIER__ thread_group new_tiled_group(
unsigned int tile_size)
const {
178 const bool pow2 = ((tile_size & (tile_size - 1)) == 0);
180 if (!tile_size || (tile_size > WAVEFRONT_SIZE) || !pow2) {
181 assert(
false &&
"invalid tile size");
184 thread_group tiledGroup = thread_group(internal::cg_tiled_group, tile_size);
185 tiledGroup.tiled_info.size = tile_size;
186 tiledGroup.tiled_info.is_tiled =
true;
192 __CG_QUALIFIER__
dim3 group_index() {
return internal::workgroup::group_index(); }
194 __CG_QUALIFIER__
dim3 thread_index() {
return internal::workgroup::thread_index(); }
195 __CG_QUALIFIER__ uint32_t thread_rank()
const {
return internal::workgroup::thread_rank(); }
196 __CG_QUALIFIER__
bool is_valid()
const {
return internal::workgroup::is_valid(); }
197 __CG_QUALIFIER__
void sync()
const { internal::workgroup::sync(); }
207 __CG_QUALIFIER__ thread_block this_thread_block() {
208 return thread_block(internal::workgroup::size());
217 class tiled_group :
public thread_group {
219 friend __CG_QUALIFIER__ thread_group tiled_partition(
const thread_group& parent,
220 unsigned int tile_size);
221 friend __CG_QUALIFIER__ tiled_group tiled_partition(
const tiled_group& parent,
222 unsigned int tile_size);
224 __CG_QUALIFIER__ tiled_group new_tiled_group(
unsigned int tile_size)
const {
225 const bool pow2 = ((tile_size & (tile_size - 1)) == 0);
227 if (!tile_size || (tile_size > WAVEFRONT_SIZE) || !pow2) {
228 assert(
false &&
"invalid tile size");
231 if (size() <= tile_size) {
235 tiled_group tiledGroup = tiled_group(tile_size);
236 tiledGroup.tiled_info.is_tiled =
true;
241 explicit __CG_QUALIFIER__ tiled_group(
unsigned int tileSize)
242 : thread_group(internal::cg_tiled_group, tileSize) {
243 tiled_info.size = tileSize;
244 tiled_info.is_tiled =
true;
248 __CG_QUALIFIER__
unsigned int size()
const {
return (tiled_info.size); }
250 __CG_QUALIFIER__
unsigned int thread_rank()
const {
251 return (internal::workgroup::thread_rank() & (tiled_info.size - 1));
254 __CG_QUALIFIER__
void sync()
const {
256 __builtin_amdgcn_fence(__ATOMIC_ACQ_REL,
"agent");
263 __CG_QUALIFIER__ uint32_t thread_group::thread_rank()
const {
264 switch (this->_type) {
265 case internal::cg_multi_grid: {
266 return (
static_cast<const multi_grid_group*
>(
this)->thread_rank());
268 case internal::cg_grid: {
269 return (
static_cast<const grid_group*
>(
this)->thread_rank());
271 case internal::cg_workgroup: {
272 return (
static_cast<const thread_block*
>(
this)->thread_rank());
274 case internal::cg_tiled_group: {
275 return (
static_cast<const tiled_group*
>(
this)->thread_rank());
278 assert(
false &&
"invalid cooperative group type");
284 __CG_QUALIFIER__
bool thread_group::is_valid()
const {
285 switch (this->_type) {
286 case internal::cg_multi_grid: {
287 return (
static_cast<const multi_grid_group*
>(
this)->is_valid());
289 case internal::cg_grid: {
290 return (
static_cast<const grid_group*
>(
this)->is_valid());
292 case internal::cg_workgroup: {
293 return (
static_cast<const thread_block*
>(
this)->is_valid());
295 case internal::cg_tiled_group: {
296 return (
static_cast<const tiled_group*
>(
this)->is_valid());
299 assert(
false &&
"invalid cooperative group type");
305 __CG_QUALIFIER__
void thread_group::sync()
const {
306 switch (this->_type) {
307 case internal::cg_multi_grid: {
308 static_cast<const multi_grid_group*
>(
this)->sync();
311 case internal::cg_grid: {
312 static_cast<const grid_group*
>(
this)->sync();
315 case internal::cg_workgroup: {
316 static_cast<const thread_block*
>(
this)->sync();
319 case internal::cg_tiled_group: {
320 static_cast<const tiled_group*
>(
this)->sync();
324 assert(
false &&
"invalid cooperative group type");
333 template <
class CGTy> __CG_QUALIFIER__ uint32_t group_size(CGTy
const& g) {
return g.size(); }
335 template <
class CGTy> __CG_QUALIFIER__ uint32_t thread_rank(CGTy
const& g) {
336 return g.thread_rank();
339 template <
class CGTy> __CG_QUALIFIER__
bool is_valid(CGTy
const& g) {
return g.is_valid(); }
341 template <
class CGTy> __CG_QUALIFIER__
void sync(CGTy
const& g) { g.sync(); }
343 template <
unsigned int tileSize>
class tile_base {
345 _CG_STATIC_CONST_DECL_
unsigned int numThreads = tileSize;
349 _CG_STATIC_CONST_DECL_
unsigned int thread_rank() {
350 return (internal::workgroup::thread_rank() & (numThreads - 1));
354 __CG_STATIC_QUALIFIER__
unsigned int size() {
return numThreads; }
357 template <
unsigned int size>
class thread_block_tile_base :
public tile_base<size> {
358 static_assert(is_valid_tile_size<size>::value,
359 "Tile size is either not a power of 2 or greater than the wavefront size");
360 using tile_base<size>::numThreads;
363 __CG_STATIC_QUALIFIER__
void sync() {
365 __builtin_amdgcn_fence(__ATOMIC_ACQ_REL,
"agent");
368 template <
class T> __CG_QUALIFIER__ T shfl(T var,
int srcRank)
const {
369 static_assert(is_valid_type<T>::value,
"Neither an integer or float type.");
370 return (__shfl(var, srcRank, numThreads));
373 template <
class T> __CG_QUALIFIER__ T shfl_down(T var,
unsigned int lane_delta)
const {
374 static_assert(is_valid_type<T>::value,
"Neither an integer or float type.");
375 return (__shfl_down(var, lane_delta, numThreads));
378 template <
class T> __CG_QUALIFIER__ T shfl_up(T var,
unsigned int lane_delta)
const {
379 static_assert(is_valid_type<T>::value,
"Neither an integer or float type.");
380 return (__shfl_up(var, lane_delta, numThreads));
383 template <
class T> __CG_QUALIFIER__ T shfl_xor(T var,
unsigned int laneMask)
const {
384 static_assert(is_valid_type<T>::value,
"Neither an integer or float type.");
385 return (__shfl_xor(var, laneMask, numThreads));
394 template <
unsigned int tileSize,
class ParentCGTy =
void>
395 class thread_block_tile_type :
public thread_block_tile_base<tileSize>,
public tiled_group {
396 _CG_STATIC_CONST_DECL_
unsigned int numThreads = tileSize;
398 friend class thread_block_tile_type<tileSize, ParentCGTy>;
400 typedef thread_block_tile_base<numThreads> tbtBase;
403 __CG_QUALIFIER__ thread_block_tile_type() : tiled_group(numThreads) {
404 tiled_info.size = numThreads;
405 tiled_info.is_tiled =
true;
411 using tbtBase::thread_rank;
421 __CG_QUALIFIER__ thread_group tiled_partition(
const thread_group& parent,
unsigned int tile_size) {
422 if (parent.cg_type() == internal::cg_tiled_group) {
423 const tiled_group* cg =
static_cast<const tiled_group*
>(&parent);
424 return cg->new_tiled_group(tile_size);
426 const thread_block* tb =
static_cast<const thread_block*
>(&parent);
427 return tb->new_tiled_group(tile_size);
432 __CG_QUALIFIER__ thread_group tiled_partition(
const thread_block& parent,
unsigned int tile_size) {
433 return (parent.new_tiled_group(tile_size));
437 __CG_QUALIFIER__ tiled_group tiled_partition(
const tiled_group& parent,
unsigned int tile_size) {
438 return (parent.new_tiled_group(tile_size));
441 template <
unsigned int size,
class ParentCGTy>
class thread_block_tile;
444 template <
unsigned int size,
class ParentCGTy>
class thread_block_tile_internal;
446 template <
unsigned int size,
class ParentCGTy>
447 class thread_block_tile_internal :
public thread_block_tile_type<size, ParentCGTy> {
449 template <
unsigned int tbtSize,
class tbtParentT>
450 __CG_QUALIFIER__ thread_block_tile_internal(
451 const thread_block_tile_internal<tbtSize, tbtParentT>& g)
452 : thread_block_tile_type<size, ParentCGTy>() {}
454 __CG_QUALIFIER__ thread_block_tile_internal(
const thread_block& g)
455 : thread_block_tile_type<size, ParentCGTy>() {}
459 template <
unsigned int size,
class ParentCGTy>
460 class thread_block_tile :
public impl::thread_block_tile_internal<size, ParentCGTy> {
462 __CG_QUALIFIER__ thread_block_tile(
const ParentCGTy& g)
463 : impl::thread_block_tile_internal<size, ParentCGTy>(g) {}
466 __CG_QUALIFIER__
operator thread_block_tile<size, void>()
const {
467 return thread_block_tile<size, void>(*
this);
472 template <
unsigned int size>
473 class thread_block_tile<size, void> :
public impl::thread_block_tile_internal<size, void> {
474 template <
unsigned int,
class ParentCGTy>
friend class thread_block_tile;
478 template <
class ParentCGTy>
479 __CG_QUALIFIER__ thread_block_tile(
const thread_block_tile<size, ParentCGTy>& g)
480 : impl::thread_block_tile_internal<size, void>(g) {}
483 template <
unsigned int size,
class ParentCGTy =
void>
class thread_block_tile;
486 template <
unsigned int size,
class ParentCGTy =
void>
struct tiled_partition_internal;
488 template <
unsigned int size>
489 struct tiled_partition_internal<size, thread_block> :
public thread_block_tile<size, thread_block> {
490 __CG_QUALIFIER__ tiled_partition_internal(
const thread_block& g)
491 : thread_block_tile<size, thread_block>(g) {}
501 template <
unsigned int size,
class ParentCGTy>
502 __CG_QUALIFIER__ thread_block_tile<size, ParentCGTy> tiled_partition(
const ParentCGTy& g) {
503 static_assert(is_valid_tile_size<size>::value,
504 "Tiled partition with size > wavefront size. Currently not supported ");
505 return impl::tiled_partition_internal<size, ParentCGTy>(g);
509 #endif // __cplusplus
510 #endif // HIP_INCLUDE_HIP_AMD_DETAIL_HIP_COOPERATIVE_GROUPS_H