235 lines
9.7 KiB
Diff
235 lines
9.7 KiB
Diff
From 186728c9f4de720547fa3d1b7951c7d5563ad3c7 Mon Sep 17 00:00:00 2001
|
|
From: Tom Rix <trix@redhat.com>
|
|
Date: Fri, 23 Feb 2024 08:27:30 -0500
|
|
Subject: [PATCH] Optionally use hipblaslt
|
|
|
|
The hipblaslt package is not available on Fedora.
|
|
Instead of requiring the package, make it optional.
|
|
If it is found, define the preprocessor variable HIPBLASLT
|
|
Convert the checks for ROCM_VERSION >= 507000 to HIPBLASLT checks
|
|
|
|
Signed-off-by: Tom Rix <trix@redhat.com>
|
|
---
|
|
aten/src/ATen/cuda/CUDABlas.cpp | 5 +++--
|
|
aten/src/ATen/cuda/CUDABlas.h | 2 +-
|
|
aten/src/ATen/cuda/CUDAContextLight.h | 4 ++--
|
|
aten/src/ATen/cuda/CublasHandlePool.cpp | 4 ++--
|
|
aten/src/ATen/cuda/tunable/TunableGemm.h | 6 +++---
|
|
aten/src/ATen/native/cuda/Blas.cpp | 10 ++++++----
|
|
cmake/Dependencies.cmake | 3 +++
|
|
cmake/public/LoadHIP.cmake | 4 ++--
|
|
8 files changed, 22 insertions(+), 16 deletions(-)
|
|
|
|
diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp
|
|
index 0a3de5f9d77..b67031e88a7 100644
|
|
--- a/aten/src/ATen/cuda/CUDABlas.cpp
|
|
+++ b/aten/src/ATen/cuda/CUDABlas.cpp
|
|
@@ -778,7 +778,7 @@ void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
|
|
}
|
|
}
|
|
|
|
-#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
|
|
+#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && defined(HIPBLASLT))
|
|
|
|
#if defined(USE_ROCM) && ROCM_VERSION >= 50700 && ROCM_VERSION < 60000
|
|
// only for rocm 5.7 where we first supported hipblaslt, it was difficult
|
|
@@ -909,6 +909,7 @@ class CuBlasLtMatmulPreference : public CuBlasLtDescriptor<
|
|
};
|
|
} // namespace
|
|
|
|
+#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && defined(HIPBLASLT))
|
|
template <typename Dtype>
|
|
void gemm_and_bias(
|
|
bool transpose_mat1,
|
|
@@ -1121,7 +1122,7 @@ template void gemm_and_bias(
|
|
at::BFloat16* result_ptr,
|
|
int64_t result_ld,
|
|
GEMMAndBiasActivationEpilogue activation);
|
|
-
|
|
+#endif
|
|
void scaled_gemm(
|
|
char transa,
|
|
char transb,
|
|
diff --git a/aten/src/ATen/cuda/CUDABlas.h b/aten/src/ATen/cuda/CUDABlas.h
|
|
index eb12bb350c5..068607467dd 100644
|
|
--- a/aten/src/ATen/cuda/CUDABlas.h
|
|
+++ b/aten/src/ATen/cuda/CUDABlas.h
|
|
@@ -82,7 +82,7 @@ void gemm_internal<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half));
|
|
template <>
|
|
void gemm_internal<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16));
|
|
|
|
-#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
|
|
+#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && defined(HIPBLASLT))
|
|
enum GEMMAndBiasActivationEpilogue {
|
|
None,
|
|
RELU,
|
|
diff --git a/aten/src/ATen/cuda/CUDAContextLight.h b/aten/src/ATen/cuda/CUDAContextLight.h
|
|
index 4ec35f59a21..e28dc42034f 100644
|
|
--- a/aten/src/ATen/cuda/CUDAContextLight.h
|
|
+++ b/aten/src/ATen/cuda/CUDAContextLight.h
|
|
@@ -9,7 +9,7 @@
|
|
|
|
// cublasLT was introduced in CUDA 10.1 but we enable only for 11.1 that also
|
|
// added bf16 support
|
|
-#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
|
|
+#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && defined(HIPBLASLT))
|
|
#include <cublasLt.h>
|
|
#endif
|
|
|
|
@@ -82,7 +82,7 @@ TORCH_CUDA_CPP_API c10::Allocator* getCUDADeviceAllocator();
|
|
/* Handles */
|
|
TORCH_CUDA_CPP_API cusparseHandle_t getCurrentCUDASparseHandle();
|
|
TORCH_CUDA_CPP_API cublasHandle_t getCurrentCUDABlasHandle();
|
|
-#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
|
|
+#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && defined(HIPBLASLT))
|
|
TORCH_CUDA_CPP_API cublasLtHandle_t getCurrentCUDABlasLtHandle();
|
|
#endif
|
|
|
|
diff --git a/aten/src/ATen/cuda/CublasHandlePool.cpp b/aten/src/ATen/cuda/CublasHandlePool.cpp
|
|
index d8ee09b1486..4510193a5cc 100644
|
|
--- a/aten/src/ATen/cuda/CublasHandlePool.cpp
|
|
+++ b/aten/src/ATen/cuda/CublasHandlePool.cpp
|
|
@@ -28,7 +28,7 @@ namespace at::cuda {
|
|
|
|
namespace {
|
|
|
|
-#if defined(USE_ROCM) && ROCM_VERSION >= 50700
|
|
+#if defined(USE_ROCM) && defined(HIPBLASLT)
|
|
void createCublasLtHandle(cublasLtHandle_t *handle) {
|
|
TORCH_CUDABLAS_CHECK(cublasLtCreate(handle));
|
|
}
|
|
@@ -177,7 +177,7 @@ cublasHandle_t getCurrentCUDABlasHandle() {
|
|
return handle;
|
|
}
|
|
|
|
-#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
|
|
+#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && defined(HIPBLASLT))
|
|
cublasLtHandle_t getCurrentCUDABlasLtHandle() {
|
|
#ifdef USE_ROCM
|
|
c10::DeviceIndex device = 0;
|
|
diff --git a/aten/src/ATen/cuda/tunable/TunableGemm.h b/aten/src/ATen/cuda/tunable/TunableGemm.h
|
|
index 3ba0d761277..dde1870cfbf 100644
|
|
--- a/aten/src/ATen/cuda/tunable/TunableGemm.h
|
|
+++ b/aten/src/ATen/cuda/tunable/TunableGemm.h
|
|
@@ -11,7 +11,7 @@
|
|
|
|
#include <ATen/cuda/tunable/GemmCommon.h>
|
|
#ifdef USE_ROCM
|
|
-#if ROCM_VERSION >= 50700
|
|
+#ifdef HIPBLASLT
|
|
#include <ATen/cuda/tunable/GemmHipblaslt.h>
|
|
#endif
|
|
#include <ATen/cuda/tunable/GemmRocblas.h>
|
|
@@ -166,7 +166,7 @@ class GemmTunableOp : public TunableOp<GemmParams<T>, StreamTimer> {
|
|
}
|
|
#endif
|
|
|
|
-#if defined(USE_ROCM) && ROCM_VERSION >= 50700
|
|
+#if defined(USE_ROCM) && defined(HIPBLASLT)
|
|
static const char *env = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED");
|
|
if (env == nullptr || strcmp(env, "1") == 0) {
|
|
// disallow tuning of hipblaslt with c10::complex
|
|
@@ -240,7 +240,7 @@ class GemmStridedBatchedTunableOp : public TunableOp<GemmStridedBatchedParams<T>
|
|
}
|
|
#endif
|
|
|
|
-#if defined(USE_ROCM) && ROCM_VERSION >= 50700
|
|
+#if defined(USE_ROCM) && defined(HIPBLASLT)
|
|
static const char *env = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED");
|
|
if (env == nullptr || strcmp(env, "1") == 0) {
|
|
// disallow tuning of hipblaslt with c10::complex
|
|
diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp
|
|
index 2b7ff1b502b..9ae2ccbd16d 100644
|
|
--- a/aten/src/ATen/native/cuda/Blas.cpp
|
|
+++ b/aten/src/ATen/native/cuda/Blas.cpp
|
|
@@ -153,7 +153,7 @@ enum class Activation {
|
|
GELU,
|
|
};
|
|
|
|
-#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
|
|
+#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && defined(HIPBLASLT))
|
|
cuda::blas::GEMMAndBiasActivationEpilogue activation_to_gemm_and_blas_arg(Activation a) {
|
|
switch (a) {
|
|
case Activation::None:
|
|
@@ -191,6 +191,7 @@ static bool getDisableAddmmCudaLt() {
|
|
|
|
#ifdef USE_ROCM
|
|
static bool isSupportedHipLtROCmArch(int index) {
|
|
+#if defined(HIPBLASLT)
|
|
hipDeviceProp_t* prop = at::cuda::getDeviceProperties(index);
|
|
std::string device_arch = prop->gcnArchName;
|
|
static const std::vector<std::string> archs = {"gfx90a", "gfx940", "gfx941", "gfx942"};
|
|
@@ -201,6 +202,7 @@ static bool isSupportedHipLtROCmArch(int index) {
|
|
}
|
|
}
|
|
TORCH_CHECK(false, "Attempting to use hipBLASLt on a unsupported architecture!");
|
|
+#endif
|
|
return false;
|
|
}
|
|
#endif
|
|
@@ -226,7 +228,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
|
|
at::ScalarType scalar_type = self.scalar_type();
|
|
c10::MaybeOwned<Tensor> self_;
|
|
if (&result != &self) {
|
|
-#if (defined(CUDA_VERSION) && CUDA_VERSION >= 11040 && !defined(_MSC_VER)) || defined(USE_ROCM) && ROCM_VERSION >= 50700
|
|
+#if (defined(CUDA_VERSION) && CUDA_VERSION >= 11040 && !defined(_MSC_VER)) || defined(USE_ROCM) && defined(HIPBLASLT)
|
|
// Strangely, if mat2 has only 1 row or column, we get
|
|
// CUBLAS_STATUS_INVALID_VALUE error from cublasLtMatmulAlgoGetHeuristic.
|
|
// self.dim() == 1 && result.dim() == 2 && self.sizes()[0] == mat2_sizes[1]
|
|
@@ -269,7 +271,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
|
|
}
|
|
self__sizes = self_->sizes();
|
|
} else {
|
|
-#if defined(USE_ROCM) && ROCM_VERSION >= 50700
|
|
+#if defined(USE_ROCM) && defined(HIPBLASLT)
|
|
useLtInterface = !disable_addmm_cuda_lt &&
|
|
result.dim() == 2 && result.is_contiguous() &&
|
|
isSupportedHipLtROCmArch(self.device().index()) &&
|
|
@@ -320,7 +322,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
|
|
|
|
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!args.result->is_conj());
|
|
|
|
-#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
|
|
+#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && defined(HIPBLASLT))
|
|
if (useLtInterface) {
|
|
AT_DISPATCH_FLOATING_TYPES_AND2(
|
|
at::ScalarType::Half,
|
|
diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake
|
|
index b7ffbeb07dc..2b6c3678984 100644
|
|
--- a/cmake/Dependencies.cmake
|
|
+++ b/cmake/Dependencies.cmake
|
|
@@ -1273,6 +1273,9 @@ if(USE_ROCM)
|
|
if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "6.0.0")
|
|
list(APPEND HIP_CXX_FLAGS -DHIPBLAS_V2)
|
|
endif()
|
|
+ if(hipblast_FOUND)
|
|
+ list(APPEND HIP_CXX_FLAGS -DHIPBLASLT)
|
|
+ endif()
|
|
if(HIPBLASLT_CUSTOM_DATA_TYPE)
|
|
list(APPEND HIP_CXX_FLAGS -DHIPBLASLT_CUSTOM_DATA_TYPE)
|
|
endif()
|
|
diff --git a/cmake/public/LoadHIP.cmake b/cmake/public/LoadHIP.cmake
|
|
index f6ca263c5e5..53eb0b63c1a 100644
|
|
--- a/cmake/public/LoadHIP.cmake
|
|
+++ b/cmake/public/LoadHIP.cmake
|
|
@@ -156,7 +156,7 @@ if(HIP_FOUND)
|
|
find_package_and_print_version(rocblas REQUIRED)
|
|
find_package_and_print_version(hipblas REQUIRED)
|
|
if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "5.7.0")
|
|
- find_package_and_print_version(hipblaslt REQUIRED)
|
|
+ find_package_and_print_version(hipblaslt)
|
|
endif()
|
|
find_package_and_print_version(miopen REQUIRED)
|
|
if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "4.1.0")
|
|
@@ -191,7 +191,7 @@ if(HIP_FOUND)
|
|
# roctx is part of roctracer
|
|
find_library(ROCM_ROCTX_LIB roctx64 HINTS ${ROCM_PATH}/lib)
|
|
|
|
- if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "5.7.0")
|
|
+ if(hipblastlt_FOUND)
|
|
# check whether hipblaslt is using its own datatype
|
|
set(file "${PROJECT_BINARY_DIR}/hipblaslt_test_data_type.cc")
|
|
file(WRITE ${file} ""
|
|
--
|
|
2.43.2
|
|
|