add multi-warp cooperative groups based reduction kernel in reduction sample

This commit is contained in:
Mahesh Doijade 2020-09-24 16:49:58 +05:30
parent dd2dba3489
commit c4e2869a2b

View File

@ -32,6 +32,7 @@
#ifndef _REDUCE_KERNEL_H_
#define _REDUCE_KERNEL_H_
#define _CG_ABI_EXPERIMENTAL
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include <stdio.h>
@ -550,6 +551,61 @@ __global__ void cg_reduce(T *g_idata, T *g_odata, unsigned int n) {
if (threadRank == 0) g_odata[blockIdx.x] = threadVal;
}
template <class T, size_t BlockSize, size_t MultiWarpGroupSize>
__global__ void multi_warp_cg_reduce(T *g_idata, T *g_odata, unsigned int n) {
// Shared memory for intermediate steps
T *sdata = SharedMemory<T>();
__shared__ cg::experimental::block_tile_memory<sizeof(T), BlockSize> scratch;
// Handle to thread block group
auto cta = cg::experimental::this_thread_block(scratch);
// Handle to multiWarpTile in thread block
auto multiWarpTile = cg::experimental::tiled_partition<MultiWarpGroupSize>(cta);
unsigned int gridSize = BlockSize * gridDim.x;
T threadVal = 0;
// we reduce multiple elements per thread. The number is determined by the
// number of active thread blocks (via gridDim). More blocks will result
// in a larger gridSize and therefore fewer elements per thread
int nIsPow2 = !(n & n-1);
if (nIsPow2) {
unsigned int i = blockIdx.x * BlockSize * 2 + threadIdx.x;
gridSize = gridSize << 1;
while (i < n) {
threadVal += g_idata[i];
// ensure we don't read out of bounds -- this is optimized away for
// powerOf2 sized arrays
if ((i + BlockSize) < n) {
threadVal += g_idata[i + blockDim.x];
}
i += gridSize;
}
} else {
unsigned int i = blockIdx.x * BlockSize + threadIdx.x;
while (i < n) {
threadVal += g_idata[i];
i += gridSize;
}
}
threadVal = cg_reduce_n(threadVal, multiWarpTile);
if (multiWarpTile.thread_rank() == 0) {
sdata[multiWarpTile.meta_group_rank()] = threadVal;
}
cg::sync(cta);
if (threadIdx.x == 0) {
threadVal = 0;
for (int i=0; i < multiWarpTile.meta_group_size(); i++) {
threadVal += sdata[i];
}
g_odata[blockIdx.x] = threadVal;
}
}
extern "C" bool isPow2(unsigned int x);
////////////////////////////////////////////////////////////////////////////////
@ -566,6 +622,13 @@ void reduce(int size, int threads, int blocks, int whichKernel, T *d_idata,
int smemSize =
(threads <= 32) ? 2 * threads * sizeof(T) : threads * sizeof(T);
// as kernel 9 - multi_warp_cg_reduce cannot work for more than 64 threads
// we choose to set kernel 7 for this purpose.
if (threads < 64 && whichKernel == 9)
{
whichKernel = 7;
}
// choose which of the optimized versions of reduction to launch
switch (whichKernel) {
case 0:
@ -809,6 +872,10 @@ void reduce(int size, int threads, int blocks, int whichKernel, T *d_idata,
smemSize = ((threads / 32) + 1) * sizeof(T);
if (isPow2(size)) {
switch (threads) {
case 1024:
reduce7<T, 1024, true>
<<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
break;
case 512:
reduce7<T, 512, true>
<<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
@ -861,6 +928,10 @@ void reduce(int size, int threads, int blocks, int whichKernel, T *d_idata,
}
} else {
switch (threads) {
case 1024:
reduce7<T, 1024, true>
<<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
break;
case 512:
reduce7<T, 512, false>
<<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
@ -915,9 +986,42 @@ void reduce(int size, int threads, int blocks, int whichKernel, T *d_idata,
break;
case 8:
default:
cg_reduce<T><<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
break;
case 9:
constexpr int numOfMultiWarpGroups = 2;
smemSize = numOfMultiWarpGroups * sizeof(T);
switch (threads) {
case 1024:
multi_warp_cg_reduce<T, 1024, 1024/numOfMultiWarpGroups>
<<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
break;
case 512:
multi_warp_cg_reduce<T, 512, 512/numOfMultiWarpGroups>
<<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
break;
case 256:
multi_warp_cg_reduce<T, 256, 256/numOfMultiWarpGroups>
<<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
break;
case 128:
multi_warp_cg_reduce<T, 128, 128/numOfMultiWarpGroups>
<<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
break;
case 64:
multi_warp_cg_reduce<T, 64, 64/numOfMultiWarpGroups>
<<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
break;
default:
printf("thread block size of < 64 is not supported for this kernel\n");
break;
}
break;
}
}