mirror of
https://github.com/NVIDIA/cuda-samples.git
synced 2025-01-19 07:05:47 +08:00
add multi-warp cooperative groups based reduction kernel in reduction sample
This commit is contained in:
parent
dd2dba3489
commit
c4e2869a2b
|
@ -32,6 +32,7 @@
|
|||
#ifndef _REDUCE_KERNEL_H_
|
||||
#define _REDUCE_KERNEL_H_
|
||||
|
||||
#define _CG_ABI_EXPERIMENTAL
|
||||
#include <cooperative_groups.h>
|
||||
#include <cooperative_groups/reduce.h>
|
||||
#include <stdio.h>
|
||||
|
@ -550,6 +551,61 @@ __global__ void cg_reduce(T *g_idata, T *g_odata, unsigned int n) {
|
|||
if (threadRank == 0) g_odata[blockIdx.x] = threadVal;
|
||||
}
|
||||
|
||||
template <class T, size_t BlockSize, size_t MultiWarpGroupSize>
|
||||
__global__ void multi_warp_cg_reduce(T *g_idata, T *g_odata, unsigned int n) {
|
||||
// Shared memory for intermediate steps
|
||||
T *sdata = SharedMemory<T>();
|
||||
__shared__ cg::experimental::block_tile_memory<sizeof(T), BlockSize> scratch;
|
||||
|
||||
// Handle to thread block group
|
||||
auto cta = cg::experimental::this_thread_block(scratch);
|
||||
// Handle to multiWarpTile in thread block
|
||||
auto multiWarpTile = cg::experimental::tiled_partition<MultiWarpGroupSize>(cta);
|
||||
|
||||
unsigned int gridSize = BlockSize * gridDim.x;
|
||||
T threadVal = 0;
|
||||
|
||||
// we reduce multiple elements per thread. The number is determined by the
|
||||
// number of active thread blocks (via gridDim). More blocks will result
|
||||
// in a larger gridSize and therefore fewer elements per thread
|
||||
int nIsPow2 = !(n & n-1);
|
||||
if (nIsPow2) {
|
||||
unsigned int i = blockIdx.x * BlockSize * 2 + threadIdx.x;
|
||||
gridSize = gridSize << 1;
|
||||
|
||||
while (i < n) {
|
||||
threadVal += g_idata[i];
|
||||
// ensure we don't read out of bounds -- this is optimized away for
|
||||
// powerOf2 sized arrays
|
||||
if ((i + BlockSize) < n) {
|
||||
threadVal += g_idata[i + blockDim.x];
|
||||
}
|
||||
i += gridSize;
|
||||
}
|
||||
} else {
|
||||
unsigned int i = blockIdx.x * BlockSize + threadIdx.x;
|
||||
while (i < n) {
|
||||
threadVal += g_idata[i];
|
||||
i += gridSize;
|
||||
}
|
||||
}
|
||||
|
||||
threadVal = cg_reduce_n(threadVal, multiWarpTile);
|
||||
|
||||
if (multiWarpTile.thread_rank() == 0) {
|
||||
sdata[multiWarpTile.meta_group_rank()] = threadVal;
|
||||
}
|
||||
cg::sync(cta);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
threadVal = 0;
|
||||
for (int i=0; i < multiWarpTile.meta_group_size(); i++) {
|
||||
threadVal += sdata[i];
|
||||
}
|
||||
g_odata[blockIdx.x] = threadVal;
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" bool isPow2(unsigned int x);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
@ -566,6 +622,13 @@ void reduce(int size, int threads, int blocks, int whichKernel, T *d_idata,
|
|||
int smemSize =
|
||||
(threads <= 32) ? 2 * threads * sizeof(T) : threads * sizeof(T);
|
||||
|
||||
// as kernel 9 - multi_warp_cg_reduce cannot work for more than 64 threads
|
||||
// we choose to set kernel 7 for this purpose.
|
||||
if (threads < 64 && whichKernel == 9)
|
||||
{
|
||||
whichKernel = 7;
|
||||
}
|
||||
|
||||
// choose which of the optimized versions of reduction to launch
|
||||
switch (whichKernel) {
|
||||
case 0:
|
||||
|
@ -809,6 +872,10 @@ void reduce(int size, int threads, int blocks, int whichKernel, T *d_idata,
|
|||
smemSize = ((threads / 32) + 1) * sizeof(T);
|
||||
if (isPow2(size)) {
|
||||
switch (threads) {
|
||||
case 1024:
|
||||
reduce7<T, 1024, true>
|
||||
<<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
|
||||
break;
|
||||
case 512:
|
||||
reduce7<T, 512, true>
|
||||
<<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
|
||||
|
@ -861,6 +928,10 @@ void reduce(int size, int threads, int blocks, int whichKernel, T *d_idata,
|
|||
}
|
||||
} else {
|
||||
switch (threads) {
|
||||
case 1024:
|
||||
reduce7<T, 1024, true>
|
||||
<<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
|
||||
break;
|
||||
case 512:
|
||||
reduce7<T, 512, false>
|
||||
<<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
|
||||
|
@ -915,9 +986,42 @@ void reduce(int size, int threads, int blocks, int whichKernel, T *d_idata,
|
|||
|
||||
break;
|
||||
case 8:
|
||||
default:
|
||||
cg_reduce<T><<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
|
||||
break;
|
||||
case 9:
|
||||
constexpr int numOfMultiWarpGroups = 2;
|
||||
smemSize = numOfMultiWarpGroups * sizeof(T);
|
||||
switch (threads) {
|
||||
case 1024:
|
||||
multi_warp_cg_reduce<T, 1024, 1024/numOfMultiWarpGroups>
|
||||
<<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
|
||||
break;
|
||||
|
||||
case 512:
|
||||
multi_warp_cg_reduce<T, 512, 512/numOfMultiWarpGroups>
|
||||
<<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
|
||||
break;
|
||||
|
||||
case 256:
|
||||
multi_warp_cg_reduce<T, 256, 256/numOfMultiWarpGroups>
|
||||
<<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
|
||||
break;
|
||||
|
||||
case 128:
|
||||
multi_warp_cg_reduce<T, 128, 128/numOfMultiWarpGroups>
|
||||
<<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
|
||||
break;
|
||||
|
||||
case 64:
|
||||
multi_warp_cg_reduce<T, 64, 64/numOfMultiWarpGroups>
|
||||
<<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
|
||||
break;
|
||||
|
||||
default:
|
||||
printf("thread block size of < 64 is not supported for this kernel\n");
|
||||
break;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user