Update hipblaslt patch
Upstream add some more things to change Signed-off-by: Tom Rix <trix@redhat.com>
This commit is contained in:
parent
bbdb0aa112
commit
b2eaedd0c2
2 changed files with 65 additions and 20 deletions
|
|
@ -1,4 +1,4 @@
|
||||||
From 186728c9f4de720547fa3d1b7951c7d5563ad3c7 Mon Sep 17 00:00:00 2001
|
From d77e05d90df006322cda021f1a8affdcc2c7eaef Mon Sep 17 00:00:00 2001
|
||||||
From: Tom Rix <trix@redhat.com>
|
From: Tom Rix <trix@redhat.com>
|
||||||
Date: Fri, 23 Feb 2024 08:27:30 -0500
|
Date: Fri, 23 Feb 2024 08:27:30 -0500
|
||||||
Subject: [PATCH] Optionally use hipblaslt
|
Subject: [PATCH] Optionally use hipblaslt
|
||||||
|
|
@ -10,21 +10,30 @@ Convert the checks for ROCM_VERSION >= 507000 to HIPBLASLT checks
|
||||||
|
|
||||||
Signed-off-by: Tom Rix <trix@redhat.com>
|
Signed-off-by: Tom Rix <trix@redhat.com>
|
||||||
---
|
---
|
||||||
aten/src/ATen/cuda/CUDABlas.cpp | 5 +++--
|
aten/src/ATen/cuda/CUDABlas.cpp | 7 ++++---
|
||||||
aten/src/ATen/cuda/CUDABlas.h | 2 +-
|
aten/src/ATen/cuda/CUDABlas.h | 2 +-
|
||||||
aten/src/ATen/cuda/CUDAContextLight.h | 4 ++--
|
aten/src/ATen/cuda/CUDAContextLight.h | 4 ++--
|
||||||
aten/src/ATen/cuda/CublasHandlePool.cpp | 4 ++--
|
aten/src/ATen/cuda/CublasHandlePool.cpp | 4 ++--
|
||||||
aten/src/ATen/cuda/tunable/TunableGemm.h | 6 +++---
|
aten/src/ATen/cuda/tunable/TunableGemm.h | 6 +++---
|
||||||
aten/src/ATen/native/cuda/Blas.cpp | 10 ++++++----
|
aten/src/ATen/native/cuda/Blas.cpp | 14 ++++++++------
|
||||||
cmake/Dependencies.cmake | 3 +++
|
cmake/Dependencies.cmake | 3 +++
|
||||||
cmake/public/LoadHIP.cmake | 4 ++--
|
cmake/public/LoadHIP.cmake | 4 ++--
|
||||||
8 files changed, 22 insertions(+), 16 deletions(-)
|
8 files changed, 25 insertions(+), 19 deletions(-)
|
||||||
|
|
||||||
diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp
|
diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp
|
||||||
index 0a3de5f9d77..b67031e88a7 100644
|
index d534ec5a178..e815463f630 100644
|
||||||
--- a/aten/src/ATen/cuda/CUDABlas.cpp
|
--- a/aten/src/ATen/cuda/CUDABlas.cpp
|
||||||
+++ b/aten/src/ATen/cuda/CUDABlas.cpp
|
+++ b/aten/src/ATen/cuda/CUDABlas.cpp
|
||||||
@@ -778,7 +778,7 @@ void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
|
@@ -14,7 +14,7 @@
|
||||||
|
#include <c10/util/irange.h>
|
||||||
|
|
||||||
|
#ifdef USE_ROCM
|
||||||
|
-#if ROCM_VERSION >= 60000
|
||||||
|
+#ifdef HIPBLASLT
|
||||||
|
#include <hipblaslt/hipblaslt-ext.hpp>
|
||||||
|
#endif
|
||||||
|
// until hipblas has an API to accept flags, we must use rocblas here
|
||||||
|
@@ -781,7 +781,7 @@ void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -33,7 +42,7 @@ index 0a3de5f9d77..b67031e88a7 100644
|
||||||
|
|
||||||
#if defined(USE_ROCM) && ROCM_VERSION >= 50700 && ROCM_VERSION < 60000
|
#if defined(USE_ROCM) && ROCM_VERSION >= 50700 && ROCM_VERSION < 60000
|
||||||
// only for rocm 5.7 where we first supported hipblaslt, it was difficult
|
// only for rocm 5.7 where we first supported hipblaslt, it was difficult
|
||||||
@@ -909,6 +909,7 @@ class CuBlasLtMatmulPreference : public CuBlasLtDescriptor<
|
@@ -912,6 +912,7 @@ class CuBlasLtMatmulPreference : public CuBlasLtDescriptor<
|
||||||
};
|
};
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
|
@ -41,7 +50,7 @@ index 0a3de5f9d77..b67031e88a7 100644
|
||||||
template <typename Dtype>
|
template <typename Dtype>
|
||||||
void gemm_and_bias(
|
void gemm_and_bias(
|
||||||
bool transpose_mat1,
|
bool transpose_mat1,
|
||||||
@@ -1121,7 +1122,7 @@ template void gemm_and_bias(
|
@@ -1124,7 +1125,7 @@ template void gemm_and_bias(
|
||||||
at::BFloat16* result_ptr,
|
at::BFloat16* result_ptr,
|
||||||
int64_t result_ld,
|
int64_t result_ld,
|
||||||
GEMMAndBiasActivationEpilogue activation);
|
GEMMAndBiasActivationEpilogue activation);
|
||||||
|
|
@ -86,10 +95,10 @@ index 4ec35f59a21..e28dc42034f 100644
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
diff --git a/aten/src/ATen/cuda/CublasHandlePool.cpp b/aten/src/ATen/cuda/CublasHandlePool.cpp
|
diff --git a/aten/src/ATen/cuda/CublasHandlePool.cpp b/aten/src/ATen/cuda/CublasHandlePool.cpp
|
||||||
index d8ee09b1486..4510193a5cc 100644
|
index 6913d2cd95e..3d4276be372 100644
|
||||||
--- a/aten/src/ATen/cuda/CublasHandlePool.cpp
|
--- a/aten/src/ATen/cuda/CublasHandlePool.cpp
|
||||||
+++ b/aten/src/ATen/cuda/CublasHandlePool.cpp
|
+++ b/aten/src/ATen/cuda/CublasHandlePool.cpp
|
||||||
@@ -28,7 +28,7 @@ namespace at::cuda {
|
@@ -29,7 +29,7 @@ namespace at::cuda {
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
|
@ -98,7 +107,7 @@ index d8ee09b1486..4510193a5cc 100644
|
||||||
void createCublasLtHandle(cublasLtHandle_t *handle) {
|
void createCublasLtHandle(cublasLtHandle_t *handle) {
|
||||||
TORCH_CUDABLAS_CHECK(cublasLtCreate(handle));
|
TORCH_CUDABLAS_CHECK(cublasLtCreate(handle));
|
||||||
}
|
}
|
||||||
@@ -177,7 +177,7 @@ cublasHandle_t getCurrentCUDABlasHandle() {
|
@@ -190,7 +190,7 @@ cublasHandle_t getCurrentCUDABlasHandle() {
|
||||||
return handle;
|
return handle;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -139,10 +148,10 @@ index 3ba0d761277..dde1870cfbf 100644
|
||||||
if (env == nullptr || strcmp(env, "1") == 0) {
|
if (env == nullptr || strcmp(env, "1") == 0) {
|
||||||
// disallow tuning of hipblaslt with c10::complex
|
// disallow tuning of hipblaslt with c10::complex
|
||||||
diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp
|
diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp
|
||||||
index 2b7ff1b502b..9ae2ccbd16d 100644
|
index 29e5c5e3cf1..df56f3d7f1d 100644
|
||||||
--- a/aten/src/ATen/native/cuda/Blas.cpp
|
--- a/aten/src/ATen/native/cuda/Blas.cpp
|
||||||
+++ b/aten/src/ATen/native/cuda/Blas.cpp
|
+++ b/aten/src/ATen/native/cuda/Blas.cpp
|
||||||
@@ -153,7 +153,7 @@ enum class Activation {
|
@@ -155,7 +155,7 @@ enum class Activation {
|
||||||
GELU,
|
GELU,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -151,7 +160,7 @@ index 2b7ff1b502b..9ae2ccbd16d 100644
|
||||||
cuda::blas::GEMMAndBiasActivationEpilogue activation_to_gemm_and_blas_arg(Activation a) {
|
cuda::blas::GEMMAndBiasActivationEpilogue activation_to_gemm_and_blas_arg(Activation a) {
|
||||||
switch (a) {
|
switch (a) {
|
||||||
case Activation::None:
|
case Activation::None:
|
||||||
@@ -191,6 +191,7 @@ static bool getDisableAddmmCudaLt() {
|
@@ -193,6 +193,7 @@ static bool getDisableAddmmCudaLt() {
|
||||||
|
|
||||||
#ifdef USE_ROCM
|
#ifdef USE_ROCM
|
||||||
static bool isSupportedHipLtROCmArch(int index) {
|
static bool isSupportedHipLtROCmArch(int index) {
|
||||||
|
|
@ -159,7 +168,7 @@ index 2b7ff1b502b..9ae2ccbd16d 100644
|
||||||
hipDeviceProp_t* prop = at::cuda::getDeviceProperties(index);
|
hipDeviceProp_t* prop = at::cuda::getDeviceProperties(index);
|
||||||
std::string device_arch = prop->gcnArchName;
|
std::string device_arch = prop->gcnArchName;
|
||||||
static const std::vector<std::string> archs = {"gfx90a", "gfx940", "gfx941", "gfx942"};
|
static const std::vector<std::string> archs = {"gfx90a", "gfx940", "gfx941", "gfx942"};
|
||||||
@@ -201,6 +202,7 @@ static bool isSupportedHipLtROCmArch(int index) {
|
@@ -203,6 +204,7 @@ static bool isSupportedHipLtROCmArch(int index) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
TORCH_CHECK(false, "Attempting to use hipBLASLt on a unsupported architecture!");
|
TORCH_CHECK(false, "Attempting to use hipBLASLt on a unsupported architecture!");
|
||||||
|
|
@ -167,7 +176,7 @@ index 2b7ff1b502b..9ae2ccbd16d 100644
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
@@ -226,7 +228,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
|
@@ -228,7 +230,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
|
||||||
at::ScalarType scalar_type = self.scalar_type();
|
at::ScalarType scalar_type = self.scalar_type();
|
||||||
c10::MaybeOwned<Tensor> self_;
|
c10::MaybeOwned<Tensor> self_;
|
||||||
if (&result != &self) {
|
if (&result != &self) {
|
||||||
|
|
@ -176,7 +185,7 @@ index 2b7ff1b502b..9ae2ccbd16d 100644
|
||||||
// Strangely, if mat2 has only 1 row or column, we get
|
// Strangely, if mat2 has only 1 row or column, we get
|
||||||
// CUBLAS_STATUS_INVALID_VALUE error from cublasLtMatmulAlgoGetHeuristic.
|
// CUBLAS_STATUS_INVALID_VALUE error from cublasLtMatmulAlgoGetHeuristic.
|
||||||
// self.dim() == 1 && result.dim() == 2 && self.sizes()[0] == mat2_sizes[1]
|
// self.dim() == 1 && result.dim() == 2 && self.sizes()[0] == mat2_sizes[1]
|
||||||
@@ -269,7 +271,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
|
@@ -271,7 +273,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
|
||||||
}
|
}
|
||||||
self__sizes = self_->sizes();
|
self__sizes = self_->sizes();
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -185,7 +194,7 @@ index 2b7ff1b502b..9ae2ccbd16d 100644
|
||||||
useLtInterface = !disable_addmm_cuda_lt &&
|
useLtInterface = !disable_addmm_cuda_lt &&
|
||||||
result.dim() == 2 && result.is_contiguous() &&
|
result.dim() == 2 && result.is_contiguous() &&
|
||||||
isSupportedHipLtROCmArch(self.device().index()) &&
|
isSupportedHipLtROCmArch(self.device().index()) &&
|
||||||
@@ -320,7 +322,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
|
@@ -322,7 +324,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
|
||||||
|
|
||||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!args.result->is_conj());
|
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!args.result->is_conj());
|
||||||
|
|
||||||
|
|
@ -194,6 +203,24 @@ index 2b7ff1b502b..9ae2ccbd16d 100644
|
||||||
if (useLtInterface) {
|
if (useLtInterface) {
|
||||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||||
at::ScalarType::Half,
|
at::ScalarType::Half,
|
||||||
|
@@ -876,7 +878,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
|
||||||
|
at::native::resize_output(out, {mat1_sizes[0], mat2_sizes[1]});
|
||||||
|
at::native::resize_output(amax, {});
|
||||||
|
|
||||||
|
-#if !defined(USE_ROCM) && !defined(_MSC_VER) || (defined(USE_ROCM) && ROCM_VERSION >= 60000)
|
||||||
|
+#if !defined(USE_ROCM) && !defined(_MSC_VER) || (defined(USE_ROCM) && defined(HIPBLASLT))
|
||||||
|
cublasCommonArgs args(mat1, mat2, out);
|
||||||
|
const auto out_dtype_ = args.result->scalar_type();
|
||||||
|
TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt");
|
||||||
|
@@ -906,7 +908,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
|
||||||
|
TORCH_CHECK(false, "_scaled_mm_out_cuda is not compiled for this platform.");
|
||||||
|
#endif
|
||||||
|
|
||||||
|
-#if defined(USE_ROCM) && ROCM_VERSION >= 60000
|
||||||
|
+#if defined(USE_ROCM) && defined(HIPBLASLT)
|
||||||
|
// rocm's hipblaslt does not yet support amax, so calculate separately
|
||||||
|
auto out_float32 = out.to(kFloat);
|
||||||
|
out_float32.abs_();
|
||||||
diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake
|
diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake
|
||||||
index b7ffbeb07dc..2b6c3678984 100644
|
index b7ffbeb07dc..2b6c3678984 100644
|
||||||
--- a/cmake/Dependencies.cmake
|
--- a/cmake/Dependencies.cmake
|
||||||
|
|
|
||||||
|
|
@ -6,8 +6,8 @@
|
||||||
# So pre releases can be tried
|
# So pre releases can be tried
|
||||||
%bcond_without gitcommit
|
%bcond_without gitcommit
|
||||||
%if %{with gitcommit}
|
%if %{with gitcommit}
|
||||||
# The top of tree ~2/22/24
|
# The top of tree ~2/28/24
|
||||||
%global commit0 5c5b71b6eebae76d744261715231093e62f0d090
|
%global commit0 3cfed0122829540444911c271ce5480832ea3526
|
||||||
%global shortcommit0 %(c=%{commit0}; echo ${c:0:7})
|
%global shortcommit0 %(c=%{commit0}; echo ${c:0:7})
|
||||||
|
|
||||||
%global pypi_version 2.3.0
|
%global pypi_version 2.3.0
|
||||||
|
|
@ -40,12 +40,18 @@
|
||||||
# For testing openmp
|
# For testing openmp
|
||||||
%bcond_without openmp
|
%bcond_without openmp
|
||||||
|
|
||||||
|
%bcond_with foxi
|
||||||
# For testing caffe2
|
# For testing caffe2
|
||||||
%if 0%{?fedora}
|
%if 0%{?fedora}
|
||||||
|
# need foxi-devel
|
||||||
|
%if %{with foxi}
|
||||||
%bcond_without caffe2
|
%bcond_without caffe2
|
||||||
%else
|
%else
|
||||||
%bcond_with caffe2
|
%bcond_with caffe2
|
||||||
%endif
|
%endif
|
||||||
|
%else
|
||||||
|
%bcond_with caffe2
|
||||||
|
%endif
|
||||||
|
|
||||||
# For testing distributed
|
# For testing distributed
|
||||||
%bcond_with distributed
|
%bcond_with distributed
|
||||||
|
|
@ -182,23 +188,33 @@ BuildRequires: python3dist(fsspec)
|
||||||
BuildRequires: python3dist(sympy)
|
BuildRequires: python3dist(sympy)
|
||||||
%endif
|
%endif
|
||||||
|
|
||||||
|
%if %{with openmp}
|
||||||
|
BuildRequires: libomp-devel
|
||||||
|
%endif
|
||||||
|
|
||||||
%if %{with rocm}
|
%if %{with rocm}
|
||||||
|
BuildRequires: compiler-rt
|
||||||
BuildRequires: hipblas-devel
|
BuildRequires: hipblas-devel
|
||||||
%if %{with hipblaslt}
|
%if %{with hipblaslt}
|
||||||
BuildRequires: hipblaslt-devel
|
BuildRequires: hipblaslt-devel
|
||||||
%endif
|
%endif
|
||||||
BuildRequires: hipcub-devel
|
BuildRequires: hipcub-devel
|
||||||
BuildRequires: hipfft-devel
|
BuildRequires: hipfft-devel
|
||||||
|
BuildRequires: hiprand-devel
|
||||||
BuildRequires: hipsparse-devel
|
BuildRequires: hipsparse-devel
|
||||||
BuildRequires: hipsolver-devel
|
BuildRequires: hipsolver-devel
|
||||||
|
BuildRequires: lld
|
||||||
BuildRequires: miopen-devel
|
BuildRequires: miopen-devel
|
||||||
BuildRequires: rocblas-devel
|
BuildRequires: rocblas-devel
|
||||||
|
BuildRequires: rocrand-devel
|
||||||
|
BuildRequires: rocfft-devel
|
||||||
%if %{with distributed}
|
%if %{with distributed}
|
||||||
BuildRequires: rccl-devel
|
BuildRequires: rccl-devel
|
||||||
%endif
|
%endif
|
||||||
BuildRequires: rocprim-devel
|
BuildRequires: rocprim-devel
|
||||||
BuildRequires: rocm-cmake
|
BuildRequires: rocm-cmake
|
||||||
BuildRequires: rocm-comgr-devel
|
BuildRequires: rocm-comgr-devel
|
||||||
|
BuildRequires: rocm-core-devel
|
||||||
BuildRequires: rocm-hip-devel
|
BuildRequires: rocm-hip-devel
|
||||||
BuildRequires: rocm-runtime-devel
|
BuildRequires: rocm-runtime-devel
|
||||||
BuildRequires: rocm-rpm-macros
|
BuildRequires: rocm-rpm-macros
|
||||||
|
|
@ -208,9 +224,11 @@ BuildRequires: rocthrust-devel
|
||||||
Requires: rocm-rpm-macros-modules
|
Requires: rocm-rpm-macros-modules
|
||||||
%endif
|
%endif
|
||||||
|
|
||||||
|
%if %{with foxi}
|
||||||
%if %{with caffe2}
|
%if %{with caffe2}
|
||||||
BuildRequires: foxi-devel
|
BuildRequires: foxi-devel
|
||||||
%endif
|
%endif
|
||||||
|
%endif
|
||||||
|
|
||||||
%if %{with test}
|
%if %{with test}
|
||||||
BuildRequires: google-benchmark-devel
|
BuildRequires: google-benchmark-devel
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue