Update gitcommit to 2.8.0-rc8

Patch problem with 3.14
Start converting over py3 macros
Handle new dependency on rocmsmi

Signed-off-by: Tom Rix <Tom.Rix@amd.com>
This commit is contained in:
Tom Rix 2025-07-24 06:07:03 -07:00
commit 61ccf033a8
4 changed files with 560 additions and 29 deletions

View file

@ -1,17 +1,17 @@
From 4cc5d88dfe7a45ab245648dc874645d32a24b98b Mon Sep 17 00:00:00 2001
From 193854993cd939de186de19589c1add4c4b2cf66 Mon Sep 17 00:00:00 2001
From: Tom Rix <Tom.Rix@amd.com>
Date: Fri, 27 Jun 2025 13:52:51 -0700
Date: Mon, 21 Jul 2025 11:35:03 -0700
Subject: [PATCH] Add cmake variable USE_ROCM_CK
---
CMakeLists.txt | 1 +
aten/src/ATen/CMakeLists.txt | 40 ++++++++++++++++-----------------
aten/src/ATen/cuda/CUDABlas.cpp | 10 ++++-----
aten/src/ATen/cuda/CUDABlas.cpp | 22 +++++++++---------
cmake/Dependencies.cmake | 3 +++
4 files changed, 29 insertions(+), 25 deletions(-)
4 files changed, 35 insertions(+), 31 deletions(-)
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 99c0b9e0ea0c..4c632e42f531 100644
index a5d25e6afa0f..afc1b53efa64 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -240,6 +240,7 @@ cmake_dependent_option(
@ -82,7 +82,7 @@ index c9cfd74b501e..59f6178218ee 100644
file(GLOB native_hip_ck "native/hip/ck*.hip")
exclude(ATen_HIP_SRCS "${ATen_HIP_SRCS}"
diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp
index 89350a11bea7..33e5f2808057 100644
index 89350a11bea7..e5b7960177cf 100644
--- a/aten/src/ATen/cuda/CUDABlas.cpp
+++ b/aten/src/ATen/cuda/CUDABlas.cpp
@@ -752,7 +752,7 @@ template <>
@ -94,16 +94,16 @@ index 89350a11bea7..33e5f2808057 100644
// hipblaslt does not support double gemm yet
bgemm_internal_cublas<double>(CUDABLAS_BGEMM_ARGS(double));
#else
@@ -1103,7 +1103,7 @@ inline void gemm_internal_cublas_half_helper(CUDABLAS_GEMM_ARGTYPES_AND_C_DTYPE(
void * beta_ptr = &fbeta;
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
GEMM_CHECK_ARGVALUES(at::Half);
-#ifdef USE_ROCM
+#ifdef USE_ROCM_CK
int flag = 0;
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
flag = at::ROCmBackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
@@ -1270,7 +1270,7 @@ template <>
@@ -836,7 +836,7 @@ void bgemm_internal<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16))
bgemm_internal_cublas<at::BFloat16>(CUDABLAS_BGEMM_ARGS(at::BFloat16));
}
}
-#if defined(USE_ROCM) && !defined(_MSC_VER)
+#if defined(USE_ROCM) && defined(USE_ROCM_CK)
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
at::native::bgemm_internal_ck<at::BFloat16>(CUDABLAS_BGEMM_ARGS(at::BFloat16));
}
@@ -1270,14 +1270,14 @@ template <>
void gemm_internal<double>(CUDABLAS_GEMM_ARGTYPES(double))
{
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
@ -112,6 +112,23 @@ index 89350a11bea7..33e5f2808057 100644
// hipblaslt does not support double gemm yet
gemm_internal_cublas<double>(CUDABLAS_GEMM_ARGS(double));
#else
gemm_internal_cublaslt<double>(CUDABLAS_GEMM_ARGS(double));
#endif
}
-#if defined(USE_ROCM) && !defined(_MSC_VER)
+#if defined(USE_ROCM) && defined(USE_ROCM_CK)
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
at::native::gemm_internal_ck<double>(CUDABLAS_GEMM_ARGS(double));
}
@@ -1293,7 +1293,7 @@ void gemm_internal<float>(CUDABLAS_GEMM_ARGTYPES(float))
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
gemm_internal_cublaslt<float>(CUDABLAS_GEMM_ARGS(float));
}
-#if defined(USE_ROCM) && !defined(_MSC_VER)
+#if defined(USE_ROCM) && defined(USE_ROCM_CK)
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
if (at::detail::getCUDAHooks().isGPUArch({"gfx1100"})) { //no CK GEMM version for gfx1100
gemm_internal_cublaslt<float>(CUDABLAS_GEMM_ARGS(float));
@@ -1311,7 +1311,7 @@ template <>
void gemm_internal<c10::complex<double>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<double>))
{
@ -130,6 +147,42 @@ index 89350a11bea7..33e5f2808057 100644
// hipblaslt does not support complex gemm yet
gemm_internal_cublas<c10::complex<float>>(CUDABLAS_GEMM_ARGS(c10::complex<float>));
#else
@@ -1345,7 +1345,7 @@ void gemm_internal<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half))
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
gemm_internal_cublaslt<at::Half>(CUDABLAS_GEMM_ARGS(at::Half));
}
-#if defined(USE_ROCM) && !defined(_MSC_VER)
+#if defined(USE_ROCM) && defined(USE_ROCM_CK)
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
at::native::gemm_internal_ck<at::Half>(CUDABLAS_GEMM_ARGS(at::Half));
}
@@ -1361,7 +1361,7 @@ void gemm_internal<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16))
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
gemm_internal_cublaslt<at::BFloat16>(CUDABLAS_GEMM_ARGS(at::BFloat16));
}
-#if defined(USE_ROCM) && !defined(_MSC_VER)
+#if defined(USE_ROCM) && defined(USE_ROCM_CK)
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
at::native::gemm_internal_ck<at::BFloat16>(CUDABLAS_GEMM_ARGS(at::BFloat16));
}
@@ -1382,7 +1382,7 @@ void gemm_internal<at::Half, float>(CUDABLAS_GEMM_ARGTYPES_AND_C_DTYPE(at::Half,
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
gemm_internal_cublaslt<at::Half, float>(CUDABLAS_GEMM_ARGS(at::Half));
}
-#if defined(USE_ROCM) && !defined(_MSC_VER)
+#if defined(USE_ROCM) && defined(USE_ROCM_CK)
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
TORCH_CHECK(false, "gemm input type at::Half and output type float is not supported for ROCm");
}
@@ -1398,7 +1398,7 @@ void gemm_internal<at::BFloat16, float>(CUDABLAS_GEMM_ARGTYPES_AND_C_DTYPE(at::B
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
gemm_internal_cublaslt<at::BFloat16, float>(CUDABLAS_GEMM_ARGS(at::BFloat16));
}
-#if defined(USE_ROCM) && !defined(_MSC_VER)
+#if defined(USE_ROCM) && defined(USE_ROCM_CK)
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
TORCH_CHECK(false, "gemm input type at::Half and output type float is not supported for ROCm");
}
diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake
index a93386c27f8d..be1368999d38 100644
--- a/cmake/Dependencies.cmake