Fixed the error path to initialize error path function pointers. Exit with error in case of LOADLIBRARY failureas initialize of function pointers in case of LOADLIBRARY failure will fail

This commit is contained in:
Nikhil Talpallikar 2025-08-06 00:29:22 -07:00
parent 527b29dbd0
commit d2c52db3e0
2 changed files with 70 additions and 8 deletions

View File

@ -242,12 +242,23 @@ static CUresult LOAD_LIBRARY(CUDADRIVER *pInstance)
if (*pInstance == NULL) { if (*pInstance == NULL) {
printf("LoadLibrary \"%s\" failed!\n", __CudaLibName); printf("LoadLibrary \"%s\" failed!\n", __CudaLibName);
return CUDA_ERROR_UNKNOWN; exit(EXIT_FAILURE);
} }
return CUDA_SUCCESS; return CUDA_SUCCESS;
} }
CUresult GET_DRIVER_HANDLE(CUDADRIVER* pInstance)
{
*pInstance = GetModuleHandle(__CudaLibName);
if (*pInstance) {
return CUDA_SUCCESS;
}
else {
return CUDA_ERROR_UNKNOWN;
}
}
#define GET_PROC_EX(name, alias, required) \ #define GET_PROC_EX(name, alias, required) \
alias = (t##name *)GetProcAddress(CudaDrvLib, #name); \ alias = (t##name *)GetProcAddress(CudaDrvLib, #name); \
if (alias == NULL && required) { \ if (alias == NULL && required) { \
@ -269,6 +280,13 @@ static CUresult LOAD_LIBRARY(CUDADRIVER *pInstance)
return CUDA_ERROR_UNKNOWN; \ return CUDA_ERROR_UNKNOWN; \
} }
#define GET_PROC_ERROR_FUNCTIONS(name, alias, required) \
alias = (t##name *)GetProcAddress(CudaDrvLib, #name); \
if (alias == NULL && required) { \
printf("Failed to find error function \"%s\" in %s\n", #name, __CudaLibName); \
exit(EXIT_FAILURE); \
} \
#elif defined(__unix__) || defined(__QNX__) || defined(__APPLE__) || defined(__MACOSX) #elif defined(__unix__) || defined(__QNX__) || defined(__APPLE__) || defined(__MACOSX)
#include <dlfcn.h> #include <dlfcn.h>
@ -293,12 +311,23 @@ static CUresult LOAD_LIBRARY(CUDADRIVER *pInstance)
if (*pInstance == NULL) { if (*pInstance == NULL) {
printf("dlopen \"%s\" failed!\n", __CudaLibName); printf("dlopen \"%s\" failed!\n", __CudaLibName);
return CUDA_ERROR_UNKNOWN; exit(EXIT_FAILURE);
} }
return CUDA_SUCCESS; return CUDA_SUCCESS;
} }
CUresult GET_DRIVER_HANDLE(CUDADRIVER* pInstance)
{
*pInstance = dlopen(__CudaLibName, RTLD_NOLOAD);
if (*pInstance) {
return CUDA_SUCCESS;
}
else {
return CUDA_ERROR_UNKNOWN;
}
}
#define GET_PROC_EX(name, alias, required) \ #define GET_PROC_EX(name, alias, required) \
alias = (t##name *)dlsym(CudaDrvLib, #name); \ alias = (t##name *)dlsym(CudaDrvLib, #name); \
if (alias == NULL && required) { \ if (alias == NULL && required) { \
@ -320,33 +349,56 @@ static CUresult LOAD_LIBRARY(CUDADRIVER *pInstance)
return CUDA_ERROR_UNKNOWN; \ return CUDA_ERROR_UNKNOWN; \
} }
#define GET_PROC_ERROR_FUNCTIONS(name, alias, required) \
alias = (t##name *)dlsym(CudaDrvLib, #name); \
if (alias == NULL && required) { \
printf("Failed to find error function \"%s\" in %s\n", #name, __CudaLibName); \
exit(EXIT_FAILURE); \
}
#else #else
#error unsupported platform #error unsupported platform
#endif #endif
#define CHECKED_CALL(call) \
do { \
CUresult result = (call); \
if (CUDA_SUCCESS != result) { \
return result; \
} \
} while (0)
#define GET_PROC_REQUIRED(name) GET_PROC_EX(name, name, 1) #define GET_PROC_REQUIRED(name) GET_PROC_EX(name, name, 1)
#define GET_PROC_OPTIONAL(name) GET_PROC_EX(name, name, 0) #define GET_PROC_OPTIONAL(name) GET_PROC_EX(name, name, 0)
#define GET_PROC(name) GET_PROC_REQUIRED(name) #define GET_PROC(name) GET_PROC_REQUIRED(name)
#define GET_PROC_V2(name) GET_PROC_EX_V2(name, name, 1) #define GET_PROC_V2(name) GET_PROC_EX_V2(name, name, 1)
#define GET_PROC_V3(name) GET_PROC_EX_V3(name, name, 1) #define GET_PROC_V3(name) GET_PROC_EX_V3(name, name, 1)
CUresult INIT_ERROR_FUNCTIONS(void)
{
CUDADRIVER CudaDrvLib;
CUresult result = CUDA_SUCCESS;
result = GET_DRIVER_HANDLE(&CudaDrvLib);
GET_PROC_ERROR_FUNCTIONS(cuGetErrorString, cuGetErrorString, 1);
return result;
}
CUresult CUDAAPI cuInit(unsigned int Flags, int cudaVersion) CUresult CUDAAPI cuInit(unsigned int Flags, int cudaVersion)
{ {
CUDADRIVER CudaDrvLib; CUDADRIVER CudaDrvLib;
int driverVer = 1000; int driverVer = 1000;
CUresult result = CUDA_SUCCESS;
result = LOAD_LIBRARY(&CudaDrvLib); CHECKED_CALL(LOAD_LIBRARY(&CudaDrvLib));
// cuInit is required; alias it to _cuInit // cuInit is required; alias it to _cuInit
GET_PROC_EX(cuInit, _cuInit, 1); GET_PROC_EX(cuInit, _cuInit, 1);
result = _cuInit(Flags); CHECKED_CALL(_cuInit(Flags));
// available since 2.2. if not present, version 1.0 is assumed // available since 2.2. if not present, version 1.0 is assumed
GET_PROC_OPTIONAL(cuDriverGetVersion); GET_PROC_OPTIONAL(cuDriverGetVersion);
if (cuDriverGetVersion) { if (cuDriverGetVersion) {
result = cuDriverGetVersion(&driverVer); CHECKED_CALL(cuDriverGetVersion(&driverVer));
} }
// fetch all function pointers // fetch all function pointers
@ -612,5 +664,5 @@ CUresult CUDAAPI cuInit(unsigned int Flags, int cudaVersion)
GET_PROC(cuGraphicsD3D9RegisterResource); GET_PROC(cuGraphicsD3D9RegisterResource);
#endif #endif
} }
return result; return CUDA_SUCCESS;
} }

View File

@ -42,11 +42,21 @@ inline int ftoi(float value) { return (value >= 0 ? static_cast<int>(value + 0.5
#ifndef checkCudaErrors #ifndef checkCudaErrors
#define checkCudaErrors(err) __checkCudaErrors(err, __FILE__, __LINE__) #define checkCudaErrors(err) __checkCudaErrors(err, __FILE__, __LINE__)
extern "C" CUresult INIT_ERROR_FUNCTIONS(void);
// These are the inline versions for all of the SDK helper functions // These are the inline versions for all of the SDK helper functions
inline void __checkCudaErrors(CUresult err, const char *file, const int line) inline void __checkCudaErrors(CUresult err, const char *file, const int line)
{ {
if (CUDA_SUCCESS != err) { if (CUDA_SUCCESS != err) {
const char *errorStr = NULL; const char *errorStr = NULL;
if (!cuGetErrorString) {
CUresult result = INIT_ERROR_FUNCTIONS();
if (result != CUDA_SUCCESS) {
printf("CUDA driver API failed");
exit(EXIT_FAILURE);
}
}
cuGetErrorString(err, &errorStr); cuGetErrorString(err, &errorStr);
fprintf(stderr, fprintf(stderr,
"checkCudaErrors() Driver API error = %04d \"%s\" from file <%s>, " "checkCudaErrors() Driver API error = %04d \"%s\" from file <%s>, "