/* Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions * are met: * * Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright * notice, this list of conditions and the following disclaimer in the * documentation and/or other materials provided with the distribution. * * Neither the name of NVIDIA CORPORATION nor the names of its * contributors may be used to endorse or promote products derived * from this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY * OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ // CUDA sample demonstrating a GEMM computation using the Warp Matrix Multiply // and Accumulate API introduced in CUDA 9. // In this program, the compute_gemm kernel computes the result of a matrix // multiplication and addition: D = alpha * A * B + beta * C. The dimensions of // both C and D matrices are M_GLOBAL x N_GLOBAL. The A matrix is M_GLOBAL x // K_GLOBAL (row-major), the B matrix is K_GLOBAL x N_GLOBAL (column-major). In // that kernel, each CTA computes one 128 x 128 tile of the resulting matrix per // iteration. When the tile is computed, the CTA stores it to the global memory // and begins a new iteration, selecting a new 128 x 128 tile to compute. // Each CTA consists of eight warps. For the 128 x 128 tile, each warp computes // eight 16 x 16 subtiles, organized in a 2 x 4 two-dimensional array. Warps // compute the 16 x 16 subtiles using nvcuda::wmma::mma_sync operations by // moving through the K_GLOBAL dimension of the A and B matrices and // accumulating the intermediate result in the local thread state. // There are a number of simple optimizations used in the algorithm: // - The CTA copies the 128 x 128 tile of the C matrix from the global memory to // shared memory. After that is done, each warp loads the C matrix fragments // from shared memory, thus avoiding a random global memory access. // - On each internal iteration, the CTA copies a portion of the A and B // matrices from // global memory to shared memory. After that, all warps in the CTA reuse the // A and B data from shared memory, thus reducing the number of data copies // from global memory. // - The portions of the A and B matrices are stored in shared memory with an // additional // padding (skew) to reduce the number of shared memory access bank conflicts. // (See a detailed explanation near the SKEW_HALF macro definition.) // - When the CTA finishes computing the tiles of the resulting matrix, each // warp stores // its subtiles to shared memory. The CTA then copies the shared memory // contents to global memory, again avoiding redundant random global memory // accesses. // - Note that the CTA tile size is chosen to maximize the GPU register // utilization, // but carefully enough to avoid local memory use. #include #include #include #include // helper functions and utilities to work with CUDA #include #include // GPU configuration. #define WARP_SIZE 32 // MMA matrix tile dimensions. #define M 16 #define N 16 #define K 16 // GEMM configuration. #define M_TILES 256 #define N_TILES 256 #define K_TILES 256 #define M_GLOBAL (M * M_TILES) #define N_GLOBAL (N * N_TILES) #define K_GLOBAL (K * K_TILES) #define C_LAYOUT wmma::mem_row_major // Implementation constants. #define WARPS_PER_BLOCK 8 #define THREADS_PER_BLOCK (WARP_SIZE * WARPS_PER_BLOCK) #define CHUNK_K 8 #define BLOCK_ROW_WARPS 2 #define BLOCK_COL_WARPS 4 #define WARP_ROW_TILES 4 #define WARP_COL_TILES 2 #define BLOCK_ROW_TILES (WARP_ROW_TILES * BLOCK_ROW_WARPS) #define BLOCK_COL_TILES (WARP_COL_TILES * BLOCK_COL_WARPS) #define GLOBAL_MEM_STRIDE N_GLOBAL #define SHMEM_STRIDE (N * BLOCK_ROW_TILES) #define SHMEM_OFFSET (N * WARP_ROW_TILES) // The macro below is used to shift rows of the A matrix and columns of the B // matrix in shared memory to minimize possible bank conflicts. Before // performing the nvcuda::wmma::mma_sync operation, the warp must load the // matrix data using the nvcuda::wmma::load_matrix_sync operation. Although the // memory access pattern is not specified for that function, each lane in the // warp can read one or multiple matrix elements from different matrix rows or // columns. For shared memory, such access can result in bank conflicts if // different rows / columns of the matrix map to the same bank. By shifting each // row and column by a few bytes, we make sure that they map to different banks, // thus reducing the number of possible bank conflicts. The number of 8 two-byte // "half" elements is chosen as the minimum possible shift because we must keep // each row and column 128-bit aligned, as required by // nvcuda::wmma::load_matrix_sync. #define SKEW_HALF 8 #define checkKernelErrors(expr) \ do { \ expr; \ \ cudaError_t __err = cudaGetLastError(); \ if (__err != cudaSuccess) { \ printf("Line %d: '%s' failed: %s\n", __LINE__, #expr, \ cudaGetErrorString(__err)); \ abort(); \ } \ } while (0) using namespace nvcuda; __host__ void init_host_matrices(float *a, float *b, float *c) { for (int i = 0; i < M_GLOBAL; i++) { for (int j = 0; j < K_GLOBAL; j++) { a[i * K_GLOBAL + j] = static_cast(rand() % 3); } } for (int i = 0; i < N_GLOBAL; i++) { for (int j = 0; j < K_GLOBAL; j++) { b[i * K_GLOBAL + j] = static_cast(rand() % 3); } } for (int t = 0; t < M_GLOBAL * N_GLOBAL; t++) { c[t] = static_cast(rand() % 3); } } __global__ void init_device_matrices(const float *A_h, const float *B_h, const float *C_h, half *A, half *B, float *C, float *D) { for (int i = blockDim.x * blockIdx.x + threadIdx.x; i < M_GLOBAL * K_GLOBAL; i += gridDim.x * blockDim.x) A[i] = __float2half(A_h[i]); for (int i = blockDim.x * blockIdx.x + threadIdx.x; i < N_GLOBAL * K_GLOBAL; i += gridDim.x * blockDim.x) B[i] = __float2half(B_h[i]); for (int i = blockDim.x * blockIdx.x + threadIdx.x; i < M_GLOBAL * N_GLOBAL; i += gridDim.x * blockDim.x) C[i] = C_h[i]; for (int i = blockDim.x * blockIdx.x + threadIdx.x; i < M_GLOBAL * N_GLOBAL; i += gridDim.x * blockDim.x) D[i] = 0; } __global__ void compute_gemm(const half *A, const half *B, const float *C, float *D, float alpha, float beta) { extern __shared__ half shmem[][CHUNK_K * K + SKEW_HALF]; // Warp and lane identification. const unsigned int warpId = threadIdx.x / WARP_SIZE; const unsigned int laneId = threadIdx.x % WARP_SIZE; // Offset in shared memory from which the B matrix is stored. const size_t shmem_idx_b_off = BLOCK_COL_TILES * M; // This pointer is used to access the C and D matrix tiles this warp computes. float *shmem_warp_tile_ptr = reinterpret_cast( &shmem[0][0] + (warpId / 2) * SHMEM_STRIDE * K * 2 + (warpId % 2) * SHMEM_OFFSET); // This pointer is used to stream the C and D matrices block-wide tile to and // from shared memory. float *shmem_warp_stream_ptr = reinterpret_cast(&shmem[0][0] + warpId * SHMEM_STRIDE * K); // Adjust the beta scaler, as it'll be multiplied by alpha at the end of // each tile computation. Technically this is not generally correct (may // result in a loss of precision). Zero still needs to be specially handled // though. beta /= alpha; // Each CTA slides along the 128 x 128 tiles from the top left corner of the // matrix to the right and down, and selects the next tile to compute. Once // there's no such tile, all warps in this CTA exit. for (unsigned int block_pos = blockIdx.x;; block_pos += gridDim.x) { const unsigned int block_tile_i = ((block_pos * BLOCK_ROW_TILES) / N_TILES) * (BLOCK_COL_TILES); const unsigned int block_tile_j = (block_pos * BLOCK_COL_TILES) % N_TILES; // Stop when there are no more D matrix tiles to compute in this CTA. if (block_tile_i >= M_TILES) { break; } // This warp's pointer to the C matrix data to copy memory from to shared // memory. const size_t gmem_idx = (block_tile_i + warpId) * M * GLOBAL_MEM_STRIDE + block_tile_j * N; const float *src_gmem_warp_stream_ptr = &C[gmem_idx]; // Stream multiple C tiles to shared memory. #pragma unroll for (int i = 0; i < K; i++) { typedef int4 copy_t; *((copy_t *)(shmem_warp_stream_ptr + SHMEM_STRIDE * i) + laneId) = *((copy_t *)(src_gmem_warp_stream_ptr + GLOBAL_MEM_STRIDE * i) + laneId); } __syncthreads(); // These fragments will accumulate the result of A and B matrix fragment // multiplications along the K_GLOBAL dimension. wmma::fragment c[WARP_COL_TILES] [WARP_ROW_TILES]; // Load the C matrix tiles into fragments from shared memory. #pragma unroll for (int i = 0; i < WARP_COL_TILES; i++) { #pragma unroll for (int j = 0; j < WARP_ROW_TILES; j++) { const float *tile_ptr = shmem_warp_tile_ptr + i * SHMEM_STRIDE * K + j * N; wmma::load_matrix_sync(c[i][j], tile_ptr, SHMEM_STRIDE, C_LAYOUT); } } __syncthreads(); // Scale the C matrix. #pragma unroll for (int i = 0; i < WARP_COL_TILES; i++) { #pragma unroll for (int j = 0; j < WARP_ROW_TILES; j++) { #pragma unroll for (int t = 0; t < c[i][j].num_elements; t++) { c[i][j].x[t] *= beta; } } } // Select what warp copies what matrix to shared memory. // Warps 0-3 copy the A matrix, warps 4-7 copy the B matrix. const half *warp_ptr = (warpId < 4) ? (&A[block_tile_i * M * K_GLOBAL] + M * K_GLOBAL * (warpId % 4) * 2) : (&B[block_tile_j * N * K_GLOBAL] + N * K_GLOBAL * (warpId % 4) * 2); // Go through the global K dimension by a fixed step at a time. #pragma unroll for (int tile_k = 0; tile_k < K_TILES; tile_k += CHUNK_K) { // Copy slices of the A and B matrices to shared memory. // The first half of the warps in the CTA copy the A matrix, the rest copy // the B matrix. size_t shmem_idx = warpId < (WARPS_PER_BLOCK / 2) ? (M * (warpId % (WARPS_PER_BLOCK / 2)) * 2) : (N * (warpId % (WARPS_PER_BLOCK / 2)) * 2 + shmem_idx_b_off); // First half of the warp copies the first row / column of the matrix, // the second half of the warp copies the next. int4 *lane_ptr = (int4 *)(warp_ptr + tile_k * K + (laneId / (WARP_SIZE / 2)) * K_GLOBAL) + (laneId % (WARP_SIZE / 2)); // Shift the second half of the warp to the next row / column in the // shared memory. shmem_idx += laneId / (WARP_SIZE / 2); #pragma unroll for (int i = 0; i < (WARP_SIZE / 2); i++) { // Copy 16 bytes at once in each lane. *((int4 *)&shmem[shmem_idx][0] + (laneId % (WARP_SIZE / 2))) = *lane_ptr; // Advance the global memory pointer and the shared memory index. lane_ptr = reinterpret_cast( reinterpret_cast(lane_ptr + K_GLOBAL * 2)); shmem_idx += 2; } __syncthreads(); // Compute a grid of C matrix tiles in each warp. #pragma unroll for (int k_step = 0; k_step < CHUNK_K; k_step++) { wmma::fragment a[WARP_COL_TILES]; wmma::fragment b[WARP_ROW_TILES]; #pragma unroll for (int i = 0; i < WARP_COL_TILES; i++) { size_t shmem_idx_a = (warpId / 2) * M * 2 + (i * M); const half *tile_ptr = &shmem[shmem_idx_a][k_step * K]; wmma::load_matrix_sync(a[i], tile_ptr, K * CHUNK_K + SKEW_HALF); #pragma unroll for (int j = 0; j < WARP_ROW_TILES; j++) { if (i == 0) { // Load the B matrix fragment once, because it is going to be // reused against the other A matrix fragments. size_t shmem_idx_b = shmem_idx_b_off + (WARP_ROW_TILES * N) * (warpId % 2) + (j * N); const half *tile_ptr = &shmem[shmem_idx_b][k_step * K]; wmma::load_matrix_sync(b[j], tile_ptr, K * CHUNK_K + SKEW_HALF); } wmma::mma_sync(c[i][j], a[i], b[j], c[i][j]); } } } __syncthreads(); } // Store the D fragments to shared memory. #pragma unroll for (int i = 0; i < WARP_COL_TILES; i++) { #pragma unroll for (int j = 0; j < WARP_ROW_TILES; j++) { #pragma unroll // Uniform, point-wise transformations of ALL fragment elements by ALL // threads in the warp are well-defined even though element indices // within fragment storage are not defined. for (int t = 0; t < c[i][j].num_elements; t++) c[i][j].x[t] *= alpha; float *tile_ptr = shmem_warp_tile_ptr + i * SHMEM_STRIDE * K + j * N; wmma::store_matrix_sync(tile_ptr, c[i][j], SHMEM_STRIDE, C_LAYOUT); } } __syncthreads(); // Now that shared memory contains all the D tiles, stream them to global // memory. float *dst_gmem_warp_stream_ptr = &D[gmem_idx]; #pragma unroll for (int i = 0; i < K; i++) { *(reinterpret_cast(dst_gmem_warp_stream_ptr + GLOBAL_MEM_STRIDE * i) + laneId) = *(reinterpret_cast(shmem_warp_stream_ptr + SHMEM_STRIDE * i) + laneId); } __syncthreads(); } } int main(int argc, char **argv) { printf("Initializing...\n"); int dev = findCudaDevice(argc, (const char **)argv); cudaDeviceProp deviceProp; checkCudaErrors(cudaGetDeviceProperties(&deviceProp, dev)); // Tensor cores require a GPU of Volta (SM7X) architecture or higher. if (deviceProp.major < 7) { printf( "cudaTensorCoreGemm requires requires SM 7.0 or higher to use Tensor " "Cores. Exiting...\n"); exit(EXIT_WAIVED); } printf("M: %d (%d x %d)\n", M_GLOBAL, M, M_TILES); printf("N: %d (%d x %d)\n", N_GLOBAL, N, N_TILES); printf("K: %d (%d x %d)\n", K_GLOBAL, K, K_TILES); float *A_h = NULL; float *B_h = NULL; float *C_h = NULL; checkCudaErrors(cudaMallocManaged(reinterpret_cast(&A_h), sizeof(float) * M_GLOBAL * K_GLOBAL)); checkCudaErrors(cudaMallocManaged(reinterpret_cast(&B_h), sizeof(float) * K_GLOBAL * N_GLOBAL)); checkCudaErrors(cudaMallocManaged(reinterpret_cast(&C_h), sizeof(float) * M_GLOBAL * N_GLOBAL)); half *A = NULL; half *B = NULL; float *C = NULL; float *D = NULL; checkCudaErrors(cudaMalloc(reinterpret_cast(&A), sizeof(half) * M_GLOBAL * K_GLOBAL)); checkCudaErrors(cudaMalloc(reinterpret_cast(&B), sizeof(half) * N_GLOBAL * K_GLOBAL)); checkCudaErrors(cudaMalloc(reinterpret_cast(&C), sizeof(float) * M_GLOBAL * N_GLOBAL)); checkCudaErrors(cudaMalloc(reinterpret_cast(&D), sizeof(float) * M_GLOBAL * N_GLOBAL)); assert(((unsigned long long)A) % 128 == 0); assert(((unsigned long long)B) % 128 == 0); assert(((unsigned long long)C) % 128 == 0); assert(((unsigned long long)D) % 128 == 0); init_host_matrices(A_h, B_h, C_h); printf("Preparing data for GPU...\n"); checkKernelErrors( (init_device_matrices<<>>(A_h, B_h, C_h, A, B, C, D))); checkCudaErrors(cudaDeviceSynchronize()); enum { SHMEM_SZ = sizeof(half) * (BLOCK_COL_TILES * M) * (CHUNK_K * K + SKEW_HALF) * 2 }; printf("Required shared memory size: %lu Kb\n", SHMEM_SZ / 1024UL); checkCudaErrors(cudaFuncSetAttribute( compute_gemm, cudaFuncAttributeMaxDynamicSharedMemorySize, SHMEM_SZ)); printf("Computing...\n"); cudaEvent_t start, stop; checkCudaErrors(cudaEventCreate(&start)); checkCudaErrors(cudaEventCreate(&stop)); checkCudaErrors(cudaEventRecord(start)); const float alpha = 1.1f; const float beta = 1.2f; checkKernelErrors( (compute_gemm<<>>(A, B, C, D, alpha, beta))); checkCudaErrors(cudaEventRecord(stop)); checkCudaErrors(cudaEventSynchronize(stop)); float milliseconds = 0; checkCudaErrors(cudaEventElapsedTime(&milliseconds, start, stop)); printf("Time: %f ms\n", milliseconds); printf("TFLOPS: %.2f\n", static_cast((static_cast(M_GLOBAL) * N_GLOBAL * K_GLOBAL * 2) / (milliseconds / 1000.)) / 1e12); checkCudaErrors(cudaFree(reinterpret_cast(A_h))); checkCudaErrors(cudaFree(reinterpret_cast(B_h))); checkCudaErrors(cudaFree(reinterpret_cast(C_h))); checkCudaErrors(cudaFree(reinterpret_cast(A))); checkCudaErrors(cudaFree(reinterpret_cast(B))); checkCudaErrors(cudaFree(reinterpret_cast(C))); checkCudaErrors(cudaFree(reinterpret_cast(D))); return 0; }