-- Add partitioned cuda pipeline prod-cons gemm kernel

-- Add cudaCompressibleMemory sample to use copy engine vs SM writes
   depending on arch
This commit is contained in:
Mahesh Doijade 2021-03-03 22:53:02 +05:30
parent b882fa00ee
commit 067cb65523
2 changed files with 199 additions and 43 deletions

View File

@ -48,26 +48,47 @@ __global__ void saxpy(const float a, const float4 *x, const float4 *y, float4 *z
}
}
__global__ void init(float4 *x, float4 *y, float4 *z, const float val, const size_t n)
__global__ void init(float4 *x, float4 *y, const float val, const size_t n)
{
const float4 val4 = make_float4(val, val, val, val);
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < n; i += gridDim.x * blockDim.x)
{
z[i] = x[i] = y[i] = val4;
x[i] = y[i] = val4;
}
}
void launchSaxpy(const float a, float4 *x, float4 *y, float4 *z, const size_t n, const float init_val)
void launchSaxpy(const float a, float4 *x, float4 *y, float4 *z, const size_t n, const float init_val, const bool compressibleZbuf)
{
cudaEvent_t start, stop;
float ms;
int blockSize;
int minGridSize;
dim3 threads, blocks;
checkCudaErrors(cudaOccupancyMaxPotentialBlockSize(&minGridSize, &blockSize, (void*)init));
dim3 threads = dim3(blockSize, 1, 1);
dim3 blocks = dim3(minGridSize, 1, 1);
init<<<blocks, threads>>>(x, y, z, init_val, n);
if (!compressibleZbuf)
{
// We are on config where compressible buffer can only be initialized through cudaMemcpy
// hence, x & y buffers are allocated as compressible and initialized via cudaMemcpy
// whereas z buffer is allocated as non-compressible.
float4 *h_x = (float4 *) malloc(sizeof(float4) * n);
float4 *h_y = (float4 *) malloc(sizeof(float4) * n);
for (int i = 0; i < n; i++)
{
h_x[i].x = h_x[i].y = h_x[i].z = h_x[i].w = init_val;
h_y[i].x = h_y[i].y = h_y[i].z = h_y[i].w = init_val;
}
checkCudaErrors(cudaMemcpy(x, h_x, sizeof(float4) * n, cudaMemcpyHostToDevice));
checkCudaErrors(cudaMemcpy(y, h_y, sizeof(float4) * n, cudaMemcpyHostToDevice));
free(h_x);
free(h_y);
}
else
{
checkCudaErrors(cudaOccupancyMaxPotentialBlockSize(&minGridSize, &blockSize, (void*)init));
threads = dim3(blockSize, 1, 1);
blocks = dim3(minGridSize, 1, 1);
init<<<blocks, threads>>>(x, y, init_val, n);
}
checkCudaErrors(cudaOccupancyMaxPotentialBlockSize(&minGridSize, &blockSize, (void*)saxpy));
threads = dim3(blockSize, 1, 1);
@ -121,19 +142,39 @@ int main(int argc, char **argv)
printf("Generic memory compression support is available\n");
int major, minor;
checkCudaErrors(cuDeviceGetAttribute(&major,
CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR,
currentDevice));
checkCudaErrors(cuDeviceGetAttribute(&minor,
CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR,
currentDevice));
float4 *x, *y, *z;
const size_t size = n * sizeof(float4);
// Allocating compressible memory
checkCudaErrors(allocateCompressible((void **)&x, size, true));
checkCudaErrors(allocateCompressible((void **)&y, size, true));
checkCudaErrors(allocateCompressible((void **)&z, size, true));
bool compressibleZbuf = 0;
if ((major == 8 && minor == 0) || (major == 8 && minor == 6))
{
// On SM 8.0 and 8.6 GPUs compressible buffer can only be initialized
// through cudaMemcpy.
printf("allocating non-compressible Z buffer\n");
checkCudaErrors(allocateCompressible((void **)&z, size, false));
compressibleZbuf = 0;
}
else
{
checkCudaErrors(allocateCompressible((void **)&z, size, true));
compressibleZbuf = 1;
}
printf("Running saxpy on %zu bytes of Compressible memory\n", size);
const float a = 1.0f;
const float init_val = 1.0f;
launchSaxpy(a, x, y, z, n, init_val);
launchSaxpy(a, x, y, z, n, init_val, compressibleZbuf);
checkCudaErrors(freeCompressible(x, size, true));
checkCudaErrors(freeCompressible(y, size, true));
@ -145,8 +186,8 @@ int main(int argc, char **argv)
checkCudaErrors(allocateCompressible((void **)&y, size, false));
checkCudaErrors(allocateCompressible((void **)&z, size, false));
launchSaxpy(a, x, y, z, n, init_val);
launchSaxpy(a, x, y, z, n, init_val, compressibleZbuf);
checkCudaErrors(freeCompressible(x, size, false));
checkCudaErrors(freeCompressible(y, size, false));
checkCudaErrors(freeCompressible(z, size, false));

View File

@ -47,6 +47,9 @@
#if __CUDA_ARCH__ >= 700
#include <cuda/barrier>
#endif
#include <cooperative_groups.h>
namespace cg = cooperative_groups;
// Helper functions and utilities to work with CUDA
#include <helper_functions.h>
@ -54,18 +57,19 @@
enum kernels
{
AsyncCopyMultiStageLargeChunk = 0,
AsyncCopyLargeChunk = 1,
AsyncCopyLargeChunkAWBarrier = 2,
AsyncCopyMultiStage = 3,
AsyncCopySingleStage = 4,
Naive = 5,
NaiveLargeChunk = 6
AsyncCopyMultiStageLargeChunk = 0,
AsyncCopyLargeChunk = 1,
AsyncCopyLargeChunkAWBarrier = 2,
AsyncCopyMultiStageSharedState = 3,
AsyncCopyMultiStage = 4,
AsyncCopySingleStage = 5,
Naive = 6,
NaiveLargeChunk = 7
};
const char* kernelNames[] = {"AsyncCopyMultiStageLargeChunk", "AsyncCopyLargeChunk",
"AsyncCopyLargeChunkAWBarrier", "AsyncCopyMultiStage",
"AsyncCopySingleStage", "Naive", "NaiveLargeChunk"};
"AsyncCopyLargeChunkAWBarrier", "AsyncCopyMultiStageSharedState",
"AsyncCopyMultiStage", "AsyncCopySingleStage", "Naive", "NaiveLargeChunk"};
constexpr int blockSize = 16;
@ -143,8 +147,8 @@ template <int BLOCK_SIZE> __global__ void MatrixMulAsyncCopyMultiStageLargeChunk
}
pipe.consumer_release();
// Don't have to synchronize because
// next iteration is loading to a different buffer
// Don't have to synchronize because maxPipelineStages is greater than one
// therefore next iteration is loading to a different buffer.
}
// Write the block sub-matrix to device memory;
@ -227,8 +231,8 @@ template <int BLOCK_SIZE> __global__ void MatrixMulAsyncCopyLargeChunk(float* __
pipe.consumer_release();
// Synchronize to make sure that the preceding
// computation is done before loading two new
// sub-matrices of A and B in the next iteration
// computation is done before overwriting the
// shared memory sub-matrix buffers As and Bs in the next iteration.
__syncthreads();
}
@ -310,9 +314,9 @@ template <int BLOCK_SIZE> __global__ void MatrixMulAsyncCopyLargeChunkAWBarrier(
}
// Synchronize to make sure that the preceding
// computation is done before loading two new
// sub-matrices of A and B in the next iteration
__syncthreads();
// computation is done before overwriting the
// shared memory sub-matrix buffers As and Bs in the next iteration.
bar.arrive_and_wait();
}
// Write the block sub-matrix to device memory;
@ -384,8 +388,8 @@ template <int BLOCK_SIZE> __global__ void MatrixMulAsyncCopySingleStage(float *C
}
// Synchronize to make sure that the preceding
// computation is done before loading two new
// sub-matrices of A and B in the next iteration
// computation is done before overwriting the
// shared memory sub-matrix buffers As and Bs in the next iteration.
__syncthreads();
}
@ -395,7 +399,7 @@ template <int BLOCK_SIZE> __global__ void MatrixMulAsyncCopySingleStage(float *C
C[c + wB * threadIdx.y + threadIdx.x] = Csub;
}
// Multi Stage memcpy_async pipeline with int copy
// Multi Stage memcpy_async thread_scope_thread pipeline with single-element async-copy
template <int BLOCK_SIZE> __global__ void MatrixMulAsyncCopyMultiStage(float* __restrict__ C,
const float* __restrict__ A,
const float* __restrict__ B, int wA,
@ -467,8 +471,8 @@ template <int BLOCK_SIZE> __global__ void MatrixMulAsyncCopyMultiStage(float* __
}
pipe.consumer_release();
// Don't have to synchronize because
// next iteration is loading to a different buffer
// Don't have to synchronize because maxPipelineStages is greater than one
// therefore next iteration is loading to a different buffer.
}
// Write the block sub-matrix to device memory;
@ -477,6 +481,102 @@ template <int BLOCK_SIZE> __global__ void MatrixMulAsyncCopyMultiStage(float* __
C[c + wB * threadIdx.y + threadIdx.x] = Csub;
}
// Multi Stage shared state memcpy_async pipeline thread_scope_block
// with parititioned producer & consumer, here we've 1 warp as producer
// group which issues memcpy_async operations and rest all warps are part of
// consumer group which perform gemm computation on the loaded matrices by producer.
template <int BLOCK_SIZE_X> __global__ void MatrixMulAsyncCopyMultiStageSharedState(float* __restrict__ C,
const float* __restrict__ A,
const float* __restrict__ B, int wA,
int wB) {
// Multi-stage pipeline version
constexpr size_t maxPipelineStages = 4;
// Declaration of the shared memory array As used to
// store the sub-matrix of A for each stage
__shared__ float As[maxPipelineStages][BLOCK_SIZE_X][BLOCK_SIZE_X];
// Declaration of the shared memory array Bs used to
// store the sub-matrix of B for each stage
__shared__ float Bs[maxPipelineStages][BLOCK_SIZE_X][BLOCK_SIZE_X];
float Csub = 0.0;
// Index of the first sub-matrix of A processed by the block
const int aBegin = wA * BLOCK_SIZE_X * blockIdx.y;
// Index of the last sub-matrix of A processed by the block
const int aEnd = aBegin + wA - 1;
// Step size used to iterate through the sub-matrices of A
constexpr int aStep = BLOCK_SIZE_X;
// Index of the first sub-matrix of B processed by the block
const int bBegin = BLOCK_SIZE_X * blockIdx.x;
// Step size used to iterate through the sub-matrices of B
int bStep = BLOCK_SIZE_X * wB;
auto cta = cg::this_thread_block();
const auto shape1 = cuda::aligned_size_t<alignof(float)>(sizeof(float));
__shared__ cuda::pipeline_shared_state<cuda::thread_scope_block, maxPipelineStages> shared_state;
constexpr int consumer_row_count = BLOCK_SIZE_X;
const auto thread_role = (cta.thread_index().y < consumer_row_count)
? cuda::pipeline_role::consumer
: cuda::pipeline_role::producer;
auto pipe = cuda::make_pipeline(cta, &shared_state, thread_role);
// Loop over all the sub-matrices of A and B
// required to compute the block sub-matrix
for (int a = aBegin, b = bBegin, i = 0, aStage = aBegin, bStage = bBegin, iStage = 0;
a <= aEnd; a += aStep, b += bStep, ++i) {
if (threadIdx.y >= consumer_row_count) {
// this is a whole producer warp because threadIdx.y >= 16 where 16 == consumer_row_count,
// which loads the matrices from device memory to shared memory;
for (; aStage <= a + aStep * maxPipelineStages; aStage += aStep, bStage += bStep, ++iStage) {
if (aStage <= aEnd) {
// Rotating buffer
const int j = iStage % maxPipelineStages;
const int strideRows = (blockDim.y - consumer_row_count);
pipe.producer_acquire();
for (int rowId = threadIdx.y - consumer_row_count; rowId < BLOCK_SIZE_X; rowId += strideRows) {
cuda::memcpy_async(&As[j][rowId][threadIdx.x],
&A[aStage + wA * rowId + threadIdx.x], shape1, pipe);
cuda::memcpy_async(&Bs[j][rowId][threadIdx.x],
&B[bStage + wB * rowId + threadIdx.x], shape1, pipe);
}
pipe.producer_commit();
}
}
}
else {
// this is a whole set of consumer group because threadIdx.y < consumer_row_count where consumer_row_count == 16,
// which computes gemm operation on matrices loaded in shared memory by producer warp.
const int j = i % maxPipelineStages;
// Synchronize consumer group to make sure the matrices are loaded by producer group.
pipe.consumer_wait();
// Multiply the two matrices together;
// each thread computes one element
// of the block sub-matrix
#pragma unroll
for (int k = 0; k < BLOCK_SIZE_X; ++k) {
Csub += As[j][threadIdx.y][k] * Bs[j][k][threadIdx.x];
}
pipe.consumer_release();
}
}
// Write the block sub-matrix to device memory;
// each thread writes four element
if (threadIdx.y < consumer_row_count)
{
const int c = wB * BLOCK_SIZE_X * blockIdx.y + BLOCK_SIZE_X * blockIdx.x;
C[c + wB * threadIdx.y + threadIdx.x] = Csub;
}
}
/**
* Matrix multiplication (CUDA Kernel) on the device: C = A * B
* wA is A's width and wB is B's width
@ -637,10 +737,12 @@ int MatrixMultiply(int argc, char **argv,
// Allocate host memory for matrices A and B
unsigned int size_A = dimsA.x * dimsA.y;
unsigned int mem_size_A = sizeof(float) * size_A;
float *h_A = reinterpret_cast<float *>(malloc(mem_size_A));
float* h_A;
checkCudaErrors(cudaMallocHost(&h_A, mem_size_A));
unsigned int size_B = dimsB.x * dimsB.y;
unsigned int mem_size_B = sizeof(float) * size_B;
float *h_B = reinterpret_cast<float *>(malloc(mem_size_B));
float* h_B;
checkCudaErrors(cudaMallocHost(&h_B, mem_size_B));
cudaStream_t stream;
// Initialize host memory
@ -654,7 +756,8 @@ int MatrixMultiply(int argc, char **argv,
// Allocate host matrix C
dim3 dimsC(dimsB.x, dimsA.y, 1);
unsigned int mem_size_C = dimsC.x * dimsC.y * sizeof(float);
float *h_C = reinterpret_cast<float *>(malloc(mem_size_C));
float* h_C;
checkCudaErrors(cudaMallocHost(&h_C, mem_size_C));
if (h_C == NULL) {
fprintf(stderr, "Failed to allocate host matrix C!\n");
@ -680,6 +783,10 @@ int MatrixMultiply(int argc, char **argv,
dim3 threads(blockSize, blockSize);
dim3 grid(dimsB.x / threads.x, dimsA.y / threads.y);
// Here the block size is 16x18, where first 16 rows are consumer thread group
// and last 2 rows (1 warp) is producer thread group
dim3 threadsSharedStateKernel(blockSize, blockSize + 2, 1);
dim3 gridSharedStateKernel(dimsB.x / threadsSharedStateKernel.x, dimsA.y / threadsSharedStateKernel.x);
printf("Running kernel = %d - %s\n", kernel_number, kernelNames[kernel_number]);
// Create and start timer
@ -698,6 +805,10 @@ int MatrixMultiply(int argc, char **argv,
case AsyncCopyLargeChunkAWBarrier :
MatrixMulAsyncCopyLargeChunkAWBarrier<blockSize><<<grid, threads, 0, stream>>>(d_C, d_A, d_B, dimsA.x, dimsB.x);
break;
case AsyncCopyMultiStageSharedState :
MatrixMulAsyncCopyMultiStageSharedState<blockSize><<<gridSharedStateKernel, threadsSharedStateKernel, 0, stream>>>
(d_C, d_A, d_B, dimsA.x, dimsB.x);
break;
case AsyncCopyMultiStage :
MatrixMulAsyncCopyMultiStage<blockSize><<<grid, threads, 0, stream>>>(d_C, d_A, d_B, dimsA.x, dimsB.x);
break;
@ -735,6 +846,10 @@ int MatrixMultiply(int argc, char **argv,
case AsyncCopyLargeChunkAWBarrier :
MatrixMulAsyncCopyLargeChunkAWBarrier<blockSize><<<grid, threads, 0, stream>>>(d_C, d_A, d_B, dimsA.x, dimsB.x);
break;
case AsyncCopyMultiStageSharedState :
MatrixMulAsyncCopyMultiStageSharedState<blockSize><<<gridSharedStateKernel, threadsSharedStateKernel, 0, stream>>>
(d_C, d_A, d_B, dimsA.x, dimsB.x);
break;
case AsyncCopyMultiStage :
MatrixMulAsyncCopyMultiStage<blockSize><<<grid, threads, 0, stream>>>(d_C, d_A, d_B, dimsA.x, dimsB.x);
break;
@ -801,15 +916,15 @@ int MatrixMultiply(int argc, char **argv,
printf("%s\n", correct ? "Result = PASS" : "Result = FAIL");
// Clean up memory
free(h_A);
free(h_B);
free(h_C);
checkCudaErrors(cudaFreeHost(h_A));
checkCudaErrors(cudaFreeHost(h_B));
checkCudaErrors(cudaFreeHost(h_C));
checkCudaErrors(cudaFree(d_A));
checkCudaErrors(cudaFree(d_B));
checkCudaErrors(cudaFree(d_C));
checkCudaErrors(cudaEventDestroy(start));
checkCudaErrors(cudaEventDestroy(stop));
printf("\nNOTE: The CUDA Samples are not meant for performance"\
printf("\nNOTE: The CUDA Samples are not meant for performance "\
"measurements. Results may vary when GPU Boost is enabled.\n");
if (correct) {
@ -829,9 +944,9 @@ int main(int argc, char **argv) {
printf(" -wA=WidthA -hA=HeightA (Width x Height of Matrix A)\n");
printf(" -wB=WidthB -hB=HeightB (Width x Height of Matrix B)\n");
printf(" -kernel=kernel_number (0 - AsyncCopyMultiStageLargeChunk; 1 - AsyncCopyLargeChunk)\n");
printf(" (2 - AsyncCopyLargeChunkAWBarrier; 3 - AsyncCopyMultiStage)\n");
printf(" (4 - AsyncCopySingleStage; 5 - Naive without memcpy_async)\n");
printf(" (6 - NaiveLargeChunk without memcpy_async)\n");
printf(" (2 - AsyncCopyLargeChunkAWBarrier; 3 - AsyncCopyMultiStageSharedState)\n");
printf(" (4 - AsyncCopyMultiStage; 5 - AsyncCopySingleStage; 6 - Naive without memcpy_async)\n");
printf(" (7 - NaiveLargeChunk without memcpy_async)\n");
printf(" Note: Outer matrix dimensions of A & B matrices must be equal.\n");
exit(EXIT_SUCCESS);
@ -876,7 +991,7 @@ int main(int argc, char **argv) {
// kernel to run - default (AsyncCopyMultiStageLargeChunk == 0)
if (checkCmdLineFlag(argc, (const char **)argv, "kernel")) {
int kernel_number = getCmdLineArgumentInt(argc, (const char **)argv, "kernel");
if (kernel_number < 7)
if (kernel_number < 8)
{
selected_kernel = (kernels)kernel_number;
}