From 214dc959acc809e1959643272c344ee5335d5a69 Mon Sep 17 00:00:00 2001 From: Tom Rix Date: Thu, 1 Feb 2024 11:29:47 -0500 Subject: [PATCH] cuda - hip signatures --- aten/src/ATen/cuda/detail/LazyNVRTC.cpp | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/aten/src/ATen/cuda/detail/LazyNVRTC.cpp b/aten/src/ATen/cuda/detail/LazyNVRTC.cpp index 1b85e7776e..bb6f88783a 100644 --- a/aten/src/ATen/cuda/detail/LazyNVRTC.cpp +++ b/aten/src/ATen/cuda/detail/LazyNVRTC.cpp @@ -134,8 +134,13 @@ nvrtcResult nvrtcCreateProgram(nvrtcProgram *prog, const char *src, const char *name, int numHeaders, +#if !defined(USE_ROCM) const char * const *headers, const char * const *includeNames) { +#else + const char **headers, + const char **includeNames) { +#endif auto fn = reinterpret_cast(getNVRTCLibrary().sym(__func__)); if (!fn) throw std::runtime_error("Can't get nvrtcCreateProgram"); @@ -150,7 +155,11 @@ NVRTC_STUB2(nvrtcGetPTX, nvrtcProgram, char *); NVRTC_STUB2(nvrtcGetCUBINSize, nvrtcProgram, size_t *); NVRTC_STUB2(nvrtcGetCUBIN, nvrtcProgram, char *); #endif +#if !defined(USE_ROCM) NVRTC_STUB3(nvrtcCompileProgram, nvrtcProgram, int, const char * const *); +#else +NVRTC_STUB3(nvrtcCompileProgram, nvrtcProgram, int, const char **); +#endif _STUB_1(NVRTC, nvrtcGetErrorString, const char *, nvrtcResult); NVRTC_STUB2(nvrtcGetProgramLogSize,nvrtcProgram, size_t*); NVRTC_STUB2(nvrtcGetProgramLog, nvrtcProgram, char *); -- 2.43.0