PyTorch 2.4

Signed-off-by: Tom Rix <trix@redhat.com>
This commit is contained in:
Tom Rix 2024-07-25 16:27:17 -06:00
commit 86185b46a2
7 changed files with 1468 additions and 273 deletions

View file

@ -1,174 +1,398 @@
From d77e05d90df006322cda021f1a8affdcc2c7eaef Mon Sep 17 00:00:00 2001
From f1d65e958afa65882dbfea8b392ab847a84d41ed Mon Sep 17 00:00:00 2001
From: Tom Rix <trix@redhat.com>
Date: Fri, 23 Feb 2024 08:27:30 -0500
Date: Sat, 29 Jun 2024 04:18:34 -0700
Subject: [PATCH] Optionally use hipblaslt
The hipblaslt package is not available on Fedora.
Instead of requiring the package, make it optional.
If it is found, define the preprocessor variable HIPBLASLT
Convert the checks for ROCM_VERSION >= 507000 to HIPBLASLT checks
Signed-off-by: Tom Rix <trix@redhat.com>
---
aten/src/ATen/cuda/CUDABlas.cpp | 7 ++++---
aten/src/ATen/cuda/CUDABlas.h | 2 +-
aten/src/ATen/cuda/CUDAContextLight.h | 4 ++--
aten/src/ATen/cuda/CublasHandlePool.cpp | 4 ++--
aten/src/ATen/cuda/tunable/TunableGemm.h | 6 +++---
aten/src/ATen/native/cuda/Blas.cpp | 14 ++++++++------
cmake/Dependencies.cmake | 3 +++
cmake/public/LoadHIP.cmake | 4 ++--
8 files changed, 25 insertions(+), 19 deletions(-)
aten/src/ATen/cuda/CUDABlas.cpp | 46 ++++++++++++++++++------
aten/src/ATen/cuda/CUDAContextLight.h | 4 +++
aten/src/ATen/cuda/CublasHandlePool.cpp | 10 ++++--
aten/src/ATen/cuda/tunable/TunableGemm.h | 18 +++++++---
aten/src/ATen/native/cuda/Blas.cpp | 18 +++++++++-
cmake/Dependencies.cmake | 3 ++
cmake/public/LoadHIP.cmake | 2 +-
7 files changed, 82 insertions(+), 19 deletions(-)
diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp
index d534ec5a178..e815463f630 100644
index ce991a9bcad4..3f0d17b52778 100644
--- a/aten/src/ATen/cuda/CUDABlas.cpp
+++ b/aten/src/ATen/cuda/CUDABlas.cpp
@@ -14,7 +14,7 @@
@@ -14,7 +14,9 @@
#include <c10/util/irange.h>
#ifdef USE_ROCM
-#if ROCM_VERSION >= 60000
+#ifdef HIPBLASLT
+#ifdef USE_HIPBLASLT
#include <hipblaslt/hipblaslt-ext.hpp>
#endif
+#endif
// until hipblas has an API to accept flags, we must use rocblas here
@@ -781,7 +781,7 @@ void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
}
}
#include <hipblas/hipblas.h>
#include <rocblas/rocblas.h>
@@ -182,6 +184,9 @@ uint32_t _getAlignment(uintptr_t address) {
static size_t _parseChosenWorkspaceSize() {
const char * val = getenv("CUBLASLT_WORKSPACE_SIZE");
#ifdef USE_ROCM
+#ifndef USE_HIPBLASLT
+ return 0;
+#endif
if (!val) {
// accept either env var
val = getenv("HIPBLASLT_WORKSPACE_SIZE");
@@ -235,6 +240,7 @@ namespace at::cuda::blas {
} while (0)
-#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
+#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && defined(HIPBLASLT))
#if defined(USE_ROCM) && ROCM_VERSION >= 50700 && ROCM_VERSION < 60000
// only for rocm 5.7 where we first supported hipblaslt, it was difficult
@@ -912,6 +912,7 @@ class CuBlasLtMatmulPreference : public CuBlasLtDescriptor<
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
namespace {
// Following the pattern of CuSparseDescriptor
// Defined here for now because this is the only place cublas_lt interface is
@@ -318,7 +324,6 @@ class CuBlasLtMatmulPreference : public CuBlasLtDescriptor<
};
} // namespace
+#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && defined(HIPBLASLT))
-
template <typename Dtype>
inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
cudaDataType_t abcType = CUDA_R_32F;
@@ -452,7 +457,7 @@ inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
" scaleType ",
scaleType);
}
-
+#endif
template <typename Dtype>
inline void bgemm_internal_cublas(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
@@ -608,10 +613,13 @@ void bgemm_internal<double>(CUDABLAS_BGEMM_ARGTYPES(double))
template <>
void bgemm_internal<float>(CUDABLAS_BGEMM_ARGTYPES(float))
{
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
bgemm_internal_cublaslt<float>(CUDABLAS_BGEMM_ARGS(float));
}
- else {
+ else
+#endif
+ {
bgemm_internal_cublas<float>(CUDABLAS_BGEMM_ARGS(float));
}
}
@@ -651,10 +659,13 @@ void bgemm_internal<c10::complex<float>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<fl
template <>
void bgemm_internal<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half))
{
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
bgemm_internal_cublaslt<at::Half>(CUDABLAS_BGEMM_ARGS(at::Half));
}
- else {
+ else
+#endif
+ {
bgemm_internal_cublas<at::Half>(CUDABLAS_BGEMM_ARGS(at::Half));
}
}
@@ -662,10 +673,13 @@ void bgemm_internal<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half))
template <>
void bgemm_internal<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16))
{
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
bgemm_internal_cublaslt<at::BFloat16>(CUDABLAS_BGEMM_ARGS(at::BFloat16));
}
- else {
+ else
+#endif
+ {
bgemm_internal_cublas<at::BFloat16>(CUDABLAS_BGEMM_ARGS(at::BFloat16));
}
}
@@ -781,11 +795,13 @@ void bgemm<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) {
}
}
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
template <typename Dtype>
inline void gemm_internal_cublaslt(CUDABLAS_GEMM_ARGTYPES(Dtype)) {
// forward to bgemm implementation but set strides and batches to 0
bgemm_internal_cublaslt(transa, transb, m, n, k, alpha, a, lda, 0, b, ldb, 0, beta, c, ldc, 0, 0);
}
+#endif
template <typename Dtype>
inline void gemm_internal_cublas(CUDABLAS_GEMM_ARGTYPES(Dtype)) {
@@ -1008,10 +1024,13 @@ void gemm_internal<double>(CUDABLAS_GEMM_ARGTYPES(double))
template <>
void gemm_internal<float>(CUDABLAS_GEMM_ARGTYPES(float))
{
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
gemm_internal_cublaslt<float>(CUDABLAS_GEMM_ARGS(float));
}
- else {
+ else
+#endif
+ {
gemm_internal_cublas<float>(CUDABLAS_GEMM_ARGS(float));
}
}
@@ -1051,10 +1070,13 @@ void gemm_internal<c10::complex<float>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<floa
template <>
void gemm_internal<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half))
{
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
gemm_internal_cublaslt<at::Half>(CUDABLAS_GEMM_ARGS(at::Half));
}
- else {
+ else
+#endif
+ {
gemm_internal_cublas<at::Half>(CUDABLAS_GEMM_ARGS(at::Half));
}
}
@@ -1062,10 +1084,13 @@ void gemm_internal<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half))
template <>
void gemm_internal<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16))
{
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
gemm_internal_cublaslt<at::BFloat16>(CUDABLAS_GEMM_ARGS(at::BFloat16));
}
- else {
+ else
+#endif
+ {
gemm_internal_cublas<at::BFloat16>(CUDABLAS_GEMM_ARGS(at::BFloat16));
}
}
@@ -1177,7 +1202,7 @@ void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
}
}
-
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
template <typename Dtype>
void gemm_and_bias(
bool transpose_mat1,
@@ -1124,7 +1125,7 @@ template void gemm_and_bias(
at::BFloat16* result_ptr,
int64_t result_ld,
GEMMAndBiasActivationEpilogue activation);
-
@@ -1410,7 +1435,7 @@ void scaled_gemm(
ScalarType result_dtype,
void* amax_ptr,
bool use_fast_accum) {
-#if CUDA_VERSION >= 11080 || defined(USE_ROCM)
+#if CUDA_VERSION >= 11080 || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
const auto computeType = CUBLAS_COMPUTE_32F;
const auto scaleType = CUDA_R_32F;
const int8_t fastAccuMode = use_fast_accum ? 1 : 0;
@@ -1681,6 +1706,7 @@ void int8_gemm(
" scaleType ",
scaleType);
}
+#endif
void scaled_gemm(
char transa,
char transb,
diff --git a/aten/src/ATen/cuda/CUDABlas.h b/aten/src/ATen/cuda/CUDABlas.h
index eb12bb350c5..068607467dd 100644
--- a/aten/src/ATen/cuda/CUDABlas.h
+++ b/aten/src/ATen/cuda/CUDABlas.h
@@ -82,7 +82,7 @@ void gemm_internal<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half));
template <>
void gemm_internal<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16));
-#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
+#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && defined(HIPBLASLT))
enum GEMMAndBiasActivationEpilogue {
None,
RELU,
template <>
void trsm<float>(CUDABLAS_TRSM_ARGTYPES(float)) {
diff --git a/aten/src/ATen/cuda/CUDAContextLight.h b/aten/src/ATen/cuda/CUDAContextLight.h
index 4ec35f59a21..e28dc42034f 100644
index f2b657ced51b..f0ee613c4208 100644
--- a/aten/src/ATen/cuda/CUDAContextLight.h
+++ b/aten/src/ATen/cuda/CUDAContextLight.h
@@ -9,7 +9,7 @@
@@ -9,7 +9,9 @@
// cublasLT was introduced in CUDA 10.1 but we enable only for 11.1 that also
// added bf16 support
-#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
+#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && defined(HIPBLASLT))
+#if (!defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT)))
#include <cublasLt.h>
#endif
+#endif
@@ -82,7 +82,7 @@ TORCH_CUDA_CPP_API c10::Allocator* getCUDADeviceAllocator();
#ifdef CUDART_VERSION
#include <cusolverDn.h>
@@ -80,7 +82,9 @@ TORCH_CUDA_CPP_API c10::Allocator* getCUDADeviceAllocator();
/* Handles */
TORCH_CUDA_CPP_API cusparseHandle_t getCurrentCUDASparseHandle();
TORCH_CUDA_CPP_API cublasHandle_t getCurrentCUDABlasHandle();
-#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
+#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && defined(HIPBLASLT))
+#if (!defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT)))
TORCH_CUDA_CPP_API cublasLtHandle_t getCurrentCUDABlasLtHandle();
#endif
+#endif
TORCH_CUDA_CPP_API void clearCublasWorkspaces();
diff --git a/aten/src/ATen/cuda/CublasHandlePool.cpp b/aten/src/ATen/cuda/CublasHandlePool.cpp
index 6913d2cd95e..3d4276be372 100644
index 8eac525b3695..abfdf7a23847 100644
--- a/aten/src/ATen/cuda/CublasHandlePool.cpp
+++ b/aten/src/ATen/cuda/CublasHandlePool.cpp
@@ -29,7 +29,7 @@ namespace at::cuda {
namespace {
-#if defined(USE_ROCM) && ROCM_VERSION >= 50700
+#if defined(USE_ROCM) && defined(HIPBLASLT)
-#if defined(USE_ROCM)
+#if defined(USE_ROCM) && defined(USE_HIPBLASLT)
void createCublasLtHandle(cublasLtHandle_t *handle) {
TORCH_CUDABLAS_CHECK(cublasLtCreate(handle));
}
@@ -190,7 +190,7 @@ cublasHandle_t getCurrentCUDABlasHandle() {
@@ -191,8 +191,9 @@ cublasHandle_t getCurrentCUDABlasHandle() {
return handle;
}
-#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
+#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && defined(HIPBLASLT))
cublasLtHandle_t getCurrentCUDABlasLtHandle() {
-cublasLtHandle_t getCurrentCUDABlasLtHandle() {
#ifdef USE_ROCM
+#if defined(USE_HIPBLASLT)
+cublasLtHandle_t getCurrentCUDABlasLtHandle() {
c10::DeviceIndex device = 0;
AT_CUDA_CHECK(c10::cuda::GetDevice(&device));
@@ -213,9 +214,12 @@ cublasLtHandle_t getCurrentCUDABlasLtHandle() {
auto handle = myPoolWindow->reserve(device);
return handle;
+}
+#endif
#else
+cublasLtHandle_t getCurrentCUDABlasLtHandle() {
return reinterpret_cast<cublasLtHandle_t>(getCurrentCUDABlasHandle());
-#endif
}
+#endif
} // namespace at::cuda
diff --git a/aten/src/ATen/cuda/tunable/TunableGemm.h b/aten/src/ATen/cuda/tunable/TunableGemm.h
index 3ba0d761277..dde1870cfbf 100644
index 53e6154120c9..fa1d664696db 100644
--- a/aten/src/ATen/cuda/tunable/TunableGemm.h
+++ b/aten/src/ATen/cuda/tunable/TunableGemm.h
@@ -11,7 +11,7 @@
@@ -11,7 +11,9 @@
#include <ATen/cuda/tunable/GemmCommon.h>
#ifdef USE_ROCM
-#if ROCM_VERSION >= 50700
+#ifdef HIPBLASLT
+#ifdef USE_HIPBLASLT
#include <ATen/cuda/tunable/GemmHipblaslt.h>
#endif
+#endif
#include <ATen/cuda/tunable/GemmRocblas.h>
@@ -166,7 +166,7 @@ class GemmTunableOp : public TunableOp<GemmParams<T>, StreamTimer> {
}
#endif
-#if defined(USE_ROCM) && ROCM_VERSION >= 50700
+#if defined(USE_ROCM) && defined(HIPBLASLT)
static const char *env = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED");
if (env == nullptr || strcmp(env, "1") == 0) {
// disallow tuning of hipblaslt with c10::complex
@@ -240,7 +240,7 @@ class GemmStridedBatchedTunableOp : public TunableOp<GemmStridedBatchedParams<T>
#include <ATen/cuda/tunable/StreamTimer.h>
@@ -65,6 +67,7 @@ class DefaultGemmStridedBatchedOp : public Callable<GemmStridedBatchedParams<T>>
}
#endif
-#if defined(USE_ROCM) && ROCM_VERSION >= 50700
+#if defined(USE_ROCM) && defined(HIPBLASLT)
static const char *env = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED");
if (env == nullptr || strcmp(env, "1") == 0) {
// disallow tuning of hipblaslt with c10::complex
diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp
index 29e5c5e3cf1..df56f3d7f1d 100644
--- a/aten/src/ATen/native/cuda/Blas.cpp
+++ b/aten/src/ATen/native/cuda/Blas.cpp
@@ -155,7 +155,7 @@ enum class Activation {
GELU,
};
-#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
+#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && defined(HIPBLASLT))
cuda::blas::GEMMAndBiasActivationEpilogue activation_to_gemm_and_blas_arg(Activation a) {
switch (a) {
case Activation::None:
@@ -193,6 +193,7 @@ static bool getDisableAddmmCudaLt() {
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
template <typename T>
class DefaultScaledGemmOp : public Callable<ScaledGemmParams<T>> {
public:
@@ -94,6 +97,7 @@ class DefaultScaledGemmOp : public Callable<ScaledGemmParams<T>> {
return OK;
}
};
+#endif
template <typename T>
inline bool IsZero(T v) {
@@ -191,6 +195,7 @@ static void AddRocblasValidator() {
}
}
+#ifdef USE_HIPBLASLT
static void AddHipblasltValidator() {
auto validators = getTuningContext()->GetTuningResultsValidator().GetAllValidators();
if (validators.find("HIPBLASLT_VERSION") == validators.end()) {
@@ -205,6 +210,7 @@ static void AddHipblasltValidator() {
[hipblaslt_version](auto&& k) { return hipblaslt_version == k ? OK : FAIL; });
}
}
+#endif
static void AddRocmValidator() {
auto validators = getTuningContext()->GetTuningResultsValidator().GetAllValidators();
@@ -243,7 +249,7 @@ class GemmTunableOp : public TunableOp<GemmParams<T>, StreamTimer> {
}
AddRocblasValidator();
}
-
+#ifdef USE_HIPBLASLT
static const char *env_hipblaslt = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED");
if (env_hipblaslt == nullptr || strcmp(env_hipblaslt, "1") == 0) {
rocm_validators = true;
@@ -257,7 +263,7 @@ class GemmTunableOp : public TunableOp<GemmParams<T>, StreamTimer> {
}
AddHipblasltValidator();
}
-
+#endif
if (rocm_validators) {
AddRocmValidator();
}
@@ -286,7 +292,7 @@ class GemmStridedBatchedTunableOp : public TunableOp<GemmStridedBatchedParams<T>
}
AddRocblasValidator();
}
-
+#ifdef USE_HIPBLASLT
static const char *env_hipblaslt = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED");
if (env_hipblaslt == nullptr || strcmp(env_hipblaslt, "1") == 0) {
rocm_validators = true;
@@ -300,7 +306,7 @@ class GemmStridedBatchedTunableOp : public TunableOp<GemmStridedBatchedParams<T>
}
AddHipblasltValidator();
}
-
+#endif
if (rocm_validators) {
AddRocmValidator();
}
@@ -312,6 +318,7 @@ class GemmStridedBatchedTunableOp : public TunableOp<GemmStridedBatchedParams<T>
}
};
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
template <typename AT, typename BT, typename CT, BlasOp ALayout, BlasOp BLayout>
class ScaledGemmTunableOp : public TunableOp<ScaledGemmParams<CT>, StreamTimer> {
public:
@@ -321,10 +328,12 @@ class ScaledGemmTunableOp : public TunableOp<ScaledGemmParams<CT>, StreamTimer>
auto validators = getTuningContext()->GetTuningResultsValidator().GetAllValidators();
#if defined(USE_ROCM)
+#ifdef USE_HIPBLASLT
for (auto&& [name, op] : GetHipBlasLtScaledGemmTypeStringAndOps<AT, BT, CT, ALayout, BLayout>()) {
this->RegisterOp(std::move(name), std::move(op));
}
AddHipblasltValidator();
+#endif
AddRocmValidator();
#endif
}
@@ -337,6 +346,7 @@ class ScaledGemmTunableOp : public TunableOp<ScaledGemmParams<CT>, StreamTimer>
"_", BlasOpToString(ALayout), BlasOpToString(BLayout));
}
};
+#endif
#undef XSTRINGIFY
#undef STRINGIFY
diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp
index 84c59a4fd0d7..56ad5de3bf2d 100644
--- a/aten/src/ATen/native/cuda/Blas.cpp
+++ b/aten/src/ATen/native/cuda/Blas.cpp
@@ -173,6 +173,7 @@ cuda::blas::GEMMAndBiasActivationEpilogue activation_to_gemm_and_blas_arg(Activa
}
static bool getDisableAddmmCudaLt() {
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
static const char* env_value = std::getenv("DISABLE_ADDMM_CUDA_LT");
#ifdef USE_ROCM
// if we enable tunable op, it'll take priority over just hipblaslt (heuristics)
@@ -196,10 +197,14 @@ static bool getDisableAddmmCudaLt() {
}
return false;
#endif
+#else
+ return true;
+#endif
}
#ifdef USE_ROCM
static bool isSupportedHipLtROCmArch(int index) {
+#if defined(HIPBLASLT)
+#ifdef USE_HIPBLASLT
hipDeviceProp_t* prop = at::cuda::getDeviceProperties(index);
std::string device_arch = prop->gcnArchName;
static const std::vector<std::string> archs = {"gfx90a", "gfx940", "gfx941", "gfx942"};
@@ -203,6 +204,7 @@ static bool isSupportedHipLtROCmArch(int index) {
@@ -210,6 +215,7 @@ static bool isSupportedHipLtROCmArch(int index) {
}
}
TORCH_CHECK(false, "Attempting to use hipBLASLt on a unsupported architecture!");
@ -176,87 +400,107 @@ index 29e5c5e3cf1..df56f3d7f1d 100644
return false;
}
#endif
@@ -228,7 +230,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
@@ -235,6 +241,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
at::ScalarType scalar_type = self.scalar_type();
c10::MaybeOwned<Tensor> self_;
if (&result != &self) {
-#if (defined(CUDA_VERSION) && CUDA_VERSION >= 11040 && !defined(_MSC_VER)) || defined(USE_ROCM) && ROCM_VERSION >= 50700
+#if (defined(CUDA_VERSION) && CUDA_VERSION >= 11040 && !defined(_MSC_VER)) || defined(USE_ROCM) && defined(HIPBLASLT)
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
#if (defined(CUDA_VERSION) && (CUDA_VERSION >= 11040)) || defined(USE_ROCM)
// Strangely, if mat2 has only 1 row or column, we get
// CUBLAS_STATUS_INVALID_VALUE error from cublasLtMatmulAlgoGetHeuristic.
// self.dim() == 1 && result.dim() == 2 && self.sizes()[0] == mat2_sizes[1]
@@ -271,7 +273,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
@@ -276,13 +283,14 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
scalar_type != at::ScalarType::BFloat16));
#endif
}
+#endif
#endif
if (!useLtInterface) {
self_ = expand_size(self, {mat1_sizes[0], mat2_sizes[1]}, "addmm");
}
self__sizes = self_->sizes();
} else {
-#if defined(USE_ROCM) && ROCM_VERSION >= 50700
+#if defined(USE_ROCM) && defined(HIPBLASLT)
-#if defined(USE_ROCM)
+#if defined(USE_ROCM) && defined(USE_HIPBLASLT)
useLtInterface = !disable_addmm_cuda_lt &&
result.dim() == 2 && result.is_contiguous() &&
isSupportedHipLtROCmArch(self.device().index()) &&
@@ -322,7 +324,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
@@ -334,6 +342,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!args.result->is_conj());
-#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && ROCM_VERSION >= 50700)
+#if (!defined(USE_ROCM) && !defined(_MSC_VER)) || (defined(USE_ROCM) && defined(HIPBLASLT))
if (useLtInterface) {
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
#if defined(USE_ROCM)
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
@@ -876,7 +878,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
at::native::resize_output(out, {mat1_sizes[0], mat2_sizes[1]});
at::native::resize_output(amax, {});
@@ -394,6 +403,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
activation_epilogue
);
});
+#endif
#endif
} else
{
@@ -803,6 +813,7 @@ Tensor _int_mm_cuda(const Tensor& self, const Tensor& mat2) {
}
-#if !defined(USE_ROCM) && !defined(_MSC_VER) || (defined(USE_ROCM) && ROCM_VERSION >= 60000)
+#if !defined(USE_ROCM) && !defined(_MSC_VER) || (defined(USE_ROCM) && defined(HIPBLASLT))
cublasCommonArgs args(mat1, mat2, out);
const auto out_dtype_ = args.result->scalar_type();
TORCH_CHECK(args.transa == 't' && args.transb == 'n', "Only multiplication of row-major and column-major matrices is supported by cuBLASLt");
@@ -906,7 +908,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
TORCH_CHECK(false, "_scaled_mm_out_cuda is not compiled for this platform.");
static bool _scaled_mm_allowed_device() {
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
auto dprops = at::cuda::getCurrentDeviceProperties();
#ifdef USE_ROCM
std::string device_arch = dprops->gcnArchName;
@@ -817,6 +828,9 @@ static bool _scaled_mm_allowed_device() {
#else
return dprops->major >= 9 || (dprops->major == 8 && dprops->minor == 9);
#endif
+#else
+ return false;
+#endif
}
// Computes matrix multiply + bias while applying scaling to input and output matrices and computes amax
@@ -850,6 +864,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
// Check sizes
bool allowed_device = _scaled_mm_allowed_device();
TORCH_CHECK(allowed_device, "torch._scaled_mm is only supported on CUDA devices with compute capability >= 9.0 or 8.9, or ROCm MI300+");
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix");
TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix");
TORCH_CHECK(
@@ -1025,6 +1040,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
#if defined(USE_ROCM) && ROCM_VERSION >= 60000 && ROCM_VERSION < 60200
// ROCm's hipBLASLt does not support amax before 6.2, so calculate separately
amax = at::max(at::abs(out.to(kFloat)));
+#endif
#endif
-#if defined(USE_ROCM) && ROCM_VERSION >= 60000
+#if defined(USE_ROCM) && defined(HIPBLASLT)
// rocm's hipblaslt does not yet support amax, so calculate separately
auto out_float32 = out.to(kFloat);
out_float32.abs_();
return {out, amax};
diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake
index b7ffbeb07dc..2b6c3678984 100644
index f1f2eb7cec31..8d05e834bbc5 100644
--- a/cmake/Dependencies.cmake
+++ b/cmake/Dependencies.cmake
@@ -1273,6 +1273,9 @@ if(USE_ROCM)
if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "6.0.0")
list(APPEND HIP_CXX_FLAGS -DHIPBLAS_V2)
endif()
@@ -1052,6 +1052,9 @@ if(USE_ROCM)
list(APPEND HIP_CXX_FLAGS -DTHRUST_DEVICE_SYSTEM=THRUST_DEVICE_SYSTEM_HIP)
list(APPEND HIP_CXX_FLAGS -std=c++17)
list(APPEND HIP_CXX_FLAGS -DHIPBLAS_V2)
+ if(hipblast_FOUND)
+ list(APPEND HIP_CXX_FLAGS -DHIPBLASLT)
+ list(APPEND HIP_CXX_FLAGS -DUSE_HIPBLASLT)
+ endif()
if(HIPBLASLT_CUSTOM_DATA_TYPE)
list(APPEND HIP_CXX_FLAGS -DHIPBLASLT_CUSTOM_DATA_TYPE)
if(HIP_NEW_TYPE_ENUMS)
list(APPEND HIP_CXX_FLAGS -DHIP_NEW_TYPE_ENUMS)
endif()
diff --git a/cmake/public/LoadHIP.cmake b/cmake/public/LoadHIP.cmake
index f6ca263c5e5..53eb0b63c1a 100644
index fa39156031ff..df4836847fdf 100644
--- a/cmake/public/LoadHIP.cmake
+++ b/cmake/public/LoadHIP.cmake
@@ -156,7 +156,7 @@ if(HIP_FOUND)
@@ -155,7 +155,7 @@ if(HIP_FOUND)
find_package_and_print_version(hiprand REQUIRED)
find_package_and_print_version(rocblas REQUIRED)
find_package_and_print_version(hipblas REQUIRED)
if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "5.7.0")
- find_package_and_print_version(hipblaslt REQUIRED)
+ find_package_and_print_version(hipblaslt)
endif()
- find_package_and_print_version(hipblaslt REQUIRED)
+ find_package_and_print_version(hipblaslt)
find_package_and_print_version(miopen REQUIRED)
if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "4.1.0")
@@ -191,7 +191,7 @@ if(HIP_FOUND)
# roctx is part of roctracer
find_library(ROCM_ROCTX_LIB roctx64 HINTS ${ROCM_PATH}/lib)
- if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "5.7.0")
+ if(hipblastlt_FOUND)
# check whether hipblaslt is using its own datatype
set(file "${PROJECT_BINARY_DIR}/hipblaslt_test_data_type.cc")
file(WRITE ${file} ""
find_package_and_print_version(hipfft REQUIRED)
find_package_and_print_version(hipsparse REQUIRED)
--
2.43.2
2.45.2