32 #ifndef HIP_INCLUDE_HIP_HCC_DETAIL_HIP_COOPERATIVE_GROUPS_H
33 #define HIP_INCLUDE_HIP_HCC_DETAIL_HIP_COOPERATIVE_GROUPS_H
38 namespace cooperative_groups {
57 __CG_QUALIFIER__ thread_group(internal::group_type type, uint32_t size,
58 uint64_t mask = (uint64_t)0) {
68 __CG_QUALIFIER__ uint32_t size()
const {
72 __CG_QUALIFIER__ uint32_t thread_rank()
const;
74 __CG_QUALIFIER__
bool is_valid()
const;
76 __CG_QUALIFIER__
void sync()
const;
85 class multi_grid_group :
public thread_group {
88 friend __CG_QUALIFIER__ multi_grid_group this_multi_grid();
92 explicit __CG_QUALIFIER__ multi_grid_group(uint32_t size)
93 : thread_group(internal::cg_multi_grid, size) { }
98 __CG_QUALIFIER__ uint32_t num_grids() {
99 return internal::multi_grid::num_grids();
103 __CG_QUALIFIER__ uint32_t grid_rank() {
104 return internal::multi_grid::grid_rank();
106 __CG_QUALIFIER__ uint32_t thread_rank()
const {
107 return internal::multi_grid::thread_rank();
109 __CG_QUALIFIER__
bool is_valid()
const {
110 return internal::multi_grid::is_valid();
112 __CG_QUALIFIER__
void sync()
const {
113 internal::multi_grid::sync();
124 __CG_QUALIFIER__ multi_grid_group
126 return multi_grid_group(internal::multi_grid::size());
135 class grid_group :
public thread_group {
138 friend __CG_QUALIFIER__ grid_group this_grid();
142 explicit __CG_QUALIFIER__ grid_group(uint32_t size)
143 : thread_group(internal::cg_grid, size) { }
146 __CG_QUALIFIER__ uint32_t thread_rank()
const {
147 return internal::grid::thread_rank();
149 __CG_QUALIFIER__
bool is_valid()
const {
150 return internal::grid::is_valid();
152 __CG_QUALIFIER__
void sync()
const {
153 internal::grid::sync();
164 __CG_QUALIFIER__ grid_group
166 return grid_group(internal::grid::size());
176 class thread_block :
public thread_group {
179 friend __CG_QUALIFIER__ thread_block this_thread_block();
183 explicit __CG_QUALIFIER__ thread_block(uint32_t size)
184 : thread_group(internal::cg_workgroup, size) { }
188 __CG_QUALIFIER__
dim3 group_index() {
189 return internal::workgroup::group_index();
192 __CG_QUALIFIER__
dim3 thread_index() {
193 return internal::workgroup::thread_index();
195 __CG_QUALIFIER__ uint32_t thread_rank()
const {
196 return internal::workgroup::thread_rank();
198 __CG_QUALIFIER__
bool is_valid()
const {
199 return internal::workgroup::is_valid();
201 __CG_QUALIFIER__
void sync()
const {
202 internal::workgroup::sync();
213 __CG_QUALIFIER__ thread_block
214 this_thread_block() {
215 return thread_block(internal::workgroup::size());
221 __CG_QUALIFIER__ uint32_t thread_group::thread_rank()
const {
222 switch (this->_type) {
223 case internal::cg_multi_grid: {
224 return (
static_cast<const multi_grid_group*
>(
this)->thread_rank());
226 case internal::cg_grid: {
227 return (
static_cast<const grid_group*
>(
this)->thread_rank());
229 case internal::cg_workgroup: {
230 return (
static_cast<const thread_block*
>(
this)->thread_rank());
233 assert(
false &&
"invalid cooperative group type");
239 __CG_QUALIFIER__
bool thread_group::is_valid()
const {
240 switch (this->_type) {
241 case internal::cg_multi_grid: {
242 return (
static_cast<const multi_grid_group*
>(
this)->is_valid());
244 case internal::cg_grid: {
245 return (
static_cast<const grid_group*
>(
this)->is_valid());
247 case internal::cg_workgroup: {
248 return (
static_cast<const thread_block*
>(
this)->is_valid());
251 assert(
false &&
"invalid cooperative group type");
257 __CG_QUALIFIER__
void thread_group::sync()
const {
258 switch (this->_type) {
259 case internal::cg_multi_grid: {
260 static_cast<const multi_grid_group*
>(
this)->sync();
263 case internal::cg_grid: {
264 static_cast<const grid_group*
>(
this)->sync();
267 case internal::cg_workgroup: {
268 static_cast<const thread_block*
>(
this)->sync();
272 assert(
false &&
"invalid cooperative group type");
281 template <
class CGTy>
282 __CG_QUALIFIER__ uint32_t group_size(CGTy
const &g) {
286 template <
class CGTy>
287 __CG_QUALIFIER__ uint32_t thread_rank(CGTy
const &g) {
288 return g.thread_rank();
291 template <
class CGTy>
292 __CG_QUALIFIER__
bool is_valid(CGTy
const &g) {
296 template <
class CGTy>
297 __CG_QUALIFIER__
void sync(CGTy
const &g) {
303 #endif // __cplusplus
304 #endif // HIP_INCLUDE_HIP_HCC_DETAIL_HIP_COOPERATIVE_GROUPS_H