add --with hipblaslt option

This commit is contained in:
Tom Rix 2024-02-24 07:39:48 -05:00
commit 7a89201199

View file

@ -0,0 +1,235 @@
From 186728c9f4de720547fa3d1b7951c7d5563ad3c7 Mon Sep 17 00:00:00 2001
From: Tom Rix <trix@redhat.com>
Date: Fri, 23 Feb 2024 08:27:30 -0500
Subject: [PATCH] Optionally use hipblaslt
The hipblaslt package is not available on Fedora.
Instead of requiring the package, make it optional.
If it is found, define the preprocessor variable HIPBLASLT
Convert the checks for ROCM_VERSION >= 507000 to HIPBLASLT checks
Signed-off-by: Tom Rix <trix@redhat.com>
---
aten/src/ATen/cuda/CUDABlas.cpp | 5 +++--
aten/src/ATen/cuda/CUDABlas.h | 2 +-
aten/src/ATen/cuda/CUDAContextLight.h | 4 ++--
aten/src/ATen/cuda/CublasHandlePool.cpp | 4 ++--
aten/src/ATen/cuda/tunable/TunableGemm.h | 6 +++---
aten/src/ATen/native/cuda/Blas.cpp | 10 ++++++----
cmake/Dependencies.cmake | 3 +++
cmake/public/LoadHIP.cmake | 4 ++--
8 files changed, 22 insertions(+), 16 deletions(-)
diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp
index 0a3de5f9d77..b67031e88a7 100644
--- a/aten/src/ATen/cuda/CUDABlas.cpp
+++ b/aten/src/ATen/cuda/CUDABlas.cpp
@@ -778,7 +778,7 @@ void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
}
}
-#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
+#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && defined(HIPBLASLT))
#if defined(USE_ROCM) && ROCM_VERSION >= 50700 && ROCM_VERSION < 60000
// only for rocm 5.7 where we first supported hipblaslt, it was difficult
@@ -909,6 +909,7 @@ class CuBlasLtMatmulPreference : public CuBlasLtDescriptor<
};
} // namespace
+#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && defined(HIPBLASLT))
template <typename Dtype>
void gemm_and_bias(
bool transpose_mat1,
@@ -1121,7 +1122,7 @@ template void gemm_and_bias(
at::BFloat16* result_ptr,
int64_t result_ld,
GEMMAndBiasActivationEpilogue activation);
-
+#endif
void scaled_gemm(
char transa,
char transb,
diff --git a/aten/src/ATen/cuda/CUDABlas.h b/aten/src/ATen/cuda/CUDABlas.h
index eb12bb350c5..068607467dd 100644
--- a/aten/src/ATen/cuda/CUDABlas.h
+++ b/aten/src/ATen/cuda/CUDABlas.h
@@ -82,7 +82,7 @@ void gemm_internal<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half));
template <>
void gemm_internal<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16));
-#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
+#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && defined(HIPBLASLT))
enum GEMMAndBiasActivationEpilogue {
None,
RELU,
diff --git a/aten/src/ATen/cuda/CUDAContextLight.h b/aten/src/ATen/cuda/CUDAContextLight.h
index 4ec35f59a21..e28dc42034f 100644
--- a/aten/src/ATen/cuda/CUDAContextLight.h
+++ b/aten/src/ATen/cuda/CUDAContextLight.h
@@ -9,7 +9,7 @@
// cublasLT was introduced in CUDA 10.1 but we enable only for 11.1 that also
// added bf16 support
-#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
+#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && defined(HIPBLASLT))
#include <cublasLt.h>
#endif
@@ -82,7 +82,7 @@ TORCH_CUDA_CPP_API c10::Allocator* getCUDADeviceAllocator();
/* Handles */
TORCH_CUDA_CPP_API cusparseHandle_t getCurrentCUDASparseHandle();
TORCH_CUDA_CPP_API cublasHandle_t getCurrentCUDABlasHandle();
-#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
+#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && defined(HIPBLASLT))
TORCH_CUDA_CPP_API cublasLtHandle_t getCurrentCUDABlasLtHandle();
#endif
diff --git a/aten/src/ATen/cuda/CublasHandlePool.cpp b/aten/src/ATen/cuda/CublasHandlePool.cpp
index d8ee09b1486..4510193a5cc 100644
--- a/aten/src/ATen/cuda/CublasHandlePool.cpp
+++ b/aten/src/ATen/cuda/CublasHandlePool.cpp
@@ -28,7 +28,7 @@ namespace at::cuda {
namespace {
-#if defined(USE_ROCM) && ROCM_VERSION >= 50700
+#if defined(USE_ROCM) && defined(HIPBLASLT)
void createCublasLtHandle(cublasLtHandle_t *handle) {
TORCH_CUDABLAS_CHECK(cublasLtCreate(handle));
}
@@ -177,7 +177,7 @@ cublasHandle_t getCurrentCUDABlasHandle() {
return handle;
}
-#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
+#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && defined(HIPBLASLT))
cublasLtHandle_t getCurrentCUDABlasLtHandle() {
#ifdef USE_ROCM
c10::DeviceIndex device = 0;
diff --git a/aten/src/ATen/cuda/tunable/TunableGemm.h b/aten/src/ATen/cuda/tunable/TunableGemm.h
index 3ba0d761277..dde1870cfbf 100644
--- a/aten/src/ATen/cuda/tunable/TunableGemm.h
+++ b/aten/src/ATen/cuda/tunable/TunableGemm.h
@@ -11,7 +11,7 @@
#include <ATen/cuda/tunable/GemmCommon.h>
#ifdef USE_ROCM
-#if ROCM_VERSION >= 50700
+#ifdef HIPBLASLT
#include <ATen/cuda/tunable/GemmHipblaslt.h>
#endif
#include <ATen/cuda/tunable/GemmRocblas.h>
@@ -166,7 +166,7 @@ class GemmTunableOp : public TunableOp<GemmParams<T>, StreamTimer> {
}
#endif
-#if defined(USE_ROCM) && ROCM_VERSION >= 50700
+#if defined(USE_ROCM) && defined(HIPBLASLT)
static const char *env = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED");
if (env == nullptr || strcmp(env, "1") == 0) {
// disallow tuning of hipblaslt with c10::complex
@@ -240,7 +240,7 @@ class GemmStridedBatchedTunableOp : public TunableOp<GemmStridedBatchedParams<T>
}
#endif
-#if defined(USE_ROCM) && ROCM_VERSION >= 50700
+#if defined(USE_ROCM) && defined(HIPBLASLT)
static const char *env = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED");
if (env == nullptr || strcmp(env, "1") == 0) {
// disallow tuning of hipblaslt with c10::complex
diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp
index 2b7ff1b502b..9ae2ccbd16d 100644
--- a/aten/src/ATen/native/cuda/Blas.cpp
+++ b/aten/src/ATen/native/cuda/Blas.cpp
@@ -153,7 +153,7 @@ enum class Activation {
GELU,
};
-#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
+#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && defined(HIPBLASLT))
cuda::blas::GEMMAndBiasActivationEpilogue activation_to_gemm_and_blas_arg(Activation a) {
switch (a) {
case Activation::None:
@@ -191,6 +191,7 @@ static bool getDisableAddmmCudaLt() {
#ifdef USE_ROCM
static bool isSupportedHipLtROCmArch(int index) {
+#if defined(HIPBLASLT)
hipDeviceProp_t* prop = at::cuda::getDeviceProperties(index);
std::string device_arch = prop->gcnArchName;
static const std::vector<std::string> archs = {"gfx90a", "gfx940", "gfx941", "gfx942"};
@@ -201,6 +202,7 @@ static bool isSupportedHipLtROCmArch(int index) {
}
}
TORCH_CHECK(false, "Attempting to use hipBLASLt on a unsupported architecture!");
+#endif
return false;
}
#endif
@@ -226,7 +228,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
at::ScalarType scalar_type = self.scalar_type();
c10::MaybeOwned<Tensor> self_;
if (&result != &self) {
-#if (defined(CUDA_VERSION) && CUDA_VERSION >= 11040 && !defined(_MSC_VER)) || defined(USE_ROCM) && ROCM_VERSION >= 50700
+#if (defined(CUDA_VERSION) && CUDA_VERSION >= 11040 && !defined(_MSC_VER)) || defined(USE_ROCM) && defined(HIPBLASLT)
// Strangely, if mat2 has only 1 row or column, we get
// CUBLAS_STATUS_INVALID_VALUE error from cublasLtMatmulAlgoGetHeuristic.
// self.dim() == 1 && result.dim() == 2 && self.sizes()[0] == mat2_sizes[1]
@@ -269,7 +271,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
}
self__sizes = self_->sizes();
} else {
-#if defined(USE_ROCM) && ROCM_VERSION >= 50700
+#if defined(USE_ROCM) && defined(HIPBLASLT)
useLtInterface = !disable_addmm_cuda_lt &&
result.dim() == 2 && result.is_contiguous() &&
isSupportedHipLtROCmArch(self.device().index()) &&
@@ -320,7 +322,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!args.result->is_conj());
-#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
+#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && defined(HIPBLASLT))
if (useLtInterface) {
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake
index b7ffbeb07dc..2b6c3678984 100644
--- a/cmake/Dependencies.cmake
+++ b/cmake/Dependencies.cmake
@@ -1273,6 +1273,9 @@ if(USE_ROCM)
if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "6.0.0")
list(APPEND HIP_CXX_FLAGS -DHIPBLAS_V2)
endif()
+ if(hipblast_FOUND)
+ list(APPEND HIP_CXX_FLAGS -DHIPBLASLT)
+ endif()
if(HIPBLASLT_CUSTOM_DATA_TYPE)
list(APPEND HIP_CXX_FLAGS -DHIPBLASLT_CUSTOM_DATA_TYPE)
endif()
diff --git a/cmake/public/LoadHIP.cmake b/cmake/public/LoadHIP.cmake
index f6ca263c5e5..53eb0b63c1a 100644
--- a/cmake/public/LoadHIP.cmake
+++ b/cmake/public/LoadHIP.cmake
@@ -156,7 +156,7 @@ if(HIP_FOUND)
find_package_and_print_version(rocblas REQUIRED)
find_package_and_print_version(hipblas REQUIRED)
if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "5.7.0")
- find_package_and_print_version(hipblaslt REQUIRED)
+ find_package_and_print_version(hipblaslt)
endif()
find_package_and_print_version(miopen REQUIRED)
if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "4.1.0")
@@ -191,7 +191,7 @@ if(HIP_FOUND)
# roctx is part of roctracer
find_library(ROCM_ROCTX_LIB roctx64 HINTS ${ROCM_PATH}/lib)
- if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "5.7.0")
+ if(hipblastlt_FOUND)
# check whether hipblaslt is using its own datatype
set(file "${PROJECT_BINARY_DIR}/hipblaslt_test_data_type.cc")
file(WRITE ${file} ""
--
2.43.2