PyTorch 2.4
Signed-off-by: Tom Rix <trix@redhat.com>
This commit is contained in:
parent
2debc89ffd
commit
86185b46a2
7 changed files with 1468 additions and 273 deletions
|
|
@ -1,174 +1,398 @@
|
|||
From d77e05d90df006322cda021f1a8affdcc2c7eaef Mon Sep 17 00:00:00 2001
|
||||
From f1d65e958afa65882dbfea8b392ab847a84d41ed Mon Sep 17 00:00:00 2001
|
||||
From: Tom Rix <trix@redhat.com>
|
||||
Date: Fri, 23 Feb 2024 08:27:30 -0500
|
||||
Date: Sat, 29 Jun 2024 04:18:34 -0700
|
||||
Subject: [PATCH] Optionally use hipblaslt
|
||||
|
||||
The hipblaslt package is not available on Fedora.
|
||||
Instead of requiring the package, make it optional.
|
||||
If it is found, define the preprocessor variable HIPBLASLT
|
||||
Convert the checks for ROCM_VERSION >= 507000 to HIPBLASLT checks
|
||||
|
||||
Signed-off-by: Tom Rix <trix@redhat.com>
|
||||
---
|
||||
aten/src/ATen/cuda/CUDABlas.cpp | 7 ++++---
|
||||
aten/src/ATen/cuda/CUDABlas.h | 2 +-
|
||||
aten/src/ATen/cuda/CUDAContextLight.h | 4 ++--
|
||||
aten/src/ATen/cuda/CublasHandlePool.cpp | 4 ++--
|
||||
aten/src/ATen/cuda/tunable/TunableGemm.h | 6 +++---
|
||||
aten/src/ATen/native/cuda/Blas.cpp | 14 ++++++++------
|
||||
cmake/Dependencies.cmake | 3 +++
|
||||
cmake/public/LoadHIP.cmake | 4 ++--
|
||||
8 files changed, 25 insertions(+), 19 deletions(-)
|
||||
aten/src/ATen/cuda/CUDABlas.cpp | 46 ++++++++++++++++++------
|
||||
aten/src/ATen/cuda/CUDAContextLight.h | 4 +++
|
||||
aten/src/ATen/cuda/CublasHandlePool.cpp | 10 ++++--
|
||||
aten/src/ATen/cuda/tunable/TunableGemm.h | 18 +++++++---
|
||||
aten/src/ATen/native/cuda/Blas.cpp | 18 +++++++++-
|
||||
cmake/Dependencies.cmake | 3 ++
|
||||
cmake/public/LoadHIP.cmake | 2 +-
|
||||
7 files changed, 82 insertions(+), 19 deletions(-)
|
||||
|
||||
diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp
|
||||
index d534ec5a178..e815463f630 100644
|
||||
index ce991a9bcad4..3f0d17b52778 100644
|
||||
--- a/aten/src/ATen/cuda/CUDABlas.cpp
|
||||
+++ b/aten/src/ATen/cuda/CUDABlas.cpp
|
||||
@@ -14,7 +14,7 @@
|
||||
@@ -14,7 +14,9 @@
|
||||
#include <c10/util/irange.h>
|
||||
|
||||
#ifdef USE_ROCM
|
||||
-#if ROCM_VERSION >= 60000
|
||||
+#ifdef HIPBLASLT
|
||||
+#ifdef USE_HIPBLASLT
|
||||
#include <hipblaslt/hipblaslt-ext.hpp>
|
||||
#endif
|
||||
+#endif
|
||||
// until hipblas has an API to accept flags, we must use rocblas here
|
||||
@@ -781,7 +781,7 @@ void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
|
||||
}
|
||||
}
|
||||
#include <hipblas/hipblas.h>
|
||||
#include <rocblas/rocblas.h>
|
||||
@@ -182,6 +184,9 @@ uint32_t _getAlignment(uintptr_t address) {
|
||||
static size_t _parseChosenWorkspaceSize() {
|
||||
const char * val = getenv("CUBLASLT_WORKSPACE_SIZE");
|
||||
#ifdef USE_ROCM
|
||||
+#ifndef USE_HIPBLASLT
|
||||
+ return 0;
|
||||
+#endif
|
||||
if (!val) {
|
||||
// accept either env var
|
||||
val = getenv("HIPBLASLT_WORKSPACE_SIZE");
|
||||
@@ -235,6 +240,7 @@ namespace at::cuda::blas {
|
||||
} while (0)
|
||||
|
||||
-#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
|
||||
+#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && defined(HIPBLASLT))
|
||||
|
||||
#if defined(USE_ROCM) && ROCM_VERSION >= 50700 && ROCM_VERSION < 60000
|
||||
// only for rocm 5.7 where we first supported hipblaslt, it was difficult
|
||||
@@ -912,6 +912,7 @@ class CuBlasLtMatmulPreference : public CuBlasLtDescriptor<
|
||||
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
|
||||
namespace {
|
||||
// Following the pattern of CuSparseDescriptor
|
||||
// Defined here for now because this is the only place cublas_lt interface is
|
||||
@@ -318,7 +324,6 @@ class CuBlasLtMatmulPreference : public CuBlasLtDescriptor<
|
||||
};
|
||||
} // namespace
|
||||
|
||||
+#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && defined(HIPBLASLT))
|
||||
-
|
||||
template <typename Dtype>
|
||||
inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
|
||||
cudaDataType_t abcType = CUDA_R_32F;
|
||||
@@ -452,7 +457,7 @@ inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
|
||||
" scaleType ",
|
||||
scaleType);
|
||||
}
|
||||
-
|
||||
+#endif
|
||||
|
||||
template <typename Dtype>
|
||||
inline void bgemm_internal_cublas(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
|
||||
@@ -608,10 +613,13 @@ void bgemm_internal<double>(CUDABLAS_BGEMM_ARGTYPES(double))
|
||||
template <>
|
||||
void bgemm_internal<float>(CUDABLAS_BGEMM_ARGTYPES(float))
|
||||
{
|
||||
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
|
||||
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
|
||||
bgemm_internal_cublaslt<float>(CUDABLAS_BGEMM_ARGS(float));
|
||||
}
|
||||
- else {
|
||||
+ else
|
||||
+#endif
|
||||
+ {
|
||||
bgemm_internal_cublas<float>(CUDABLAS_BGEMM_ARGS(float));
|
||||
}
|
||||
}
|
||||
@@ -651,10 +659,13 @@ void bgemm_internal<c10::complex<float>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<fl
|
||||
template <>
|
||||
void bgemm_internal<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half))
|
||||
{
|
||||
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
|
||||
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
|
||||
bgemm_internal_cublaslt<at::Half>(CUDABLAS_BGEMM_ARGS(at::Half));
|
||||
}
|
||||
- else {
|
||||
+ else
|
||||
+#endif
|
||||
+ {
|
||||
bgemm_internal_cublas<at::Half>(CUDABLAS_BGEMM_ARGS(at::Half));
|
||||
}
|
||||
}
|
||||
@@ -662,10 +673,13 @@ void bgemm_internal<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half))
|
||||
template <>
|
||||
void bgemm_internal<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16))
|
||||
{
|
||||
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
|
||||
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
|
||||
bgemm_internal_cublaslt<at::BFloat16>(CUDABLAS_BGEMM_ARGS(at::BFloat16));
|
||||
}
|
||||
- else {
|
||||
+ else
|
||||
+#endif
|
||||
+ {
|
||||
bgemm_internal_cublas<at::BFloat16>(CUDABLAS_BGEMM_ARGS(at::BFloat16));
|
||||
}
|
||||
}
|
||||
@@ -781,11 +795,13 @@ void bgemm<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) {
|
||||
}
|
||||
}
|
||||
|
||||
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
|
||||
template <typename Dtype>
|
||||
inline void gemm_internal_cublaslt(CUDABLAS_GEMM_ARGTYPES(Dtype)) {
|
||||
// forward to bgemm implementation but set strides and batches to 0
|
||||
bgemm_internal_cublaslt(transa, transb, m, n, k, alpha, a, lda, 0, b, ldb, 0, beta, c, ldc, 0, 0);
|
||||
}
|
||||
+#endif
|
||||
|
||||
template <typename Dtype>
|
||||
inline void gemm_internal_cublas(CUDABLAS_GEMM_ARGTYPES(Dtype)) {
|
||||
@@ -1008,10 +1024,13 @@ void gemm_internal<double>(CUDABLAS_GEMM_ARGTYPES(double))
|
||||
template <>
|
||||
void gemm_internal<float>(CUDABLAS_GEMM_ARGTYPES(float))
|
||||
{
|
||||
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
|
||||
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
|
||||
gemm_internal_cublaslt<float>(CUDABLAS_GEMM_ARGS(float));
|
||||
}
|
||||
- else {
|
||||
+ else
|
||||
+#endif
|
||||
+ {
|
||||
gemm_internal_cublas<float>(CUDABLAS_GEMM_ARGS(float));
|
||||
}
|
||||
}
|
||||
@@ -1051,10 +1070,13 @@ void gemm_internal<c10::complex<float>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<floa
|
||||
template <>
|
||||
void gemm_internal<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half))
|
||||
{
|
||||
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
|
||||
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
|
||||
gemm_internal_cublaslt<at::Half>(CUDABLAS_GEMM_ARGS(at::Half));
|
||||
}
|
||||
- else {
|
||||
+ else
|
||||
+#endif
|
||||
+ {
|
||||
gemm_internal_cublas<at::Half>(CUDABLAS_GEMM_ARGS(at::Half));
|
||||
}
|
||||
}
|
||||
@@ -1062,10 +1084,13 @@ void gemm_internal<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half))
|
||||
template <>
|
||||
void gemm_internal<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16))
|
||||
{
|
||||
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
|
||||
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
|
||||
gemm_internal_cublaslt<at::BFloat16>(CUDABLAS_GEMM_ARGS(at::BFloat16));
|
||||
}
|
||||
- else {
|
||||
+ else
|
||||
+#endif
|
||||
+ {
|
||||
gemm_internal_cublas<at::BFloat16>(CUDABLAS_GEMM_ARGS(at::BFloat16));
|
||||
}
|
||||
}
|
||||
@@ -1177,7 +1202,7 @@ void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
|
||||
}
|
||||
}
|
||||
|
||||
-
|
||||
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
|
||||
template <typename Dtype>
|
||||
void gemm_and_bias(
|
||||
bool transpose_mat1,
|
||||
@@ -1124,7 +1125,7 @@ template void gemm_and_bias(
|
||||
at::BFloat16* result_ptr,
|
||||
int64_t result_ld,
|
||||
GEMMAndBiasActivationEpilogue activation);
|
||||
-
|
||||
@@ -1410,7 +1435,7 @@ void scaled_gemm(
|
||||
ScalarType result_dtype,
|
||||
void* amax_ptr,
|
||||
bool use_fast_accum) {
|
||||
-#if CUDA_VERSION >= 11080 || defined(USE_ROCM)
|
||||
+#if CUDA_VERSION >= 11080 || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
|
||||
const auto computeType = CUBLAS_COMPUTE_32F;
|
||||
const auto scaleType = CUDA_R_32F;
|
||||
const int8_t fastAccuMode = use_fast_accum ? 1 : 0;
|
||||
@@ -1681,6 +1706,7 @@ void int8_gemm(
|
||||
" scaleType ",
|
||||
scaleType);
|
||||
}
|
||||
+#endif
|
||||
void scaled_gemm(
|
||||
char transa,
|
||||
char transb,
|
||||
diff --git a/aten/src/ATen/cuda/CUDABlas.h b/aten/src/ATen/cuda/CUDABlas.h
|
||||
index eb12bb350c5..068607467dd 100644
|
||||
--- a/aten/src/ATen/cuda/CUDABlas.h
|
||||
+++ b/aten/src/ATen/cuda/CUDABlas.h
|
||||
@@ -82,7 +82,7 @@ void gemm_internal<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half));
|
||||
template <>
|
||||
void gemm_internal<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16));
|
||||
|
||||
-#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
|
||||
+#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && defined(HIPBLASLT))
|
||||
enum GEMMAndBiasActivationEpilogue {
|
||||
None,
|
||||
RELU,
|
||||
template <>
|
||||
void trsm<float>(CUDABLAS_TRSM_ARGTYPES(float)) {
|
||||
diff --git a/aten/src/ATen/cuda/CUDAContextLight.h b/aten/src/ATen/cuda/CUDAContextLight.h
|
||||
index 4ec35f59a21..e28dc42034f 100644
|
||||
index f2b657ced51b..f0ee613c4208 100644
|
||||
--- a/aten/src/ATen/cuda/CUDAContextLight.h
|
||||
+++ b/aten/src/ATen/cuda/CUDAContextLight.h
|
||||
@@ -9,7 +9,7 @@
|
||||
@@ -9,7 +9,9 @@
|
||||
|
||||
// cublasLT was introduced in CUDA 10.1 but we enable only for 11.1 that also
|
||||
// added bf16 support
|
||||
-#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
|
||||
+#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && defined(HIPBLASLT))
|
||||
+#if (!defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT)))
|
||||
#include <cublasLt.h>
|
||||
#endif
|
||||
+#endif
|
||||
|
||||
@@ -82,7 +82,7 @@ TORCH_CUDA_CPP_API c10::Allocator* getCUDADeviceAllocator();
|
||||
#ifdef CUDART_VERSION
|
||||
#include <cusolverDn.h>
|
||||
@@ -80,7 +82,9 @@ TORCH_CUDA_CPP_API c10::Allocator* getCUDADeviceAllocator();
|
||||
/* Handles */
|
||||
TORCH_CUDA_CPP_API cusparseHandle_t getCurrentCUDASparseHandle();
|
||||
TORCH_CUDA_CPP_API cublasHandle_t getCurrentCUDABlasHandle();
|
||||
-#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
|
||||
+#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && defined(HIPBLASLT))
|
||||
+#if (!defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT)))
|
||||
TORCH_CUDA_CPP_API cublasLtHandle_t getCurrentCUDABlasLtHandle();
|
||||
#endif
|
||||
+#endif
|
||||
|
||||
TORCH_CUDA_CPP_API void clearCublasWorkspaces();
|
||||
|
||||
diff --git a/aten/src/ATen/cuda/CublasHandlePool.cpp b/aten/src/ATen/cuda/CublasHandlePool.cpp
|
||||
index 6913d2cd95e..3d4276be372 100644
|
||||
index 8eac525b3695..abfdf7a23847 100644
|
||||
--- a/aten/src/ATen/cuda/CublasHandlePool.cpp
|
||||
+++ b/aten/src/ATen/cuda/CublasHandlePool.cpp
|
||||
@@ -29,7 +29,7 @@ namespace at::cuda {
|
||||
|
||||
namespace {
|
||||
|
||||
-#if defined(USE_ROCM) && ROCM_VERSION >= 50700
|
||||
+#if defined(USE_ROCM) && defined(HIPBLASLT)
|
||||
-#if defined(USE_ROCM)
|
||||
+#if defined(USE_ROCM) && defined(USE_HIPBLASLT)
|
||||
void createCublasLtHandle(cublasLtHandle_t *handle) {
|
||||
TORCH_CUDABLAS_CHECK(cublasLtCreate(handle));
|
||||
}
|
||||
@@ -190,7 +190,7 @@ cublasHandle_t getCurrentCUDABlasHandle() {
|
||||
@@ -191,8 +191,9 @@ cublasHandle_t getCurrentCUDABlasHandle() {
|
||||
return handle;
|
||||
}
|
||||
|
||||
-#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
|
||||
+#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && defined(HIPBLASLT))
|
||||
cublasLtHandle_t getCurrentCUDABlasLtHandle() {
|
||||
-cublasLtHandle_t getCurrentCUDABlasLtHandle() {
|
||||
#ifdef USE_ROCM
|
||||
+#if defined(USE_HIPBLASLT)
|
||||
+cublasLtHandle_t getCurrentCUDABlasLtHandle() {
|
||||
c10::DeviceIndex device = 0;
|
||||
AT_CUDA_CHECK(c10::cuda::GetDevice(&device));
|
||||
|
||||
@@ -213,9 +214,12 @@ cublasLtHandle_t getCurrentCUDABlasLtHandle() {
|
||||
|
||||
auto handle = myPoolWindow->reserve(device);
|
||||
return handle;
|
||||
+}
|
||||
+#endif
|
||||
#else
|
||||
+cublasLtHandle_t getCurrentCUDABlasLtHandle() {
|
||||
return reinterpret_cast<cublasLtHandle_t>(getCurrentCUDABlasHandle());
|
||||
-#endif
|
||||
}
|
||||
+#endif
|
||||
|
||||
} // namespace at::cuda
|
||||
diff --git a/aten/src/ATen/cuda/tunable/TunableGemm.h b/aten/src/ATen/cuda/tunable/TunableGemm.h
|
||||
index 3ba0d761277..dde1870cfbf 100644
|
||||
index 53e6154120c9..fa1d664696db 100644
|
||||
--- a/aten/src/ATen/cuda/tunable/TunableGemm.h
|
||||
+++ b/aten/src/ATen/cuda/tunable/TunableGemm.h
|
||||
@@ -11,7 +11,7 @@
|
||||
@@ -11,7 +11,9 @@
|
||||
|
||||
#include <ATen/cuda/tunable/GemmCommon.h>
|
||||
#ifdef USE_ROCM
|
||||
-#if ROCM_VERSION >= 50700
|
||||
+#ifdef HIPBLASLT
|
||||
+#ifdef USE_HIPBLASLT
|
||||
#include <ATen/cuda/tunable/GemmHipblaslt.h>
|
||||
#endif
|
||||
+#endif
|
||||
#include <ATen/cuda/tunable/GemmRocblas.h>
|
||||
@@ -166,7 +166,7 @@ class GemmTunableOp : public TunableOp<GemmParams<T>, StreamTimer> {
|
||||
}
|
||||
#endif
|
||||
|
||||
-#if defined(USE_ROCM) && ROCM_VERSION >= 50700
|
||||
+#if defined(USE_ROCM) && defined(HIPBLASLT)
|
||||
static const char *env = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED");
|
||||
if (env == nullptr || strcmp(env, "1") == 0) {
|
||||
// disallow tuning of hipblaslt with c10::complex
|
||||
@@ -240,7 +240,7 @@ class GemmStridedBatchedTunableOp : public TunableOp<GemmStridedBatchedParams<T>
|
||||
#include <ATen/cuda/tunable/StreamTimer.h>
|
||||
@@ -65,6 +67,7 @@ class DefaultGemmStridedBatchedOp : public Callable<GemmStridedBatchedParams<T>>
|
||||
}
|
||||
#endif
|
||||
|
||||
-#if defined(USE_ROCM) && ROCM_VERSION >= 50700
|
||||
+#if defined(USE_ROCM) && defined(HIPBLASLT)
|
||||
static const char *env = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED");
|
||||
if (env == nullptr || strcmp(env, "1") == 0) {
|
||||
// disallow tuning of hipblaslt with c10::complex
|
||||
diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp
|
||||
index 29e5c5e3cf1..df56f3d7f1d 100644
|
||||
--- a/aten/src/ATen/native/cuda/Blas.cpp
|
||||
+++ b/aten/src/ATen/native/cuda/Blas.cpp
|
||||
@@ -155,7 +155,7 @@ enum class Activation {
|
||||
GELU,
|
||||
};
|
||||
|
||||
-#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
|
||||
+#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && defined(HIPBLASLT))
|
||||
cuda::blas::GEMMAndBiasActivationEpilogue activation_to_gemm_and_blas_arg(Activation a) {
|
||||
switch (a) {
|
||||
case Activation::None:
|
||||
@@ -193,6 +193,7 @@ static bool getDisableAddmmCudaLt() {
|
||||
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
|
||||
template <typename T>
|
||||
class DefaultScaledGemmOp : public Callable<ScaledGemmParams<T>> {
|
||||
public:
|
||||
@@ -94,6 +97,7 @@ class DefaultScaledGemmOp : public Callable<ScaledGemmParams<T>> {
|
||||
return OK;
|
||||
}
|
||||
};
|
||||
+#endif
|
||||
|
||||
template <typename T>
|
||||
inline bool IsZero(T v) {
|
||||
@@ -191,6 +195,7 @@ static void AddRocblasValidator() {
|
||||
}
|
||||
}
|
||||
|
||||
+#ifdef USE_HIPBLASLT
|
||||
static void AddHipblasltValidator() {
|
||||
auto validators = getTuningContext()->GetTuningResultsValidator().GetAllValidators();
|
||||
if (validators.find("HIPBLASLT_VERSION") == validators.end()) {
|
||||
@@ -205,6 +210,7 @@ static void AddHipblasltValidator() {
|
||||
[hipblaslt_version](auto&& k) { return hipblaslt_version == k ? OK : FAIL; });
|
||||
}
|
||||
}
|
||||
+#endif
|
||||
|
||||
static void AddRocmValidator() {
|
||||
auto validators = getTuningContext()->GetTuningResultsValidator().GetAllValidators();
|
||||
@@ -243,7 +249,7 @@ class GemmTunableOp : public TunableOp<GemmParams<T>, StreamTimer> {
|
||||
}
|
||||
AddRocblasValidator();
|
||||
}
|
||||
-
|
||||
+#ifdef USE_HIPBLASLT
|
||||
static const char *env_hipblaslt = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED");
|
||||
if (env_hipblaslt == nullptr || strcmp(env_hipblaslt, "1") == 0) {
|
||||
rocm_validators = true;
|
||||
@@ -257,7 +263,7 @@ class GemmTunableOp : public TunableOp<GemmParams<T>, StreamTimer> {
|
||||
}
|
||||
AddHipblasltValidator();
|
||||
}
|
||||
-
|
||||
+#endif
|
||||
if (rocm_validators) {
|
||||
AddRocmValidator();
|
||||
}
|
||||
@@ -286,7 +292,7 @@ class GemmStridedBatchedTunableOp : public TunableOp<GemmStridedBatchedParams<T>
|
||||
}
|
||||
AddRocblasValidator();
|
||||
}
|
||||
-
|
||||
+#ifdef USE_HIPBLASLT
|
||||
static const char *env_hipblaslt = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED");
|
||||
if (env_hipblaslt == nullptr || strcmp(env_hipblaslt, "1") == 0) {
|
||||
rocm_validators = true;
|
||||
@@ -300,7 +306,7 @@ class GemmStridedBatchedTunableOp : public TunableOp<GemmStridedBatchedParams<T>
|
||||
}
|
||||
AddHipblasltValidator();
|
||||
}
|
||||
-
|
||||
+#endif
|
||||
if (rocm_validators) {
|
||||
AddRocmValidator();
|
||||
}
|
||||
@@ -312,6 +318,7 @@ class GemmStridedBatchedTunableOp : public TunableOp<GemmStridedBatchedParams<T>
|
||||
}
|
||||
};
|
||||
|
||||
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
|
||||
template <typename AT, typename BT, typename CT, BlasOp ALayout, BlasOp BLayout>
|
||||
class ScaledGemmTunableOp : public TunableOp<ScaledGemmParams<CT>, StreamTimer> {
|
||||
public:
|
||||
@@ -321,10 +328,12 @@ class ScaledGemmTunableOp : public TunableOp<ScaledGemmParams<CT>, StreamTimer>
|
||||
auto validators = getTuningContext()->GetTuningResultsValidator().GetAllValidators();
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
+#ifdef USE_HIPBLASLT
|
||||
for (auto&& [name, op] : GetHipBlasLtScaledGemmTypeStringAndOps<AT, BT, CT, ALayout, BLayout>()) {
|
||||
this->RegisterOp(std::move(name), std::move(op));
|
||||
}
|
||||
AddHipblasltValidator();
|
||||
+#endif
|
||||
AddRocmValidator();
|
||||
#endif
|
||||
}
|
||||
@@ -337,6 +346,7 @@ class ScaledGemmTunableOp : public TunableOp<ScaledGemmParams<CT>, StreamTimer>
|
||||
"_", BlasOpToString(ALayout), BlasOpToString(BLayout));
|
||||
}
|
||||
};
|
||||
+#endif
|
||||
|
||||
#undef XSTRINGIFY
|
||||
#undef STRINGIFY
|
||||
diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp
|
||||
index 84c59a4fd0d7..56ad5de3bf2d 100644
|
||||
--- a/aten/src/ATen/native/cuda/Blas.cpp
|
||||
+++ b/aten/src/ATen/native/cuda/Blas.cpp
|
||||
@@ -173,6 +173,7 @@ cuda::blas::GEMMAndBiasActivationEpilogue activation_to_gemm_and_blas_arg(Activa
|
||||
}
|
||||
|
||||
static bool getDisableAddmmCudaLt() {
|
||||
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
|
||||
static const char* env_value = std::getenv("DISABLE_ADDMM_CUDA_LT");
|
||||
#ifdef USE_ROCM
|
||||
// if we enable tunable op, it'll take priority over just hipblaslt (heuristics)
|
||||
@@ -196,10 +197,14 @@ static bool getDisableAddmmCudaLt() {
|
||||
}
|
||||
return false;
|
||||
#endif
|
||||
+#else
|
||||
+ return true;
|
||||
+#endif
|
||||
}
|
||||
|
||||
#ifdef USE_ROCM
|
||||
static bool isSupportedHipLtROCmArch(int index) {
|
||||
+#if defined(HIPBLASLT)
|
||||
+#ifdef USE_HIPBLASLT
|
||||
hipDeviceProp_t* prop = at::cuda::getDeviceProperties(index);
|
||||
std::string device_arch = prop->gcnArchName;
|
||||
static const std::vector<std::string> archs = {"gfx90a", "gfx940", "gfx941", "gfx942"};
|
||||
@@ -203,6 +204,7 @@ static bool isSupportedHipLtROCmArch(int index) {
|
||||
@@ -210,6 +215,7 @@ static bool isSupportedHipLtROCmArch(int index) {
|
||||
}
|
||||
}
|
||||
TORCH_CHECK(false, "Attempting to use hipBLASLt on a unsupported architecture!");
|
||||
|
|
@ -176,87 +400,107 @@ index 29e5c5e3cf1..df56f3d7f1d 100644
|
|||
return false;
|
||||
}
|
||||
#endif
|
||||
@@ -228,7 +230,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
|
||||
@@ -235,6 +241,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
|
||||
at::ScalarType scalar_type = self.scalar_type();
|
||||
c10::MaybeOwned<Tensor> self_;
|
||||
if (&result != &self) {
|
||||
-#if (defined(CUDA_VERSION) && CUDA_VERSION >= 11040 && !defined(_MSC_VER)) || defined(USE_ROCM) && ROCM_VERSION >= 50700
|
||||
+#if (defined(CUDA_VERSION) && CUDA_VERSION >= 11040 && !defined(_MSC_VER)) || defined(USE_ROCM) && defined(HIPBLASLT)
|
||||
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
|
||||
#if (defined(CUDA_VERSION) && (CUDA_VERSION >= 11040)) || defined(USE_ROCM)
|
||||
// Strangely, if mat2 has only 1 row or column, we get
|
||||
// CUBLAS_STATUS_INVALID_VALUE error from cublasLtMatmulAlgoGetHeuristic.
|
||||
// self.dim() == 1 && result.dim() == 2 && self.sizes()[0] == mat2_sizes[1]
|
||||
@@ -271,7 +273,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
|
||||
@@ -276,13 +283,14 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
|
||||
scalar_type != at::ScalarType::BFloat16));
|
||||
#endif
|
||||
}
|
||||
+#endif
|
||||
#endif
|
||||
if (!useLtInterface) {
|
||||
self_ = expand_size(self, {mat1_sizes[0], mat2_sizes[1]}, "addmm");
|
||||
}
|
||||
self__sizes = self_->sizes();
|
||||
} else {
|
||||
-#if defined(USE_ROCM) && ROCM_VERSION >= 50700
|
||||
+#if defined(USE_ROCM) && defined(HIPBLASLT)
|
||||
-#if defined(USE_ROCM)
|
||||
+#if defined(USE_ROCM) && defined(USE_HIPBLASLT)
|
||||
useLtInterface = !disable_addmm_cuda_lt &&
|
||||
result.dim() == 2 && result.is_contiguous() &&
|
||||
isSupportedHipLtROCmArch(self.device().index()) &&
|
||||
@@ -322,7 +324,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
|
||||
|
||||
@@ -334,6 +342,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!args.result->is_conj());
|
||||
|
||||
-#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
|
||||
+#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && defined(HIPBLASLT))
|
||||
if (useLtInterface) {
|
||||
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
|
||||
#if defined(USE_ROCM)
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
at::ScalarType::Half,
|
||||
@@ -876,7 +878,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
|
||||
at::native::resize_output(out, {mat1_sizes[0], mat2_sizes[1]});
|
||||
at::native::resize_output(amax, {});
|
||||
@@ -394,6 +403,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
|
||||
activation_epilogue
|
||||
);
|
||||
});
|
||||
+#endif
|
||||
#endif
|
||||
} else
|
||||
{
|
||||
@@ -803,6 +813,7 @@ Tensor _int_mm_cuda(const Tensor& self, const Tensor& mat2) {
|
||||
}
|
||||
|
||||
-#if !defined(USE_ROCM) && !defined(_MSC_VER) || (defined(USE_ROCM) && ROCM_VERSION >= 60000)
|
||||
+#if !defined(USE_ROCM) && !defined(_MSC_VER) || (defined(USE_ROCM) && defined(HIPBLASLT))
|
||||
cublasCommonArgs args(mat1, mat2, out);
|
||||
const auto out_dtype_ = args.result->scalar_type();
|
||||
TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt");
|
||||
@@ -906,7 +908,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
|
||||
TORCH_CHECK(false, "_scaled_mm_out_cuda is not compiled for this platform.");
|
||||
static bool _scaled_mm_allowed_device() {
|
||||
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
#ifdef USE_ROCM
|
||||
std::string device_arch = dprops->gcnArchName;
|
||||
@@ -817,6 +828,9 @@ static bool _scaled_mm_allowed_device() {
|
||||
#else
|
||||
return dprops->major >= 9 || (dprops->major == 8 && dprops->minor == 9);
|
||||
#endif
|
||||
+#else
|
||||
+ return false;
|
||||
+#endif
|
||||
}
|
||||
|
||||
// Computes matrix multiply + bias while applying scaling to input and output matrices and computes amax
|
||||
@@ -850,6 +864,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
|
||||
// Check sizes
|
||||
bool allowed_device = _scaled_mm_allowed_device();
|
||||
TORCH_CHECK(allowed_device, "torch._scaled_mm is only supported on CUDA devices with compute capability >= 9.0 or 8.9, or ROCm MI300+");
|
||||
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
|
||||
TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix");
|
||||
TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix");
|
||||
TORCH_CHECK(
|
||||
@@ -1025,6 +1040,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
|
||||
#if defined(USE_ROCM) && ROCM_VERSION >= 60000 && ROCM_VERSION < 60200
|
||||
// ROCm's hipBLASLt does not support amax before 6.2, so calculate separately
|
||||
amax = at::max(at::abs(out.to(kFloat)));
|
||||
+#endif
|
||||
#endif
|
||||
|
||||
-#if defined(USE_ROCM) && ROCM_VERSION >= 60000
|
||||
+#if defined(USE_ROCM) && defined(HIPBLASLT)
|
||||
// rocm's hipblaslt does not yet support amax, so calculate separately
|
||||
auto out_float32 = out.to(kFloat);
|
||||
out_float32.abs_();
|
||||
return {out, amax};
|
||||
diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake
|
||||
index b7ffbeb07dc..2b6c3678984 100644
|
||||
index f1f2eb7cec31..8d05e834bbc5 100644
|
||||
--- a/cmake/Dependencies.cmake
|
||||
+++ b/cmake/Dependencies.cmake
|
||||
@@ -1273,6 +1273,9 @@ if(USE_ROCM)
|
||||
if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "6.0.0")
|
||||
list(APPEND HIP_CXX_FLAGS -DHIPBLAS_V2)
|
||||
endif()
|
||||
@@ -1052,6 +1052,9 @@ if(USE_ROCM)
|
||||
list(APPEND HIP_CXX_FLAGS -DTHRUST_DEVICE_SYSTEM=THRUST_DEVICE_SYSTEM_HIP)
|
||||
list(APPEND HIP_CXX_FLAGS -std=c++17)
|
||||
list(APPEND HIP_CXX_FLAGS -DHIPBLAS_V2)
|
||||
+ if(hipblast_FOUND)
|
||||
+ list(APPEND HIP_CXX_FLAGS -DHIPBLASLT)
|
||||
+ list(APPEND HIP_CXX_FLAGS -DUSE_HIPBLASLT)
|
||||
+ endif()
|
||||
if(HIPBLASLT_CUSTOM_DATA_TYPE)
|
||||
list(APPEND HIP_CXX_FLAGS -DHIPBLASLT_CUSTOM_DATA_TYPE)
|
||||
if(HIP_NEW_TYPE_ENUMS)
|
||||
list(APPEND HIP_CXX_FLAGS -DHIP_NEW_TYPE_ENUMS)
|
||||
endif()
|
||||
diff --git a/cmake/public/LoadHIP.cmake b/cmake/public/LoadHIP.cmake
|
||||
index f6ca263c5e5..53eb0b63c1a 100644
|
||||
index fa39156031ff..df4836847fdf 100644
|
||||
--- a/cmake/public/LoadHIP.cmake
|
||||
+++ b/cmake/public/LoadHIP.cmake
|
||||
@@ -156,7 +156,7 @@ if(HIP_FOUND)
|
||||
@@ -155,7 +155,7 @@ if(HIP_FOUND)
|
||||
find_package_and_print_version(hiprand REQUIRED)
|
||||
find_package_and_print_version(rocblas REQUIRED)
|
||||
find_package_and_print_version(hipblas REQUIRED)
|
||||
if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "5.7.0")
|
||||
- find_package_and_print_version(hipblaslt REQUIRED)
|
||||
+ find_package_and_print_version(hipblaslt)
|
||||
endif()
|
||||
- find_package_and_print_version(hipblaslt REQUIRED)
|
||||
+ find_package_and_print_version(hipblaslt)
|
||||
find_package_and_print_version(miopen REQUIRED)
|
||||
if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "4.1.0")
|
||||
@@ -191,7 +191,7 @@ if(HIP_FOUND)
|
||||
# roctx is part of roctracer
|
||||
find_library(ROCM_ROCTX_LIB roctx64 HINTS ${ROCM_PATH}/lib)
|
||||
|
||||
- if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "5.7.0")
|
||||
+ if(hipblastlt_FOUND)
|
||||
# check whether hipblaslt is using its own datatype
|
||||
set(file "${PROJECT_BINARY_DIR}/hipblaslt_test_data_type.cc")
|
||||
file(WRITE ${file} ""
|
||||
find_package_and_print_version(hipfft REQUIRED)
|
||||
find_package_and_print_version(hipsparse REQUIRED)
|
||||
--
|
||||
2.43.2
|
||||
2.45.2
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue