Update hipblaslt patch

Upstream add some more things to change

Signed-off-by: Tom Rix <trix@redhat.com>
This commit is contained in:
Tom Rix 2024-03-03 07:31:08 -05:00
commit b2eaedd0c2
2 changed files with 65 additions and 20 deletions

View file

@ -1,4 +1,4 @@
From 186728c9f4de720547fa3d1b7951c7d5563ad3c7 Mon Sep 17 00:00:00 2001 From d77e05d90df006322cda021f1a8affdcc2c7eaef Mon Sep 17 00:00:00 2001
From: Tom Rix <trix@redhat.com> From: Tom Rix <trix@redhat.com>
Date: Fri, 23 Feb 2024 08:27:30 -0500 Date: Fri, 23 Feb 2024 08:27:30 -0500
Subject: [PATCH] Optionally use hipblaslt Subject: [PATCH] Optionally use hipblaslt
@ -10,21 +10,30 @@ Convert the checks for ROCM_VERSION >= 507000 to HIPBLASLT checks
Signed-off-by: Tom Rix <trix@redhat.com> Signed-off-by: Tom Rix <trix@redhat.com>
--- ---
aten/src/ATen/cuda/CUDABlas.cpp | 5 +++-- aten/src/ATen/cuda/CUDABlas.cpp | 7 ++++---
aten/src/ATen/cuda/CUDABlas.h | 2 +- aten/src/ATen/cuda/CUDABlas.h | 2 +-
aten/src/ATen/cuda/CUDAContextLight.h | 4 ++-- aten/src/ATen/cuda/CUDAContextLight.h | 4 ++--
aten/src/ATen/cuda/CublasHandlePool.cpp | 4 ++-- aten/src/ATen/cuda/CublasHandlePool.cpp | 4 ++--
aten/src/ATen/cuda/tunable/TunableGemm.h | 6 +++--- aten/src/ATen/cuda/tunable/TunableGemm.h | 6 +++---
aten/src/ATen/native/cuda/Blas.cpp | 10 ++++++---- aten/src/ATen/native/cuda/Blas.cpp | 14 ++++++++------
cmake/Dependencies.cmake | 3 +++ cmake/Dependencies.cmake | 3 +++
cmake/public/LoadHIP.cmake | 4 ++-- cmake/public/LoadHIP.cmake | 4 ++--
8 files changed, 22 insertions(+), 16 deletions(-) 8 files changed, 25 insertions(+), 19 deletions(-)
diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp
index 0a3de5f9d77..b67031e88a7 100644 index d534ec5a178..e815463f630 100644
--- a/aten/src/ATen/cuda/CUDABlas.cpp --- a/aten/src/ATen/cuda/CUDABlas.cpp
+++ b/aten/src/ATen/cuda/CUDABlas.cpp +++ b/aten/src/ATen/cuda/CUDABlas.cpp
@@ -778,7 +778,7 @@ void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) { @@ -14,7 +14,7 @@
#include <c10/util/irange.h>
#ifdef USE_ROCM
-#if ROCM_VERSION >= 60000
+#ifdef HIPBLASLT
#include <hipblaslt/hipblaslt-ext.hpp>
#endif
// until hipblas has an API to accept flags, we must use rocblas here
@@ -781,7 +781,7 @@ void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
} }
} }
@ -33,7 +42,7 @@ index 0a3de5f9d77..b67031e88a7 100644
#if defined(USE_ROCM) && ROCM_VERSION >= 50700 && ROCM_VERSION < 60000 #if defined(USE_ROCM) && ROCM_VERSION >= 50700 && ROCM_VERSION < 60000
// only for rocm 5.7 where we first supported hipblaslt, it was difficult // only for rocm 5.7 where we first supported hipblaslt, it was difficult
@@ -909,6 +909,7 @@ class CuBlasLtMatmulPreference : public CuBlasLtDescriptor< @@ -912,6 +912,7 @@ class CuBlasLtMatmulPreference : public CuBlasLtDescriptor<
}; };
} // namespace } // namespace
@ -41,7 +50,7 @@ index 0a3de5f9d77..b67031e88a7 100644
template <typename Dtype> template <typename Dtype>
void gemm_and_bias( void gemm_and_bias(
bool transpose_mat1, bool transpose_mat1,
@@ -1121,7 +1122,7 @@ template void gemm_and_bias( @@ -1124,7 +1125,7 @@ template void gemm_and_bias(
at::BFloat16* result_ptr, at::BFloat16* result_ptr,
int64_t result_ld, int64_t result_ld,
GEMMAndBiasActivationEpilogue activation); GEMMAndBiasActivationEpilogue activation);
@ -86,10 +95,10 @@ index 4ec35f59a21..e28dc42034f 100644
#endif #endif
diff --git a/aten/src/ATen/cuda/CublasHandlePool.cpp b/aten/src/ATen/cuda/CublasHandlePool.cpp diff --git a/aten/src/ATen/cuda/CublasHandlePool.cpp b/aten/src/ATen/cuda/CublasHandlePool.cpp
index d8ee09b1486..4510193a5cc 100644 index 6913d2cd95e..3d4276be372 100644
--- a/aten/src/ATen/cuda/CublasHandlePool.cpp --- a/aten/src/ATen/cuda/CublasHandlePool.cpp
+++ b/aten/src/ATen/cuda/CublasHandlePool.cpp +++ b/aten/src/ATen/cuda/CublasHandlePool.cpp
@@ -28,7 +28,7 @@ namespace at::cuda { @@ -29,7 +29,7 @@ namespace at::cuda {
namespace { namespace {
@ -98,7 +107,7 @@ index d8ee09b1486..4510193a5cc 100644
void createCublasLtHandle(cublasLtHandle_t *handle) { void createCublasLtHandle(cublasLtHandle_t *handle) {
TORCH_CUDABLAS_CHECK(cublasLtCreate(handle)); TORCH_CUDABLAS_CHECK(cublasLtCreate(handle));
} }
@@ -177,7 +177,7 @@ cublasHandle_t getCurrentCUDABlasHandle() { @@ -190,7 +190,7 @@ cublasHandle_t getCurrentCUDABlasHandle() {
return handle; return handle;
} }
@ -139,10 +148,10 @@ index 3ba0d761277..dde1870cfbf 100644
if (env == nullptr || strcmp(env, "1") == 0) { if (env == nullptr || strcmp(env, "1") == 0) {
// disallow tuning of hipblaslt with c10::complex // disallow tuning of hipblaslt with c10::complex
diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp
index 2b7ff1b502b..9ae2ccbd16d 100644 index 29e5c5e3cf1..df56f3d7f1d 100644
--- a/aten/src/ATen/native/cuda/Blas.cpp --- a/aten/src/ATen/native/cuda/Blas.cpp
+++ b/aten/src/ATen/native/cuda/Blas.cpp +++ b/aten/src/ATen/native/cuda/Blas.cpp
@@ -153,7 +153,7 @@ enum class Activation { @@ -155,7 +155,7 @@ enum class Activation {
GELU, GELU,
}; };
@ -151,7 +160,7 @@ index 2b7ff1b502b..9ae2ccbd16d 100644
cuda::blas::GEMMAndBiasActivationEpilogue activation_to_gemm_and_blas_arg(Activation a) { cuda::blas::GEMMAndBiasActivationEpilogue activation_to_gemm_and_blas_arg(Activation a) {
switch (a) { switch (a) {
case Activation::None: case Activation::None:
@@ -191,6 +191,7 @@ static bool getDisableAddmmCudaLt() { @@ -193,6 +193,7 @@ static bool getDisableAddmmCudaLt() {
#ifdef USE_ROCM #ifdef USE_ROCM
static bool isSupportedHipLtROCmArch(int index) { static bool isSupportedHipLtROCmArch(int index) {
@ -159,7 +168,7 @@ index 2b7ff1b502b..9ae2ccbd16d 100644
hipDeviceProp_t* prop = at::cuda::getDeviceProperties(index); hipDeviceProp_t* prop = at::cuda::getDeviceProperties(index);
std::string device_arch = prop->gcnArchName; std::string device_arch = prop->gcnArchName;
static const std::vector<std::string> archs = {"gfx90a", "gfx940", "gfx941", "gfx942"}; static const std::vector<std::string> archs = {"gfx90a", "gfx940", "gfx941", "gfx942"};
@@ -201,6 +202,7 @@ static bool isSupportedHipLtROCmArch(int index) { @@ -203,6 +204,7 @@ static bool isSupportedHipLtROCmArch(int index) {
} }
} }
TORCH_CHECK(false, "Attempting to use hipBLASLt on a unsupported architecture!"); TORCH_CHECK(false, "Attempting to use hipBLASLt on a unsupported architecture!");
@ -167,7 +176,7 @@ index 2b7ff1b502b..9ae2ccbd16d 100644
return false; return false;
} }
#endif #endif
@@ -226,7 +228,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma @@ -228,7 +230,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
at::ScalarType scalar_type = self.scalar_type(); at::ScalarType scalar_type = self.scalar_type();
c10::MaybeOwned<Tensor> self_; c10::MaybeOwned<Tensor> self_;
if (&result != &self) { if (&result != &self) {
@ -176,7 +185,7 @@ index 2b7ff1b502b..9ae2ccbd16d 100644
// Strangely, if mat2 has only 1 row or column, we get // Strangely, if mat2 has only 1 row or column, we get
// CUBLAS_STATUS_INVALID_VALUE error from cublasLtMatmulAlgoGetHeuristic. // CUBLAS_STATUS_INVALID_VALUE error from cublasLtMatmulAlgoGetHeuristic.
// self.dim() == 1 && result.dim() == 2 && self.sizes()[0] == mat2_sizes[1] // self.dim() == 1 && result.dim() == 2 && self.sizes()[0] == mat2_sizes[1]
@@ -269,7 +271,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma @@ -271,7 +273,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
} }
self__sizes = self_->sizes(); self__sizes = self_->sizes();
} else { } else {
@ -185,7 +194,7 @@ index 2b7ff1b502b..9ae2ccbd16d 100644
useLtInterface = !disable_addmm_cuda_lt && useLtInterface = !disable_addmm_cuda_lt &&
result.dim() == 2 && result.is_contiguous() && result.dim() == 2 && result.is_contiguous() &&
isSupportedHipLtROCmArch(self.device().index()) && isSupportedHipLtROCmArch(self.device().index()) &&
@@ -320,7 +322,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma @@ -322,7 +324,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!args.result->is_conj()); TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!args.result->is_conj());
@ -194,6 +203,24 @@ index 2b7ff1b502b..9ae2ccbd16d 100644
if (useLtInterface) { if (useLtInterface) {
AT_DISPATCH_FLOATING_TYPES_AND2( AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half, at::ScalarType::Half,
@@ -876,7 +878,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
at::native::resize_output(out, {mat1_sizes[0], mat2_sizes[1]});
at::native::resize_output(amax, {});
-#if !defined(USE_ROCM) && !defined(_MSC_VER) || (defined(USE_ROCM) && ROCM_VERSION >= 60000)
+#if !defined(USE_ROCM) && !defined(_MSC_VER) || (defined(USE_ROCM) && defined(HIPBLASLT))
cublasCommonArgs args(mat1, mat2, out);
const auto out_dtype_ = args.result->scalar_type();
TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt");
@@ -906,7 +908,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
TORCH_CHECK(false, "_scaled_mm_out_cuda is not compiled for this platform.");
#endif
-#if defined(USE_ROCM) && ROCM_VERSION >= 60000
+#if defined(USE_ROCM) && defined(HIPBLASLT)
// rocm's hipblaslt does not yet support amax, so calculate separately
auto out_float32 = out.to(kFloat);
out_float32.abs_();
diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake
index b7ffbeb07dc..2b6c3678984 100644 index b7ffbeb07dc..2b6c3678984 100644
--- a/cmake/Dependencies.cmake --- a/cmake/Dependencies.cmake

View file

@ -6,8 +6,8 @@
# So pre releases can be tried # So pre releases can be tried
%bcond_without gitcommit %bcond_without gitcommit
%if %{with gitcommit} %if %{with gitcommit}
# The top of tree ~2/22/24 # The top of tree ~2/28/24
%global commit0 5c5b71b6eebae76d744261715231093e62f0d090 %global commit0 3cfed0122829540444911c271ce5480832ea3526
%global shortcommit0 %(c=%{commit0}; echo ${c:0:7}) %global shortcommit0 %(c=%{commit0}; echo ${c:0:7})
%global pypi_version 2.3.0 %global pypi_version 2.3.0
@ -40,12 +40,18 @@
# For testing openmp # For testing openmp
%bcond_without openmp %bcond_without openmp
%bcond_with foxi
# For testing caffe2 # For testing caffe2
%if 0%{?fedora} %if 0%{?fedora}
# need foxi-devel
%if %{with foxi}
%bcond_without caffe2 %bcond_without caffe2
%else %else
%bcond_with caffe2 %bcond_with caffe2
%endif %endif
%else
%bcond_with caffe2
%endif
# For testing distributed # For testing distributed
%bcond_with distributed %bcond_with distributed
@ -182,23 +188,33 @@ BuildRequires: python3dist(fsspec)
BuildRequires: python3dist(sympy) BuildRequires: python3dist(sympy)
%endif %endif
%if %{with openmp}
BuildRequires: libomp-devel
%endif
%if %{with rocm} %if %{with rocm}
BuildRequires: compiler-rt
BuildRequires: hipblas-devel BuildRequires: hipblas-devel
%if %{with hipblaslt} %if %{with hipblaslt}
BuildRequires: hipblaslt-devel BuildRequires: hipblaslt-devel
%endif %endif
BuildRequires: hipcub-devel BuildRequires: hipcub-devel
BuildRequires: hipfft-devel BuildRequires: hipfft-devel
BuildRequires: hiprand-devel
BuildRequires: hipsparse-devel BuildRequires: hipsparse-devel
BuildRequires: hipsolver-devel BuildRequires: hipsolver-devel
BuildRequires: lld
BuildRequires: miopen-devel BuildRequires: miopen-devel
BuildRequires: rocblas-devel BuildRequires: rocblas-devel
BuildRequires: rocrand-devel
BuildRequires: rocfft-devel
%if %{with distributed} %if %{with distributed}
BuildRequires: rccl-devel BuildRequires: rccl-devel
%endif %endif
BuildRequires: rocprim-devel BuildRequires: rocprim-devel
BuildRequires: rocm-cmake BuildRequires: rocm-cmake
BuildRequires: rocm-comgr-devel BuildRequires: rocm-comgr-devel
BuildRequires: rocm-core-devel
BuildRequires: rocm-hip-devel BuildRequires: rocm-hip-devel
BuildRequires: rocm-runtime-devel BuildRequires: rocm-runtime-devel
BuildRequires: rocm-rpm-macros BuildRequires: rocm-rpm-macros
@ -208,9 +224,11 @@ BuildRequires: rocthrust-devel
Requires: rocm-rpm-macros-modules Requires: rocm-rpm-macros-modules
%endif %endif
%if %{with foxi}
%if %{with caffe2} %if %{with caffe2}
BuildRequires: foxi-devel BuildRequires: foxi-devel
%endif %endif
%endif
%if %{with test} %if %{with test}
BuildRequires: google-benchmark-devel BuildRequires: google-benchmark-devel