diff --git a/Samples/reduction/reduction_kernel.cu b/Samples/reduction/reduction_kernel.cu index 67c536a2..9e688fbf 100644 --- a/Samples/reduction/reduction_kernel.cu +++ b/Samples/reduction/reduction_kernel.cu @@ -32,6 +32,7 @@ #ifndef _REDUCE_KERNEL_H_ #define _REDUCE_KERNEL_H_ +#define _CG_ABI_EXPERIMENTAL #include #include #include @@ -550,6 +551,61 @@ __global__ void cg_reduce(T *g_idata, T *g_odata, unsigned int n) { if (threadRank == 0) g_odata[blockIdx.x] = threadVal; } +template +__global__ void multi_warp_cg_reduce(T *g_idata, T *g_odata, unsigned int n) { + // Shared memory for intermediate steps + T *sdata = SharedMemory(); + __shared__ cg::experimental::block_tile_memory scratch; + + // Handle to thread block group + auto cta = cg::experimental::this_thread_block(scratch); + // Handle to multiWarpTile in thread block + auto multiWarpTile = cg::experimental::tiled_partition(cta); + + unsigned int gridSize = BlockSize * gridDim.x; + T threadVal = 0; + + // we reduce multiple elements per thread. The number is determined by the + // number of active thread blocks (via gridDim). More blocks will result + // in a larger gridSize and therefore fewer elements per thread + int nIsPow2 = !(n & n-1); + if (nIsPow2) { + unsigned int i = blockIdx.x * BlockSize * 2 + threadIdx.x; + gridSize = gridSize << 1; + + while (i < n) { + threadVal += g_idata[i]; + // ensure we don't read out of bounds -- this is optimized away for + // powerOf2 sized arrays + if ((i + BlockSize) < n) { + threadVal += g_idata[i + blockDim.x]; + } + i += gridSize; + } + } else { + unsigned int i = blockIdx.x * BlockSize + threadIdx.x; + while (i < n) { + threadVal += g_idata[i]; + i += gridSize; + } + } + + threadVal = cg_reduce_n(threadVal, multiWarpTile); + + if (multiWarpTile.thread_rank() == 0) { + sdata[multiWarpTile.meta_group_rank()] = threadVal; + } + cg::sync(cta); + + if (threadIdx.x == 0) { + threadVal = 0; + for (int i=0; i < multiWarpTile.meta_group_size(); i++) { + threadVal += sdata[i]; + } + g_odata[blockIdx.x] = threadVal; + } +} + extern "C" bool isPow2(unsigned int x); //////////////////////////////////////////////////////////////////////////////// @@ -566,6 +622,13 @@ void reduce(int size, int threads, int blocks, int whichKernel, T *d_idata, int smemSize = (threads <= 32) ? 2 * threads * sizeof(T) : threads * sizeof(T); + // as kernel 9 - multi_warp_cg_reduce cannot work for more than 64 threads + // we choose to set kernel 7 for this purpose. + if (threads < 64 && whichKernel == 9) + { + whichKernel = 7; + } + // choose which of the optimized versions of reduction to launch switch (whichKernel) { case 0: @@ -809,6 +872,10 @@ void reduce(int size, int threads, int blocks, int whichKernel, T *d_idata, smemSize = ((threads / 32) + 1) * sizeof(T); if (isPow2(size)) { switch (threads) { + case 1024: + reduce7 + <<>>(d_idata, d_odata, size); + break; case 512: reduce7 <<>>(d_idata, d_odata, size); @@ -861,6 +928,10 @@ void reduce(int size, int threads, int blocks, int whichKernel, T *d_idata, } } else { switch (threads) { + case 1024: + reduce7 + <<>>(d_idata, d_odata, size); + break; case 512: reduce7 <<>>(d_idata, d_odata, size); @@ -915,9 +986,42 @@ void reduce(int size, int threads, int blocks, int whichKernel, T *d_idata, break; case 8: - default: cg_reduce<<>>(d_idata, d_odata, size); break; + case 9: + constexpr int numOfMultiWarpGroups = 2; + smemSize = numOfMultiWarpGroups * sizeof(T); + switch (threads) { + case 1024: + multi_warp_cg_reduce + <<>>(d_idata, d_odata, size); + break; + + case 512: + multi_warp_cg_reduce + <<>>(d_idata, d_odata, size); + break; + + case 256: + multi_warp_cg_reduce + <<>>(d_idata, d_odata, size); + break; + + case 128: + multi_warp_cg_reduce + <<>>(d_idata, d_odata, size); + break; + + case 64: + multi_warp_cg_reduce + <<>>(d_idata, d_odata, size); + break; + + default: + printf("thread block size of < 64 is not supported for this kernel\n"); + break; + } + break; } }