mirror of
https://github.com/NVIDIA/cuda-samples.git
synced 2024-11-24 19:59:17 +08:00
add multi-warp cooperative groups based reduction kernel in reduction sample
This commit is contained in:
parent
dd2dba3489
commit
c4e2869a2b
|
@ -32,6 +32,7 @@
|
||||||
#ifndef _REDUCE_KERNEL_H_
|
#ifndef _REDUCE_KERNEL_H_
|
||||||
#define _REDUCE_KERNEL_H_
|
#define _REDUCE_KERNEL_H_
|
||||||
|
|
||||||
|
#define _CG_ABI_EXPERIMENTAL
|
||||||
#include <cooperative_groups.h>
|
#include <cooperative_groups.h>
|
||||||
#include <cooperative_groups/reduce.h>
|
#include <cooperative_groups/reduce.h>
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
|
@ -550,6 +551,61 @@ __global__ void cg_reduce(T *g_idata, T *g_odata, unsigned int n) {
|
||||||
if (threadRank == 0) g_odata[blockIdx.x] = threadVal;
|
if (threadRank == 0) g_odata[blockIdx.x] = threadVal;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <class T, size_t BlockSize, size_t MultiWarpGroupSize>
|
||||||
|
__global__ void multi_warp_cg_reduce(T *g_idata, T *g_odata, unsigned int n) {
|
||||||
|
// Shared memory for intermediate steps
|
||||||
|
T *sdata = SharedMemory<T>();
|
||||||
|
__shared__ cg::experimental::block_tile_memory<sizeof(T), BlockSize> scratch;
|
||||||
|
|
||||||
|
// Handle to thread block group
|
||||||
|
auto cta = cg::experimental::this_thread_block(scratch);
|
||||||
|
// Handle to multiWarpTile in thread block
|
||||||
|
auto multiWarpTile = cg::experimental::tiled_partition<MultiWarpGroupSize>(cta);
|
||||||
|
|
||||||
|
unsigned int gridSize = BlockSize * gridDim.x;
|
||||||
|
T threadVal = 0;
|
||||||
|
|
||||||
|
// we reduce multiple elements per thread. The number is determined by the
|
||||||
|
// number of active thread blocks (via gridDim). More blocks will result
|
||||||
|
// in a larger gridSize and therefore fewer elements per thread
|
||||||
|
int nIsPow2 = !(n & n-1);
|
||||||
|
if (nIsPow2) {
|
||||||
|
unsigned int i = blockIdx.x * BlockSize * 2 + threadIdx.x;
|
||||||
|
gridSize = gridSize << 1;
|
||||||
|
|
||||||
|
while (i < n) {
|
||||||
|
threadVal += g_idata[i];
|
||||||
|
// ensure we don't read out of bounds -- this is optimized away for
|
||||||
|
// powerOf2 sized arrays
|
||||||
|
if ((i + BlockSize) < n) {
|
||||||
|
threadVal += g_idata[i + blockDim.x];
|
||||||
|
}
|
||||||
|
i += gridSize;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
unsigned int i = blockIdx.x * BlockSize + threadIdx.x;
|
||||||
|
while (i < n) {
|
||||||
|
threadVal += g_idata[i];
|
||||||
|
i += gridSize;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
threadVal = cg_reduce_n(threadVal, multiWarpTile);
|
||||||
|
|
||||||
|
if (multiWarpTile.thread_rank() == 0) {
|
||||||
|
sdata[multiWarpTile.meta_group_rank()] = threadVal;
|
||||||
|
}
|
||||||
|
cg::sync(cta);
|
||||||
|
|
||||||
|
if (threadIdx.x == 0) {
|
||||||
|
threadVal = 0;
|
||||||
|
for (int i=0; i < multiWarpTile.meta_group_size(); i++) {
|
||||||
|
threadVal += sdata[i];
|
||||||
|
}
|
||||||
|
g_odata[blockIdx.x] = threadVal;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
extern "C" bool isPow2(unsigned int x);
|
extern "C" bool isPow2(unsigned int x);
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
@ -566,6 +622,13 @@ void reduce(int size, int threads, int blocks, int whichKernel, T *d_idata,
|
||||||
int smemSize =
|
int smemSize =
|
||||||
(threads <= 32) ? 2 * threads * sizeof(T) : threads * sizeof(T);
|
(threads <= 32) ? 2 * threads * sizeof(T) : threads * sizeof(T);
|
||||||
|
|
||||||
|
// as kernel 9 - multi_warp_cg_reduce cannot work for more than 64 threads
|
||||||
|
// we choose to set kernel 7 for this purpose.
|
||||||
|
if (threads < 64 && whichKernel == 9)
|
||||||
|
{
|
||||||
|
whichKernel = 7;
|
||||||
|
}
|
||||||
|
|
||||||
// choose which of the optimized versions of reduction to launch
|
// choose which of the optimized versions of reduction to launch
|
||||||
switch (whichKernel) {
|
switch (whichKernel) {
|
||||||
case 0:
|
case 0:
|
||||||
|
@ -809,6 +872,10 @@ void reduce(int size, int threads, int blocks, int whichKernel, T *d_idata,
|
||||||
smemSize = ((threads / 32) + 1) * sizeof(T);
|
smemSize = ((threads / 32) + 1) * sizeof(T);
|
||||||
if (isPow2(size)) {
|
if (isPow2(size)) {
|
||||||
switch (threads) {
|
switch (threads) {
|
||||||
|
case 1024:
|
||||||
|
reduce7<T, 1024, true>
|
||||||
|
<<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
|
||||||
|
break;
|
||||||
case 512:
|
case 512:
|
||||||
reduce7<T, 512, true>
|
reduce7<T, 512, true>
|
||||||
<<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
|
<<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
|
||||||
|
@ -861,6 +928,10 @@ void reduce(int size, int threads, int blocks, int whichKernel, T *d_idata,
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
switch (threads) {
|
switch (threads) {
|
||||||
|
case 1024:
|
||||||
|
reduce7<T, 1024, true>
|
||||||
|
<<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
|
||||||
|
break;
|
||||||
case 512:
|
case 512:
|
||||||
reduce7<T, 512, false>
|
reduce7<T, 512, false>
|
||||||
<<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
|
<<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
|
||||||
|
@ -915,9 +986,42 @@ void reduce(int size, int threads, int blocks, int whichKernel, T *d_idata,
|
||||||
|
|
||||||
break;
|
break;
|
||||||
case 8:
|
case 8:
|
||||||
default:
|
|
||||||
cg_reduce<T><<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
|
cg_reduce<T><<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
|
||||||
break;
|
break;
|
||||||
|
case 9:
|
||||||
|
constexpr int numOfMultiWarpGroups = 2;
|
||||||
|
smemSize = numOfMultiWarpGroups * sizeof(T);
|
||||||
|
switch (threads) {
|
||||||
|
case 1024:
|
||||||
|
multi_warp_cg_reduce<T, 1024, 1024/numOfMultiWarpGroups>
|
||||||
|
<<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
|
||||||
|
break;
|
||||||
|
|
||||||
|
case 512:
|
||||||
|
multi_warp_cg_reduce<T, 512, 512/numOfMultiWarpGroups>
|
||||||
|
<<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
|
||||||
|
break;
|
||||||
|
|
||||||
|
case 256:
|
||||||
|
multi_warp_cg_reduce<T, 256, 256/numOfMultiWarpGroups>
|
||||||
|
<<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
|
||||||
|
break;
|
||||||
|
|
||||||
|
case 128:
|
||||||
|
multi_warp_cg_reduce<T, 128, 128/numOfMultiWarpGroups>
|
||||||
|
<<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
|
||||||
|
break;
|
||||||
|
|
||||||
|
case 64:
|
||||||
|
multi_warp_cg_reduce<T, 64, 64/numOfMultiWarpGroups>
|
||||||
|
<<<dimGrid, dimBlock, smemSize>>>(d_idata, d_odata, size);
|
||||||
|
break;
|
||||||
|
|
||||||
|
default:
|
||||||
|
printf("thread block size of < 64 is not supported for this kernel\n");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user