From f1d65e958afa65882dbfea8b392ab847a84d41ed Mon Sep 17 00:00:00 2001 From: Tom Rix Date: Sat, 29 Jun 2024 04:18:34 -0700 Subject: [PATCH] Optionally use hipblaslt --- aten/src/ATen/cuda/CUDABlas.cpp | 46 ++++++++++++++++++------ aten/src/ATen/cuda/CUDAContextLight.h | 4 +++ aten/src/ATen/cuda/CublasHandlePool.cpp | 10 ++++-- aten/src/ATen/cuda/tunable/TunableGemm.h | 18 +++++++--- aten/src/ATen/native/cuda/Blas.cpp | 18 +++++++++- cmake/Dependencies.cmake | 3 ++ cmake/public/LoadHIP.cmake | 2 +- 7 files changed, 82 insertions(+), 19 deletions(-) diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp index ce991a9bcad4..3f0d17b52778 100644 --- a/aten/src/ATen/cuda/CUDABlas.cpp +++ b/aten/src/ATen/cuda/CUDABlas.cpp @@ -14,7 +14,9 @@ #include #ifdef USE_ROCM +#ifdef USE_HIPBLASLT #include +#endif // until hipblas has an API to accept flags, we must use rocblas here #include #include @@ -182,6 +184,9 @@ uint32_t _getAlignment(uintptr_t address) { static size_t _parseChosenWorkspaceSize() { const char * val = getenv("CUBLASLT_WORKSPACE_SIZE"); #ifdef USE_ROCM +#ifndef USE_HIPBLASLT + return 0; +#endif if (!val) { // accept either env var val = getenv("HIPBLASLT_WORKSPACE_SIZE"); @@ -235,6 +240,7 @@ namespace at::cuda::blas { } while (0) +#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT)) namespace { // Following the pattern of CuSparseDescriptor // Defined here for now because this is the only place cublas_lt interface is @@ -318,7 +324,6 @@ class CuBlasLtMatmulPreference : public CuBlasLtDescriptor< }; } // namespace - template inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) { cudaDataType_t abcType = CUDA_R_32F; @@ -452,7 +457,7 @@ inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) { " scaleType ", scaleType); } - +#endif template inline void bgemm_internal_cublas(CUDABLAS_BGEMM_ARGTYPES(Dtype)) { @@ -608,10 +613,13 @@ void bgemm_internal(CUDABLAS_BGEMM_ARGTYPES(double)) template <> void bgemm_internal(CUDABLAS_BGEMM_ARGTYPES(float)) { +#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT)) if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) { bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGS(float)); } - else { + else +#endif + { bgemm_internal_cublas(CUDABLAS_BGEMM_ARGS(float)); } } @@ -651,10 +659,13 @@ void bgemm_internal>(CUDABLAS_BGEMM_ARGTYPES(c10::complex void bgemm_internal(CUDABLAS_BGEMM_ARGTYPES(at::Half)) { +#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT)) if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) { bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGS(at::Half)); } - else { + else +#endif + { bgemm_internal_cublas(CUDABLAS_BGEMM_ARGS(at::Half)); } } @@ -662,10 +673,13 @@ void bgemm_internal(CUDABLAS_BGEMM_ARGTYPES(at::Half)) template <> void bgemm_internal(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) { +#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT)) if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) { bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGS(at::BFloat16)); } - else { + else +#endif + { bgemm_internal_cublas(CUDABLAS_BGEMM_ARGS(at::BFloat16)); } } @@ -781,11 +795,13 @@ void bgemm(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) { } } +#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT)) template inline void gemm_internal_cublaslt(CUDABLAS_GEMM_ARGTYPES(Dtype)) { // forward to bgemm implementation but set strides and batches to 0 bgemm_internal_cublaslt(transa, transb, m, n, k, alpha, a, lda, 0, b, ldb, 0, beta, c, ldc, 0, 0); } +#endif template inline void gemm_internal_cublas(CUDABLAS_GEMM_ARGTYPES(Dtype)) { @@ -1008,10 +1024,13 @@ void gemm_internal(CUDABLAS_GEMM_ARGTYPES(double)) template <> void gemm_internal(CUDABLAS_GEMM_ARGTYPES(float)) { +#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT)) if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) { gemm_internal_cublaslt(CUDABLAS_GEMM_ARGS(float)); } - else { + else +#endif + { gemm_internal_cublas(CUDABLAS_GEMM_ARGS(float)); } } @@ -1051,10 +1070,13 @@ void gemm_internal>(CUDABLAS_GEMM_ARGTYPES(c10::complex void gemm_internal(CUDABLAS_GEMM_ARGTYPES(at::Half)) { +#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT)) if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) { gemm_internal_cublaslt(CUDABLAS_GEMM_ARGS(at::Half)); } - else { + else +#endif + { gemm_internal_cublas(CUDABLAS_GEMM_ARGS(at::Half)); } } @@ -1062,10 +1084,13 @@ void gemm_internal(CUDABLAS_GEMM_ARGTYPES(at::Half)) template <> void gemm_internal(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) { +#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT)) if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) { gemm_internal_cublaslt(CUDABLAS_GEMM_ARGS(at::BFloat16)); } - else { + else +#endif + { gemm_internal_cublas(CUDABLAS_GEMM_ARGS(at::BFloat16)); } } @@ -1177,7 +1202,7 @@ void gemm(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) { } } - +#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT)) template void gemm_and_bias( bool transpose_mat1, @@ -1410,7 +1435,7 @@ void scaled_gemm( ScalarType result_dtype, void* amax_ptr, bool use_fast_accum) { -#if CUDA_VERSION >= 11080 || defined(USE_ROCM) +#if CUDA_VERSION >= 11080 || (defined(USE_ROCM) && defined(USE_HIPBLASLT)) const auto computeType = CUBLAS_COMPUTE_32F; const auto scaleType = CUDA_R_32F; const int8_t fastAccuMode = use_fast_accum ? 1 : 0; @@ -1681,6 +1706,7 @@ void int8_gemm( " scaleType ", scaleType); } +#endif template <> void trsm(CUDABLAS_TRSM_ARGTYPES(float)) { diff --git a/aten/src/ATen/cuda/CUDAContextLight.h b/aten/src/ATen/cuda/CUDAContextLight.h index f2b657ced51b..f0ee613c4208 100644 --- a/aten/src/ATen/cuda/CUDAContextLight.h +++ b/aten/src/ATen/cuda/CUDAContextLight.h @@ -9,7 +9,9 @@ // cublasLT was introduced in CUDA 10.1 but we enable only for 11.1 that also // added bf16 support +#if (!defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))) #include +#endif #ifdef CUDART_VERSION #include @@ -80,7 +82,9 @@ TORCH_CUDA_CPP_API c10::Allocator* getCUDADeviceAllocator(); /* Handles */ TORCH_CUDA_CPP_API cusparseHandle_t getCurrentCUDASparseHandle(); TORCH_CUDA_CPP_API cublasHandle_t getCurrentCUDABlasHandle(); +#if (!defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))) TORCH_CUDA_CPP_API cublasLtHandle_t getCurrentCUDABlasLtHandle(); +#endif TORCH_CUDA_CPP_API void clearCublasWorkspaces(); diff --git a/aten/src/ATen/cuda/CublasHandlePool.cpp b/aten/src/ATen/cuda/CublasHandlePool.cpp index 8eac525b3695..abfdf7a23847 100644 --- a/aten/src/ATen/cuda/CublasHandlePool.cpp +++ b/aten/src/ATen/cuda/CublasHandlePool.cpp @@ -29,7 +29,7 @@ namespace at::cuda { namespace { -#if defined(USE_ROCM) +#if defined(USE_ROCM) && defined(USE_HIPBLASLT) void createCublasLtHandle(cublasLtHandle_t *handle) { TORCH_CUDABLAS_CHECK(cublasLtCreate(handle)); } @@ -191,8 +191,9 @@ cublasHandle_t getCurrentCUDABlasHandle() { return handle; } -cublasLtHandle_t getCurrentCUDABlasLtHandle() { #ifdef USE_ROCM +#if defined(USE_HIPBLASLT) +cublasLtHandle_t getCurrentCUDABlasLtHandle() { c10::DeviceIndex device = 0; AT_CUDA_CHECK(c10::cuda::GetDevice(&device)); @@ -213,9 +214,12 @@ cublasLtHandle_t getCurrentCUDABlasLtHandle() { auto handle = myPoolWindow->reserve(device); return handle; +} +#endif #else +cublasLtHandle_t getCurrentCUDABlasLtHandle() { return reinterpret_cast(getCurrentCUDABlasHandle()); -#endif } +#endif } // namespace at::cuda diff --git a/aten/src/ATen/cuda/tunable/TunableGemm.h b/aten/src/ATen/cuda/tunable/TunableGemm.h index 53e6154120c9..fa1d664696db 100644 --- a/aten/src/ATen/cuda/tunable/TunableGemm.h +++ b/aten/src/ATen/cuda/tunable/TunableGemm.h @@ -11,7 +11,9 @@ #include #ifdef USE_ROCM +#ifdef USE_HIPBLASLT #include +#endif #include #endif #include @@ -65,6 +67,7 @@ class DefaultGemmStridedBatchedOp : public Callable> } }; +#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT)) template class DefaultScaledGemmOp : public Callable> { public: @@ -94,6 +97,7 @@ class DefaultScaledGemmOp : public Callable> { return OK; } }; +#endif template inline bool IsZero(T v) { @@ -191,6 +195,7 @@ static void AddRocblasValidator() { } } +#ifdef USE_HIPBLASLT static void AddHipblasltValidator() { auto validators = getTuningContext()->GetTuningResultsValidator().GetAllValidators(); if (validators.find("HIPBLASLT_VERSION") == validators.end()) { @@ -205,6 +210,7 @@ static void AddHipblasltValidator() { [hipblaslt_version](auto&& k) { return hipblaslt_version == k ? OK : FAIL; }); } } +#endif static void AddRocmValidator() { auto validators = getTuningContext()->GetTuningResultsValidator().GetAllValidators(); @@ -243,7 +249,7 @@ class GemmTunableOp : public TunableOp, StreamTimer> { } AddRocblasValidator(); } - +#ifdef USE_HIPBLASLT static const char *env_hipblaslt = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED"); if (env_hipblaslt == nullptr || strcmp(env_hipblaslt, "1") == 0) { rocm_validators = true; @@ -257,7 +263,7 @@ class GemmTunableOp : public TunableOp, StreamTimer> { } AddHipblasltValidator(); } - +#endif if (rocm_validators) { AddRocmValidator(); } @@ -286,7 +292,7 @@ class GemmStridedBatchedTunableOp : public TunableOp } AddRocblasValidator(); } - +#ifdef USE_HIPBLASLT static const char *env_hipblaslt = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED"); if (env_hipblaslt == nullptr || strcmp(env_hipblaslt, "1") == 0) { rocm_validators = true; @@ -300,7 +306,7 @@ class GemmStridedBatchedTunableOp : public TunableOp } AddHipblasltValidator(); } - +#endif if (rocm_validators) { AddRocmValidator(); } @@ -312,6 +318,7 @@ class GemmStridedBatchedTunableOp : public TunableOp } }; +#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT)) template class ScaledGemmTunableOp : public TunableOp, StreamTimer> { public: @@ -321,10 +328,12 @@ class ScaledGemmTunableOp : public TunableOp, StreamTimer> auto validators = getTuningContext()->GetTuningResultsValidator().GetAllValidators(); #if defined(USE_ROCM) +#ifdef USE_HIPBLASLT for (auto&& [name, op] : GetHipBlasLtScaledGemmTypeStringAndOps()) { this->RegisterOp(std::move(name), std::move(op)); } AddHipblasltValidator(); +#endif AddRocmValidator(); #endif } @@ -337,6 +346,7 @@ class ScaledGemmTunableOp : public TunableOp, StreamTimer> "_", BlasOpToString(ALayout), BlasOpToString(BLayout)); } }; +#endif #undef XSTRINGIFY #undef STRINGIFY diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp index 84c59a4fd0d7..56ad5de3bf2d 100644 --- a/aten/src/ATen/native/cuda/Blas.cpp +++ b/aten/src/ATen/native/cuda/Blas.cpp @@ -173,6 +173,7 @@ cuda::blas::GEMMAndBiasActivationEpilogue activation_to_gemm_and_blas_arg(Activa } static bool getDisableAddmmCudaLt() { +#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT)) static const char* env_value = std::getenv("DISABLE_ADDMM_CUDA_LT"); #ifdef USE_ROCM // if we enable tunable op, it'll take priority over just hipblaslt (heuristics) @@ -196,10 +197,14 @@ static bool getDisableAddmmCudaLt() { } return false; #endif +#else + return true; +#endif } #ifdef USE_ROCM static bool isSupportedHipLtROCmArch(int index) { +#ifdef USE_HIPBLASLT hipDeviceProp_t* prop = at::cuda::getDeviceProperties(index); std::string device_arch = prop->gcnArchName; static const std::vector archs = {"gfx90a", "gfx940", "gfx941", "gfx942"}; @@ -210,6 +215,7 @@ static bool isSupportedHipLtROCmArch(int index) { } } TORCH_CHECK(false, "Attempting to use hipBLASLt on a unsupported architecture!"); +#endif return false; } #endif @@ -235,6 +241,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma at::ScalarType scalar_type = self.scalar_type(); c10::MaybeOwned self_; if (&result != &self) { +#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT)) #if (defined(CUDA_VERSION) && (CUDA_VERSION >= 11040)) || defined(USE_ROCM) // Strangely, if mat2 has only 1 row or column, we get // CUBLAS_STATUS_INVALID_VALUE error from cublasLtMatmulAlgoGetHeuristic. @@ -276,13 +283,14 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma scalar_type != at::ScalarType::BFloat16)); #endif } +#endif #endif if (!useLtInterface) { self_ = expand_size(self, {mat1_sizes[0], mat2_sizes[1]}, "addmm"); } self__sizes = self_->sizes(); } else { -#if defined(USE_ROCM) +#if defined(USE_ROCM) && defined(USE_HIPBLASLT) useLtInterface = !disable_addmm_cuda_lt && result.dim() == 2 && result.is_contiguous() && isSupportedHipLtROCmArch(self.device().index()) && @@ -334,6 +342,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!args.result->is_conj()); if (useLtInterface) { +#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT)) #if defined(USE_ROCM) AT_DISPATCH_FLOATING_TYPES_AND2( at::ScalarType::Half, @@ -394,6 +403,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma activation_epilogue ); }); +#endif #endif } else { @@ -803,6 +813,7 @@ Tensor _int_mm_cuda(const Tensor& self, const Tensor& mat2) { } static bool _scaled_mm_allowed_device() { +#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT)) auto dprops = at::cuda::getCurrentDeviceProperties(); #ifdef USE_ROCM std::string device_arch = dprops->gcnArchName; @@ -817,6 +828,9 @@ static bool _scaled_mm_allowed_device() { #else return dprops->major >= 9 || (dprops->major == 8 && dprops->minor == 9); #endif +#else + return false; +#endif } // Computes matrix multiply + bias while applying scaling to input and output matrices and computes amax @@ -850,6 +864,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, // Check sizes bool allowed_device = _scaled_mm_allowed_device(); TORCH_CHECK(allowed_device, "torch._scaled_mm is only supported on CUDA devices with compute capability >= 9.0 or 8.9, or ROCm MI300+"); +#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT)) TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix"); TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix"); TORCH_CHECK( @@ -1025,6 +1040,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, #if defined(USE_ROCM) && ROCM_VERSION >= 60000 && ROCM_VERSION < 60200 // ROCm's hipBLASLt does not support amax before 6.2, so calculate separately amax = at::max(at::abs(out.to(kFloat))); +#endif #endif return {out, amax}; diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index f1f2eb7cec31..8d05e834bbc5 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -1052,6 +1052,9 @@ if(USE_ROCM) list(APPEND HIP_CXX_FLAGS -DTHRUST_DEVICE_SYSTEM=THRUST_DEVICE_SYSTEM_HIP) list(APPEND HIP_CXX_FLAGS -std=c++17) list(APPEND HIP_CXX_FLAGS -DHIPBLAS_V2) + if(hipblast_FOUND) + list(APPEND HIP_CXX_FLAGS -DUSE_HIPBLASLT) + endif() if(HIP_NEW_TYPE_ENUMS) list(APPEND HIP_CXX_FLAGS -DHIP_NEW_TYPE_ENUMS) endif() diff --git a/cmake/public/LoadHIP.cmake b/cmake/public/LoadHIP.cmake index fa39156031ff..df4836847fdf 100644 --- a/cmake/public/LoadHIP.cmake +++ b/cmake/public/LoadHIP.cmake @@ -155,7 +155,7 @@ if(HIP_FOUND) find_package_and_print_version(hiprand REQUIRED) find_package_and_print_version(rocblas REQUIRED) find_package_and_print_version(hipblas REQUIRED) - find_package_and_print_version(hipblaslt REQUIRED) + find_package_and_print_version(hipblaslt) find_package_and_print_version(miopen REQUIRED) find_package_and_print_version(hipfft REQUIRED) find_package_and_print_version(hipsparse REQUIRED) -- 2.45.2