From a3b5b817e39b56d66d7bf8b33d369f2d66f64a60 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Mon, 7 Aug 2023 17:36:27 +0200 Subject: [PATCH] Use CTK-provided type for cuTensorMapEncodeTiled --- .../globalToShmemTMACopy.cu | 30 ++++++++----------- 1 file changed, 12 insertions(+), 18 deletions(-) diff --git a/Samples/3_CUDA_Features/globalToShmemTMACopy/globalToShmemTMACopy.cu b/Samples/3_CUDA_Features/globalToShmemTMACopy/globalToShmemTMACopy.cu index b775f6dd..2928873d 100644 --- a/Samples/3_CUDA_Features/globalToShmemTMACopy/globalToShmemTMACopy.cu +++ b/Samples/3_CUDA_Features/globalToShmemTMACopy/globalToShmemTMACopy.cu @@ -32,9 +32,9 @@ * - Create a TensorMap (TMA descriptor) * - Load a 2D tile of data into shared memory * - * Compile with: + * Compile and run with: * - * nvcc -arch sm_90 globalToShmemTMACopy.cu -o globalToShmemTMACopy + * nvcc -arch sm_90 -run globalToShmemTMACopy.cu * * It can be that the compiler issues the following note. This can be safely ignored. * @@ -42,10 +42,12 @@ * GCC 4.6 * */ -#include // fprintf -#include // std::vector +#include // printf +#include // std::vector + +#include // PFN_cuTensorMapEncodeTiled +#include // CUtensormap -#include // CUtensorMap #include // __mbarrier_* #include "util.h" // CUDA_CHECK macro @@ -63,19 +65,11 @@ constexpr int SMEM_H = 8; // Height of shared memory buffer (in # elements) * CUDA Driver API */ -// The type of the cuTensorMapEncodeTiled function. -using cuTensorMapEncodeTiled_t = decltype(cuTensorMapEncodeTiled); - -// Get function pointer to driver API cuTensorMapEncodeTiled -cuTensorMapEncodeTiled_t * get_cuTensorMapEncodeTiled() { - // https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html - void* cuda_ptr = nullptr; - unsigned long long flags = cudaEnableDefault; - - // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__DRIVER__ENTRY__POINT.html - CUDA_CHECK(cudaGetDriverEntryPoint("cuTensorMapEncodeTiled", &cuda_ptr, flags)); - - return reinterpret_cast(cuda_ptr); +PFN_cuTensorMapEncodeTiled get_cuTensorMapEncodeTiled() { + void* driver_ptr = nullptr; + cudaDriverEntryPointQueryResult driver_status; + CUDA_CHECK(cudaGetDriverEntryPoint("cuTensorMapEncodeTiled", &driver_ptr, cudaEnableDefault, &driver_status)); + return reinterpret_cast(driver_ptr); } /*