Use CTK-provided type for cuTensorMapEncodeTiled

This commit is contained in:
Allard Hendriksen 2023-08-07 17:36:27 +02:00
parent 5925483b33
commit a3b5b817e3
No known key found for this signature in database

View File

@ -32,9 +32,9 @@
* - Create a TensorMap (TMA descriptor) * - Create a TensorMap (TMA descriptor)
* - Load a 2D tile of data into shared memory * - Load a 2D tile of data into shared memory
* *
* Compile with: * Compile and run with:
* *
* nvcc -arch sm_90 globalToShmemTMACopy.cu -o globalToShmemTMACopy * nvcc -arch sm_90 -run globalToShmemTMACopy.cu
* *
* It can be that the compiler issues the following note. This can be safely ignored. * It can be that the compiler issues the following note. This can be safely ignored.
* *
@ -42,10 +42,12 @@
* GCC 4.6 * GCC 4.6
* *
*/ */
#include <cstdio> // fprintf #include <cstdio> // printf
#include <vector> // std::vector #include <vector> // std::vector
#include <cudaTypedefs.h> // PFN_cuTensorMapEncodeTiled
#include <cuda.h> // CUtensormap
#include <cuda.h> // CUtensorMap
#include <cuda_awbarrier_primitives.h> // __mbarrier_* #include <cuda_awbarrier_primitives.h> // __mbarrier_*
#include "util.h" // CUDA_CHECK macro #include "util.h" // CUDA_CHECK macro
@ -63,19 +65,11 @@ constexpr int SMEM_H = 8; // Height of shared memory buffer (in # elements)
* CUDA Driver API * CUDA Driver API
*/ */
// The type of the cuTensorMapEncodeTiled function. PFN_cuTensorMapEncodeTiled get_cuTensorMapEncodeTiled() {
using cuTensorMapEncodeTiled_t = decltype(cuTensorMapEncodeTiled); void* driver_ptr = nullptr;
cudaDriverEntryPointQueryResult driver_status;
// Get function pointer to driver API cuTensorMapEncodeTiled CUDA_CHECK(cudaGetDriverEntryPoint("cuTensorMapEncodeTiled", &driver_ptr, cudaEnableDefault, &driver_status));
cuTensorMapEncodeTiled_t * get_cuTensorMapEncodeTiled() { return reinterpret_cast<PFN_cuTensorMapEncodeTiled>(driver_ptr);
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html
void* cuda_ptr = nullptr;
unsigned long long flags = cudaEnableDefault;
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__DRIVER__ENTRY__POINT.html
CUDA_CHECK(cudaGetDriverEntryPoint("cuTensorMapEncodeTiled", &cuda_ptr, flags));
return reinterpret_cast<cuTensorMapEncodeTiled_t*>(cuda_ptr);
} }
/* /*