Use CTK-provided type for cuTensorMapEncodeTiled

This commit is contained in:
Allard Hendriksen 2023-08-07 17:36:27 +02:00
parent 5925483b33
commit a3b5b817e3
No known key found for this signature in database

View File

@ -32,9 +32,9 @@
* - Create a TensorMap (TMA descriptor)
* - Load a 2D tile of data into shared memory
*
* Compile with:
* Compile and run with:
*
* nvcc -arch sm_90 globalToShmemTMACopy.cu -o globalToShmemTMACopy
* nvcc -arch sm_90 -run globalToShmemTMACopy.cu
*
* It can be that the compiler issues the following note. This can be safely ignored.
*
@ -42,10 +42,12 @@
* GCC 4.6
*
*/
#include <cstdio> // fprintf
#include <vector> // std::vector
#include <cstdio> // printf
#include <vector> // std::vector
#include <cudaTypedefs.h> // PFN_cuTensorMapEncodeTiled
#include <cuda.h> // CUtensormap
#include <cuda.h> // CUtensorMap
#include <cuda_awbarrier_primitives.h> // __mbarrier_*
#include "util.h" // CUDA_CHECK macro
@ -63,19 +65,11 @@ constexpr int SMEM_H = 8; // Height of shared memory buffer (in # elements)
* CUDA Driver API
*/
// The type of the cuTensorMapEncodeTiled function.
using cuTensorMapEncodeTiled_t = decltype(cuTensorMapEncodeTiled);
// Get function pointer to driver API cuTensorMapEncodeTiled
cuTensorMapEncodeTiled_t * get_cuTensorMapEncodeTiled() {
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html
void* cuda_ptr = nullptr;
unsigned long long flags = cudaEnableDefault;
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__DRIVER__ENTRY__POINT.html
CUDA_CHECK(cudaGetDriverEntryPoint("cuTensorMapEncodeTiled", &cuda_ptr, flags));
return reinterpret_cast<cuTensorMapEncodeTiled_t*>(cuda_ptr);
PFN_cuTensorMapEncodeTiled get_cuTensorMapEncodeTiled() {
void* driver_ptr = nullptr;
cudaDriverEntryPointQueryResult driver_status;
CUDA_CHECK(cudaGetDriverEntryPoint("cuTensorMapEncodeTiled", &driver_ptr, cudaEnableDefault, &driver_status));
return reinterpret_cast<PFN_cuTensorMapEncodeTiled>(driver_ptr);
}
/*