Update the next gitcommit to v2.8.0-rc6

Remove old patches.

Signed-off-by: Tom Rix <Tom.Rix@amd.com>
This commit is contained in:
Tom Rix 2025-07-20 12:44:41 -07:00
commit 42c33b8dcd
27 changed files with 154 additions and 4539 deletions

View file

@ -1,47 +0,0 @@
From 091b7fe1ccbb5e4ff4ac6017d42bacb869f61a27 Mon Sep 17 00:00:00 2001
From: Tom Rix <trix@redhat.com>
Date: Sat, 20 Jul 2024 05:37:15 -0600
Subject: [PATCH] Add cmake option USE_SYSTEM_FBGEMM
Signed-off-by: Tom Rix <trix@redhat.com>
---
CMakeLists.txt | 1 +
cmake/Dependencies.cmake | 3 ++-
2 files changed, 3 insertions(+), 1 deletion(-)
diff --git a/CMakeLists.txt b/CMakeLists.txt
index c4cd4b2c2a98..2068f7c6c4f2 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -253,6 +253,7 @@ cmake_dependent_option(USE_STATIC_CUDNN "Use cuDNN static libraries" OFF
"USE_CUDNN" OFF)
cmake_dependent_option(USE_CUSPARSELT "Use cuSPARSELt" ON "USE_CUDA" OFF)
option(USE_FBGEMM "Use FBGEMM (quantized 8-bit server operators)" ON)
+option(USE_SYSTEM_FBGEMM "Use system-wide FBGEMM" OFF)
option(USE_KINETO "Use Kineto profiling library" ON)
option(USE_CUPTI_SO "Use CUPTI as a shared library" ON)
option(USE_FAKELOWP "Use FakeLowp operators" OFF)
diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake
index f1f2eb7cec31..192dac46f13b 100644
--- a/cmake/Dependencies.cmake
+++ b/cmake/Dependencies.cmake
@@ -706,6 +706,7 @@ endif()
# ---[ FBGEMM
if(USE_FBGEMM)
+ if (NOT USE_SYSTEM_FBGEMM)
set(CAFFE2_THIRD_PARTY_ROOT "${PROJECT_SOURCE_DIR}/third_party")
if(NOT DEFINED FBGEMM_SOURCE_DIR)
set(FBGEMM_SOURCE_DIR "${CAFFE2_THIRD_PARTY_ROOT}/fbgemm" CACHE STRING "FBGEMM source directory")
@@ -746,7 +747,7 @@ if(USE_FBGEMM)
target_compile_options_if_supported(asmjit -Wno-unused-but-set-variable)
endif()
endif()
-
+ endif()
if(USE_FBGEMM)
list(APPEND Caffe2_DEPENDENCY_LIBS fbgemm)
endif()
--
2.45.1

View file

@ -0,0 +1,149 @@
From 4cc5d88dfe7a45ab245648dc874645d32a24b98b Mon Sep 17 00:00:00 2001
From: Tom Rix <Tom.Rix@amd.com>
Date: Fri, 27 Jun 2025 13:52:51 -0700
Subject: [PATCH] Add cmake variable USE_ROCM_CK
---
CMakeLists.txt | 1 +
aten/src/ATen/CMakeLists.txt | 40 ++++++++++++++++-----------------
aten/src/ATen/cuda/CUDABlas.cpp | 10 ++++-----
cmake/Dependencies.cmake | 3 +++
4 files changed, 29 insertions(+), 25 deletions(-)
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 99c0b9e0ea0c..4c632e42f531 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -240,6 +240,7 @@ cmake_dependent_option(
BUILD_LAZY_CUDA_LINALG "Build cuda linalg ops as separate library" ON
"USE_CUDA AND LINUX AND BUILD_PYTHON" OFF)
cmake_dependent_option(USE_ROCM "Use ROCm" ON "LINUX" OFF)
+cmake_dependent_option(USE_ROCM_CK "Use ROCm Composable Kernel" ON "USE_ROCM" ON)
option(CAFFE2_STATIC_LINK_CUDA "Statically link CUDA libraries" OFF)
cmake_dependent_option(USE_CUDNN "Use cuDNN" ON "USE_CUDA" OFF)
cmake_dependent_option(USE_STATIC_CUDNN "Use cuDNN static libraries" OFF
diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt
index c9cfd74b501e..59f6178218ee 100644
--- a/aten/src/ATen/CMakeLists.txt
+++ b/aten/src/ATen/CMakeLists.txt
@@ -373,26 +373,26 @@ if(USE_ROCM)
# is header only, so this should be ok, except that the CMake build generates
# a ck/config.h. We just do that part here. Without this, the ck.h from the
# ROCM SDK may get accidentally used instead.
- function(_pytorch_rocm_generate_ck_conf)
- set(CK_ENABLE_INT8 "ON")
- set(CK_ENABLE_FP16 "ON")
- set(CK_ENABLE_FP32 "ON")
- set(CK_ENABLE_FP64 "ON")
- set(CK_ENABLE_BF16 "ON")
- set(CK_ENABLE_FP8 "ON")
- set(CK_ENABLE_BF8 "ON")
- set(CK_USE_XDL "ON")
- set(CK_USE_WMMA "ON")
- configure_file(
- "${Torch_SOURCE_DIR}/third_party/composable_kernel/include/ck/config.h.in"
- "${CMAKE_CURRENT_BINARY_DIR}/composable_kernel/ck/config.h"
- )
- endfunction()
+# function(_pytorch_rocm_generate_ck_conf)
+# set(CK_ENABLE_INT8 "ON")
+# set(CK_ENABLE_FP16 "ON")
+# set(CK_ENABLE_FP32 "ON")
+# set(CK_ENABLE_FP64 "ON")
+# set(CK_ENABLE_BF16 "ON")
+# set(CK_ENABLE_FP8 "ON")
+# set(CK_ENABLE_BF8 "ON")
+# set(CK_USE_XDL "ON")
+# set(CK_USE_WMMA "ON")
+# configure_file(
+# "${Torch_SOURCE_DIR}/third_party/composable_kernel/include/ck/config.h.in"
+# "${CMAKE_CURRENT_BINARY_DIR}/composable_kernel/ck/config.h"
+# )
+# endfunction()
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/hip)
- list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/include)
- list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/library/include)
- list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_BINARY_DIR}/composable_kernel)
- _pytorch_rocm_generate_ck_conf()
+# list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/include)
+# list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/library/include)
+# list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_BINARY_DIR}/composable_kernel)
+# _pytorch_rocm_generate_ck_conf()
# Next two lines are needed because TunableOp uses third-party/fmt
list(APPEND ATen_HIP_INCLUDE $<TARGET_PROPERTY:fmt::fmt-header-only,INTERFACE_INCLUDE_DIRECTORIES>)
@@ -409,7 +409,7 @@ endif()
${native_quantized_hip_hip}
${native_transformers_hip_hip} ${native_transformers_src_hip_hip}
)
- if(WIN32) # Windows doesn't support Composable Kernels
+ if(NOT USE_ROCM_CK) # Windows doesn't support Composable Kernels
file(GLOB native_hip_bgemm "native/hip/bgemm_kernels/*.hip")
file(GLOB native_hip_ck "native/hip/ck*.hip")
exclude(ATen_HIP_SRCS "${ATen_HIP_SRCS}"
diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp
index 89350a11bea7..33e5f2808057 100644
--- a/aten/src/ATen/cuda/CUDABlas.cpp
+++ b/aten/src/ATen/cuda/CUDABlas.cpp
@@ -752,7 +752,7 @@ template <>
void bgemm_internal<double>(CUDABLAS_BGEMM_ARGTYPES(double))
{
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
-#ifdef USE_ROCM
+#ifdef USE_ROCM_CK
// hipblaslt does not support double gemm yet
bgemm_internal_cublas<double>(CUDABLAS_BGEMM_ARGS(double));
#else
@@ -1103,7 +1103,7 @@ inline void gemm_internal_cublas_half_helper(CUDABLAS_GEMM_ARGTYPES_AND_C_DTYPE(
void * beta_ptr = &fbeta;
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
GEMM_CHECK_ARGVALUES(at::Half);
-#ifdef USE_ROCM
+#ifdef USE_ROCM_CK
int flag = 0;
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
flag = at::ROCmBackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
@@ -1270,7 +1270,7 @@ template <>
void gemm_internal<double>(CUDABLAS_GEMM_ARGTYPES(double))
{
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
-#ifdef USE_ROCM
+#ifdef USE_ROCM_CK
// hipblaslt does not support double gemm yet
gemm_internal_cublas<double>(CUDABLAS_GEMM_ARGS(double));
#else
@@ -1311,7 +1311,7 @@ template <>
void gemm_internal<c10::complex<double>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<double>))
{
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
-#ifdef USE_ROCM
+#ifdef USE_ROCM_CK
// hipblaslt does not support complex gemm yet
gemm_internal_cublas<c10::complex<double>>(CUDABLAS_GEMM_ARGS(c10::complex<double>));
#else
@@ -1327,7 +1327,7 @@ template <>
void gemm_internal<c10::complex<float>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<float>))
{
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
-#ifdef USE_ROCM
+#ifdef USE_ROCM_CK
// hipblaslt does not support complex gemm yet
gemm_internal_cublas<c10::complex<float>>(CUDABLAS_GEMM_ARGS(c10::complex<float>));
#else
diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake
index a93386c27f8d..be1368999d38 100644
--- a/cmake/Dependencies.cmake
+++ b/cmake/Dependencies.cmake
@@ -1031,6 +1031,9 @@ if(USE_ROCM)
if(HIPBLASLT_VEC_EXT)
list(APPEND HIP_CXX_FLAGS -DHIPBLASLT_VEC_EXT)
endif()
+ if(USE_ROCM_CK)
+ list(APPEND HIP_CXX_FLAGS -DUSE_ROCM_CK)
+ endif()
list(APPEND HIP_HIPCC_FLAGS --offload-compress)
if(WIN32)
add_definitions(-DROCM_ON_WINDOWS)
--
2.49.0

View file

@ -1,506 +0,0 @@
From f1d65e958afa65882dbfea8b392ab847a84d41ed Mon Sep 17 00:00:00 2001
From: Tom Rix <trix@redhat.com>
Date: Sat, 29 Jun 2024 04:18:34 -0700
Subject: [PATCH] Optionally use hipblaslt
---
aten/src/ATen/cuda/CUDABlas.cpp | 46 ++++++++++++++++++------
aten/src/ATen/cuda/CUDAContextLight.h | 4 +++
aten/src/ATen/cuda/CublasHandlePool.cpp | 10 ++++--
aten/src/ATen/cuda/tunable/TunableGemm.h | 18 +++++++---
aten/src/ATen/native/cuda/Blas.cpp | 18 +++++++++-
cmake/Dependencies.cmake | 3 ++
cmake/public/LoadHIP.cmake | 2 +-
7 files changed, 82 insertions(+), 19 deletions(-)
diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp
index ce991a9bcad4..3f0d17b52778 100644
--- a/aten/src/ATen/cuda/CUDABlas.cpp
+++ b/aten/src/ATen/cuda/CUDABlas.cpp
@@ -14,7 +14,9 @@
#include <c10/util/irange.h>
#ifdef USE_ROCM
+#ifdef USE_HIPBLASLT
#include <hipblaslt/hipblaslt-ext.hpp>
+#endif
// until hipblas has an API to accept flags, we must use rocblas here
#include <hipblas/hipblas.h>
#include <rocblas/rocblas.h>
@@ -182,6 +184,9 @@ uint32_t _getAlignment(uintptr_t address) {
static size_t _parseChosenWorkspaceSize() {
const char * val = getenv("CUBLASLT_WORKSPACE_SIZE");
#ifdef USE_ROCM
+#ifndef USE_HIPBLASLT
+ return 0;
+#endif
if (!val) {
// accept either env var
val = getenv("HIPBLASLT_WORKSPACE_SIZE");
@@ -235,6 +240,7 @@ namespace at::cuda::blas {
} while (0)
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
namespace {
// Following the pattern of CuSparseDescriptor
// Defined here for now because this is the only place cublas_lt interface is
@@ -318,7 +324,6 @@ class CuBlasLtMatmulPreference : public CuBlasLtDescriptor<
};
} // namespace
-
template <typename Dtype>
inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
cudaDataType_t abcType = CUDA_R_32F;
@@ -452,7 +457,7 @@ inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
" scaleType ",
scaleType);
}
-
+#endif
template <typename Dtype>
inline void bgemm_internal_cublas(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
@@ -608,10 +613,13 @@ void bgemm_internal<double>(CUDABLAS_BGEMM_ARGTYPES(double))
template <>
void bgemm_internal<float>(CUDABLAS_BGEMM_ARGTYPES(float))
{
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
bgemm_internal_cublaslt<float>(CUDABLAS_BGEMM_ARGS(float));
}
- else {
+ else
+#endif
+ {
bgemm_internal_cublas<float>(CUDABLAS_BGEMM_ARGS(float));
}
}
@@ -651,10 +659,13 @@ void bgemm_internal<c10::complex<float>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<fl
template <>
void bgemm_internal<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half))
{
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
bgemm_internal_cublaslt<at::Half>(CUDABLAS_BGEMM_ARGS(at::Half));
}
- else {
+ else
+#endif
+ {
bgemm_internal_cublas<at::Half>(CUDABLAS_BGEMM_ARGS(at::Half));
}
}
@@ -662,10 +673,13 @@ void bgemm_internal<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half))
template <>
void bgemm_internal<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16))
{
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
bgemm_internal_cublaslt<at::BFloat16>(CUDABLAS_BGEMM_ARGS(at::BFloat16));
}
- else {
+ else
+#endif
+ {
bgemm_internal_cublas<at::BFloat16>(CUDABLAS_BGEMM_ARGS(at::BFloat16));
}
}
@@ -781,11 +795,13 @@ void bgemm<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) {
}
}
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
template <typename Dtype>
inline void gemm_internal_cublaslt(CUDABLAS_GEMM_ARGTYPES(Dtype)) {
// forward to bgemm implementation but set strides and batches to 0
bgemm_internal_cublaslt(transa, transb, m, n, k, alpha, a, lda, 0, b, ldb, 0, beta, c, ldc, 0, 0);
}
+#endif
template <typename Dtype>
inline void gemm_internal_cublas(CUDABLAS_GEMM_ARGTYPES(Dtype)) {
@@ -1008,10 +1024,13 @@ void gemm_internal<double>(CUDABLAS_GEMM_ARGTYPES(double))
template <>
void gemm_internal<float>(CUDABLAS_GEMM_ARGTYPES(float))
{
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
gemm_internal_cublaslt<float>(CUDABLAS_GEMM_ARGS(float));
}
- else {
+ else
+#endif
+ {
gemm_internal_cublas<float>(CUDABLAS_GEMM_ARGS(float));
}
}
@@ -1051,10 +1070,13 @@ void gemm_internal<c10::complex<float>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<floa
template <>
void gemm_internal<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half))
{
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
gemm_internal_cublaslt<at::Half>(CUDABLAS_GEMM_ARGS(at::Half));
}
- else {
+ else
+#endif
+ {
gemm_internal_cublas<at::Half>(CUDABLAS_GEMM_ARGS(at::Half));
}
}
@@ -1062,10 +1084,13 @@ void gemm_internal<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half))
template <>
void gemm_internal<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16))
{
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
gemm_internal_cublaslt<at::BFloat16>(CUDABLAS_GEMM_ARGS(at::BFloat16));
}
- else {
+ else
+#endif
+ {
gemm_internal_cublas<at::BFloat16>(CUDABLAS_GEMM_ARGS(at::BFloat16));
}
}
@@ -1177,7 +1202,7 @@ void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
}
}
-
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
template <typename Dtype>
void gemm_and_bias(
bool transpose_mat1,
@@ -1410,7 +1435,7 @@ void scaled_gemm(
ScalarType result_dtype,
void* amax_ptr,
bool use_fast_accum) {
-#if CUDA_VERSION >= 11080 || defined(USE_ROCM)
+#if CUDA_VERSION >= 11080 || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
const auto computeType = CUBLAS_COMPUTE_32F;
const auto scaleType = CUDA_R_32F;
const int8_t fastAccuMode = use_fast_accum ? 1 : 0;
@@ -1681,6 +1706,7 @@ void int8_gemm(
" scaleType ",
scaleType);
}
+#endif
template <>
void trsm<float>(CUDABLAS_TRSM_ARGTYPES(float)) {
diff --git a/aten/src/ATen/cuda/CUDAContextLight.h b/aten/src/ATen/cuda/CUDAContextLight.h
index f2b657ced51b..f0ee613c4208 100644
--- a/aten/src/ATen/cuda/CUDAContextLight.h
+++ b/aten/src/ATen/cuda/CUDAContextLight.h
@@ -9,7 +9,9 @@
// cublasLT was introduced in CUDA 10.1 but we enable only for 11.1 that also
// added bf16 support
+#if (!defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT)))
#include <cublasLt.h>
+#endif
#ifdef CUDART_VERSION
#include <cusolverDn.h>
@@ -80,7 +82,9 @@ TORCH_CUDA_CPP_API c10::Allocator* getCUDADeviceAllocator();
/* Handles */
TORCH_CUDA_CPP_API cusparseHandle_t getCurrentCUDASparseHandle();
TORCH_CUDA_CPP_API cublasHandle_t getCurrentCUDABlasHandle();
+#if (!defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT)))
TORCH_CUDA_CPP_API cublasLtHandle_t getCurrentCUDABlasLtHandle();
+#endif
TORCH_CUDA_CPP_API void clearCublasWorkspaces();
diff --git a/aten/src/ATen/cuda/CublasHandlePool.cpp b/aten/src/ATen/cuda/CublasHandlePool.cpp
index 8eac525b3695..abfdf7a23847 100644
--- a/aten/src/ATen/cuda/CublasHandlePool.cpp
+++ b/aten/src/ATen/cuda/CublasHandlePool.cpp
@@ -29,7 +29,7 @@ namespace at::cuda {
namespace {
-#if defined(USE_ROCM)
+#if defined(USE_ROCM) && defined(USE_HIPBLASLT)
void createCublasLtHandle(cublasLtHandle_t *handle) {
TORCH_CUDABLAS_CHECK(cublasLtCreate(handle));
}
@@ -191,8 +191,9 @@ cublasHandle_t getCurrentCUDABlasHandle() {
return handle;
}
-cublasLtHandle_t getCurrentCUDABlasLtHandle() {
#ifdef USE_ROCM
+#if defined(USE_HIPBLASLT)
+cublasLtHandle_t getCurrentCUDABlasLtHandle() {
c10::DeviceIndex device = 0;
AT_CUDA_CHECK(c10::cuda::GetDevice(&device));
@@ -213,9 +214,12 @@ cublasLtHandle_t getCurrentCUDABlasLtHandle() {
auto handle = myPoolWindow->reserve(device);
return handle;
+}
+#endif
#else
+cublasLtHandle_t getCurrentCUDABlasLtHandle() {
return reinterpret_cast<cublasLtHandle_t>(getCurrentCUDABlasHandle());
-#endif
}
+#endif
} // namespace at::cuda
diff --git a/aten/src/ATen/cuda/tunable/TunableGemm.h b/aten/src/ATen/cuda/tunable/TunableGemm.h
index 53e6154120c9..fa1d664696db 100644
--- a/aten/src/ATen/cuda/tunable/TunableGemm.h
+++ b/aten/src/ATen/cuda/tunable/TunableGemm.h
@@ -11,7 +11,9 @@
#include <ATen/cuda/tunable/GemmCommon.h>
#ifdef USE_ROCM
+#ifdef USE_HIPBLASLT
#include <ATen/cuda/tunable/GemmHipblaslt.h>
+#endif
#include <ATen/cuda/tunable/GemmRocblas.h>
#endif
#include <ATen/cuda/tunable/StreamTimer.h>
@@ -65,6 +67,7 @@ class DefaultGemmStridedBatchedOp : public Callable<GemmStridedBatchedParams<T>>
}
};
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
template <typename T>
class DefaultScaledGemmOp : public Callable<ScaledGemmParams<T>> {
public:
@@ -94,6 +97,7 @@ class DefaultScaledGemmOp : public Callable<ScaledGemmParams<T>> {
return OK;
}
};
+#endif
template <typename T>
inline bool IsZero(T v) {
@@ -191,6 +195,7 @@ static void AddRocblasValidator() {
}
}
+#ifdef USE_HIPBLASLT
static void AddHipblasltValidator() {
auto validators = getTuningContext()->GetTuningResultsValidator().GetAllValidators();
if (validators.find("HIPBLASLT_VERSION") == validators.end()) {
@@ -205,6 +210,7 @@ static void AddHipblasltValidator() {
[hipblaslt_version](auto&& k) { return hipblaslt_version == k ? OK : FAIL; });
}
}
+#endif
static void AddRocmValidator() {
auto validators = getTuningContext()->GetTuningResultsValidator().GetAllValidators();
@@ -243,7 +249,7 @@ class GemmTunableOp : public TunableOp<GemmParams<T>, StreamTimer> {
}
AddRocblasValidator();
}
-
+#ifdef USE_HIPBLASLT
static const char *env_hipblaslt = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED");
if (env_hipblaslt == nullptr || strcmp(env_hipblaslt, "1") == 0) {
rocm_validators = true;
@@ -257,7 +263,7 @@ class GemmTunableOp : public TunableOp<GemmParams<T>, StreamTimer> {
}
AddHipblasltValidator();
}
-
+#endif
if (rocm_validators) {
AddRocmValidator();
}
@@ -286,7 +292,7 @@ class GemmStridedBatchedTunableOp : public TunableOp<GemmStridedBatchedParams<T>
}
AddRocblasValidator();
}
-
+#ifdef USE_HIPBLASLT
static const char *env_hipblaslt = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED");
if (env_hipblaslt == nullptr || strcmp(env_hipblaslt, "1") == 0) {
rocm_validators = true;
@@ -300,7 +306,7 @@ class GemmStridedBatchedTunableOp : public TunableOp<GemmStridedBatchedParams<T>
}
AddHipblasltValidator();
}
-
+#endif
if (rocm_validators) {
AddRocmValidator();
}
@@ -312,6 +318,7 @@ class GemmStridedBatchedTunableOp : public TunableOp<GemmStridedBatchedParams<T>
}
};
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
template <typename AT, typename BT, typename CT, BlasOp ALayout, BlasOp BLayout>
class ScaledGemmTunableOp : public TunableOp<ScaledGemmParams<CT>, StreamTimer> {
public:
@@ -321,10 +328,12 @@ class ScaledGemmTunableOp : public TunableOp<ScaledGemmParams<CT>, StreamTimer>
auto validators = getTuningContext()->GetTuningResultsValidator().GetAllValidators();
#if defined(USE_ROCM)
+#ifdef USE_HIPBLASLT
for (auto&& [name, op] : GetHipBlasLtScaledGemmTypeStringAndOps<AT, BT, CT, ALayout, BLayout>()) {
this->RegisterOp(std::move(name), std::move(op));
}
AddHipblasltValidator();
+#endif
AddRocmValidator();
#endif
}
@@ -337,6 +346,7 @@ class ScaledGemmTunableOp : public TunableOp<ScaledGemmParams<CT>, StreamTimer>
"_", BlasOpToString(ALayout), BlasOpToString(BLayout));
}
};
+#endif
#undef XSTRINGIFY
#undef STRINGIFY
diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp
index 84c59a4fd0d7..56ad5de3bf2d 100644
--- a/aten/src/ATen/native/cuda/Blas.cpp
+++ b/aten/src/ATen/native/cuda/Blas.cpp
@@ -173,6 +173,7 @@ cuda::blas::GEMMAndBiasActivationEpilogue activation_to_gemm_and_blas_arg(Activa
}
static bool getDisableAddmmCudaLt() {
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
static const char* env_value = std::getenv("DISABLE_ADDMM_CUDA_LT");
#ifdef USE_ROCM
// if we enable tunable op, it'll take priority over just hipblaslt (heuristics)
@@ -196,10 +197,14 @@ static bool getDisableAddmmCudaLt() {
}
return false;
#endif
+#else
+ return true;
+#endif
}
#ifdef USE_ROCM
static bool isSupportedHipLtROCmArch(int index) {
+#ifdef USE_HIPBLASLT
hipDeviceProp_t* prop = at::cuda::getDeviceProperties(index);
std::string device_arch = prop->gcnArchName;
static const std::vector<std::string> archs = {"gfx90a", "gfx940", "gfx941", "gfx942"};
@@ -210,6 +215,7 @@ static bool isSupportedHipLtROCmArch(int index) {
}
}
TORCH_CHECK(false, "Attempting to use hipBLASLt on a unsupported architecture!");
+#endif
return false;
}
#endif
@@ -235,6 +241,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
at::ScalarType scalar_type = self.scalar_type();
c10::MaybeOwned<Tensor> self_;
if (&result != &self) {
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
#if (defined(CUDA_VERSION) && (CUDA_VERSION >= 11040)) || defined(USE_ROCM)
// Strangely, if mat2 has only 1 row or column, we get
// CUBLAS_STATUS_INVALID_VALUE error from cublasLtMatmulAlgoGetHeuristic.
@@ -276,13 +283,14 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
scalar_type != at::ScalarType::BFloat16));
#endif
}
+#endif
#endif
if (!useLtInterface) {
self_ = expand_size(self, {mat1_sizes[0], mat2_sizes[1]}, "addmm");
}
self__sizes = self_->sizes();
} else {
-#if defined(USE_ROCM)
+#if defined(USE_ROCM) && defined(USE_HIPBLASLT)
useLtInterface = !disable_addmm_cuda_lt &&
result.dim() == 2 && result.is_contiguous() &&
isSupportedHipLtROCmArch(self.device().index()) &&
@@ -334,6 +342,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!args.result->is_conj());
if (useLtInterface) {
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
#if defined(USE_ROCM)
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
@@ -394,6 +403,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
activation_epilogue
);
});
+#endif
#endif
} else
{
@@ -803,6 +813,7 @@ Tensor _int_mm_cuda(const Tensor& self, const Tensor& mat2) {
}
static bool _scaled_mm_allowed_device() {
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
auto dprops = at::cuda::getCurrentDeviceProperties();
#ifdef USE_ROCM
std::string device_arch = dprops->gcnArchName;
@@ -817,6 +828,9 @@ static bool _scaled_mm_allowed_device() {
#else
return dprops->major >= 9 || (dprops->major == 8 && dprops->minor == 9);
#endif
+#else
+ return false;
+#endif
}
// Computes matrix multiply + bias while applying scaling to input and output matrices and computes amax
@@ -850,6 +864,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
// Check sizes
bool allowed_device = _scaled_mm_allowed_device();
TORCH_CHECK(allowed_device, "torch._scaled_mm is only supported on CUDA devices with compute capability >= 9.0 or 8.9, or ROCm MI300+");
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix");
TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix");
TORCH_CHECK(
@@ -1025,6 +1040,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
#if defined(USE_ROCM) && ROCM_VERSION >= 60000 && ROCM_VERSION < 60200
// ROCm's hipBLASLt does not support amax before 6.2, so calculate separately
amax = at::max(at::abs(out.to(kFloat)));
+#endif
#endif
return {out, amax};
diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake
index f1f2eb7cec31..8d05e834bbc5 100644
--- a/cmake/Dependencies.cmake
+++ b/cmake/Dependencies.cmake
@@ -1052,6 +1052,9 @@ if(USE_ROCM)
list(APPEND HIP_CXX_FLAGS -DTHRUST_DEVICE_SYSTEM=THRUST_DEVICE_SYSTEM_HIP)
list(APPEND HIP_CXX_FLAGS -std=c++17)
list(APPEND HIP_CXX_FLAGS -DHIPBLAS_V2)
+ if(hipblast_FOUND)
+ list(APPEND HIP_CXX_FLAGS -DUSE_HIPBLASLT)
+ endif()
if(HIP_NEW_TYPE_ENUMS)
list(APPEND HIP_CXX_FLAGS -DHIP_NEW_TYPE_ENUMS)
endif()
diff --git a/cmake/public/LoadHIP.cmake b/cmake/public/LoadHIP.cmake
index fa39156031ff..df4836847fdf 100644
--- a/cmake/public/LoadHIP.cmake
+++ b/cmake/public/LoadHIP.cmake
@@ -155,7 +155,7 @@ if(HIP_FOUND)
find_package_and_print_version(hiprand REQUIRED)
find_package_and_print_version(rocblas REQUIRED)
find_package_and_print_version(hipblas REQUIRED)
- find_package_and_print_version(hipblaslt REQUIRED)
+ find_package_and_print_version(hipblaslt)
find_package_and_print_version(miopen REQUIRED)
find_package_and_print_version(hipfft REQUIRED)
find_package_and_print_version(hipsparse REQUIRED)
--
2.45.2

View file

@ -1,94 +0,0 @@
From 038ce9e44776e23f21c1816daa259bc0ea335088 Mon Sep 17 00:00:00 2001
From: Tom Rix <trix@redhat.com>
Date: Sat, 29 Jun 2024 07:06:09 -0700
Subject: [PATCH] disable use of aotriton
---
.../ATen/native/transformers/cuda/sdp_utils.cpp | 17 +++++++++++++++--
1 file changed, 15 insertions(+), 2 deletions(-)
diff --git a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp
index 214b02d8262e..7b3eb9dcd8cd 100644
--- a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp
+++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp
@@ -19,9 +19,12 @@
#include <c10/core/SymInt.h>
#include <c10/util/string_view.h>
+#ifdef USE_FLASH_ATTENTION
#if USE_ROCM
#include <aotriton/flash.h>
#endif
+#endif
+
/**
* Note [SDPA Runtime Dispatch]
@@ -182,6 +185,9 @@ bool check_sm_version(cudaDeviceProp * dprops) {
bool check_flash_attention_hardware_support(sdp_params const& params, bool debug) {
// Check that the gpu is capable of running flash attention
+#ifndef USE_FLASH_ATTENTION
+ return false;
+#else
using sm80 = SMVersion<8, 0>;
using sm90 = SMVersion<9, 0>;
#if USE_ROCM
@@ -209,9 +215,13 @@ bool check_flash_attention_hardware_support(sdp_params const& params, bool debug
}
#endif
return true;
+#endif
}
bool check_mem_efficient_hardware_support(sdp_params const& params, bool debug) {
+#ifndef USE_FLASH_ATTENTION
+ return false;
+#else
// Mem Efficient attention supports hardware in the range [sm_50, sm_90]
using sm50 = SMVersion<5, 0>;
using sm90 = SMVersion<9, 0>;
@@ -240,6 +250,7 @@ bool check_mem_efficient_hardware_support(sdp_params const& params, bool debug)
}
#endif
return true;
+#endif
}
bool check_requires_grad_and_head_dim_gt192_constraints_on_sm86_89(
@@ -554,7 +565,7 @@ bool can_use_flash_attention(sdp_params const& params, bool debug) {
#ifndef USE_FLASH_ATTENTION
TORCH_WARN_ONCE(!debug, "Torch was not compiled with flash attention.");
return false;
-#endif
+#else
// Define gate functions that determine if a flash kernel can be ran
// Replace with std::to_array when we migrate to c++20
@@ -597,13 +608,14 @@ bool can_use_flash_attention(sdp_params const& params, bool debug) {
}
}
return true;
+#endif
}
bool can_use_mem_efficient_attention(sdp_params const& params, bool debug) {
#ifndef USE_MEM_EFF_ATTENTION
TORCH_WARN_ONCE(!debug, "Torch was not compiled with memory efficient attention.");
return false;
-#endif
+#else
// Constraints specific to mem efficient attention
constexpr auto greater_than_or_equal_sm80_mem_efficient_dtypes =
array_of<at::ScalarType>(at::kHalf, at::kFloat, at::kBFloat16);
@@ -663,6 +675,7 @@ bool can_use_mem_efficient_attention(sdp_params const& params, bool debug) {
}
#endif
return check_tensor_dtype(params, less_than_sm80_mem_efficient_dtypes, debug);
+#endif
}
SDPBackend select_sdp_backend(sdp_params const& kernel_params) {
--
2.45.2