Update the next gitcommit to v2.8.0-rc6
Remove old patches. Signed-off-by: Tom Rix <Tom.Rix@amd.com>
This commit is contained in:
parent
27593d78b3
commit
42c33b8dcd
27 changed files with 154 additions and 4539 deletions
|
|
@ -1,47 +0,0 @@
|
|||
From 091b7fe1ccbb5e4ff4ac6017d42bacb869f61a27 Mon Sep 17 00:00:00 2001
|
||||
From: Tom Rix <trix@redhat.com>
|
||||
Date: Sat, 20 Jul 2024 05:37:15 -0600
|
||||
Subject: [PATCH] Add cmake option USE_SYSTEM_FBGEMM
|
||||
|
||||
Signed-off-by: Tom Rix <trix@redhat.com>
|
||||
---
|
||||
CMakeLists.txt | 1 +
|
||||
cmake/Dependencies.cmake | 3 ++-
|
||||
2 files changed, 3 insertions(+), 1 deletion(-)
|
||||
|
||||
diff --git a/CMakeLists.txt b/CMakeLists.txt
|
||||
index c4cd4b2c2a98..2068f7c6c4f2 100644
|
||||
--- a/CMakeLists.txt
|
||||
+++ b/CMakeLists.txt
|
||||
@@ -253,6 +253,7 @@ cmake_dependent_option(USE_STATIC_CUDNN "Use cuDNN static libraries" OFF
|
||||
"USE_CUDNN" OFF)
|
||||
cmake_dependent_option(USE_CUSPARSELT "Use cuSPARSELt" ON "USE_CUDA" OFF)
|
||||
option(USE_FBGEMM "Use FBGEMM (quantized 8-bit server operators)" ON)
|
||||
+option(USE_SYSTEM_FBGEMM "Use system-wide FBGEMM" OFF)
|
||||
option(USE_KINETO "Use Kineto profiling library" ON)
|
||||
option(USE_CUPTI_SO "Use CUPTI as a shared library" ON)
|
||||
option(USE_FAKELOWP "Use FakeLowp operators" OFF)
|
||||
diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake
|
||||
index f1f2eb7cec31..192dac46f13b 100644
|
||||
--- a/cmake/Dependencies.cmake
|
||||
+++ b/cmake/Dependencies.cmake
|
||||
@@ -706,6 +706,7 @@ endif()
|
||||
|
||||
# ---[ FBGEMM
|
||||
if(USE_FBGEMM)
|
||||
+ if (NOT USE_SYSTEM_FBGEMM)
|
||||
set(CAFFE2_THIRD_PARTY_ROOT "${PROJECT_SOURCE_DIR}/third_party")
|
||||
if(NOT DEFINED FBGEMM_SOURCE_DIR)
|
||||
set(FBGEMM_SOURCE_DIR "${CAFFE2_THIRD_PARTY_ROOT}/fbgemm" CACHE STRING "FBGEMM source directory")
|
||||
@@ -746,7 +747,7 @@ if(USE_FBGEMM)
|
||||
target_compile_options_if_supported(asmjit -Wno-unused-but-set-variable)
|
||||
endif()
|
||||
endif()
|
||||
-
|
||||
+ endif()
|
||||
if(USE_FBGEMM)
|
||||
list(APPEND Caffe2_DEPENDENCY_LIBS fbgemm)
|
||||
endif()
|
||||
--
|
||||
2.45.1
|
||||
|
||||
|
|
@ -1,222 +0,0 @@
|
|||
From 655a06444b261cb28e71a0973c0ab67aaa8261ab Mon Sep 17 00:00:00 2001
|
||||
From: albanD <desmaison.alban@gmail.com>
|
||||
Date: Tue, 14 May 2024 02:14:53 +0000
|
||||
Subject: [PATCH] Changes to compile with 3.13 (#126033)
|
||||
|
||||
This is mainly:
|
||||
- Fix refcount access macro
|
||||
- Hide all the Dynamo code that needs update as usual
|
||||
- Add _PyWeakref_ClearRef as an extern provided by CPython. Including the pycore header that defines it would require raw c include shenanigans that I don't think are worth it.
|
||||
This allows to build both with regular and nogil version of cpython. Both
|
||||
|
||||
Note that this requires the 3.13 branch at least past [d3094744d40de2deefbda9b1996d5029c9ebf0b0](https://github.com/python/cpython/commit/d3094744d40de2deefbda9b1996d5029c9ebf0b0) which we need for mimalloc include and weakref function being exposed.
|
||||
|
||||
debug-only issues in pybind11 with PyMem_MALLOC vs PyObject_MALLOC being should be synced either by updating pybind or cpython. @colesbury I can send a PR to ifdef the proper use in pybind if you think that this is the best solution here?
|
||||
|
||||
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126033
|
||||
Approved by: https://github.com/colesbury
|
||||
---
|
||||
torch/csrc/Storage.cpp | 2 +-
|
||||
torch/csrc/autograd/python_variable.cpp | 2 +-
|
||||
torch/csrc/dynamo/cpython_defs.c | 15 +++++-
|
||||
torch/csrc/dynamo/cpython_defs.h | 2 +
|
||||
torch/csrc/dynamo/eval_frame.c | 67 ++++++++++++++++++-------
|
||||
torch/csrc/utils/python_compat.h | 4 ++
|
||||
6 files changed, 70 insertions(+), 22 deletions(-)
|
||||
|
||||
diff --git a/torch/csrc/Storage.cpp b/torch/csrc/Storage.cpp
|
||||
index 93dbc9c09bb2..b22bbac35981 100644
|
||||
--- a/torch/csrc/Storage.cpp
|
||||
+++ b/torch/csrc/Storage.cpp
|
||||
@@ -236,7 +236,7 @@ static void THPStorage_subclass_dealloc(PyObject* self) {
|
||||
if (type->tp_del) {
|
||||
PyObject_GC_Track(self);
|
||||
type->tp_del(self);
|
||||
- if (self->ob_refcnt > 0) {
|
||||
+ if (Py_REFCNT(self) > 0) {
|
||||
// Resurrected (see above comment about resurrection from `__del__`)
|
||||
return;
|
||||
}
|
||||
diff --git a/torch/csrc/autograd/python_variable.cpp b/torch/csrc/autograd/python_variable.cpp
|
||||
index 9e85f0026b35..8fd1129da63c 100644
|
||||
--- a/torch/csrc/autograd/python_variable.cpp
|
||||
+++ b/torch/csrc/autograd/python_variable.cpp
|
||||
@@ -1910,7 +1910,7 @@ void THPVariable_subclass_dealloc(PyObject* self) {
|
||||
if (type->tp_del) {
|
||||
PyObject_GC_Track(self);
|
||||
type->tp_del(self);
|
||||
- if (self->ob_refcnt > 0) {
|
||||
+ if (Py_REFCNT(self) > 0) {
|
||||
/* Resurrected */
|
||||
return;
|
||||
}
|
||||
diff --git a/torch/csrc/dynamo/cpython_defs.c b/torch/csrc/dynamo/cpython_defs.c
|
||||
index 4a1dba63009a..5e0945a052ae 100644
|
||||
--- a/torch/csrc/dynamo/cpython_defs.c
|
||||
+++ b/torch/csrc/dynamo/cpython_defs.c
|
||||
@@ -13,6 +13,17 @@
|
||||
} else { \
|
||||
}
|
||||
|
||||
+#if IS_PYTHON_3_13_PLUS
|
||||
+// Gave up after fixing a few of these
|
||||
+// pycore_opcode.h is gone (new is pycore_opcode_metadata.h ?)
|
||||
+// f_code is gone (new is f_executable?)
|
||||
+
|
||||
+// Fake definitions for what we removed
|
||||
+const uint8_t* THP_PyOpcode_Caches = NULL;
|
||||
+const int THP_PyOpcode_Caches_size = 0;
|
||||
+
|
||||
+#else
|
||||
+
|
||||
// NOTE: all `assert`s below are converted to `CHECK`s
|
||||
|
||||
#if IS_PYTHON_3_11_PLUS
|
||||
@@ -29,8 +40,8 @@
|
||||
#define NEED_OPCODE_TABLES // To get _PyOpcode_Deopt
|
||||
#include <internal/pycore_opcode.h>
|
||||
#undef NEED_OPCODE_TABLES
|
||||
-#undef Py_BUILD_CORE
|
||||
#include <internal/pycore_frame.h>
|
||||
+#undef Py_BUILD_CORE
|
||||
|
||||
// As a simple way to reduce the impact of ABI changes on the CPython side, this check forces
|
||||
// us to manually re-check that the function didn't change on the next major version
|
||||
@@ -364,3 +375,5 @@ THP_PyFrame_Clear(_PyInterpreterFrame *frame)
|
||||
}
|
||||
|
||||
#endif
|
||||
+
|
||||
+#endif // CPython 3.13
|
||||
\ No newline at end of file
|
||||
diff --git a/torch/csrc/dynamo/cpython_defs.h b/torch/csrc/dynamo/cpython_defs.h
|
||||
index a897c3e6c6e7..3b6c9667f8c9 100644
|
||||
--- a/torch/csrc/dynamo/cpython_defs.h
|
||||
+++ b/torch/csrc/dynamo/cpython_defs.h
|
||||
@@ -8,7 +8,9 @@
|
||||
|
||||
#if IS_PYTHON_3_11_PLUS
|
||||
|
||||
+#define Py_BUILD_CORE
|
||||
#include <internal/pycore_frame.h>
|
||||
+#undef Py_BUILD_CORE
|
||||
|
||||
int THP_PyFrame_FastToLocalsWithError(
|
||||
_PyInterpreterFrame* frame,
|
||||
diff --git a/torch/csrc/dynamo/eval_frame.c b/torch/csrc/dynamo/eval_frame.c
|
||||
index c286e821f09d..e13cb5af2a0e 100644
|
||||
--- a/torch/csrc/dynamo/eval_frame.c
|
||||
+++ b/torch/csrc/dynamo/eval_frame.c
|
||||
@@ -8,6 +8,31 @@
|
||||
#include <opcode.h>
|
||||
#include <stdbool.h>
|
||||
|
||||
+
|
||||
+
|
||||
+PyObject* guard_error_hook = NULL;
|
||||
+const char* cache_lookup_profiler_str = "TorchDynamo Cache Lookup";
|
||||
+
|
||||
+static int active_dynamo_threads = 0;
|
||||
+
|
||||
+static Py_tss_t eval_frame_callback_key = Py_tss_NEEDS_INIT;
|
||||
+
|
||||
+inline static PyObject* eval_frame_callback_get(void) {
|
||||
+ void* result = PyThread_tss_get(&eval_frame_callback_key);
|
||||
+ if (unlikely(result == NULL)) {
|
||||
+ return (PyObject*)Py_None;
|
||||
+ } else {
|
||||
+ return (PyObject*)result;
|
||||
+ }
|
||||
+}
|
||||
+
|
||||
+inline static void eval_frame_callback_set(PyObject* obj) {
|
||||
+ PyThread_tss_set(&eval_frame_callback_key, obj);
|
||||
+}
|
||||
+
|
||||
+// 3.13 Not supported at all. See cpython_defs.c for hints
|
||||
+#if !(IS_PYTHON_3_13_PLUS)
|
||||
+
|
||||
// Problem in CPython includes when mixing core and non-core build
|
||||
// The fix was not backported to 3.12 so this is needed here
|
||||
// https://github.com/python/cpython/issues/105268
|
||||
@@ -138,24 +163,6 @@ THP_PyFrame_FastToLocalsWithError(THP_EVAL_API_FRAME_OBJECT *frame, int *free_va
|
||||
}
|
||||
#endif
|
||||
|
||||
-PyObject* guard_error_hook = NULL;
|
||||
-const char* cache_lookup_profiler_str = "TorchDynamo Cache Lookup";
|
||||
-
|
||||
-static Py_tss_t eval_frame_callback_key = Py_tss_NEEDS_INIT;
|
||||
-
|
||||
-inline static PyObject* eval_frame_callback_get(void) {
|
||||
- void* result = PyThread_tss_get(&eval_frame_callback_key);
|
||||
- if (unlikely(result == NULL)) {
|
||||
- return (PyObject*)Py_None;
|
||||
- } else {
|
||||
- return (PyObject*)result;
|
||||
- }
|
||||
-}
|
||||
-
|
||||
-inline static void eval_frame_callback_set(PyObject* obj) {
|
||||
- PyThread_tss_set(&eval_frame_callback_key, obj);
|
||||
-}
|
||||
-
|
||||
static PyObject* _custom_eval_frame_shim(
|
||||
PyThreadState* tstate,
|
||||
THP_EVAL_API_FRAME_OBJECT* frame,
|
||||
@@ -627,7 +634,29 @@ static PyObject* _custom_eval_frame(
|
||||
}
|
||||
}
|
||||
|
||||
-static int active_dynamo_threads = 0;
|
||||
+#else // IS_PYTHON_3_13_PLUS
|
||||
+
|
||||
+// Fake definitions for everything we removed
|
||||
+
|
||||
+typedef struct THPPyInterpreterFrame {
|
||||
+ PyObject_HEAD
|
||||
+ _PyInterpreterFrame* frame; // Borrowed reference
|
||||
+} THPPyInterpreterFrame;
|
||||
+
|
||||
+inline static void enable_eval_frame_shim(PyThreadState* tstate) {}
|
||||
+inline static void enable_eval_frame_default(PyThreadState* tstate) {}
|
||||
+
|
||||
+static struct PyGetSetDef THPPyInterpreterFrame_properties[] = {NULL};
|
||||
+
|
||||
+static PyTypeObject THPPyInterpreterFrameType = {
|
||||
+ PyVarObject_HEAD_INIT(NULL, 0)
|
||||
+ .tp_name = "torch._C.dynamo.eval_frame._PyInterpreterFrame",
|
||||
+ .tp_basicsize = sizeof(THPPyInterpreterFrame),
|
||||
+ .tp_flags = Py_TPFLAGS_DEFAULT,
|
||||
+ .tp_getset = THPPyInterpreterFrame_properties,
|
||||
+};
|
||||
+
|
||||
+#endif // CPython 3.13
|
||||
|
||||
static PyObject* increment_working_threads(PyThreadState* tstate) {
|
||||
active_dynamo_threads = active_dynamo_threads + 1;
|
||||
diff --git a/torch/csrc/utils/python_compat.h b/torch/csrc/utils/python_compat.h
|
||||
index 73b991cf3fbf..b060db00db73 100644
|
||||
--- a/torch/csrc/utils/python_compat.h
|
||||
+++ b/torch/csrc/utils/python_compat.h
|
||||
@@ -11,6 +11,7 @@ extern "C" {
|
||||
|
||||
#define IS_PYTHON_3_11_PLUS PY_VERSION_HEX >= 0x030B00C1
|
||||
#define IS_PYTHON_3_12_PLUS PY_VERSION_HEX >= 0x030C0000
|
||||
+#define IS_PYTHON_3_13_PLUS PY_VERSION_HEX >= 0x030D0000
|
||||
|
||||
PYCAPI_COMPAT_STATIC_INLINE(int)
|
||||
PyCode_GetNCellvars(PyCodeObject* code) {
|
||||
@@ -32,6 +33,9 @@ PyCode_GetNFreevars(PyCodeObject* code) {
|
||||
#endif
|
||||
}
|
||||
|
||||
+// Provided by CPython but getting the header for them is very hard
|
||||
+extern void _PyWeakref_ClearRef(PyWeakReference* self);
|
||||
+
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
--
|
||||
2.45.1
|
||||
|
||||
|
|
@ -1,910 +0,0 @@
|
|||
From 3d1e4b3e5ddcdd2717e590c635097163fef64c83 Mon Sep 17 00:00:00 2001
|
||||
From: Xu Han <xu.han@intel.com>
|
||||
Date: Sun, 31 Mar 2024 03:07:32 +0000
|
||||
Subject: [PATCH] Enable x86 CPU vectorization on windows [submodule sleef]
|
||||
(#118980)
|
||||
|
||||
Enable VEC on Windows OS.
|
||||
1. Fix some type defination gap between Windows and Linux.
|
||||
2. Fix some operator not support on Windows, such as [], /.
|
||||
3. Enable static sleef library build on Windows.
|
||||
4. Disable unsupported function overloading on MSVC.
|
||||
5. Upgrade submodule sleef lib, which fixed build issue on Windows.
|
||||
6. Fixed bazel build issues.
|
||||
7. Fix test app not link to sleef on Windows.
|
||||
|
||||
Note: If rebuild fail after pulled this PR, please sync `sleef` submodule by run:
|
||||
```cmd
|
||||
git submodule sync
|
||||
git submodule update --init --recursive
|
||||
```
|
||||
|
||||
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118980
|
||||
Approved by: https://github.com/jgong5, https://github.com/ezyang, https://github.com/malfet
|
||||
---
|
||||
aten/src/ATen/CMakeLists.txt | 48 ++++++--------
|
||||
aten/src/ATen/cpu/vec/vec256/vec256.h | 14 ++--
|
||||
.../src/ATen/cpu/vec/vec256/vec256_bfloat16.h | 27 ++++++--
|
||||
.../cpu/vec/vec256/vec256_complex_double.h | 7 +-
|
||||
.../cpu/vec/vec256/vec256_complex_float.h | 7 +-
|
||||
aten/src/ATen/cpu/vec/vec256/vec256_double.h | 5 +-
|
||||
aten/src/ATen/cpu/vec/vec256/vec256_float.h | 15 +++--
|
||||
aten/src/ATen/cpu/vec/vec256/vec256_qint.h | 12 +++-
|
||||
aten/src/ATen/cpu/vec/vec512/vec512.h | 14 ++--
|
||||
.../src/ATen/cpu/vec/vec512/vec512_bfloat16.h | 27 ++++++--
|
||||
.../cpu/vec/vec512/vec512_complex_double.h | 7 +-
|
||||
.../cpu/vec/vec512/vec512_complex_float.h | 7 +-
|
||||
aten/src/ATen/cpu/vec/vec512/vec512_double.h | 5 +-
|
||||
aten/src/ATen/cpu/vec/vec512/vec512_float.h | 15 +++--
|
||||
aten/src/ATen/cpu/vec/vec512/vec512_qint.h | 66 ++++++++++++++++++-
|
||||
aten/src/ATen/cpu/vec/vec_base.h | 6 ++
|
||||
caffe2/CMakeLists.txt | 2 +-
|
||||
third_party/sleef.BUILD | 3 +-
|
||||
18 files changed, 194 insertions(+), 93 deletions(-)
|
||||
|
||||
diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt
|
||||
index bf425af5fa9..58d5828e8ca 100644
|
||||
--- a/aten/src/ATen/CMakeLists.txt
|
||||
+++ b/aten/src/ATen/CMakeLists.txt
|
||||
@@ -419,32 +419,25 @@ if(NOT CMAKE_SYSTEM_PROCESSOR MATCHES "^(s390x|ppc64le)$")
|
||||
list(APPEND ATen_CPU_DEPENDENCY_LIBS cpuinfo)
|
||||
endif()
|
||||
|
||||
-if(NOT MSVC AND NOT EMSCRIPTEN AND NOT INTERN_BUILD_MOBILE)
|
||||
- # Preserve values for the main build
|
||||
- set(__aten_sleef_build_shared_libs ${BUILD_SHARED_LIBS})
|
||||
- set(__aten_sleef_build_tests ${BUILD_TESTS})
|
||||
-
|
||||
- # Unset our restrictive C++ flags here and reset them later.
|
||||
- # Remove this once we use proper target_compile_options.
|
||||
- set(OLD_CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS})
|
||||
- set(CMAKE_CXX_FLAGS)
|
||||
-
|
||||
- # Bump up optimization level for sleef to -O1, since at -O0 the compiler
|
||||
- # excessively spills intermediate vector registers to the stack
|
||||
- # and makes things run impossibly slowly
|
||||
- set(OLD_CMAKE_C_FLAGS_DEBUG ${CMAKE_C_FLAGS_DEBUG})
|
||||
- if(${CMAKE_C_FLAGS_DEBUG} MATCHES "-O0")
|
||||
- string(REGEX REPLACE "-O0" "-O1" CMAKE_C_FLAGS_DEBUG ${OLD_CMAKE_C_FLAGS_DEBUG})
|
||||
- else()
|
||||
- set(CMAKE_C_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG} -O1")
|
||||
+if(NOT EMSCRIPTEN AND NOT INTERN_BUILD_MOBILE)
|
||||
+ if(NOT MSVC)
|
||||
+ # Bump up optimization level for sleef to -O1, since at -O0 the compiler
|
||||
+ # excessively spills intermediate vector registers to the stack
|
||||
+ # and makes things run impossibly slowly
|
||||
+ set(OLD_CMAKE_C_FLAGS_DEBUG ${CMAKE_C_FLAGS_DEBUG})
|
||||
+ if(${CMAKE_C_FLAGS_DEBUG} MATCHES "-O0")
|
||||
+ string(REGEX REPLACE "-O0" "-O1" CMAKE_C_FLAGS_DEBUG ${OLD_CMAKE_C_FLAGS_DEBUG})
|
||||
+ else()
|
||||
+ set(CMAKE_C_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG} -O1")
|
||||
+ endif()
|
||||
endif()
|
||||
|
||||
if(NOT USE_SYSTEM_SLEEF)
|
||||
- set(BUILD_SHARED_LIBS OFF CACHE BOOL "Build sleef static" FORCE)
|
||||
- set(BUILD_DFT OFF CACHE BOOL "Don't build sleef DFT lib" FORCE)
|
||||
- set(BUILD_GNUABI_LIBS OFF CACHE BOOL "Don't build sleef gnuabi libs" FORCE)
|
||||
- set(BUILD_TESTS OFF CACHE BOOL "Don't build sleef tests" FORCE)
|
||||
- set(OLD_CMAKE_BUILD_TYPE ${CMAKE_BUILD_TYPE})
|
||||
+ set(SLEEF_BUILD_SHARED_LIBS OFF CACHE BOOL "Build sleef static" FORCE)
|
||||
+ set(SLEEF_BUILD_DFT OFF CACHE BOOL "Don't build sleef DFT lib" FORCE)
|
||||
+ set(SLEEF_BUILD_GNUABI_LIBS OFF CACHE BOOL "Don't build sleef gnuabi libs" FORCE)
|
||||
+ set(SLEEF_BUILD_TESTS OFF CACHE BOOL "Don't build sleef tests" FORCE)
|
||||
+ set(SLEEF_BUILD_SCALAR_LIB OFF CACHE BOOL "libsleefscalar will be built." FORCE)
|
||||
if(CMAKE_SYSTEM_NAME STREQUAL "Darwin")
|
||||
if(CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64" OR CMAKE_OSX_ARCHITECTURES MATCHES "arm64")
|
||||
set(DISABLE_SVE ON CACHE BOOL "Xcode's clang-12.5 crashes while trying to compile SVE code" FORCE)
|
||||
@@ -465,12 +458,9 @@ if(NOT MSVC AND NOT EMSCRIPTEN AND NOT INTERN_BUILD_MOBILE)
|
||||
endif()
|
||||
list(APPEND ATen_CPU_DEPENDENCY_LIBS sleef)
|
||||
|
||||
- set(CMAKE_C_FLAGS_DEBUG ${OLD_CMAKE_C_FLAGS_DEBUG})
|
||||
- set(CMAKE_CXX_FLAGS ${OLD_CMAKE_CXX_FLAGS})
|
||||
-
|
||||
- # Set these back. TODO: Use SLEEF_ to pass these instead
|
||||
- set(BUILD_SHARED_LIBS ${__aten_sleef_build_shared_libs} CACHE BOOL "Build shared libs" FORCE)
|
||||
- set(BUILD_TESTS ${__aten_sleef_build_tests} CACHE BOOL "Build tests" FORCE)
|
||||
+ if(NOT MSVC)
|
||||
+ set(CMAKE_C_FLAGS_DEBUG ${OLD_CMAKE_C_FLAGS_DEBUG})
|
||||
+ endif()
|
||||
endif()
|
||||
|
||||
if(USE_CUDA AND NOT USE_ROCM)
|
||||
diff --git a/aten/src/ATen/cpu/vec/vec256/vec256.h b/aten/src/ATen/cpu/vec/vec256/vec256.h
|
||||
index 800b027e469..c431fa3c605 100644
|
||||
--- a/aten/src/ATen/cpu/vec/vec256/vec256.h
|
||||
+++ b/aten/src/ATen/cpu/vec/vec256/vec256.h
|
||||
@@ -69,7 +69,7 @@ std::ostream& operator<<(std::ostream& stream, const Vectorized<T>& vec) {
|
||||
}
|
||||
|
||||
|
||||
-#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
|
||||
+#if defined(CPU_CAPABILITY_AVX2)
|
||||
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CAST (AVX2) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
@@ -94,7 +94,8 @@ inline Vectorized<double> cast<double, int64_t>(const Vectorized<int64_t>& src)
|
||||
}
|
||||
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
-
|
||||
+#ifndef _MSC_VER
|
||||
+// MSVC is not working well on complex function overload.
|
||||
template<int64_t scale = 1>
|
||||
std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<double>>
|
||||
inline gather(const double* base_addr, const Vectorized<int64_t>& vindex) {
|
||||
@@ -106,9 +107,10 @@ std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorize
|
||||
inline gather(const float* base_addr, const Vectorized<int32_t>& vindex) {
|
||||
return _mm256_i32gather_ps(base_addr, vindex, scale);
|
||||
}
|
||||
-
|
||||
+#endif
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MASK GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
-
|
||||
+#ifndef _MSC_VER
|
||||
+// MSVC is not working well on complex function overload.
|
||||
template<int64_t scale = 1>
|
||||
std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<double>>
|
||||
inline mask_gather(const Vectorized<double>& src, const double* base_addr,
|
||||
@@ -122,7 +124,7 @@ inline mask_gather(const Vectorized<float>& src, const float* base_addr,
|
||||
const Vectorized<int32_t>& vindex, Vectorized<float>& mask) {
|
||||
return _mm256_mask_i32gather_ps(src, base_addr, vindex, mask, scale);
|
||||
}
|
||||
-
|
||||
+#endif
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CONVERT ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
// Only works for inputs in the range: [-2^51, 2^51]
|
||||
@@ -302,6 +304,6 @@ inline Vectorized<uint8_t> flip(const Vectorized<uint8_t> & v) {
|
||||
return flip8(v);
|
||||
}
|
||||
|
||||
-#endif // (defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
|
||||
+#endif // (defined(CPU_CAPABILITY_AVX2)
|
||||
|
||||
}} // namepsace at::vec::CPU_CAPABILITY
|
||||
diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_bfloat16.h b/aten/src/ATen/cpu/vec/vec256/vec256_bfloat16.h
|
||||
index 3e26213d6d2..66557436c70 100644
|
||||
--- a/aten/src/ATen/cpu/vec/vec256/vec256_bfloat16.h
|
||||
+++ b/aten/src/ATen/cpu/vec/vec256/vec256_bfloat16.h
|
||||
@@ -7,7 +7,8 @@
|
||||
#include <ATen/cpu/vec/vec_base.h>
|
||||
#include <c10/util/irange.h>
|
||||
|
||||
-#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
|
||||
+#if defined(CPU_CAPABILITY_AVX2)
|
||||
+#define SLEEF_STATIC_LIBS
|
||||
#include <sleef.h>
|
||||
#endif
|
||||
|
||||
@@ -18,7 +19,18 @@ namespace at::vec {
|
||||
// See Note [CPU_CAPABILITY namespace]
|
||||
inline namespace CPU_CAPABILITY {
|
||||
|
||||
-#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
|
||||
+#if defined(CPU_CAPABILITY_AVX2)
|
||||
+
|
||||
+#ifndef SLEEF_CONST
|
||||
+#if (defined(__GNUC__) || defined(__CLANG__)) && !defined(__INTEL_COMPILER)
|
||||
+#define SLEEF_CONST const
|
||||
+#else
|
||||
+#define SLEEF_CONST
|
||||
+#endif
|
||||
+#define SLEEF_CONST_OLD SLEEF_CONST
|
||||
+#else
|
||||
+#define SLEEF_CONST_OLD
|
||||
+#endif
|
||||
|
||||
// bfloat16 conversion
|
||||
static inline void cvtbf16_fp32(const __m128i& a, __m256& o) {
|
||||
@@ -265,7 +277,8 @@ public:
|
||||
}
|
||||
return b;
|
||||
}
|
||||
- Vectorized<T> map(const __m256 (*const vop)(__m256)) const {
|
||||
+
|
||||
+ Vectorized<T> map(SLEEF_CONST __m256 (*SLEEF_CONST_OLD vop)(__m256)) const {
|
||||
__m256 lo, hi;
|
||||
cvt_to_fp32<T>(values, lo, hi);
|
||||
const auto o1 = vop(lo);
|
||||
@@ -1026,7 +1039,7 @@ inline Vectorized<type> convert_float_##name(const Vectorized<float>& a, const V
|
||||
CONVERT_VECTORIZED_INIT(BFloat16, bfloat16);
|
||||
CONVERT_VECTORIZED_INIT(Half, half);
|
||||
|
||||
-#else // defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
|
||||
+#else // defined(CPU_CAPABILITY_AVX2)
|
||||
|
||||
#define CONVERT_NON_VECTORIZED_INIT(type, name) \
|
||||
inline std::tuple<Vectorized<float>, Vectorized<float>> convert_##name##_float(const Vectorized<type>& a) { \
|
||||
@@ -1051,9 +1064,9 @@ inline Vectorized<type> convert_float_##name(const Vectorized<float>& a, const V
|
||||
CONVERT_NON_VECTORIZED_INIT(BFloat16, bfloat16);
|
||||
CONVERT_NON_VECTORIZED_INIT(Half, half);
|
||||
|
||||
-#endif // defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
|
||||
+#endif // defined(CPU_CAPABILITY_AVX2)
|
||||
|
||||
-#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
|
||||
+#if defined(CPU_CAPABILITY_AVX2)
|
||||
#define LOAD_FP32_VECTORIZED_INIT(type, name) \
|
||||
inline void load_fp32_from_##name(const type *data, Vectorized<float>& out) { \
|
||||
auto values = _mm_loadu_si128(reinterpret_cast<const __m128i*>(data)); \
|
||||
@@ -1072,7 +1085,7 @@ inline void load_fp32_from_##name(const type *data, Vectorized<float>& out1, Vec
|
||||
LOAD_FP32_VECTORIZED_INIT(BFloat16, bf16);
|
||||
LOAD_FP32_VECTORIZED_INIT(Half, fp16);
|
||||
|
||||
-#else // defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
|
||||
+#else // defined(CPU_CAPABILITY_AVX2)
|
||||
#define LOAD_FP32_NON_VECTORIZED_INIT(type, name) \
|
||||
inline void load_fp32_from_##name(const type *data, Vectorized<float>& out) { \
|
||||
__at_align__ float values[Vectorized<float>::size()]; \
|
||||
diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_complex_double.h b/aten/src/ATen/cpu/vec/vec256/vec256_complex_double.h
|
||||
index f93ea1e63c3..6c198fb37d3 100644
|
||||
--- a/aten/src/ATen/cpu/vec/vec256/vec256_complex_double.h
|
||||
+++ b/aten/src/ATen/cpu/vec/vec256/vec256_complex_double.h
|
||||
@@ -8,7 +8,8 @@
|
||||
#include <ATen/cpu/vec/intrinsics.h>
|
||||
#include <ATen/cpu/vec/vec_base.h>
|
||||
|
||||
-#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
|
||||
+#if defined(CPU_CAPABILITY_AVX2)
|
||||
+#define SLEEF_STATIC_LIBS
|
||||
#include <sleef.h>
|
||||
#endif
|
||||
|
||||
@@ -16,7 +17,7 @@ namespace at::vec {
|
||||
// See Note [CPU_CAPABILITY namespace]
|
||||
inline namespace CPU_CAPABILITY {
|
||||
|
||||
-#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
|
||||
+#if defined(CPU_CAPABILITY_AVX2)
|
||||
|
||||
template <> class Vectorized<c10::complex<double>> {
|
||||
private:
|
||||
@@ -145,7 +146,7 @@ public:
|
||||
auto abs = abs_();
|
||||
auto zero = _mm256_setzero_pd();
|
||||
auto mask = _mm256_cmp_pd(abs, zero, _CMP_EQ_OQ);
|
||||
- auto div = values / abs;
|
||||
+ auto div = _mm256_div_pd(values, abs);
|
||||
return _mm256_blendv_pd(div, zero, mask);
|
||||
}
|
||||
__m256d real_() const {
|
||||
diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_complex_float.h b/aten/src/ATen/cpu/vec/vec256/vec256_complex_float.h
|
||||
index 7c142c04b79..c72d4d49274 100644
|
||||
--- a/aten/src/ATen/cpu/vec/vec256/vec256_complex_float.h
|
||||
+++ b/aten/src/ATen/cpu/vec/vec256/vec256_complex_float.h
|
||||
@@ -7,7 +7,8 @@
|
||||
#include <c10/util/irange.h>
|
||||
#include <ATen/cpu/vec/intrinsics.h>
|
||||
#include <ATen/cpu/vec/vec_base.h>
|
||||
-#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
|
||||
+#if defined(CPU_CAPABILITY_AVX2)
|
||||
+#define SLEEF_STATIC_LIBS
|
||||
#include <sleef.h>
|
||||
#endif
|
||||
|
||||
@@ -15,7 +16,7 @@ namespace at::vec {
|
||||
// See Note [CPU_CAPABILITY namespace]
|
||||
inline namespace CPU_CAPABILITY {
|
||||
|
||||
-#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
|
||||
+#if defined(CPU_CAPABILITY_AVX2)
|
||||
|
||||
template <> class Vectorized<c10::complex<float>> {
|
||||
private:
|
||||
@@ -180,7 +181,7 @@ public:
|
||||
auto abs = abs_();
|
||||
auto zero = _mm256_setzero_ps();
|
||||
auto mask = _mm256_cmp_ps(abs, zero, _CMP_EQ_OQ);
|
||||
- auto div = values / abs;
|
||||
+ auto div = _mm256_div_ps(values, abs);
|
||||
return _mm256_blendv_ps(div, zero, mask);
|
||||
}
|
||||
__m256 real_() const {
|
||||
diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_double.h b/aten/src/ATen/cpu/vec/vec256/vec256_double.h
|
||||
index bc82d07edd1..bed6da627af 100644
|
||||
--- a/aten/src/ATen/cpu/vec/vec256/vec256_double.h
|
||||
+++ b/aten/src/ATen/cpu/vec/vec256/vec256_double.h
|
||||
@@ -6,7 +6,8 @@
|
||||
#include <ATen/cpu/vec/intrinsics.h>
|
||||
#include <ATen/cpu/vec/vec_base.h>
|
||||
#include <c10/util/irange.h>
|
||||
-#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
|
||||
+#if defined(CPU_CAPABILITY_AVX2)
|
||||
+#define SLEEF_STATIC_LIBS
|
||||
#include <sleef.h>
|
||||
#endif
|
||||
|
||||
@@ -15,7 +16,7 @@ namespace at::vec {
|
||||
inline namespace CPU_CAPABILITY {
|
||||
|
||||
|
||||
-#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
|
||||
+#if defined(CPU_CAPABILITY_AVX2)
|
||||
|
||||
template <> class Vectorized<double> {
|
||||
private:
|
||||
diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_float.h b/aten/src/ATen/cpu/vec/vec256/vec256_float.h
|
||||
index 886809a0b8a..0e3664cd37b 100644
|
||||
--- a/aten/src/ATen/cpu/vec/vec256/vec256_float.h
|
||||
+++ b/aten/src/ATen/cpu/vec/vec256/vec256_float.h
|
||||
@@ -6,7 +6,8 @@
|
||||
#include <ATen/cpu/vec/intrinsics.h>
|
||||
#include <ATen/cpu/vec/vec_base.h>
|
||||
#include <c10/util/irange.h>
|
||||
-#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
|
||||
+#if defined(CPU_CAPABILITY_AVX2)
|
||||
+#define SLEEF_STATIC_LIBS
|
||||
#include <sleef.h>
|
||||
#endif
|
||||
|
||||
@@ -14,7 +15,7 @@ namespace at::vec {
|
||||
// See Note [CPU_CAPABILITY namespace]
|
||||
inline namespace CPU_CAPABILITY {
|
||||
|
||||
-#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
|
||||
+#if defined(CPU_CAPABILITY_AVX2)
|
||||
|
||||
template <> class Vectorized<float> {
|
||||
private:
|
||||
@@ -226,14 +227,14 @@ public:
|
||||
static __m256 vec_factorial_5 =
|
||||
_mm256_set1_ps(0.00828929059f); // 1/factorial(5)
|
||||
static __m256 vec_exp_log2ef =
|
||||
- (__m256)_mm256_set1_epi32(0x3fb8aa3b); // log2(e)
|
||||
+ _mm256_castsi256_ps(_mm256_set1_epi32(0x3fb8aa3b)); // log2(e)
|
||||
static __m256 vec_half = _mm256_set1_ps(0.5f);
|
||||
static __m256 vec_one = _mm256_set1_ps(1.f);
|
||||
static __m256 vec_zero = _mm256_set1_ps(0.f);
|
||||
static __m256 vec_two = _mm256_set1_ps(2.f);
|
||||
- static __m256 vec_ln2f = (__m256)_mm256_set1_epi32(0x3f317218); // ln(2)
|
||||
- static __m256 vec_ln_flt_min = (__m256)_mm256_set1_epi32(0xc2aeac50);
|
||||
- static __m256 vec_ln_flt_max = (__m256)_mm256_set1_epi32(0x42b17218);
|
||||
+ static __m256 vec_ln2f = _mm256_castsi256_ps(_mm256_set1_epi32(0x3f317218)); // ln(2)
|
||||
+ static __m256 vec_ln_flt_min = _mm256_castsi256_ps(_mm256_set1_epi32(0xc2aeac50));
|
||||
+ static __m256 vec_ln_flt_max = _mm256_castsi256_ps(_mm256_set1_epi32(0x42b17218));
|
||||
static __m256i vec_127 = _mm256_set1_epi32(0x0000007f);
|
||||
static int n_mantissa_bits = 23;
|
||||
|
||||
@@ -266,7 +267,7 @@ public:
|
||||
auto vec_exp_number_i = _mm256_cvtps_epi32(vec_exp_number);
|
||||
auto vec_two_pow_n_i = _mm256_add_epi32(vec_exp_number_i, vec_127);
|
||||
vec_two_pow_n_i = _mm256_slli_epi32(vec_two_pow_n_i, n_mantissa_bits);
|
||||
- auto vec_two_pow_n = (__m256)vec_two_pow_n_i;
|
||||
+ auto vec_two_pow_n = _mm256_castsi256_ps(vec_two_pow_n_i);
|
||||
vec_two_pow_n =
|
||||
_mm256_blendv_ps(vec_two_pow_n, vec_zero, less_ln_flt_min_mask);
|
||||
|
||||
diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_qint.h b/aten/src/ATen/cpu/vec/vec256/vec256_qint.h
|
||||
index 4128841701a..85e099904cd 100644
|
||||
--- a/aten/src/ATen/cpu/vec/vec256/vec256_qint.h
|
||||
+++ b/aten/src/ATen/cpu/vec/vec256/vec256_qint.h
|
||||
@@ -41,11 +41,17 @@
|
||||
namespace at::vec {
|
||||
inline namespace CPU_CAPABILITY {
|
||||
|
||||
-#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
|
||||
+#if defined(CPU_CAPABILITY_AVX2)
|
||||
|
||||
+#ifdef _MSC_VER
|
||||
+__declspec(align(64)) struct Vectorizedqi {
|
||||
+ protected:
|
||||
+ __m256i vals;
|
||||
+#else
|
||||
struct Vectorizedqi {
|
||||
protected:
|
||||
__m256i vals __attribute__((aligned(64)));
|
||||
+#endif
|
||||
|
||||
public:
|
||||
Vectorizedqi() {}
|
||||
@@ -133,7 +139,7 @@ inline convert_float_to_int8(at::vec::Vectorized<float> src) {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
-inline void __attribute__((always_inline)) QuantizeAvx2(
|
||||
+__FORCE_INLINE void QuantizeAvx2(
|
||||
const float* src,
|
||||
T* dst,
|
||||
int len,
|
||||
@@ -1331,5 +1337,5 @@ Vectorized<c10::quint8> inline maximum(const Vectorized<c10::quint8>& a, const V
|
||||
return a.maximum(b);
|
||||
}
|
||||
|
||||
-#endif // if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
|
||||
+#endif // if defined(CPU_CAPABILITY_AVX2)
|
||||
}} // namespace at::vec::CPU_CAPABILITY
|
||||
diff --git a/aten/src/ATen/cpu/vec/vec512/vec512.h b/aten/src/ATen/cpu/vec/vec512/vec512.h
|
||||
index fe96d123e64..87f723d782c 100644
|
||||
--- a/aten/src/ATen/cpu/vec/vec512/vec512.h
|
||||
+++ b/aten/src/ATen/cpu/vec/vec512/vec512.h
|
||||
@@ -55,7 +55,7 @@ std::ostream& operator<<(std::ostream& stream, const Vectorized<T>& vec) {
|
||||
}
|
||||
|
||||
|
||||
-#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
|
||||
+#if defined(CPU_CAPABILITY_AVX512)
|
||||
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CAST (AVX512) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
@@ -80,7 +80,8 @@ inline Vectorized<double> cast<double, int64_t>(const Vectorized<int64_t>& src)
|
||||
}
|
||||
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
-
|
||||
+#ifndef _MSC_VER
|
||||
+// MSVC is not working well on complex function overload.
|
||||
template<int64_t scale = 1>
|
||||
std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<double>>
|
||||
inline gather(const double* base_addr, const Vectorized<int64_t>& vindex) {
|
||||
@@ -92,9 +93,10 @@ std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorize
|
||||
inline gather(const float* base_addr, const Vectorized<int32_t>& vindex) {
|
||||
return _mm512_i32gather_ps(vindex, base_addr, scale);
|
||||
}
|
||||
-
|
||||
+#endif
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MASK GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
-
|
||||
+#ifndef _MSC_VER
|
||||
+// MSVC is not working well on complex function overload.
|
||||
template<int64_t scale = 1>
|
||||
std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<double>>
|
||||
inline mask_gather(const Vectorized<double>& src, const double* base_addr,
|
||||
@@ -112,7 +114,7 @@ inline mask_gather(const Vectorized<float>& src, const float* base_addr,
|
||||
auto mask_ = _mm512_cmp_ps_mask(all_ones, mask.values, _CMP_EQ_OQ);
|
||||
return _mm512_mask_i32gather_ps(src, mask_, vindex, base_addr, scale);
|
||||
}
|
||||
-
|
||||
+#endif
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CONVERT ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
template<>
|
||||
@@ -270,6 +272,6 @@ inline Vectorized<uint8_t> flip(const Vectorized<uint8_t> & v) {
|
||||
return flip8(v);
|
||||
}
|
||||
|
||||
-#endif // defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
|
||||
+#endif // defined(CPU_CAPABILITY_AVX512)
|
||||
|
||||
}}}
|
||||
diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_bfloat16.h b/aten/src/ATen/cpu/vec/vec512/vec512_bfloat16.h
|
||||
index f9fc92d52bf..eb3b6a72240 100644
|
||||
--- a/aten/src/ATen/cpu/vec/vec512/vec512_bfloat16.h
|
||||
+++ b/aten/src/ATen/cpu/vec/vec512/vec512_bfloat16.h
|
||||
@@ -7,7 +7,8 @@
|
||||
#include <ATen/cpu/vec/vec_base.h>
|
||||
#include <c10/util/irange.h>
|
||||
|
||||
-#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
|
||||
+#if defined(CPU_CAPABILITY_AVX512)
|
||||
+#define SLEEF_STATIC_LIBS
|
||||
#include <sleef.h>
|
||||
#endif
|
||||
|
||||
@@ -16,7 +17,18 @@ namespace vec {
|
||||
// See Note [CPU_CAPABILITY namespace]
|
||||
inline namespace CPU_CAPABILITY {
|
||||
|
||||
-#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
|
||||
+#if defined(CPU_CAPABILITY_AVX512)
|
||||
+
|
||||
+#ifndef SLEEF_CONST
|
||||
+#if (defined(__GNUC__) || defined(__CLANG__)) && !defined(__INTEL_COMPILER)
|
||||
+#define SLEEF_CONST const
|
||||
+#else
|
||||
+#define SLEEF_CONST
|
||||
+#endif
|
||||
+#define SLEEF_CONST_OLD SLEEF_CONST
|
||||
+#else
|
||||
+#define SLEEF_CONST_OLD
|
||||
+#endif
|
||||
|
||||
// bfloat16 conversion
|
||||
static inline void cvtbf16_fp32(const __m256i& a, __m512& o) {
|
||||
@@ -362,7 +374,8 @@ public:
|
||||
}
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wignored-qualifiers"
|
||||
- Vectorized<T> map(const __m512 (*const vop)(__m512)) const {
|
||||
+
|
||||
+ Vectorized<T> map(SLEEF_CONST __m512 (*SLEEF_CONST_OLD vop)(__m512)) const {
|
||||
__m512 lo, hi;
|
||||
cvt_to_fp32<T>(values, lo, hi);
|
||||
const auto o1 = vop(lo);
|
||||
@@ -1571,7 +1584,7 @@ inline Vectorized<type> convert_float_##name(const Vectorized<float>& a, const V
|
||||
CONVERT_VECTORIZED_INIT(BFloat16, bfloat16);
|
||||
CONVERT_VECTORIZED_INIT(Half, half);
|
||||
|
||||
-#else //defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
|
||||
+#else //defined(CPU_CAPABILITY_AVX512)
|
||||
|
||||
#define CONVERT_NON_VECTORIZED_INIT(type, name) \
|
||||
inline std::tuple<Vectorized<float>, Vectorized<float>> convert_##name##_float(const Vectorized<type>& a) { \
|
||||
@@ -1601,9 +1614,9 @@ inline Vectorized<type> convert_float_##name(const Vectorized<float>& a, const V
|
||||
CONVERT_NON_VECTORIZED_INIT(BFloat16, bfloat16);
|
||||
CONVERT_NON_VECTORIZED_INIT(Half, half);
|
||||
|
||||
-#endif // defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
|
||||
+#endif // defined(CPU_CAPABILITY_AVX512)
|
||||
|
||||
-#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
|
||||
+#if defined(CPU_CAPABILITY_AVX512)
|
||||
#define LOAD_FP32_VECTORIZED_INIT(type, name) \
|
||||
inline void load_fp32_from_##name(const type *data, Vectorized<float>& out) { \
|
||||
auto values = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(data)); \
|
||||
@@ -1622,7 +1635,7 @@ inline void load_fp32_from_##name(const type *data, Vectorized<float>& out1, Vec
|
||||
LOAD_FP32_VECTORIZED_INIT(BFloat16, bf16);
|
||||
LOAD_FP32_VECTORIZED_INIT(Half, fp16);
|
||||
|
||||
-#else // defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
|
||||
+#else // defined(CPU_CAPABILITY_AVX512)
|
||||
#define LOAD_FP32_NON_VECTORIZED_INIT(type, name) \
|
||||
inline void load_fp32_from_##name(const type *data, Vectorized<float>& out) { \
|
||||
__at_align__ float values[Vectorized<float>::size()]; \
|
||||
diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_complex_double.h b/aten/src/ATen/cpu/vec/vec512/vec512_complex_double.h
|
||||
index 02aa3a87cc1..c35204f9da2 100644
|
||||
--- a/aten/src/ATen/cpu/vec/vec512/vec512_complex_double.h
|
||||
+++ b/aten/src/ATen/cpu/vec/vec512/vec512_complex_double.h
|
||||
@@ -7,7 +7,8 @@
|
||||
#include <c10/util/irange.h>
|
||||
#include <ATen/cpu/vec/intrinsics.h>
|
||||
#include <ATen/cpu/vec/vec_base.h>
|
||||
-#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
|
||||
+#if defined(CPU_CAPABILITY_AVX512)
|
||||
+#define SLEEF_STATIC_LIBS
|
||||
#include <sleef.h>
|
||||
#endif
|
||||
|
||||
@@ -16,7 +17,7 @@ namespace vec {
|
||||
// See Note [CPU_CAPABILITY namespace]
|
||||
inline namespace CPU_CAPABILITY {
|
||||
|
||||
-#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
|
||||
+#if defined(CPU_CAPABILITY_AVX512)
|
||||
|
||||
template <> class Vectorized<c10::complex<double>> {
|
||||
private:
|
||||
@@ -203,7 +204,7 @@ public:
|
||||
auto abs = abs_();
|
||||
auto zero = _mm512_setzero_pd();
|
||||
auto mask = _mm512_cmp_pd_mask(abs, zero, _CMP_EQ_OQ);
|
||||
- auto div = values / abs;
|
||||
+ auto div = _mm512_div_pd(values, abs);
|
||||
return _mm512_mask_blend_pd(mask, div, zero);
|
||||
}
|
||||
__m512d real_() const {
|
||||
diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_complex_float.h b/aten/src/ATen/cpu/vec/vec512/vec512_complex_float.h
|
||||
index a5d790c98b2..2801e484d94 100644
|
||||
--- a/aten/src/ATen/cpu/vec/vec512/vec512_complex_float.h
|
||||
+++ b/aten/src/ATen/cpu/vec/vec512/vec512_complex_float.h
|
||||
@@ -7,7 +7,8 @@
|
||||
#include <c10/util/irange.h>
|
||||
#include <ATen/cpu/vec/intrinsics.h>
|
||||
#include <ATen/cpu/vec/vec_base.h>
|
||||
-#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
|
||||
+#if defined(CPU_CAPABILITY_AVX512)
|
||||
+#define SLEEF_STATIC_LIBS
|
||||
#include <sleef.h>
|
||||
#endif
|
||||
|
||||
@@ -16,7 +17,7 @@ namespace vec {
|
||||
// See Note [CPU_CAPABILITY namespace]
|
||||
inline namespace CPU_CAPABILITY {
|
||||
|
||||
-#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
|
||||
+#if defined(CPU_CAPABILITY_AVX512)
|
||||
|
||||
template <> class Vectorized<c10::complex<float>> {
|
||||
private:
|
||||
@@ -708,7 +709,7 @@ public:
|
||||
auto abs = abs_();
|
||||
auto zero = _mm512_setzero_ps();
|
||||
auto mask = _mm512_cmp_ps_mask(abs, zero, _CMP_EQ_OQ);
|
||||
- auto div = values / abs;
|
||||
+ auto div = _mm512_div_ps(values, abs);
|
||||
return _mm512_mask_blend_ps(mask, div, zero);
|
||||
}
|
||||
__m512 real_() const {
|
||||
diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_double.h b/aten/src/ATen/cpu/vec/vec512/vec512_double.h
|
||||
index 27b2753c903..508ab257e60 100644
|
||||
--- a/aten/src/ATen/cpu/vec/vec512/vec512_double.h
|
||||
+++ b/aten/src/ATen/cpu/vec/vec512/vec512_double.h
|
||||
@@ -6,7 +6,8 @@
|
||||
#include <ATen/cpu/vec/intrinsics.h>
|
||||
#include <ATen/cpu/vec/vec_base.h>
|
||||
#include <c10/util/irange.h>
|
||||
-#if (defined(CPU_CAPABILITY_AVX512)) && !defined(_MSC_VER)
|
||||
+#if (defined(CPU_CAPABILITY_AVX512))
|
||||
+#define SLEEF_STATIC_LIBS
|
||||
#include <sleef.h>
|
||||
#endif
|
||||
|
||||
@@ -15,7 +16,7 @@ namespace vec {
|
||||
// See Note [CPU_CAPABILITY namespace]
|
||||
inline namespace CPU_CAPABILITY {
|
||||
|
||||
-#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
|
||||
+#if defined(CPU_CAPABILITY_AVX512)
|
||||
|
||||
template <> class Vectorized<double> {
|
||||
private:
|
||||
diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_float.h b/aten/src/ATen/cpu/vec/vec512/vec512_float.h
|
||||
index ba5738687fd..a08df3c141a 100644
|
||||
--- a/aten/src/ATen/cpu/vec/vec512/vec512_float.h
|
||||
+++ b/aten/src/ATen/cpu/vec/vec512/vec512_float.h
|
||||
@@ -6,7 +6,8 @@
|
||||
#include <ATen/cpu/vec/intrinsics.h>
|
||||
#include <ATen/cpu/vec/vec_base.h>
|
||||
#include <c10/util/irange.h>
|
||||
-#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
|
||||
+#if defined(CPU_CAPABILITY_AVX512)
|
||||
+#define SLEEF_STATIC_LIBS
|
||||
#include <sleef.h>
|
||||
#endif
|
||||
|
||||
@@ -15,7 +16,7 @@ namespace vec {
|
||||
// See Note [CPU_CAPABILITY namespace]
|
||||
inline namespace CPU_CAPABILITY {
|
||||
|
||||
-#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
|
||||
+#if defined(CPU_CAPABILITY_AVX512)
|
||||
|
||||
template <> class Vectorized<float> {
|
||||
private:
|
||||
@@ -246,14 +247,14 @@ public:
|
||||
static __m512 vec_factorial_5 =
|
||||
_mm512_set1_ps(0.00828929059f); // 1/factorial(5)
|
||||
static __m512 vec_exp_log2ef =
|
||||
- (__m512)_mm512_set1_epi32(0x3fb8aa3b); // log2(e)
|
||||
+ _mm512_castsi512_ps(_mm512_set1_epi32(0x3fb8aa3b)); // log2(e)
|
||||
static __m512 vec_half = _mm512_set1_ps(0.5f);
|
||||
static __m512 vec_one = _mm512_set1_ps(1.f);
|
||||
static __m512 vec_zero = _mm512_set1_ps(0.f);
|
||||
static __m512 vec_two = _mm512_set1_ps(2.f);
|
||||
- static __m512 vec_ln2f = (__m512)_mm512_set1_epi32(0x3f317218); // ln(2)
|
||||
- static __m512 vec_ln_flt_min = (__m512)_mm512_set1_epi32(0xc2aeac50);
|
||||
- static __m512 vec_ln_flt_max = (__m512)_mm512_set1_epi32(0x42b17218);
|
||||
+ static __m512 vec_ln2f = _mm512_castsi512_ps(_mm512_set1_epi32(0x3f317218)); // ln(2)
|
||||
+ static __m512 vec_ln_flt_min = _mm512_castsi512_ps(_mm512_set1_epi32(0xc2aeac50));
|
||||
+ static __m512 vec_ln_flt_max = _mm512_castsi512_ps(_mm512_set1_epi32(0x42b17218));
|
||||
static __m512i vec_127 = _mm512_set1_epi32(0x0000007f);
|
||||
static int n_mantissa_bits = 23;
|
||||
|
||||
@@ -288,7 +289,7 @@ public:
|
||||
auto vec_exp_number_i = _mm512_cvtps_epi32(vec_exp_number);
|
||||
auto vec_two_pow_n_i = _mm512_add_epi32(vec_exp_number_i, vec_127);
|
||||
vec_two_pow_n_i = _mm512_slli_epi32(vec_two_pow_n_i, n_mantissa_bits);
|
||||
- auto vec_two_pow_n = (__m512)vec_two_pow_n_i;
|
||||
+ auto vec_two_pow_n = _mm512_castsi512_ps(vec_two_pow_n_i);
|
||||
vec_two_pow_n =
|
||||
_mm512_mask_blend_ps(less_ln_flt_min_mask, vec_two_pow_n, vec_zero);
|
||||
|
||||
diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_qint.h b/aten/src/ATen/cpu/vec/vec512/vec512_qint.h
|
||||
index e0713d01312..a5671ed4a50 100644
|
||||
--- a/aten/src/ATen/cpu/vec/vec512/vec512_qint.h
|
||||
+++ b/aten/src/ATen/cpu/vec/vec512/vec512_qint.h
|
||||
@@ -42,11 +42,17 @@ namespace at {
|
||||
namespace vec {
|
||||
inline namespace CPU_CAPABILITY {
|
||||
|
||||
-#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
|
||||
+#if defined(CPU_CAPABILITY_AVX512)
|
||||
|
||||
+#ifdef _MSC_VER
|
||||
+__declspec(align(64)) struct Vectorizedqi {
|
||||
+ protected:
|
||||
+ __m512i vals;
|
||||
+#else
|
||||
struct Vectorizedqi {
|
||||
protected:
|
||||
__m512i vals __attribute__((aligned(64)));
|
||||
+#endif
|
||||
|
||||
public:
|
||||
Vectorizedqi() {}
|
||||
@@ -136,7 +142,7 @@ inline convert_float_to_int8(at::vec::Vectorized<float> src) {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
-inline void __attribute__((always_inline)) QuantizeAvx512(
|
||||
+__FORCE_INLINE void QuantizeAvx512(
|
||||
const float* src,
|
||||
T* dst,
|
||||
int len,
|
||||
@@ -525,10 +531,17 @@ struct Vectorized<c10::qint8> : public Vectorizedqi {
|
||||
Vectorized<float> scale,
|
||||
Vectorized<float> zero_point,
|
||||
Vectorized<float> scale_neg_zp_premul) const {
|
||||
+ #if defined(_MSC_VER) && !defined(__clang__)
|
||||
+ __m128i int_val0 = _mm_set_epi64x(vals.m512i_u64[1], vals.m512i_u64[0]);
|
||||
+ __m128i int_val1 = _mm_set_epi64x(vals.m512i_u64[3], vals.m512i_u64[2]);
|
||||
+ __m128i int_val2 = _mm_set_epi64x(vals.m512i_u64[5], vals.m512i_u64[4]);
|
||||
+ __m128i int_val3 = _mm_set_epi64x(vals.m512i_u64[7], vals.m512i_u64[6]);
|
||||
+ #else
|
||||
__m128i int_val0 = _mm_set_epi64x(vals[1], vals[0]);
|
||||
__m128i int_val1 = _mm_set_epi64x(vals[3], vals[2]);
|
||||
__m128i int_val2 = _mm_set_epi64x(vals[5], vals[4]);
|
||||
__m128i int_val3 = _mm_set_epi64x(vals[7], vals[6]);
|
||||
+ #endif
|
||||
|
||||
__m512 float_val0 = _mm512_cvtepi32_ps(cvtepi8_epi32(int_val0));
|
||||
__m512 float_val1 = _mm512_cvtepi32_ps(cvtepi8_epi32(int_val1));
|
||||
@@ -549,10 +562,17 @@ struct Vectorized<c10::qint8> : public Vectorizedqi {
|
||||
float_vec_return_type dequantize(
|
||||
Vectorized<float> scale,
|
||||
Vectorized<float> zero_point) const {
|
||||
+ #if defined(_MSC_VER) && !defined(__clang__)
|
||||
+ __m128i int_val0 = _mm_set_epi64x(vals.m512i_u64[1], vals.m512i_u64[0]);
|
||||
+ __m128i int_val1 = _mm_set_epi64x(vals.m512i_u64[3], vals.m512i_u64[2]);
|
||||
+ __m128i int_val2 = _mm_set_epi64x(vals.m512i_u64[5], vals.m512i_u64[4]);
|
||||
+ __m128i int_val3 = _mm_set_epi64x(vals.m512i_u64[7], vals.m512i_u64[6]);
|
||||
+ #else
|
||||
__m128i int_val0 = _mm_set_epi64x(vals[1], vals[0]);
|
||||
__m128i int_val1 = _mm_set_epi64x(vals[3], vals[2]);
|
||||
__m128i int_val2 = _mm_set_epi64x(vals[5], vals[4]);
|
||||
__m128i int_val3 = _mm_set_epi64x(vals[7], vals[6]);
|
||||
+ #endif
|
||||
|
||||
__m512 float_val0 = _mm512_cvtepi32_ps(cvtepi8_epi32(int_val0));
|
||||
__m512 float_val1 = _mm512_cvtepi32_ps(cvtepi8_epi32(int_val1));
|
||||
@@ -598,20 +618,34 @@ struct Vectorized<c10::qint8> : public Vectorizedqi {
|
||||
}
|
||||
|
||||
int_vec_return_type widening_subtract(Vectorized<c10::qint8> b) const {
|
||||
+ #if defined(_MSC_VER) && !defined(__clang__)
|
||||
+ __m128i int_val0 = _mm_set_epi64x(vals.m512i_u64[1], vals.m512i_u64[0]);
|
||||
+ __m128i int_val1 = _mm_set_epi64x(vals.m512i_u64[3], vals.m512i_u64[2]);
|
||||
+ __m128i int_val2 = _mm_set_epi64x(vals.m512i_u64[5], vals.m512i_u64[4]);
|
||||
+ __m128i int_val3 = _mm_set_epi64x(vals.m512i_u64[7], vals.m512i_u64[6]);
|
||||
+ #else
|
||||
__m128i int_val0 = _mm_set_epi64x(vals[1], vals[0]);
|
||||
__m128i int_val1 = _mm_set_epi64x(vals[3], vals[2]);
|
||||
__m128i int_val2 = _mm_set_epi64x(vals[5], vals[4]);
|
||||
__m128i int_val3 = _mm_set_epi64x(vals[7], vals[6]);
|
||||
+ #endif
|
||||
|
||||
__m512i int32_val0 = cvtepi8_epi32(int_val0);
|
||||
__m512i int32_val1 = cvtepi8_epi32(int_val1);
|
||||
__m512i int32_val2 = cvtepi8_epi32(int_val2);
|
||||
__m512i int32_val3 = cvtepi8_epi32(int_val3);
|
||||
|
||||
+ #if defined(_MSC_VER) && !defined(__clang__)
|
||||
+ __m128i int_b0 = _mm_set_epi64x(b.vals.m512i_u64[1], b.vals.m512i_u64[0]);
|
||||
+ __m128i int_b1 = _mm_set_epi64x(b.vals.m512i_u64[3], b.vals.m512i_u64[2]);
|
||||
+ __m128i int_b2 = _mm_set_epi64x(b.vals.m512i_u64[5], b.vals.m512i_u64[4]);
|
||||
+ __m128i int_b3 = _mm_set_epi64x(b.vals.m512i_u64[7], b.vals.m512i_u64[6]);
|
||||
+ #else
|
||||
__m128i int_b0 = _mm_set_epi64x(b.vals[1], b.vals[0]);
|
||||
__m128i int_b1 = _mm_set_epi64x(b.vals[3], b.vals[2]);
|
||||
__m128i int_b2 = _mm_set_epi64x(b.vals[5], b.vals[4]);
|
||||
__m128i int_b3 = _mm_set_epi64x(b.vals[7], b.vals[6]);
|
||||
+ #endif
|
||||
|
||||
__m512i int32_b0 = cvtepi8_epi32(int_b0);
|
||||
__m512i int32_b1 = cvtepi8_epi32(int_b1);
|
||||
@@ -721,10 +755,17 @@ struct Vectorized<c10::quint8> : public Vectorizedqi {
|
||||
Vectorized<float> scale,
|
||||
Vectorized<float> zero_point,
|
||||
Vectorized<float> scale_zp_premul) const {
|
||||
+ #if defined(_MSC_VER) && !defined(__clang__)
|
||||
+ __m128i int_val0 = _mm_set_epi64x(vals.m512i_u64[1], vals.m512i_u64[0]);
|
||||
+ __m128i int_val1 = _mm_set_epi64x(vals.m512i_u64[3], vals.m512i_u64[2]);
|
||||
+ __m128i int_val2 = _mm_set_epi64x(vals.m512i_u64[5], vals.m512i_u64[4]);
|
||||
+ __m128i int_val3 = _mm_set_epi64x(vals.m512i_u64[7], vals.m512i_u64[6]);
|
||||
+ #else
|
||||
__m128i int_val0 = _mm_set_epi64x(vals[1], vals[0]);
|
||||
__m128i int_val1 = _mm_set_epi64x(vals[3], vals[2]);
|
||||
__m128i int_val2 = _mm_set_epi64x(vals[5], vals[4]);
|
||||
__m128i int_val3 = _mm_set_epi64x(vals[7], vals[6]);
|
||||
+ #endif
|
||||
|
||||
__m512 float_val0 = _mm512_cvtepi32_ps(cvtepu8_epi32(int_val0));
|
||||
__m512 float_val1 = _mm512_cvtepi32_ps(cvtepu8_epi32(int_val1));
|
||||
@@ -746,10 +787,17 @@ struct Vectorized<c10::quint8> : public Vectorizedqi {
|
||||
float_vec_return_type dequantize(
|
||||
Vectorized<float> scale,
|
||||
Vectorized<float> zero_point) const {
|
||||
+ #if defined(_MSC_VER) && !defined(__clang__)
|
||||
+ __m128i int_val0 = _mm_set_epi64x(vals.m512i_u64[1], vals.m512i_u64[0]);
|
||||
+ __m128i int_val1 = _mm_set_epi64x(vals.m512i_u64[3], vals.m512i_u64[2]);
|
||||
+ __m128i int_val2 = _mm_set_epi64x(vals.m512i_u64[5], vals.m512i_u64[4]);
|
||||
+ __m128i int_val3 = _mm_set_epi64x(vals.m512i_u64[7], vals.m512i_u64[6]);
|
||||
+ #else
|
||||
__m128i int_val0 = _mm_set_epi64x(vals[1], vals[0]);
|
||||
__m128i int_val1 = _mm_set_epi64x(vals[3], vals[2]);
|
||||
__m128i int_val2 = _mm_set_epi64x(vals[5], vals[4]);
|
||||
__m128i int_val3 = _mm_set_epi64x(vals[7], vals[6]);
|
||||
+ #endif
|
||||
|
||||
__m512 float_val0 = _mm512_cvtepi32_ps(cvtepu8_epi32(int_val0));
|
||||
__m512 float_val1 = _mm512_cvtepi32_ps(cvtepu8_epi32(int_val1));
|
||||
@@ -796,20 +844,34 @@ struct Vectorized<c10::quint8> : public Vectorizedqi {
|
||||
}
|
||||
|
||||
int_vec_return_type widening_subtract(Vectorized<c10::quint8> b) const {
|
||||
+ #if defined(_MSC_VER) && !defined(__clang__)
|
||||
+ __m128i int_val0 = _mm_set_epi64x(vals.m512i_u64[1], vals.m512i_u64[0]);
|
||||
+ __m128i int_val1 = _mm_set_epi64x(vals.m512i_u64[3], vals.m512i_u64[2]);
|
||||
+ __m128i int_val2 = _mm_set_epi64x(vals.m512i_u64[5], vals.m512i_u64[4]);
|
||||
+ __m128i int_val3 = _mm_set_epi64x(vals.m512i_u64[7], vals.m512i_u64[6]);
|
||||
+ #else
|
||||
__m128i int_val0 = _mm_set_epi64x(vals[1], vals[0]);
|
||||
__m128i int_val1 = _mm_set_epi64x(vals[3], vals[2]);
|
||||
__m128i int_val2 = _mm_set_epi64x(vals[5], vals[4]);
|
||||
__m128i int_val3 = _mm_set_epi64x(vals[7], vals[6]);
|
||||
+ #endif
|
||||
|
||||
__m512i int32_val0 = cvtepu8_epi32(int_val0);
|
||||
__m512i int32_val1 = cvtepu8_epi32(int_val1);
|
||||
__m512i int32_val2 = cvtepu8_epi32(int_val2);
|
||||
__m512i int32_val3 = cvtepu8_epi32(int_val3);
|
||||
|
||||
+ #if defined(_MSC_VER) && !defined(__clang__)
|
||||
+ __m128i int_b0 = _mm_set_epi64x(b.vals.m512i_u64[1], b.vals.m512i_u64[0]);
|
||||
+ __m128i int_b1 = _mm_set_epi64x(b.vals.m512i_u64[3], b.vals.m512i_u64[2]);
|
||||
+ __m128i int_b2 = _mm_set_epi64x(b.vals.m512i_u64[5], b.vals.m512i_u64[4]);
|
||||
+ __m128i int_b3 = _mm_set_epi64x(b.vals.m512i_u64[7], b.vals.m512i_u64[6]);
|
||||
+ #else
|
||||
__m128i int_b0 = _mm_set_epi64x(b.vals[1], b.vals[0]);
|
||||
__m128i int_b1 = _mm_set_epi64x(b.vals[3], b.vals[2]);
|
||||
__m128i int_b2 = _mm_set_epi64x(b.vals[5], b.vals[4]);
|
||||
__m128i int_b3 = _mm_set_epi64x(b.vals[7], b.vals[6]);
|
||||
+ #endif
|
||||
|
||||
__m512i int32_b0 = cvtepu8_epi32(int_b0);
|
||||
__m512i int32_b1 = cvtepu8_epi32(int_b1);
|
||||
diff --git a/aten/src/ATen/cpu/vec/vec_base.h b/aten/src/ATen/cpu/vec/vec_base.h
|
||||
index adf81dd915c..20cb8ef6dbc 100644
|
||||
--- a/aten/src/ATen/cpu/vec/vec_base.h
|
||||
+++ b/aten/src/ATen/cpu/vec/vec_base.h
|
||||
@@ -36,6 +36,12 @@
|
||||
#include <c10/util/irange.h>
|
||||
#include <c10/util/Load.h>
|
||||
|
||||
+#if defined(__GNUC__)
|
||||
+#define __FORCE_INLINE __attribute__((always_inline)) inline
|
||||
+#elif defined(_MSC_VER)
|
||||
+#define __FORCE_INLINE __forceinline
|
||||
+#endif
|
||||
+
|
||||
// These macros helped us unify vec_base.h
|
||||
#ifdef CPU_CAPABILITY_AVX512
|
||||
#if defined(__GNUC__)
|
||||
diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt
|
||||
index a6b6f0f7d1d..15d37cf4861 100644
|
||||
--- a/caffe2/CMakeLists.txt
|
||||
+++ b/caffe2/CMakeLists.txt
|
||||
@@ -1787,7 +1787,7 @@ if(BUILD_TEST)
|
||||
endif()
|
||||
else()
|
||||
add_executable(${test_name}_${CPU_CAPABILITY} "${test_src}")
|
||||
- target_link_libraries(${test_name}_${CPU_CAPABILITY} torch_library gtest_main)
|
||||
+ target_link_libraries(${test_name}_${CPU_CAPABILITY} torch_library sleef gtest_main)
|
||||
endif()
|
||||
target_include_directories(${test_name}_${CPU_CAPABILITY} PRIVATE $<INSTALL_INTERFACE:include>)
|
||||
target_include_directories(${test_name}_${CPU_CAPABILITY} PRIVATE $<BUILD_INTERFACE:${CMAKE_BINARY_DIR}/include>)
|
||||
diff --git a/third_party/sleef.BUILD b/third_party/sleef.BUILD
|
||||
index 573f9c5b54a..f22a6e905e2 100644
|
||||
--- a/third_party/sleef.BUILD
|
||||
+++ b/third_party/sleef.BUILD
|
||||
@@ -38,6 +38,7 @@ SLEEF_PUBLIC_HEADERS = [
|
||||
SLEEF_PRIVATE_INCLUDES = [
|
||||
"-Iexternal/sleef/src/arch",
|
||||
"-Iexternal/sleef/src/common",
|
||||
+ "-Iexternal/sleef/src/libm",
|
||||
]
|
||||
|
||||
SLEEF_PUBLIC_INCLUDES = [
|
||||
@@ -201,8 +202,6 @@ cc_library(
|
||||
srcs = [
|
||||
"src/libm/rempitab.c",
|
||||
"src/libm/sleefdp.c",
|
||||
- "src/libm/sleefld.c",
|
||||
- "src/libm/sleefqp.c",
|
||||
"src/libm/sleefsp.c",
|
||||
],
|
||||
hdrs = SLEEF_PUBLIC_HEADERS,
|
||||
--
|
||||
2.45.1
|
||||
|
||||
|
|
@ -1,142 +0,0 @@
|
|||
From 201ac4618a1526e048a0d6c02d9bc4cf30bf0ee1 Mon Sep 17 00:00:00 2001
|
||||
From: Tom Rix <Tom.Rix@amd.com>
|
||||
Date: Wed, 14 Aug 2024 17:18:38 -0700
|
||||
Subject: [PATCH] Improve finding and using the rocm_version.h
|
||||
|
||||
On Fedora, the rocm_version.h's path is /usr/include/rocm_version.h
|
||||
So we have this build error
|
||||
pytorch/aten/src/ATen/hip/tunable/Tunable.cpp:40:10: fatal error:
|
||||
rocm-core/rocm_version.h: No such file or directory
|
||||
40 | #include <rocm-core/rocm_version.h>
|
||||
| ^~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
In other cases, depending on the rocm release either
|
||||
/opt/rocm/include or /opt/rocm/include/rocm-core
|
||||
|
||||
Convert the EXISTS() checks into a find_path.
|
||||
Add a -I${ROCM_VERSION_DIR} to the compile options so it can be
|
||||
found by Tunable.cpp
|
||||
|
||||
Signed-off-by: Tom Rix <Tom.Rix@amd.com>
|
||||
---
|
||||
aten/src/ATen/cuda/tunable/Tunable.cpp | 2 +-
|
||||
cmake/Dependencies.cmake | 1 +
|
||||
cmake/public/LoadHIP.cmake | 72 ++++++++++----------------
|
||||
3 files changed, 30 insertions(+), 45 deletions(-)
|
||||
|
||||
diff --git a/aten/src/ATen/cuda/tunable/Tunable.cpp b/aten/src/ATen/cuda/tunable/Tunable.cpp
|
||||
index 1b7c89875855..32c1d70f3152 100644
|
||||
--- a/aten/src/ATen/cuda/tunable/Tunable.cpp
|
||||
+++ b/aten/src/ATen/cuda/tunable/Tunable.cpp
|
||||
@@ -36,7 +36,7 @@
|
||||
|
||||
// for validators
|
||||
#ifdef USE_ROCM
|
||||
-#include <rocm-core/rocm_version.h>
|
||||
+#include <rocm_version.h>
|
||||
#define ROCBLAS_BETA_FEATURES_API
|
||||
#include <rocblas/rocblas.h>
|
||||
#include <hipblaslt/hipblaslt.h>
|
||||
diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake
|
||||
index 7ef8eabb5162..61bc4d7a54b6 100644
|
||||
--- a/cmake/Dependencies.cmake
|
||||
+++ b/cmake/Dependencies.cmake
|
||||
@@ -1063,6 +1063,7 @@ if(USE_ROCM)
|
||||
list(APPEND HIP_CXX_FLAGS -DTHRUST_DEVICE_SYSTEM=THRUST_DEVICE_SYSTEM_HIP)
|
||||
list(APPEND HIP_CXX_FLAGS -std=c++17)
|
||||
list(APPEND HIP_CXX_FLAGS -DHIPBLAS_V2)
|
||||
+ list(APPEND HIP_CXX_FLAGS -I${ROCM_VERSION_DIR})
|
||||
if(HIP_NEW_TYPE_ENUMS)
|
||||
list(APPEND HIP_CXX_FLAGS -DHIP_NEW_TYPE_ENUMS)
|
||||
endif()
|
||||
diff --git a/cmake/public/LoadHIP.cmake b/cmake/public/LoadHIP.cmake
|
||||
index 1c0d3a203991..6a7e3bd163f5 100644
|
||||
--- a/cmake/public/LoadHIP.cmake
|
||||
+++ b/cmake/public/LoadHIP.cmake
|
||||
@@ -42,55 +42,39 @@ find_package_and_print_version(HIP 1.0)
|
||||
|
||||
if(HIP_FOUND)
|
||||
set(PYTORCH_FOUND_HIP TRUE)
|
||||
- set(FOUND_ROCM_VERSION_H FALSE)
|
||||
-
|
||||
set(PROJECT_RANDOM_BINARY_DIR "${PROJECT_BINARY_DIR}")
|
||||
- set(file "${PROJECT_BINARY_DIR}/detect_rocm_version.cc")
|
||||
|
||||
# Find ROCM version for checks
|
||||
# ROCM 5.0 and later will have header api for version management
|
||||
- if(EXISTS ${ROCM_INCLUDE_DIRS}/rocm_version.h)
|
||||
- set(FOUND_ROCM_VERSION_H TRUE)
|
||||
- file(WRITE ${file} ""
|
||||
- "#include <rocm_version.h>\n"
|
||||
- )
|
||||
- elseif(EXISTS ${ROCM_INCLUDE_DIRS}/rocm-core/rocm_version.h)
|
||||
- set(FOUND_ROCM_VERSION_H TRUE)
|
||||
- file(WRITE ${file} ""
|
||||
- "#include <rocm-core/rocm_version.h>\n"
|
||||
- )
|
||||
- else()
|
||||
- message("********************* rocm_version.h couldnt be found ******************\n")
|
||||
- endif()
|
||||
-
|
||||
- if(FOUND_ROCM_VERSION_H)
|
||||
- file(APPEND ${file} ""
|
||||
- "#include <cstdio>\n"
|
||||
-
|
||||
- "#ifndef ROCM_VERSION_PATCH\n"
|
||||
- "#define ROCM_VERSION_PATCH 0\n"
|
||||
- "#endif\n"
|
||||
- "#define STRINGIFYHELPER(x) #x\n"
|
||||
- "#define STRINGIFY(x) STRINGIFYHELPER(x)\n"
|
||||
- "int main() {\n"
|
||||
- " printf(\"%d.%d.%s\", ROCM_VERSION_MAJOR, ROCM_VERSION_MINOR, STRINGIFY(ROCM_VERSION_PATCH));\n"
|
||||
- " return 0;\n"
|
||||
- "}\n"
|
||||
- )
|
||||
-
|
||||
- try_run(run_result compile_result ${PROJECT_RANDOM_BINARY_DIR} ${file}
|
||||
- CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${ROCM_INCLUDE_DIRS}"
|
||||
- RUN_OUTPUT_VARIABLE rocm_version_from_header
|
||||
- COMPILE_OUTPUT_VARIABLE output_var
|
||||
- )
|
||||
- # We expect the compile to be successful if the include directory exists.
|
||||
- if(NOT compile_result)
|
||||
- message(FATAL_ERROR "Caffe2: Couldn't determine version from header: " ${output_var})
|
||||
- endif()
|
||||
- message(STATUS "Caffe2: Header version is: " ${rocm_version_from_header})
|
||||
- set(ROCM_VERSION_DEV_RAW ${rocm_version_from_header})
|
||||
- message("\n***** ROCm version from rocm_version.h ****\n")
|
||||
+ find_path(ROCM_VERSION_DIR rocm_version.h HINTS ${ROCM_INCLUDE_DIRS} ${ROCM_INCLUDE_DIRS}/rocm-core)
|
||||
+ set(file "${PROJECT_BINARY_DIR}/detect_rocm_version.cc")
|
||||
+ file(WRITE ${file} ""
|
||||
+ "#include <rocm_version.h>\n"
|
||||
+ "#include <cstdio>\n"
|
||||
+
|
||||
+ "#ifndef ROCM_VERSION_PATCH\n"
|
||||
+ "#define ROCM_VERSION_PATCH 0\n"
|
||||
+ "#endif\n"
|
||||
+ "#define STRINGIFYHELPER(x) #x\n"
|
||||
+ "#define STRINGIFY(x) STRINGIFYHELPER(x)\n"
|
||||
+ "int main() {\n"
|
||||
+ " printf(\"%d.%d.%s\", ROCM_VERSION_MAJOR, ROCM_VERSION_MINOR, STRINGIFY(ROCM_VERSION_PATCH));\n"
|
||||
+ " return 0;\n"
|
||||
+ "}\n"
|
||||
+ )
|
||||
+
|
||||
+ try_run(run_result compile_result ${PROJECT_RANDOM_BINARY_DIR} ${file}
|
||||
+ CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${ROCM_VERSION_DIR}"
|
||||
+ RUN_OUTPUT_VARIABLE rocm_version_from_header
|
||||
+ COMPILE_OUTPUT_VARIABLE output_var
|
||||
+ )
|
||||
+ # We expect the compile to be successful if the include directory exists.
|
||||
+ if(NOT compile_result)
|
||||
+ message(FATAL_ERROR "Caffe2: Couldn't determine version from header: " ${output_var})
|
||||
endif()
|
||||
+ message(STATUS "Caffe2: Header version is: " ${rocm_version_from_header})
|
||||
+ set(ROCM_VERSION_DEV_RAW ${rocm_version_from_header})
|
||||
+ message("\n***** ROCm version from rocm_version.h ****\n")
|
||||
|
||||
string(REGEX MATCH "^([0-9]+)\.([0-9]+)\.([0-9]+).*$" ROCM_VERSION_DEV_MATCH ${ROCM_VERSION_DEV_RAW})
|
||||
|
||||
--
|
||||
2.46.0
|
||||
|
||||
|
|
@ -1,506 +0,0 @@
|
|||
From f1d65e958afa65882dbfea8b392ab847a84d41ed Mon Sep 17 00:00:00 2001
|
||||
From: Tom Rix <trix@redhat.com>
|
||||
Date: Sat, 29 Jun 2024 04:18:34 -0700
|
||||
Subject: [PATCH] Optionally use hipblaslt
|
||||
|
||||
---
|
||||
aten/src/ATen/cuda/CUDABlas.cpp | 46 ++++++++++++++++++------
|
||||
aten/src/ATen/cuda/CUDAContextLight.h | 4 +++
|
||||
aten/src/ATen/cuda/CublasHandlePool.cpp | 10 ++++--
|
||||
aten/src/ATen/cuda/tunable/TunableGemm.h | 18 +++++++---
|
||||
aten/src/ATen/native/cuda/Blas.cpp | 18 +++++++++-
|
||||
cmake/Dependencies.cmake | 3 ++
|
||||
cmake/public/LoadHIP.cmake | 2 +-
|
||||
7 files changed, 82 insertions(+), 19 deletions(-)
|
||||
|
||||
diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp
|
||||
index ce991a9bcad4..3f0d17b52778 100644
|
||||
--- a/aten/src/ATen/cuda/CUDABlas.cpp
|
||||
+++ b/aten/src/ATen/cuda/CUDABlas.cpp
|
||||
@@ -14,7 +14,9 @@
|
||||
#include <c10/util/irange.h>
|
||||
|
||||
#ifdef USE_ROCM
|
||||
+#ifdef USE_HIPBLASLT
|
||||
#include <hipblaslt/hipblaslt-ext.hpp>
|
||||
+#endif
|
||||
// until hipblas has an API to accept flags, we must use rocblas here
|
||||
#include <hipblas/hipblas.h>
|
||||
#include <rocblas/rocblas.h>
|
||||
@@ -182,6 +184,9 @@ uint32_t _getAlignment(uintptr_t address) {
|
||||
static size_t _parseChosenWorkspaceSize() {
|
||||
const char * val = getenv("CUBLASLT_WORKSPACE_SIZE");
|
||||
#ifdef USE_ROCM
|
||||
+#ifndef USE_HIPBLASLT
|
||||
+ return 0;
|
||||
+#endif
|
||||
if (!val) {
|
||||
// accept either env var
|
||||
val = getenv("HIPBLASLT_WORKSPACE_SIZE");
|
||||
@@ -235,6 +240,7 @@ namespace at::cuda::blas {
|
||||
} while (0)
|
||||
|
||||
|
||||
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
|
||||
namespace {
|
||||
// Following the pattern of CuSparseDescriptor
|
||||
// Defined here for now because this is the only place cublas_lt interface is
|
||||
@@ -318,7 +324,6 @@ class CuBlasLtMatmulPreference : public CuBlasLtDescriptor<
|
||||
};
|
||||
} // namespace
|
||||
|
||||
-
|
||||
template <typename Dtype>
|
||||
inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
|
||||
cudaDataType_t abcType = CUDA_R_32F;
|
||||
@@ -452,7 +457,7 @@ inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
|
||||
" scaleType ",
|
||||
scaleType);
|
||||
}
|
||||
-
|
||||
+#endif
|
||||
|
||||
template <typename Dtype>
|
||||
inline void bgemm_internal_cublas(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
|
||||
@@ -608,10 +613,13 @@ void bgemm_internal<double>(CUDABLAS_BGEMM_ARGTYPES(double))
|
||||
template <>
|
||||
void bgemm_internal<float>(CUDABLAS_BGEMM_ARGTYPES(float))
|
||||
{
|
||||
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
|
||||
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
|
||||
bgemm_internal_cublaslt<float>(CUDABLAS_BGEMM_ARGS(float));
|
||||
}
|
||||
- else {
|
||||
+ else
|
||||
+#endif
|
||||
+ {
|
||||
bgemm_internal_cublas<float>(CUDABLAS_BGEMM_ARGS(float));
|
||||
}
|
||||
}
|
||||
@@ -651,10 +659,13 @@ void bgemm_internal<c10::complex<float>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<fl
|
||||
template <>
|
||||
void bgemm_internal<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half))
|
||||
{
|
||||
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
|
||||
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
|
||||
bgemm_internal_cublaslt<at::Half>(CUDABLAS_BGEMM_ARGS(at::Half));
|
||||
}
|
||||
- else {
|
||||
+ else
|
||||
+#endif
|
||||
+ {
|
||||
bgemm_internal_cublas<at::Half>(CUDABLAS_BGEMM_ARGS(at::Half));
|
||||
}
|
||||
}
|
||||
@@ -662,10 +673,13 @@ void bgemm_internal<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half))
|
||||
template <>
|
||||
void bgemm_internal<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16))
|
||||
{
|
||||
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
|
||||
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
|
||||
bgemm_internal_cublaslt<at::BFloat16>(CUDABLAS_BGEMM_ARGS(at::BFloat16));
|
||||
}
|
||||
- else {
|
||||
+ else
|
||||
+#endif
|
||||
+ {
|
||||
bgemm_internal_cublas<at::BFloat16>(CUDABLAS_BGEMM_ARGS(at::BFloat16));
|
||||
}
|
||||
}
|
||||
@@ -781,11 +795,13 @@ void bgemm<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) {
|
||||
}
|
||||
}
|
||||
|
||||
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
|
||||
template <typename Dtype>
|
||||
inline void gemm_internal_cublaslt(CUDABLAS_GEMM_ARGTYPES(Dtype)) {
|
||||
// forward to bgemm implementation but set strides and batches to 0
|
||||
bgemm_internal_cublaslt(transa, transb, m, n, k, alpha, a, lda, 0, b, ldb, 0, beta, c, ldc, 0, 0);
|
||||
}
|
||||
+#endif
|
||||
|
||||
template <typename Dtype>
|
||||
inline void gemm_internal_cublas(CUDABLAS_GEMM_ARGTYPES(Dtype)) {
|
||||
@@ -1008,10 +1024,13 @@ void gemm_internal<double>(CUDABLAS_GEMM_ARGTYPES(double))
|
||||
template <>
|
||||
void gemm_internal<float>(CUDABLAS_GEMM_ARGTYPES(float))
|
||||
{
|
||||
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
|
||||
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
|
||||
gemm_internal_cublaslt<float>(CUDABLAS_GEMM_ARGS(float));
|
||||
}
|
||||
- else {
|
||||
+ else
|
||||
+#endif
|
||||
+ {
|
||||
gemm_internal_cublas<float>(CUDABLAS_GEMM_ARGS(float));
|
||||
}
|
||||
}
|
||||
@@ -1051,10 +1070,13 @@ void gemm_internal<c10::complex<float>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<floa
|
||||
template <>
|
||||
void gemm_internal<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half))
|
||||
{
|
||||
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
|
||||
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
|
||||
gemm_internal_cublaslt<at::Half>(CUDABLAS_GEMM_ARGS(at::Half));
|
||||
}
|
||||
- else {
|
||||
+ else
|
||||
+#endif
|
||||
+ {
|
||||
gemm_internal_cublas<at::Half>(CUDABLAS_GEMM_ARGS(at::Half));
|
||||
}
|
||||
}
|
||||
@@ -1062,10 +1084,13 @@ void gemm_internal<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half))
|
||||
template <>
|
||||
void gemm_internal<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16))
|
||||
{
|
||||
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
|
||||
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
|
||||
gemm_internal_cublaslt<at::BFloat16>(CUDABLAS_GEMM_ARGS(at::BFloat16));
|
||||
}
|
||||
- else {
|
||||
+ else
|
||||
+#endif
|
||||
+ {
|
||||
gemm_internal_cublas<at::BFloat16>(CUDABLAS_GEMM_ARGS(at::BFloat16));
|
||||
}
|
||||
}
|
||||
@@ -1177,7 +1202,7 @@ void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
|
||||
}
|
||||
}
|
||||
|
||||
-
|
||||
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
|
||||
template <typename Dtype>
|
||||
void gemm_and_bias(
|
||||
bool transpose_mat1,
|
||||
@@ -1410,7 +1435,7 @@ void scaled_gemm(
|
||||
ScalarType result_dtype,
|
||||
void* amax_ptr,
|
||||
bool use_fast_accum) {
|
||||
-#if CUDA_VERSION >= 11080 || defined(USE_ROCM)
|
||||
+#if CUDA_VERSION >= 11080 || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
|
||||
const auto computeType = CUBLAS_COMPUTE_32F;
|
||||
const auto scaleType = CUDA_R_32F;
|
||||
const int8_t fastAccuMode = use_fast_accum ? 1 : 0;
|
||||
@@ -1681,6 +1706,7 @@ void int8_gemm(
|
||||
" scaleType ",
|
||||
scaleType);
|
||||
}
|
||||
+#endif
|
||||
|
||||
template <>
|
||||
void trsm<float>(CUDABLAS_TRSM_ARGTYPES(float)) {
|
||||
diff --git a/aten/src/ATen/cuda/CUDAContextLight.h b/aten/src/ATen/cuda/CUDAContextLight.h
|
||||
index f2b657ced51b..f0ee613c4208 100644
|
||||
--- a/aten/src/ATen/cuda/CUDAContextLight.h
|
||||
+++ b/aten/src/ATen/cuda/CUDAContextLight.h
|
||||
@@ -9,7 +9,9 @@
|
||||
|
||||
// cublasLT was introduced in CUDA 10.1 but we enable only for 11.1 that also
|
||||
// added bf16 support
|
||||
+#if (!defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT)))
|
||||
#include <cublasLt.h>
|
||||
+#endif
|
||||
|
||||
#ifdef CUDART_VERSION
|
||||
#include <cusolverDn.h>
|
||||
@@ -80,7 +82,9 @@ TORCH_CUDA_CPP_API c10::Allocator* getCUDADeviceAllocator();
|
||||
/* Handles */
|
||||
TORCH_CUDA_CPP_API cusparseHandle_t getCurrentCUDASparseHandle();
|
||||
TORCH_CUDA_CPP_API cublasHandle_t getCurrentCUDABlasHandle();
|
||||
+#if (!defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT)))
|
||||
TORCH_CUDA_CPP_API cublasLtHandle_t getCurrentCUDABlasLtHandle();
|
||||
+#endif
|
||||
|
||||
TORCH_CUDA_CPP_API void clearCublasWorkspaces();
|
||||
|
||||
diff --git a/aten/src/ATen/cuda/CublasHandlePool.cpp b/aten/src/ATen/cuda/CublasHandlePool.cpp
|
||||
index 8eac525b3695..abfdf7a23847 100644
|
||||
--- a/aten/src/ATen/cuda/CublasHandlePool.cpp
|
||||
+++ b/aten/src/ATen/cuda/CublasHandlePool.cpp
|
||||
@@ -29,7 +29,7 @@ namespace at::cuda {
|
||||
|
||||
namespace {
|
||||
|
||||
-#if defined(USE_ROCM)
|
||||
+#if defined(USE_ROCM) && defined(USE_HIPBLASLT)
|
||||
void createCublasLtHandle(cublasLtHandle_t *handle) {
|
||||
TORCH_CUDABLAS_CHECK(cublasLtCreate(handle));
|
||||
}
|
||||
@@ -191,8 +191,9 @@ cublasHandle_t getCurrentCUDABlasHandle() {
|
||||
return handle;
|
||||
}
|
||||
|
||||
-cublasLtHandle_t getCurrentCUDABlasLtHandle() {
|
||||
#ifdef USE_ROCM
|
||||
+#if defined(USE_HIPBLASLT)
|
||||
+cublasLtHandle_t getCurrentCUDABlasLtHandle() {
|
||||
c10::DeviceIndex device = 0;
|
||||
AT_CUDA_CHECK(c10::cuda::GetDevice(&device));
|
||||
|
||||
@@ -213,9 +214,12 @@ cublasLtHandle_t getCurrentCUDABlasLtHandle() {
|
||||
|
||||
auto handle = myPoolWindow->reserve(device);
|
||||
return handle;
|
||||
+}
|
||||
+#endif
|
||||
#else
|
||||
+cublasLtHandle_t getCurrentCUDABlasLtHandle() {
|
||||
return reinterpret_cast<cublasLtHandle_t>(getCurrentCUDABlasHandle());
|
||||
-#endif
|
||||
}
|
||||
+#endif
|
||||
|
||||
} // namespace at::cuda
|
||||
diff --git a/aten/src/ATen/cuda/tunable/TunableGemm.h b/aten/src/ATen/cuda/tunable/TunableGemm.h
|
||||
index 53e6154120c9..fa1d664696db 100644
|
||||
--- a/aten/src/ATen/cuda/tunable/TunableGemm.h
|
||||
+++ b/aten/src/ATen/cuda/tunable/TunableGemm.h
|
||||
@@ -11,7 +11,9 @@
|
||||
|
||||
#include <ATen/cuda/tunable/GemmCommon.h>
|
||||
#ifdef USE_ROCM
|
||||
+#ifdef USE_HIPBLASLT
|
||||
#include <ATen/cuda/tunable/GemmHipblaslt.h>
|
||||
+#endif
|
||||
#include <ATen/cuda/tunable/GemmRocblas.h>
|
||||
#endif
|
||||
#include <ATen/cuda/tunable/StreamTimer.h>
|
||||
@@ -65,6 +67,7 @@ class DefaultGemmStridedBatchedOp : public Callable<GemmStridedBatchedParams<T>>
|
||||
}
|
||||
};
|
||||
|
||||
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
|
||||
template <typename T>
|
||||
class DefaultScaledGemmOp : public Callable<ScaledGemmParams<T>> {
|
||||
public:
|
||||
@@ -94,6 +97,7 @@ class DefaultScaledGemmOp : public Callable<ScaledGemmParams<T>> {
|
||||
return OK;
|
||||
}
|
||||
};
|
||||
+#endif
|
||||
|
||||
template <typename T>
|
||||
inline bool IsZero(T v) {
|
||||
@@ -191,6 +195,7 @@ static void AddRocblasValidator() {
|
||||
}
|
||||
}
|
||||
|
||||
+#ifdef USE_HIPBLASLT
|
||||
static void AddHipblasltValidator() {
|
||||
auto validators = getTuningContext()->GetTuningResultsValidator().GetAllValidators();
|
||||
if (validators.find("HIPBLASLT_VERSION") == validators.end()) {
|
||||
@@ -205,6 +210,7 @@ static void AddHipblasltValidator() {
|
||||
[hipblaslt_version](auto&& k) { return hipblaslt_version == k ? OK : FAIL; });
|
||||
}
|
||||
}
|
||||
+#endif
|
||||
|
||||
static void AddRocmValidator() {
|
||||
auto validators = getTuningContext()->GetTuningResultsValidator().GetAllValidators();
|
||||
@@ -243,7 +249,7 @@ class GemmTunableOp : public TunableOp<GemmParams<T>, StreamTimer> {
|
||||
}
|
||||
AddRocblasValidator();
|
||||
}
|
||||
-
|
||||
+#ifdef USE_HIPBLASLT
|
||||
static const char *env_hipblaslt = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED");
|
||||
if (env_hipblaslt == nullptr || strcmp(env_hipblaslt, "1") == 0) {
|
||||
rocm_validators = true;
|
||||
@@ -257,7 +263,7 @@ class GemmTunableOp : public TunableOp<GemmParams<T>, StreamTimer> {
|
||||
}
|
||||
AddHipblasltValidator();
|
||||
}
|
||||
-
|
||||
+#endif
|
||||
if (rocm_validators) {
|
||||
AddRocmValidator();
|
||||
}
|
||||
@@ -286,7 +292,7 @@ class GemmStridedBatchedTunableOp : public TunableOp<GemmStridedBatchedParams<T>
|
||||
}
|
||||
AddRocblasValidator();
|
||||
}
|
||||
-
|
||||
+#ifdef USE_HIPBLASLT
|
||||
static const char *env_hipblaslt = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED");
|
||||
if (env_hipblaslt == nullptr || strcmp(env_hipblaslt, "1") == 0) {
|
||||
rocm_validators = true;
|
||||
@@ -300,7 +306,7 @@ class GemmStridedBatchedTunableOp : public TunableOp<GemmStridedBatchedParams<T>
|
||||
}
|
||||
AddHipblasltValidator();
|
||||
}
|
||||
-
|
||||
+#endif
|
||||
if (rocm_validators) {
|
||||
AddRocmValidator();
|
||||
}
|
||||
@@ -312,6 +318,7 @@ class GemmStridedBatchedTunableOp : public TunableOp<GemmStridedBatchedParams<T>
|
||||
}
|
||||
};
|
||||
|
||||
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
|
||||
template <typename AT, typename BT, typename CT, BlasOp ALayout, BlasOp BLayout>
|
||||
class ScaledGemmTunableOp : public TunableOp<ScaledGemmParams<CT>, StreamTimer> {
|
||||
public:
|
||||
@@ -321,10 +328,12 @@ class ScaledGemmTunableOp : public TunableOp<ScaledGemmParams<CT>, StreamTimer>
|
||||
auto validators = getTuningContext()->GetTuningResultsValidator().GetAllValidators();
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
+#ifdef USE_HIPBLASLT
|
||||
for (auto&& [name, op] : GetHipBlasLtScaledGemmTypeStringAndOps<AT, BT, CT, ALayout, BLayout>()) {
|
||||
this->RegisterOp(std::move(name), std::move(op));
|
||||
}
|
||||
AddHipblasltValidator();
|
||||
+#endif
|
||||
AddRocmValidator();
|
||||
#endif
|
||||
}
|
||||
@@ -337,6 +346,7 @@ class ScaledGemmTunableOp : public TunableOp<ScaledGemmParams<CT>, StreamTimer>
|
||||
"_", BlasOpToString(ALayout), BlasOpToString(BLayout));
|
||||
}
|
||||
};
|
||||
+#endif
|
||||
|
||||
#undef XSTRINGIFY
|
||||
#undef STRINGIFY
|
||||
diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp
|
||||
index 84c59a4fd0d7..56ad5de3bf2d 100644
|
||||
--- a/aten/src/ATen/native/cuda/Blas.cpp
|
||||
+++ b/aten/src/ATen/native/cuda/Blas.cpp
|
||||
@@ -173,6 +173,7 @@ cuda::blas::GEMMAndBiasActivationEpilogue activation_to_gemm_and_blas_arg(Activa
|
||||
}
|
||||
|
||||
static bool getDisableAddmmCudaLt() {
|
||||
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
|
||||
static const char* env_value = std::getenv("DISABLE_ADDMM_CUDA_LT");
|
||||
#ifdef USE_ROCM
|
||||
// if we enable tunable op, it'll take priority over just hipblaslt (heuristics)
|
||||
@@ -196,10 +197,14 @@ static bool getDisableAddmmCudaLt() {
|
||||
}
|
||||
return false;
|
||||
#endif
|
||||
+#else
|
||||
+ return true;
|
||||
+#endif
|
||||
}
|
||||
|
||||
#ifdef USE_ROCM
|
||||
static bool isSupportedHipLtROCmArch(int index) {
|
||||
+#ifdef USE_HIPBLASLT
|
||||
hipDeviceProp_t* prop = at::cuda::getDeviceProperties(index);
|
||||
std::string device_arch = prop->gcnArchName;
|
||||
static const std::vector<std::string> archs = {"gfx90a", "gfx940", "gfx941", "gfx942"};
|
||||
@@ -210,6 +215,7 @@ static bool isSupportedHipLtROCmArch(int index) {
|
||||
}
|
||||
}
|
||||
TORCH_CHECK(false, "Attempting to use hipBLASLt on a unsupported architecture!");
|
||||
+#endif
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
@@ -235,6 +241,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
|
||||
at::ScalarType scalar_type = self.scalar_type();
|
||||
c10::MaybeOwned<Tensor> self_;
|
||||
if (&result != &self) {
|
||||
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
|
||||
#if (defined(CUDA_VERSION) && (CUDA_VERSION >= 11040)) || defined(USE_ROCM)
|
||||
// Strangely, if mat2 has only 1 row or column, we get
|
||||
// CUBLAS_STATUS_INVALID_VALUE error from cublasLtMatmulAlgoGetHeuristic.
|
||||
@@ -276,13 +283,14 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
|
||||
scalar_type != at::ScalarType::BFloat16));
|
||||
#endif
|
||||
}
|
||||
+#endif
|
||||
#endif
|
||||
if (!useLtInterface) {
|
||||
self_ = expand_size(self, {mat1_sizes[0], mat2_sizes[1]}, "addmm");
|
||||
}
|
||||
self__sizes = self_->sizes();
|
||||
} else {
|
||||
-#if defined(USE_ROCM)
|
||||
+#if defined(USE_ROCM) && defined(USE_HIPBLASLT)
|
||||
useLtInterface = !disable_addmm_cuda_lt &&
|
||||
result.dim() == 2 && result.is_contiguous() &&
|
||||
isSupportedHipLtROCmArch(self.device().index()) &&
|
||||
@@ -334,6 +342,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!args.result->is_conj());
|
||||
|
||||
if (useLtInterface) {
|
||||
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
|
||||
#if defined(USE_ROCM)
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
at::ScalarType::Half,
|
||||
@@ -394,6 +403,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
|
||||
activation_epilogue
|
||||
);
|
||||
});
|
||||
+#endif
|
||||
#endif
|
||||
} else
|
||||
{
|
||||
@@ -803,6 +813,7 @@ Tensor _int_mm_cuda(const Tensor& self, const Tensor& mat2) {
|
||||
}
|
||||
|
||||
static bool _scaled_mm_allowed_device() {
|
||||
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
#ifdef USE_ROCM
|
||||
std::string device_arch = dprops->gcnArchName;
|
||||
@@ -817,6 +828,9 @@ static bool _scaled_mm_allowed_device() {
|
||||
#else
|
||||
return dprops->major >= 9 || (dprops->major == 8 && dprops->minor == 9);
|
||||
#endif
|
||||
+#else
|
||||
+ return false;
|
||||
+#endif
|
||||
}
|
||||
|
||||
// Computes matrix multiply + bias while applying scaling to input and output matrices and computes amax
|
||||
@@ -850,6 +864,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
|
||||
// Check sizes
|
||||
bool allowed_device = _scaled_mm_allowed_device();
|
||||
TORCH_CHECK(allowed_device, "torch._scaled_mm is only supported on CUDA devices with compute capability >= 9.0 or 8.9, or ROCm MI300+");
|
||||
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
|
||||
TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix");
|
||||
TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix");
|
||||
TORCH_CHECK(
|
||||
@@ -1025,6 +1040,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
|
||||
#if defined(USE_ROCM) && ROCM_VERSION >= 60000 && ROCM_VERSION < 60200
|
||||
// ROCm's hipBLASLt does not support amax before 6.2, so calculate separately
|
||||
amax = at::max(at::abs(out.to(kFloat)));
|
||||
+#endif
|
||||
#endif
|
||||
|
||||
return {out, amax};
|
||||
diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake
|
||||
index f1f2eb7cec31..8d05e834bbc5 100644
|
||||
--- a/cmake/Dependencies.cmake
|
||||
+++ b/cmake/Dependencies.cmake
|
||||
@@ -1052,6 +1052,9 @@ if(USE_ROCM)
|
||||
list(APPEND HIP_CXX_FLAGS -DTHRUST_DEVICE_SYSTEM=THRUST_DEVICE_SYSTEM_HIP)
|
||||
list(APPEND HIP_CXX_FLAGS -std=c++17)
|
||||
list(APPEND HIP_CXX_FLAGS -DHIPBLAS_V2)
|
||||
+ if(hipblast_FOUND)
|
||||
+ list(APPEND HIP_CXX_FLAGS -DUSE_HIPBLASLT)
|
||||
+ endif()
|
||||
if(HIP_NEW_TYPE_ENUMS)
|
||||
list(APPEND HIP_CXX_FLAGS -DHIP_NEW_TYPE_ENUMS)
|
||||
endif()
|
||||
diff --git a/cmake/public/LoadHIP.cmake b/cmake/public/LoadHIP.cmake
|
||||
index fa39156031ff..df4836847fdf 100644
|
||||
--- a/cmake/public/LoadHIP.cmake
|
||||
+++ b/cmake/public/LoadHIP.cmake
|
||||
@@ -155,7 +155,7 @@ if(HIP_FOUND)
|
||||
find_package_and_print_version(hiprand REQUIRED)
|
||||
find_package_and_print_version(rocblas REQUIRED)
|
||||
find_package_and_print_version(hipblas REQUIRED)
|
||||
- find_package_and_print_version(hipblaslt REQUIRED)
|
||||
+ find_package_and_print_version(hipblaslt)
|
||||
find_package_and_print_version(miopen REQUIRED)
|
||||
find_package_and_print_version(hipfft REQUIRED)
|
||||
find_package_and_print_version(hipsparse REQUIRED)
|
||||
--
|
||||
2.45.2
|
||||
|
||||
|
|
@ -1,952 +0,0 @@
|
|||
From 273f23698c887b52e66c2abec8101b7398f0f9c4 Mon Sep 17 00:00:00 2001
|
||||
From: "Benjamin A. Beasley" <code@musicinmybrain.net>
|
||||
Date: Wed, 5 Jun 2024 11:06:02 -0400
|
||||
Subject: [PATCH] Patch for sleef 3.6
|
||||
|
||||
---
|
||||
...ectorization-on-windows-submodule-sl.patch | 910 ++++++++++++++++++
|
||||
python-torch.spec | 11 +
|
||||
2 files changed, 921 insertions(+)
|
||||
create mode 100644 0001-Enable-x86-CPU-vectorization-on-windows-submodule-sl.patch
|
||||
|
||||
diff --git a/0001-Enable-x86-CPU-vectorization-on-windows-submodule-sl.patch b/0001-Enable-x86-CPU-vectorization-on-windows-submodule-sl.patch
|
||||
new file mode 100644
|
||||
index 000000000000..562f55b742c2
|
||||
--- /dev/null
|
||||
+++ b/0001-Enable-x86-CPU-vectorization-on-windows-submodule-sl.patch
|
||||
@@ -0,0 +1,910 @@
|
||||
+From 3d1e4b3e5ddcdd2717e590c635097163fef64c83 Mon Sep 17 00:00:00 2001
|
||||
+From: Xu Han <xu.han@intel.com>
|
||||
+Date: Sun, 31 Mar 2024 03:07:32 +0000
|
||||
+Subject: [PATCH] Enable x86 CPU vectorization on windows [submodule sleef]
|
||||
+ (#118980)
|
||||
+
|
||||
+Enable VEC on Windows OS.
|
||||
+1. Fix some type defination gap between Windows and Linux.
|
||||
+2. Fix some operator not support on Windows, such as [], /.
|
||||
+3. Enable static sleef library build on Windows.
|
||||
+4. Disable unsupported function overloading on MSVC.
|
||||
+5. Upgrade submodule sleef lib, which fixed build issue on Windows.
|
||||
+6. Fixed bazel build issues.
|
||||
+7. Fix test app not link to sleef on Windows.
|
||||
+
|
||||
+Note: If rebuild fail after pulled this PR, please sync `sleef` submodule by run:
|
||||
+```cmd
|
||||
+git submodule sync
|
||||
+git submodule update --init --recursive
|
||||
+```
|
||||
+
|
||||
+Pull Request resolved: https://github.com/pytorch/pytorch/pull/118980
|
||||
+Approved by: https://github.com/jgong5, https://github.com/ezyang, https://github.com/malfet
|
||||
+---
|
||||
+ aten/src/ATen/CMakeLists.txt | 48 ++++++--------
|
||||
+ aten/src/ATen/cpu/vec/vec256/vec256.h | 14 ++--
|
||||
+ .../src/ATen/cpu/vec/vec256/vec256_bfloat16.h | 27 ++++++--
|
||||
+ .../cpu/vec/vec256/vec256_complex_double.h | 7 +-
|
||||
+ .../cpu/vec/vec256/vec256_complex_float.h | 7 +-
|
||||
+ aten/src/ATen/cpu/vec/vec256/vec256_double.h | 5 +-
|
||||
+ aten/src/ATen/cpu/vec/vec256/vec256_float.h | 15 +++--
|
||||
+ aten/src/ATen/cpu/vec/vec256/vec256_qint.h | 12 +++-
|
||||
+ aten/src/ATen/cpu/vec/vec512/vec512.h | 14 ++--
|
||||
+ .../src/ATen/cpu/vec/vec512/vec512_bfloat16.h | 27 ++++++--
|
||||
+ .../cpu/vec/vec512/vec512_complex_double.h | 7 +-
|
||||
+ .../cpu/vec/vec512/vec512_complex_float.h | 7 +-
|
||||
+ aten/src/ATen/cpu/vec/vec512/vec512_double.h | 5 +-
|
||||
+ aten/src/ATen/cpu/vec/vec512/vec512_float.h | 15 +++--
|
||||
+ aten/src/ATen/cpu/vec/vec512/vec512_qint.h | 66 ++++++++++++++++++-
|
||||
+ aten/src/ATen/cpu/vec/vec_base.h | 6 ++
|
||||
+ caffe2/CMakeLists.txt | 2 +-
|
||||
+ third_party/sleef.BUILD | 3 +-
|
||||
+ 18 files changed, 194 insertions(+), 93 deletions(-)
|
||||
+
|
||||
+diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt
|
||||
+index bf425af5fa9..58d5828e8ca 100644
|
||||
+--- a/aten/src/ATen/CMakeLists.txt
|
||||
++++ b/aten/src/ATen/CMakeLists.txt
|
||||
+@@ -419,32 +419,25 @@ if(NOT CMAKE_SYSTEM_PROCESSOR MATCHES "^(s390x|ppc64le)$")
|
||||
+ list(APPEND ATen_CPU_DEPENDENCY_LIBS cpuinfo)
|
||||
+ endif()
|
||||
+
|
||||
+-if(NOT MSVC AND NOT EMSCRIPTEN AND NOT INTERN_BUILD_MOBILE)
|
||||
+- # Preserve values for the main build
|
||||
+- set(__aten_sleef_build_shared_libs ${BUILD_SHARED_LIBS})
|
||||
+- set(__aten_sleef_build_tests ${BUILD_TESTS})
|
||||
+-
|
||||
+- # Unset our restrictive C++ flags here and reset them later.
|
||||
+- # Remove this once we use proper target_compile_options.
|
||||
+- set(OLD_CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS})
|
||||
+- set(CMAKE_CXX_FLAGS)
|
||||
+-
|
||||
+- # Bump up optimization level for sleef to -O1, since at -O0 the compiler
|
||||
+- # excessively spills intermediate vector registers to the stack
|
||||
+- # and makes things run impossibly slowly
|
||||
+- set(OLD_CMAKE_C_FLAGS_DEBUG ${CMAKE_C_FLAGS_DEBUG})
|
||||
+- if(${CMAKE_C_FLAGS_DEBUG} MATCHES "-O0")
|
||||
+- string(REGEX REPLACE "-O0" "-O1" CMAKE_C_FLAGS_DEBUG ${OLD_CMAKE_C_FLAGS_DEBUG})
|
||||
+- else()
|
||||
+- set(CMAKE_C_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG} -O1")
|
||||
++if(NOT EMSCRIPTEN AND NOT INTERN_BUILD_MOBILE)
|
||||
++ if(NOT MSVC)
|
||||
++ # Bump up optimization level for sleef to -O1, since at -O0 the compiler
|
||||
++ # excessively spills intermediate vector registers to the stack
|
||||
++ # and makes things run impossibly slowly
|
||||
++ set(OLD_CMAKE_C_FLAGS_DEBUG ${CMAKE_C_FLAGS_DEBUG})
|
||||
++ if(${CMAKE_C_FLAGS_DEBUG} MATCHES "-O0")
|
||||
++ string(REGEX REPLACE "-O0" "-O1" CMAKE_C_FLAGS_DEBUG ${OLD_CMAKE_C_FLAGS_DEBUG})
|
||||
++ else()
|
||||
++ set(CMAKE_C_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG} -O1")
|
||||
++ endif()
|
||||
+ endif()
|
||||
+
|
||||
+ if(NOT USE_SYSTEM_SLEEF)
|
||||
+- set(BUILD_SHARED_LIBS OFF CACHE BOOL "Build sleef static" FORCE)
|
||||
+- set(BUILD_DFT OFF CACHE BOOL "Don't build sleef DFT lib" FORCE)
|
||||
+- set(BUILD_GNUABI_LIBS OFF CACHE BOOL "Don't build sleef gnuabi libs" FORCE)
|
||||
+- set(BUILD_TESTS OFF CACHE BOOL "Don't build sleef tests" FORCE)
|
||||
+- set(OLD_CMAKE_BUILD_TYPE ${CMAKE_BUILD_TYPE})
|
||||
++ set(SLEEF_BUILD_SHARED_LIBS OFF CACHE BOOL "Build sleef static" FORCE)
|
||||
++ set(SLEEF_BUILD_DFT OFF CACHE BOOL "Don't build sleef DFT lib" FORCE)
|
||||
++ set(SLEEF_BUILD_GNUABI_LIBS OFF CACHE BOOL "Don't build sleef gnuabi libs" FORCE)
|
||||
++ set(SLEEF_BUILD_TESTS OFF CACHE BOOL "Don't build sleef tests" FORCE)
|
||||
++ set(SLEEF_BUILD_SCALAR_LIB OFF CACHE BOOL "libsleefscalar will be built." FORCE)
|
||||
+ if(CMAKE_SYSTEM_NAME STREQUAL "Darwin")
|
||||
+ if(CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64" OR CMAKE_OSX_ARCHITECTURES MATCHES "arm64")
|
||||
+ set(DISABLE_SVE ON CACHE BOOL "Xcode's clang-12.5 crashes while trying to compile SVE code" FORCE)
|
||||
+@@ -465,12 +458,9 @@ if(NOT MSVC AND NOT EMSCRIPTEN AND NOT INTERN_BUILD_MOBILE)
|
||||
+ endif()
|
||||
+ list(APPEND ATen_CPU_DEPENDENCY_LIBS sleef)
|
||||
+
|
||||
+- set(CMAKE_C_FLAGS_DEBUG ${OLD_CMAKE_C_FLAGS_DEBUG})
|
||||
+- set(CMAKE_CXX_FLAGS ${OLD_CMAKE_CXX_FLAGS})
|
||||
+-
|
||||
+- # Set these back. TODO: Use SLEEF_ to pass these instead
|
||||
+- set(BUILD_SHARED_LIBS ${__aten_sleef_build_shared_libs} CACHE BOOL "Build shared libs" FORCE)
|
||||
+- set(BUILD_TESTS ${__aten_sleef_build_tests} CACHE BOOL "Build tests" FORCE)
|
||||
++ if(NOT MSVC)
|
||||
++ set(CMAKE_C_FLAGS_DEBUG ${OLD_CMAKE_C_FLAGS_DEBUG})
|
||||
++ endif()
|
||||
+ endif()
|
||||
+
|
||||
+ if(USE_CUDA AND NOT USE_ROCM)
|
||||
+diff --git a/aten/src/ATen/cpu/vec/vec256/vec256.h b/aten/src/ATen/cpu/vec/vec256/vec256.h
|
||||
+index 800b027e469..c431fa3c605 100644
|
||||
+--- a/aten/src/ATen/cpu/vec/vec256/vec256.h
|
||||
++++ b/aten/src/ATen/cpu/vec/vec256/vec256.h
|
||||
+@@ -69,7 +69,7 @@ std::ostream& operator<<(std::ostream& stream, const Vectorized<T>& vec) {
|
||||
+ }
|
||||
+
|
||||
+
|
||||
+-#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
|
||||
++#if defined(CPU_CAPABILITY_AVX2)
|
||||
+
|
||||
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CAST (AVX2) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
+
|
||||
+@@ -94,7 +94,8 @@ inline Vectorized<double> cast<double, int64_t>(const Vectorized<int64_t>& src)
|
||||
+ }
|
||||
+
|
||||
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
+-
|
||||
++#ifndef _MSC_VER
|
||||
++// MSVC is not working well on complex function overload.
|
||||
+ template<int64_t scale = 1>
|
||||
+ std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<double>>
|
||||
+ inline gather(const double* base_addr, const Vectorized<int64_t>& vindex) {
|
||||
+@@ -106,9 +107,10 @@ std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorize
|
||||
+ inline gather(const float* base_addr, const Vectorized<int32_t>& vindex) {
|
||||
+ return _mm256_i32gather_ps(base_addr, vindex, scale);
|
||||
+ }
|
||||
+-
|
||||
++#endif
|
||||
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MASK GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
+-
|
||||
++#ifndef _MSC_VER
|
||||
++// MSVC is not working well on complex function overload.
|
||||
+ template<int64_t scale = 1>
|
||||
+ std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<double>>
|
||||
+ inline mask_gather(const Vectorized<double>& src, const double* base_addr,
|
||||
+@@ -122,7 +124,7 @@ inline mask_gather(const Vectorized<float>& src, const float* base_addr,
|
||||
+ const Vectorized<int32_t>& vindex, Vectorized<float>& mask) {
|
||||
+ return _mm256_mask_i32gather_ps(src, base_addr, vindex, mask, scale);
|
||||
+ }
|
||||
+-
|
||||
++#endif
|
||||
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CONVERT ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
+
|
||||
+ // Only works for inputs in the range: [-2^51, 2^51]
|
||||
+@@ -302,6 +304,6 @@ inline Vectorized<uint8_t> flip(const Vectorized<uint8_t> & v) {
|
||||
+ return flip8(v);
|
||||
+ }
|
||||
+
|
||||
+-#endif // (defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
|
||||
++#endif // (defined(CPU_CAPABILITY_AVX2)
|
||||
+
|
||||
+ }} // namepsace at::vec::CPU_CAPABILITY
|
||||
+diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_bfloat16.h b/aten/src/ATen/cpu/vec/vec256/vec256_bfloat16.h
|
||||
+index 3e26213d6d2..66557436c70 100644
|
||||
+--- a/aten/src/ATen/cpu/vec/vec256/vec256_bfloat16.h
|
||||
++++ b/aten/src/ATen/cpu/vec/vec256/vec256_bfloat16.h
|
||||
+@@ -7,7 +7,8 @@
|
||||
+ #include <ATen/cpu/vec/vec_base.h>
|
||||
+ #include <c10/util/irange.h>
|
||||
+
|
||||
+-#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
|
||||
++#if defined(CPU_CAPABILITY_AVX2)
|
||||
++#define SLEEF_STATIC_LIBS
|
||||
+ #include <sleef.h>
|
||||
+ #endif
|
||||
+
|
||||
+@@ -18,7 +19,18 @@ namespace at::vec {
|
||||
+ // See Note [CPU_CAPABILITY namespace]
|
||||
+ inline namespace CPU_CAPABILITY {
|
||||
+
|
||||
+-#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
|
||||
++#if defined(CPU_CAPABILITY_AVX2)
|
||||
++
|
||||
++#ifndef SLEEF_CONST
|
||||
++#if (defined(__GNUC__) || defined(__CLANG__)) && !defined(__INTEL_COMPILER)
|
||||
++#define SLEEF_CONST const
|
||||
++#else
|
||||
++#define SLEEF_CONST
|
||||
++#endif
|
||||
++#define SLEEF_CONST_OLD SLEEF_CONST
|
||||
++#else
|
||||
++#define SLEEF_CONST_OLD
|
||||
++#endif
|
||||
+
|
||||
+ // bfloat16 conversion
|
||||
+ static inline void cvtbf16_fp32(const __m128i& a, __m256& o) {
|
||||
+@@ -265,7 +277,8 @@ public:
|
||||
+ }
|
||||
+ return b;
|
||||
+ }
|
||||
+- Vectorized<T> map(const __m256 (*const vop)(__m256)) const {
|
||||
++
|
||||
++ Vectorized<T> map(SLEEF_CONST __m256 (*SLEEF_CONST_OLD vop)(__m256)) const {
|
||||
+ __m256 lo, hi;
|
||||
+ cvt_to_fp32<T>(values, lo, hi);
|
||||
+ const auto o1 = vop(lo);
|
||||
+@@ -1026,7 +1039,7 @@ inline Vectorized<type> convert_float_##name(const Vectorized<float>& a, const V
|
||||
+ CONVERT_VECTORIZED_INIT(BFloat16, bfloat16);
|
||||
+ CONVERT_VECTORIZED_INIT(Half, half);
|
||||
+
|
||||
+-#else // defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
|
||||
++#else // defined(CPU_CAPABILITY_AVX2)
|
||||
+
|
||||
+ #define CONVERT_NON_VECTORIZED_INIT(type, name) \
|
||||
+ inline std::tuple<Vectorized<float>, Vectorized<float>> convert_##name##_float(const Vectorized<type>& a) { \
|
||||
+@@ -1051,9 +1064,9 @@ inline Vectorized<type> convert_float_##name(const Vectorized<float>& a, const V
|
||||
+ CONVERT_NON_VECTORIZED_INIT(BFloat16, bfloat16);
|
||||
+ CONVERT_NON_VECTORIZED_INIT(Half, half);
|
||||
+
|
||||
+-#endif // defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
|
||||
++#endif // defined(CPU_CAPABILITY_AVX2)
|
||||
+
|
||||
+-#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
|
||||
++#if defined(CPU_CAPABILITY_AVX2)
|
||||
+ #define LOAD_FP32_VECTORIZED_INIT(type, name) \
|
||||
+ inline void load_fp32_from_##name(const type *data, Vectorized<float>& out) { \
|
||||
+ auto values = _mm_loadu_si128(reinterpret_cast<const __m128i*>(data)); \
|
||||
+@@ -1072,7 +1085,7 @@ inline void load_fp32_from_##name(const type *data, Vectorized<float>& out1, Vec
|
||||
+ LOAD_FP32_VECTORIZED_INIT(BFloat16, bf16);
|
||||
+ LOAD_FP32_VECTORIZED_INIT(Half, fp16);
|
||||
+
|
||||
+-#else // defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
|
||||
++#else // defined(CPU_CAPABILITY_AVX2)
|
||||
+ #define LOAD_FP32_NON_VECTORIZED_INIT(type, name) \
|
||||
+ inline void load_fp32_from_##name(const type *data, Vectorized<float>& out) { \
|
||||
+ __at_align__ float values[Vectorized<float>::size()]; \
|
||||
+diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_complex_double.h b/aten/src/ATen/cpu/vec/vec256/vec256_complex_double.h
|
||||
+index f93ea1e63c3..6c198fb37d3 100644
|
||||
+--- a/aten/src/ATen/cpu/vec/vec256/vec256_complex_double.h
|
||||
++++ b/aten/src/ATen/cpu/vec/vec256/vec256_complex_double.h
|
||||
+@@ -8,7 +8,8 @@
|
||||
+ #include <ATen/cpu/vec/intrinsics.h>
|
||||
+ #include <ATen/cpu/vec/vec_base.h>
|
||||
+
|
||||
+-#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
|
||||
++#if defined(CPU_CAPABILITY_AVX2)
|
||||
++#define SLEEF_STATIC_LIBS
|
||||
+ #include <sleef.h>
|
||||
+ #endif
|
||||
+
|
||||
+@@ -16,7 +17,7 @@ namespace at::vec {
|
||||
+ // See Note [CPU_CAPABILITY namespace]
|
||||
+ inline namespace CPU_CAPABILITY {
|
||||
+
|
||||
+-#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
|
||||
++#if defined(CPU_CAPABILITY_AVX2)
|
||||
+
|
||||
+ template <> class Vectorized<c10::complex<double>> {
|
||||
+ private:
|
||||
+@@ -145,7 +146,7 @@ public:
|
||||
+ auto abs = abs_();
|
||||
+ auto zero = _mm256_setzero_pd();
|
||||
+ auto mask = _mm256_cmp_pd(abs, zero, _CMP_EQ_OQ);
|
||||
+- auto div = values / abs;
|
||||
++ auto div = _mm256_div_pd(values, abs);
|
||||
+ return _mm256_blendv_pd(div, zero, mask);
|
||||
+ }
|
||||
+ __m256d real_() const {
|
||||
+diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_complex_float.h b/aten/src/ATen/cpu/vec/vec256/vec256_complex_float.h
|
||||
+index 7c142c04b79..c72d4d49274 100644
|
||||
+--- a/aten/src/ATen/cpu/vec/vec256/vec256_complex_float.h
|
||||
++++ b/aten/src/ATen/cpu/vec/vec256/vec256_complex_float.h
|
||||
+@@ -7,7 +7,8 @@
|
||||
+ #include <c10/util/irange.h>
|
||||
+ #include <ATen/cpu/vec/intrinsics.h>
|
||||
+ #include <ATen/cpu/vec/vec_base.h>
|
||||
+-#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
|
||||
++#if defined(CPU_CAPABILITY_AVX2)
|
||||
++#define SLEEF_STATIC_LIBS
|
||||
+ #include <sleef.h>
|
||||
+ #endif
|
||||
+
|
||||
+@@ -15,7 +16,7 @@ namespace at::vec {
|
||||
+ // See Note [CPU_CAPABILITY namespace]
|
||||
+ inline namespace CPU_CAPABILITY {
|
||||
+
|
||||
+-#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
|
||||
++#if defined(CPU_CAPABILITY_AVX2)
|
||||
+
|
||||
+ template <> class Vectorized<c10::complex<float>> {
|
||||
+ private:
|
||||
+@@ -180,7 +181,7 @@ public:
|
||||
+ auto abs = abs_();
|
||||
+ auto zero = _mm256_setzero_ps();
|
||||
+ auto mask = _mm256_cmp_ps(abs, zero, _CMP_EQ_OQ);
|
||||
+- auto div = values / abs;
|
||||
++ auto div = _mm256_div_ps(values, abs);
|
||||
+ return _mm256_blendv_ps(div, zero, mask);
|
||||
+ }
|
||||
+ __m256 real_() const {
|
||||
+diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_double.h b/aten/src/ATen/cpu/vec/vec256/vec256_double.h
|
||||
+index bc82d07edd1..bed6da627af 100644
|
||||
+--- a/aten/src/ATen/cpu/vec/vec256/vec256_double.h
|
||||
++++ b/aten/src/ATen/cpu/vec/vec256/vec256_double.h
|
||||
+@@ -6,7 +6,8 @@
|
||||
+ #include <ATen/cpu/vec/intrinsics.h>
|
||||
+ #include <ATen/cpu/vec/vec_base.h>
|
||||
+ #include <c10/util/irange.h>
|
||||
+-#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
|
||||
++#if defined(CPU_CAPABILITY_AVX2)
|
||||
++#define SLEEF_STATIC_LIBS
|
||||
+ #include <sleef.h>
|
||||
+ #endif
|
||||
+
|
||||
+@@ -15,7 +16,7 @@ namespace at::vec {
|
||||
+ inline namespace CPU_CAPABILITY {
|
||||
+
|
||||
+
|
||||
+-#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
|
||||
++#if defined(CPU_CAPABILITY_AVX2)
|
||||
+
|
||||
+ template <> class Vectorized<double> {
|
||||
+ private:
|
||||
+diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_float.h b/aten/src/ATen/cpu/vec/vec256/vec256_float.h
|
||||
+index 886809a0b8a..0e3664cd37b 100644
|
||||
+--- a/aten/src/ATen/cpu/vec/vec256/vec256_float.h
|
||||
++++ b/aten/src/ATen/cpu/vec/vec256/vec256_float.h
|
||||
+@@ -6,7 +6,8 @@
|
||||
+ #include <ATen/cpu/vec/intrinsics.h>
|
||||
+ #include <ATen/cpu/vec/vec_base.h>
|
||||
+ #include <c10/util/irange.h>
|
||||
+-#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
|
||||
++#if defined(CPU_CAPABILITY_AVX2)
|
||||
++#define SLEEF_STATIC_LIBS
|
||||
+ #include <sleef.h>
|
||||
+ #endif
|
||||
+
|
||||
+@@ -14,7 +15,7 @@ namespace at::vec {
|
||||
+ // See Note [CPU_CAPABILITY namespace]
|
||||
+ inline namespace CPU_CAPABILITY {
|
||||
+
|
||||
+-#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
|
||||
++#if defined(CPU_CAPABILITY_AVX2)
|
||||
+
|
||||
+ template <> class Vectorized<float> {
|
||||
+ private:
|
||||
+@@ -226,14 +227,14 @@ public:
|
||||
+ static __m256 vec_factorial_5 =
|
||||
+ _mm256_set1_ps(0.00828929059f); // 1/factorial(5)
|
||||
+ static __m256 vec_exp_log2ef =
|
||||
+- (__m256)_mm256_set1_epi32(0x3fb8aa3b); // log2(e)
|
||||
++ _mm256_castsi256_ps(_mm256_set1_epi32(0x3fb8aa3b)); // log2(e)
|
||||
+ static __m256 vec_half = _mm256_set1_ps(0.5f);
|
||||
+ static __m256 vec_one = _mm256_set1_ps(1.f);
|
||||
+ static __m256 vec_zero = _mm256_set1_ps(0.f);
|
||||
+ static __m256 vec_two = _mm256_set1_ps(2.f);
|
||||
+- static __m256 vec_ln2f = (__m256)_mm256_set1_epi32(0x3f317218); // ln(2)
|
||||
+- static __m256 vec_ln_flt_min = (__m256)_mm256_set1_epi32(0xc2aeac50);
|
||||
+- static __m256 vec_ln_flt_max = (__m256)_mm256_set1_epi32(0x42b17218);
|
||||
++ static __m256 vec_ln2f = _mm256_castsi256_ps(_mm256_set1_epi32(0x3f317218)); // ln(2)
|
||||
++ static __m256 vec_ln_flt_min = _mm256_castsi256_ps(_mm256_set1_epi32(0xc2aeac50));
|
||||
++ static __m256 vec_ln_flt_max = _mm256_castsi256_ps(_mm256_set1_epi32(0x42b17218));
|
||||
+ static __m256i vec_127 = _mm256_set1_epi32(0x0000007f);
|
||||
+ static int n_mantissa_bits = 23;
|
||||
+
|
||||
+@@ -266,7 +267,7 @@ public:
|
||||
+ auto vec_exp_number_i = _mm256_cvtps_epi32(vec_exp_number);
|
||||
+ auto vec_two_pow_n_i = _mm256_add_epi32(vec_exp_number_i, vec_127);
|
||||
+ vec_two_pow_n_i = _mm256_slli_epi32(vec_two_pow_n_i, n_mantissa_bits);
|
||||
+- auto vec_two_pow_n = (__m256)vec_two_pow_n_i;
|
||||
++ auto vec_two_pow_n = _mm256_castsi256_ps(vec_two_pow_n_i);
|
||||
+ vec_two_pow_n =
|
||||
+ _mm256_blendv_ps(vec_two_pow_n, vec_zero, less_ln_flt_min_mask);
|
||||
+
|
||||
+diff --git a/aten/src/ATen/cpu/vec/vec256/vec256_qint.h b/aten/src/ATen/cpu/vec/vec256/vec256_qint.h
|
||||
+index 4128841701a..85e099904cd 100644
|
||||
+--- a/aten/src/ATen/cpu/vec/vec256/vec256_qint.h
|
||||
++++ b/aten/src/ATen/cpu/vec/vec256/vec256_qint.h
|
||||
+@@ -41,11 +41,17 @@
|
||||
+ namespace at::vec {
|
||||
+ inline namespace CPU_CAPABILITY {
|
||||
+
|
||||
+-#if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
|
||||
++#if defined(CPU_CAPABILITY_AVX2)
|
||||
+
|
||||
++#ifdef _MSC_VER
|
||||
++__declspec(align(64)) struct Vectorizedqi {
|
||||
++ protected:
|
||||
++ __m256i vals;
|
||||
++#else
|
||||
+ struct Vectorizedqi {
|
||||
+ protected:
|
||||
+ __m256i vals __attribute__((aligned(64)));
|
||||
++#endif
|
||||
+
|
||||
+ public:
|
||||
+ Vectorizedqi() {}
|
||||
+@@ -133,7 +139,7 @@ inline convert_float_to_int8(at::vec::Vectorized<float> src) {
|
||||
+ }
|
||||
+
|
||||
+ template <typename T>
|
||||
+-inline void __attribute__((always_inline)) QuantizeAvx2(
|
||||
++__FORCE_INLINE void QuantizeAvx2(
|
||||
+ const float* src,
|
||||
+ T* dst,
|
||||
+ int len,
|
||||
+@@ -1331,5 +1337,5 @@ Vectorized<c10::quint8> inline maximum(const Vectorized<c10::quint8>& a, const V
|
||||
+ return a.maximum(b);
|
||||
+ }
|
||||
+
|
||||
+-#endif // if defined(CPU_CAPABILITY_AVX2) && !defined(_MSC_VER)
|
||||
++#endif // if defined(CPU_CAPABILITY_AVX2)
|
||||
+ }} // namespace at::vec::CPU_CAPABILITY
|
||||
+diff --git a/aten/src/ATen/cpu/vec/vec512/vec512.h b/aten/src/ATen/cpu/vec/vec512/vec512.h
|
||||
+index fe96d123e64..87f723d782c 100644
|
||||
+--- a/aten/src/ATen/cpu/vec/vec512/vec512.h
|
||||
++++ b/aten/src/ATen/cpu/vec/vec512/vec512.h
|
||||
+@@ -55,7 +55,7 @@ std::ostream& operator<<(std::ostream& stream, const Vectorized<T>& vec) {
|
||||
+ }
|
||||
+
|
||||
+
|
||||
+-#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
|
||||
++#if defined(CPU_CAPABILITY_AVX512)
|
||||
+
|
||||
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CAST (AVX512) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
+
|
||||
+@@ -80,7 +80,8 @@ inline Vectorized<double> cast<double, int64_t>(const Vectorized<int64_t>& src)
|
||||
+ }
|
||||
+
|
||||
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
+-
|
||||
++#ifndef _MSC_VER
|
||||
++// MSVC is not working well on complex function overload.
|
||||
+ template<int64_t scale = 1>
|
||||
+ std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<double>>
|
||||
+ inline gather(const double* base_addr, const Vectorized<int64_t>& vindex) {
|
||||
+@@ -92,9 +93,10 @@ std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorize
|
||||
+ inline gather(const float* base_addr, const Vectorized<int32_t>& vindex) {
|
||||
+ return _mm512_i32gather_ps(vindex, base_addr, scale);
|
||||
+ }
|
||||
+-
|
||||
++#endif
|
||||
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MASK GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
+-
|
||||
++#ifndef _MSC_VER
|
||||
++// MSVC is not working well on complex function overload.
|
||||
+ template<int64_t scale = 1>
|
||||
+ std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<double>>
|
||||
+ inline mask_gather(const Vectorized<double>& src, const double* base_addr,
|
||||
+@@ -112,7 +114,7 @@ inline mask_gather(const Vectorized<float>& src, const float* base_addr,
|
||||
+ auto mask_ = _mm512_cmp_ps_mask(all_ones, mask.values, _CMP_EQ_OQ);
|
||||
+ return _mm512_mask_i32gather_ps(src, mask_, vindex, base_addr, scale);
|
||||
+ }
|
||||
+-
|
||||
++#endif
|
||||
+ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CONVERT ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
+
|
||||
+ template<>
|
||||
+@@ -270,6 +272,6 @@ inline Vectorized<uint8_t> flip(const Vectorized<uint8_t> & v) {
|
||||
+ return flip8(v);
|
||||
+ }
|
||||
+
|
||||
+-#endif // defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
|
||||
++#endif // defined(CPU_CAPABILITY_AVX512)
|
||||
+
|
||||
+ }}}
|
||||
+diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_bfloat16.h b/aten/src/ATen/cpu/vec/vec512/vec512_bfloat16.h
|
||||
+index f9fc92d52bf..eb3b6a72240 100644
|
||||
+--- a/aten/src/ATen/cpu/vec/vec512/vec512_bfloat16.h
|
||||
++++ b/aten/src/ATen/cpu/vec/vec512/vec512_bfloat16.h
|
||||
+@@ -7,7 +7,8 @@
|
||||
+ #include <ATen/cpu/vec/vec_base.h>
|
||||
+ #include <c10/util/irange.h>
|
||||
+
|
||||
+-#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
|
||||
++#if defined(CPU_CAPABILITY_AVX512)
|
||||
++#define SLEEF_STATIC_LIBS
|
||||
+ #include <sleef.h>
|
||||
+ #endif
|
||||
+
|
||||
+@@ -16,7 +17,18 @@ namespace vec {
|
||||
+ // See Note [CPU_CAPABILITY namespace]
|
||||
+ inline namespace CPU_CAPABILITY {
|
||||
+
|
||||
+-#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
|
||||
++#if defined(CPU_CAPABILITY_AVX512)
|
||||
++
|
||||
++#ifndef SLEEF_CONST
|
||||
++#if (defined(__GNUC__) || defined(__CLANG__)) && !defined(__INTEL_COMPILER)
|
||||
++#define SLEEF_CONST const
|
||||
++#else
|
||||
++#define SLEEF_CONST
|
||||
++#endif
|
||||
++#define SLEEF_CONST_OLD SLEEF_CONST
|
||||
++#else
|
||||
++#define SLEEF_CONST_OLD
|
||||
++#endif
|
||||
+
|
||||
+ // bfloat16 conversion
|
||||
+ static inline void cvtbf16_fp32(const __m256i& a, __m512& o) {
|
||||
+@@ -362,7 +374,8 @@ public:
|
||||
+ }
|
||||
+ #pragma clang diagnostic push
|
||||
+ #pragma clang diagnostic ignored "-Wignored-qualifiers"
|
||||
+- Vectorized<T> map(const __m512 (*const vop)(__m512)) const {
|
||||
++
|
||||
++ Vectorized<T> map(SLEEF_CONST __m512 (*SLEEF_CONST_OLD vop)(__m512)) const {
|
||||
+ __m512 lo, hi;
|
||||
+ cvt_to_fp32<T>(values, lo, hi);
|
||||
+ const auto o1 = vop(lo);
|
||||
+@@ -1571,7 +1584,7 @@ inline Vectorized<type> convert_float_##name(const Vectorized<float>& a, const V
|
||||
+ CONVERT_VECTORIZED_INIT(BFloat16, bfloat16);
|
||||
+ CONVERT_VECTORIZED_INIT(Half, half);
|
||||
+
|
||||
+-#else //defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
|
||||
++#else //defined(CPU_CAPABILITY_AVX512)
|
||||
+
|
||||
+ #define CONVERT_NON_VECTORIZED_INIT(type, name) \
|
||||
+ inline std::tuple<Vectorized<float>, Vectorized<float>> convert_##name##_float(const Vectorized<type>& a) { \
|
||||
+@@ -1601,9 +1614,9 @@ inline Vectorized<type> convert_float_##name(const Vectorized<float>& a, const V
|
||||
+ CONVERT_NON_VECTORIZED_INIT(BFloat16, bfloat16);
|
||||
+ CONVERT_NON_VECTORIZED_INIT(Half, half);
|
||||
+
|
||||
+-#endif // defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
|
||||
++#endif // defined(CPU_CAPABILITY_AVX512)
|
||||
+
|
||||
+-#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
|
||||
++#if defined(CPU_CAPABILITY_AVX512)
|
||||
+ #define LOAD_FP32_VECTORIZED_INIT(type, name) \
|
||||
+ inline void load_fp32_from_##name(const type *data, Vectorized<float>& out) { \
|
||||
+ auto values = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(data)); \
|
||||
+@@ -1622,7 +1635,7 @@ inline void load_fp32_from_##name(const type *data, Vectorized<float>& out1, Vec
|
||||
+ LOAD_FP32_VECTORIZED_INIT(BFloat16, bf16);
|
||||
+ LOAD_FP32_VECTORIZED_INIT(Half, fp16);
|
||||
+
|
||||
+-#else // defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
|
||||
++#else // defined(CPU_CAPABILITY_AVX512)
|
||||
+ #define LOAD_FP32_NON_VECTORIZED_INIT(type, name) \
|
||||
+ inline void load_fp32_from_##name(const type *data, Vectorized<float>& out) { \
|
||||
+ __at_align__ float values[Vectorized<float>::size()]; \
|
||||
+diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_complex_double.h b/aten/src/ATen/cpu/vec/vec512/vec512_complex_double.h
|
||||
+index 02aa3a87cc1..c35204f9da2 100644
|
||||
+--- a/aten/src/ATen/cpu/vec/vec512/vec512_complex_double.h
|
||||
++++ b/aten/src/ATen/cpu/vec/vec512/vec512_complex_double.h
|
||||
+@@ -7,7 +7,8 @@
|
||||
+ #include <c10/util/irange.h>
|
||||
+ #include <ATen/cpu/vec/intrinsics.h>
|
||||
+ #include <ATen/cpu/vec/vec_base.h>
|
||||
+-#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
|
||||
++#if defined(CPU_CAPABILITY_AVX512)
|
||||
++#define SLEEF_STATIC_LIBS
|
||||
+ #include <sleef.h>
|
||||
+ #endif
|
||||
+
|
||||
+@@ -16,7 +17,7 @@ namespace vec {
|
||||
+ // See Note [CPU_CAPABILITY namespace]
|
||||
+ inline namespace CPU_CAPABILITY {
|
||||
+
|
||||
+-#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
|
||||
++#if defined(CPU_CAPABILITY_AVX512)
|
||||
+
|
||||
+ template <> class Vectorized<c10::complex<double>> {
|
||||
+ private:
|
||||
+@@ -203,7 +204,7 @@ public:
|
||||
+ auto abs = abs_();
|
||||
+ auto zero = _mm512_setzero_pd();
|
||||
+ auto mask = _mm512_cmp_pd_mask(abs, zero, _CMP_EQ_OQ);
|
||||
+- auto div = values / abs;
|
||||
++ auto div = _mm512_div_pd(values, abs);
|
||||
+ return _mm512_mask_blend_pd(mask, div, zero);
|
||||
+ }
|
||||
+ __m512d real_() const {
|
||||
+diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_complex_float.h b/aten/src/ATen/cpu/vec/vec512/vec512_complex_float.h
|
||||
+index a5d790c98b2..2801e484d94 100644
|
||||
+--- a/aten/src/ATen/cpu/vec/vec512/vec512_complex_float.h
|
||||
++++ b/aten/src/ATen/cpu/vec/vec512/vec512_complex_float.h
|
||||
+@@ -7,7 +7,8 @@
|
||||
+ #include <c10/util/irange.h>
|
||||
+ #include <ATen/cpu/vec/intrinsics.h>
|
||||
+ #include <ATen/cpu/vec/vec_base.h>
|
||||
+-#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
|
||||
++#if defined(CPU_CAPABILITY_AVX512)
|
||||
++#define SLEEF_STATIC_LIBS
|
||||
+ #include <sleef.h>
|
||||
+ #endif
|
||||
+
|
||||
+@@ -16,7 +17,7 @@ namespace vec {
|
||||
+ // See Note [CPU_CAPABILITY namespace]
|
||||
+ inline namespace CPU_CAPABILITY {
|
||||
+
|
||||
+-#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
|
||||
++#if defined(CPU_CAPABILITY_AVX512)
|
||||
+
|
||||
+ template <> class Vectorized<c10::complex<float>> {
|
||||
+ private:
|
||||
+@@ -708,7 +709,7 @@ public:
|
||||
+ auto abs = abs_();
|
||||
+ auto zero = _mm512_setzero_ps();
|
||||
+ auto mask = _mm512_cmp_ps_mask(abs, zero, _CMP_EQ_OQ);
|
||||
+- auto div = values / abs;
|
||||
++ auto div = _mm512_div_ps(values, abs);
|
||||
+ return _mm512_mask_blend_ps(mask, div, zero);
|
||||
+ }
|
||||
+ __m512 real_() const {
|
||||
+diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_double.h b/aten/src/ATen/cpu/vec/vec512/vec512_double.h
|
||||
+index 27b2753c903..508ab257e60 100644
|
||||
+--- a/aten/src/ATen/cpu/vec/vec512/vec512_double.h
|
||||
++++ b/aten/src/ATen/cpu/vec/vec512/vec512_double.h
|
||||
+@@ -6,7 +6,8 @@
|
||||
+ #include <ATen/cpu/vec/intrinsics.h>
|
||||
+ #include <ATen/cpu/vec/vec_base.h>
|
||||
+ #include <c10/util/irange.h>
|
||||
+-#if (defined(CPU_CAPABILITY_AVX512)) && !defined(_MSC_VER)
|
||||
++#if (defined(CPU_CAPABILITY_AVX512))
|
||||
++#define SLEEF_STATIC_LIBS
|
||||
+ #include <sleef.h>
|
||||
+ #endif
|
||||
+
|
||||
+@@ -15,7 +16,7 @@ namespace vec {
|
||||
+ // See Note [CPU_CAPABILITY namespace]
|
||||
+ inline namespace CPU_CAPABILITY {
|
||||
+
|
||||
+-#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
|
||||
++#if defined(CPU_CAPABILITY_AVX512)
|
||||
+
|
||||
+ template <> class Vectorized<double> {
|
||||
+ private:
|
||||
+diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_float.h b/aten/src/ATen/cpu/vec/vec512/vec512_float.h
|
||||
+index ba5738687fd..a08df3c141a 100644
|
||||
+--- a/aten/src/ATen/cpu/vec/vec512/vec512_float.h
|
||||
++++ b/aten/src/ATen/cpu/vec/vec512/vec512_float.h
|
||||
+@@ -6,7 +6,8 @@
|
||||
+ #include <ATen/cpu/vec/intrinsics.h>
|
||||
+ #include <ATen/cpu/vec/vec_base.h>
|
||||
+ #include <c10/util/irange.h>
|
||||
+-#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
|
||||
++#if defined(CPU_CAPABILITY_AVX512)
|
||||
++#define SLEEF_STATIC_LIBS
|
||||
+ #include <sleef.h>
|
||||
+ #endif
|
||||
+
|
||||
+@@ -15,7 +16,7 @@ namespace vec {
|
||||
+ // See Note [CPU_CAPABILITY namespace]
|
||||
+ inline namespace CPU_CAPABILITY {
|
||||
+
|
||||
+-#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
|
||||
++#if defined(CPU_CAPABILITY_AVX512)
|
||||
+
|
||||
+ template <> class Vectorized<float> {
|
||||
+ private:
|
||||
+@@ -246,14 +247,14 @@ public:
|
||||
+ static __m512 vec_factorial_5 =
|
||||
+ _mm512_set1_ps(0.00828929059f); // 1/factorial(5)
|
||||
+ static __m512 vec_exp_log2ef =
|
||||
+- (__m512)_mm512_set1_epi32(0x3fb8aa3b); // log2(e)
|
||||
++ _mm512_castsi512_ps(_mm512_set1_epi32(0x3fb8aa3b)); // log2(e)
|
||||
+ static __m512 vec_half = _mm512_set1_ps(0.5f);
|
||||
+ static __m512 vec_one = _mm512_set1_ps(1.f);
|
||||
+ static __m512 vec_zero = _mm512_set1_ps(0.f);
|
||||
+ static __m512 vec_two = _mm512_set1_ps(2.f);
|
||||
+- static __m512 vec_ln2f = (__m512)_mm512_set1_epi32(0x3f317218); // ln(2)
|
||||
+- static __m512 vec_ln_flt_min = (__m512)_mm512_set1_epi32(0xc2aeac50);
|
||||
+- static __m512 vec_ln_flt_max = (__m512)_mm512_set1_epi32(0x42b17218);
|
||||
++ static __m512 vec_ln2f = _mm512_castsi512_ps(_mm512_set1_epi32(0x3f317218)); // ln(2)
|
||||
++ static __m512 vec_ln_flt_min = _mm512_castsi512_ps(_mm512_set1_epi32(0xc2aeac50));
|
||||
++ static __m512 vec_ln_flt_max = _mm512_castsi512_ps(_mm512_set1_epi32(0x42b17218));
|
||||
+ static __m512i vec_127 = _mm512_set1_epi32(0x0000007f);
|
||||
+ static int n_mantissa_bits = 23;
|
||||
+
|
||||
+@@ -288,7 +289,7 @@ public:
|
||||
+ auto vec_exp_number_i = _mm512_cvtps_epi32(vec_exp_number);
|
||||
+ auto vec_two_pow_n_i = _mm512_add_epi32(vec_exp_number_i, vec_127);
|
||||
+ vec_two_pow_n_i = _mm512_slli_epi32(vec_two_pow_n_i, n_mantissa_bits);
|
||||
+- auto vec_two_pow_n = (__m512)vec_two_pow_n_i;
|
||||
++ auto vec_two_pow_n = _mm512_castsi512_ps(vec_two_pow_n_i);
|
||||
+ vec_two_pow_n =
|
||||
+ _mm512_mask_blend_ps(less_ln_flt_min_mask, vec_two_pow_n, vec_zero);
|
||||
+
|
||||
+diff --git a/aten/src/ATen/cpu/vec/vec512/vec512_qint.h b/aten/src/ATen/cpu/vec/vec512/vec512_qint.h
|
||||
+index e0713d01312..a5671ed4a50 100644
|
||||
+--- a/aten/src/ATen/cpu/vec/vec512/vec512_qint.h
|
||||
++++ b/aten/src/ATen/cpu/vec/vec512/vec512_qint.h
|
||||
+@@ -42,11 +42,17 @@ namespace at {
|
||||
+ namespace vec {
|
||||
+ inline namespace CPU_CAPABILITY {
|
||||
+
|
||||
+-#if defined(CPU_CAPABILITY_AVX512) && !defined(_MSC_VER)
|
||||
++#if defined(CPU_CAPABILITY_AVX512)
|
||||
+
|
||||
++#ifdef _MSC_VER
|
||||
++__declspec(align(64)) struct Vectorizedqi {
|
||||
++ protected:
|
||||
++ __m512i vals;
|
||||
++#else
|
||||
+ struct Vectorizedqi {
|
||||
+ protected:
|
||||
+ __m512i vals __attribute__((aligned(64)));
|
||||
++#endif
|
||||
+
|
||||
+ public:
|
||||
+ Vectorizedqi() {}
|
||||
+@@ -136,7 +142,7 @@ inline convert_float_to_int8(at::vec::Vectorized<float> src) {
|
||||
+ }
|
||||
+
|
||||
+ template <typename T>
|
||||
+-inline void __attribute__((always_inline)) QuantizeAvx512(
|
||||
++__FORCE_INLINE void QuantizeAvx512(
|
||||
+ const float* src,
|
||||
+ T* dst,
|
||||
+ int len,
|
||||
+@@ -525,10 +531,17 @@ struct Vectorized<c10::qint8> : public Vectorizedqi {
|
||||
+ Vectorized<float> scale,
|
||||
+ Vectorized<float> zero_point,
|
||||
+ Vectorized<float> scale_neg_zp_premul) const {
|
||||
++ #if defined(_MSC_VER) && !defined(__clang__)
|
||||
++ __m128i int_val0 = _mm_set_epi64x(vals.m512i_u64[1], vals.m512i_u64[0]);
|
||||
++ __m128i int_val1 = _mm_set_epi64x(vals.m512i_u64[3], vals.m512i_u64[2]);
|
||||
++ __m128i int_val2 = _mm_set_epi64x(vals.m512i_u64[5], vals.m512i_u64[4]);
|
||||
++ __m128i int_val3 = _mm_set_epi64x(vals.m512i_u64[7], vals.m512i_u64[6]);
|
||||
++ #else
|
||||
+ __m128i int_val0 = _mm_set_epi64x(vals[1], vals[0]);
|
||||
+ __m128i int_val1 = _mm_set_epi64x(vals[3], vals[2]);
|
||||
+ __m128i int_val2 = _mm_set_epi64x(vals[5], vals[4]);
|
||||
+ __m128i int_val3 = _mm_set_epi64x(vals[7], vals[6]);
|
||||
++ #endif
|
||||
+
|
||||
+ __m512 float_val0 = _mm512_cvtepi32_ps(cvtepi8_epi32(int_val0));
|
||||
+ __m512 float_val1 = _mm512_cvtepi32_ps(cvtepi8_epi32(int_val1));
|
||||
+@@ -549,10 +562,17 @@ struct Vectorized<c10::qint8> : public Vectorizedqi {
|
||||
+ float_vec_return_type dequantize(
|
||||
+ Vectorized<float> scale,
|
||||
+ Vectorized<float> zero_point) const {
|
||||
++ #if defined(_MSC_VER) && !defined(__clang__)
|
||||
++ __m128i int_val0 = _mm_set_epi64x(vals.m512i_u64[1], vals.m512i_u64[0]);
|
||||
++ __m128i int_val1 = _mm_set_epi64x(vals.m512i_u64[3], vals.m512i_u64[2]);
|
||||
++ __m128i int_val2 = _mm_set_epi64x(vals.m512i_u64[5], vals.m512i_u64[4]);
|
||||
++ __m128i int_val3 = _mm_set_epi64x(vals.m512i_u64[7], vals.m512i_u64[6]);
|
||||
++ #else
|
||||
+ __m128i int_val0 = _mm_set_epi64x(vals[1], vals[0]);
|
||||
+ __m128i int_val1 = _mm_set_epi64x(vals[3], vals[2]);
|
||||
+ __m128i int_val2 = _mm_set_epi64x(vals[5], vals[4]);
|
||||
+ __m128i int_val3 = _mm_set_epi64x(vals[7], vals[6]);
|
||||
++ #endif
|
||||
+
|
||||
+ __m512 float_val0 = _mm512_cvtepi32_ps(cvtepi8_epi32(int_val0));
|
||||
+ __m512 float_val1 = _mm512_cvtepi32_ps(cvtepi8_epi32(int_val1));
|
||||
+@@ -598,20 +618,34 @@ struct Vectorized<c10::qint8> : public Vectorizedqi {
|
||||
+ }
|
||||
+
|
||||
+ int_vec_return_type widening_subtract(Vectorized<c10::qint8> b) const {
|
||||
++ #if defined(_MSC_VER) && !defined(__clang__)
|
||||
++ __m128i int_val0 = _mm_set_epi64x(vals.m512i_u64[1], vals.m512i_u64[0]);
|
||||
++ __m128i int_val1 = _mm_set_epi64x(vals.m512i_u64[3], vals.m512i_u64[2]);
|
||||
++ __m128i int_val2 = _mm_set_epi64x(vals.m512i_u64[5], vals.m512i_u64[4]);
|
||||
++ __m128i int_val3 = _mm_set_epi64x(vals.m512i_u64[7], vals.m512i_u64[6]);
|
||||
++ #else
|
||||
+ __m128i int_val0 = _mm_set_epi64x(vals[1], vals[0]);
|
||||
+ __m128i int_val1 = _mm_set_epi64x(vals[3], vals[2]);
|
||||
+ __m128i int_val2 = _mm_set_epi64x(vals[5], vals[4]);
|
||||
+ __m128i int_val3 = _mm_set_epi64x(vals[7], vals[6]);
|
||||
++ #endif
|
||||
+
|
||||
+ __m512i int32_val0 = cvtepi8_epi32(int_val0);
|
||||
+ __m512i int32_val1 = cvtepi8_epi32(int_val1);
|
||||
+ __m512i int32_val2 = cvtepi8_epi32(int_val2);
|
||||
+ __m512i int32_val3 = cvtepi8_epi32(int_val3);
|
||||
+
|
||||
++ #if defined(_MSC_VER) && !defined(__clang__)
|
||||
++ __m128i int_b0 = _mm_set_epi64x(b.vals.m512i_u64[1], b.vals.m512i_u64[0]);
|
||||
++ __m128i int_b1 = _mm_set_epi64x(b.vals.m512i_u64[3], b.vals.m512i_u64[2]);
|
||||
++ __m128i int_b2 = _mm_set_epi64x(b.vals.m512i_u64[5], b.vals.m512i_u64[4]);
|
||||
++ __m128i int_b3 = _mm_set_epi64x(b.vals.m512i_u64[7], b.vals.m512i_u64[6]);
|
||||
++ #else
|
||||
+ __m128i int_b0 = _mm_set_epi64x(b.vals[1], b.vals[0]);
|
||||
+ __m128i int_b1 = _mm_set_epi64x(b.vals[3], b.vals[2]);
|
||||
+ __m128i int_b2 = _mm_set_epi64x(b.vals[5], b.vals[4]);
|
||||
+ __m128i int_b3 = _mm_set_epi64x(b.vals[7], b.vals[6]);
|
||||
++ #endif
|
||||
+
|
||||
+ __m512i int32_b0 = cvtepi8_epi32(int_b0);
|
||||
+ __m512i int32_b1 = cvtepi8_epi32(int_b1);
|
||||
+@@ -721,10 +755,17 @@ struct Vectorized<c10::quint8> : public Vectorizedqi {
|
||||
+ Vectorized<float> scale,
|
||||
+ Vectorized<float> zero_point,
|
||||
+ Vectorized<float> scale_zp_premul) const {
|
||||
++ #if defined(_MSC_VER) && !defined(__clang__)
|
||||
++ __m128i int_val0 = _mm_set_epi64x(vals.m512i_u64[1], vals.m512i_u64[0]);
|
||||
++ __m128i int_val1 = _mm_set_epi64x(vals.m512i_u64[3], vals.m512i_u64[2]);
|
||||
++ __m128i int_val2 = _mm_set_epi64x(vals.m512i_u64[5], vals.m512i_u64[4]);
|
||||
++ __m128i int_val3 = _mm_set_epi64x(vals.m512i_u64[7], vals.m512i_u64[6]);
|
||||
++ #else
|
||||
+ __m128i int_val0 = _mm_set_epi64x(vals[1], vals[0]);
|
||||
+ __m128i int_val1 = _mm_set_epi64x(vals[3], vals[2]);
|
||||
+ __m128i int_val2 = _mm_set_epi64x(vals[5], vals[4]);
|
||||
+ __m128i int_val3 = _mm_set_epi64x(vals[7], vals[6]);
|
||||
++ #endif
|
||||
+
|
||||
+ __m512 float_val0 = _mm512_cvtepi32_ps(cvtepu8_epi32(int_val0));
|
||||
+ __m512 float_val1 = _mm512_cvtepi32_ps(cvtepu8_epi32(int_val1));
|
||||
+@@ -746,10 +787,17 @@ struct Vectorized<c10::quint8> : public Vectorizedqi {
|
||||
+ float_vec_return_type dequantize(
|
||||
+ Vectorized<float> scale,
|
||||
+ Vectorized<float> zero_point) const {
|
||||
++ #if defined(_MSC_VER) && !defined(__clang__)
|
||||
++ __m128i int_val0 = _mm_set_epi64x(vals.m512i_u64[1], vals.m512i_u64[0]);
|
||||
++ __m128i int_val1 = _mm_set_epi64x(vals.m512i_u64[3], vals.m512i_u64[2]);
|
||||
++ __m128i int_val2 = _mm_set_epi64x(vals.m512i_u64[5], vals.m512i_u64[4]);
|
||||
++ __m128i int_val3 = _mm_set_epi64x(vals.m512i_u64[7], vals.m512i_u64[6]);
|
||||
++ #else
|
||||
+ __m128i int_val0 = _mm_set_epi64x(vals[1], vals[0]);
|
||||
+ __m128i int_val1 = _mm_set_epi64x(vals[3], vals[2]);
|
||||
+ __m128i int_val2 = _mm_set_epi64x(vals[5], vals[4]);
|
||||
+ __m128i int_val3 = _mm_set_epi64x(vals[7], vals[6]);
|
||||
++ #endif
|
||||
+
|
||||
+ __m512 float_val0 = _mm512_cvtepi32_ps(cvtepu8_epi32(int_val0));
|
||||
+ __m512 float_val1 = _mm512_cvtepi32_ps(cvtepu8_epi32(int_val1));
|
||||
+@@ -796,20 +844,34 @@ struct Vectorized<c10::quint8> : public Vectorizedqi {
|
||||
+ }
|
||||
+
|
||||
+ int_vec_return_type widening_subtract(Vectorized<c10::quint8> b) const {
|
||||
++ #if defined(_MSC_VER) && !defined(__clang__)
|
||||
++ __m128i int_val0 = _mm_set_epi64x(vals.m512i_u64[1], vals.m512i_u64[0]);
|
||||
++ __m128i int_val1 = _mm_set_epi64x(vals.m512i_u64[3], vals.m512i_u64[2]);
|
||||
++ __m128i int_val2 = _mm_set_epi64x(vals.m512i_u64[5], vals.m512i_u64[4]);
|
||||
++ __m128i int_val3 = _mm_set_epi64x(vals.m512i_u64[7], vals.m512i_u64[6]);
|
||||
++ #else
|
||||
+ __m128i int_val0 = _mm_set_epi64x(vals[1], vals[0]);
|
||||
+ __m128i int_val1 = _mm_set_epi64x(vals[3], vals[2]);
|
||||
+ __m128i int_val2 = _mm_set_epi64x(vals[5], vals[4]);
|
||||
+ __m128i int_val3 = _mm_set_epi64x(vals[7], vals[6]);
|
||||
++ #endif
|
||||
+
|
||||
+ __m512i int32_val0 = cvtepu8_epi32(int_val0);
|
||||
+ __m512i int32_val1 = cvtepu8_epi32(int_val1);
|
||||
+ __m512i int32_val2 = cvtepu8_epi32(int_val2);
|
||||
+ __m512i int32_val3 = cvtepu8_epi32(int_val3);
|
||||
+
|
||||
++ #if defined(_MSC_VER) && !defined(__clang__)
|
||||
++ __m128i int_b0 = _mm_set_epi64x(b.vals.m512i_u64[1], b.vals.m512i_u64[0]);
|
||||
++ __m128i int_b1 = _mm_set_epi64x(b.vals.m512i_u64[3], b.vals.m512i_u64[2]);
|
||||
++ __m128i int_b2 = _mm_set_epi64x(b.vals.m512i_u64[5], b.vals.m512i_u64[4]);
|
||||
++ __m128i int_b3 = _mm_set_epi64x(b.vals.m512i_u64[7], b.vals.m512i_u64[6]);
|
||||
++ #else
|
||||
+ __m128i int_b0 = _mm_set_epi64x(b.vals[1], b.vals[0]);
|
||||
+ __m128i int_b1 = _mm_set_epi64x(b.vals[3], b.vals[2]);
|
||||
+ __m128i int_b2 = _mm_set_epi64x(b.vals[5], b.vals[4]);
|
||||
+ __m128i int_b3 = _mm_set_epi64x(b.vals[7], b.vals[6]);
|
||||
++ #endif
|
||||
+
|
||||
+ __m512i int32_b0 = cvtepu8_epi32(int_b0);
|
||||
+ __m512i int32_b1 = cvtepu8_epi32(int_b1);
|
||||
+diff --git a/aten/src/ATen/cpu/vec/vec_base.h b/aten/src/ATen/cpu/vec/vec_base.h
|
||||
+index adf81dd915c..20cb8ef6dbc 100644
|
||||
+--- a/aten/src/ATen/cpu/vec/vec_base.h
|
||||
++++ b/aten/src/ATen/cpu/vec/vec_base.h
|
||||
+@@ -36,6 +36,12 @@
|
||||
+ #include <c10/util/irange.h>
|
||||
+ #include <c10/util/Load.h>
|
||||
+
|
||||
++#if defined(__GNUC__)
|
||||
++#define __FORCE_INLINE __attribute__((always_inline)) inline
|
||||
++#elif defined(_MSC_VER)
|
||||
++#define __FORCE_INLINE __forceinline
|
||||
++#endif
|
||||
++
|
||||
+ // These macros helped us unify vec_base.h
|
||||
+ #ifdef CPU_CAPABILITY_AVX512
|
||||
+ #if defined(__GNUC__)
|
||||
+diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt
|
||||
+index a6b6f0f7d1d..15d37cf4861 100644
|
||||
+--- a/caffe2/CMakeLists.txt
|
||||
++++ b/caffe2/CMakeLists.txt
|
||||
+@@ -1787,7 +1787,7 @@ if(BUILD_TEST)
|
||||
+ endif()
|
||||
+ else()
|
||||
+ add_executable(${test_name}_${CPU_CAPABILITY} "${test_src}")
|
||||
+- target_link_libraries(${test_name}_${CPU_CAPABILITY} torch_library gtest_main)
|
||||
++ target_link_libraries(${test_name}_${CPU_CAPABILITY} torch_library sleef gtest_main)
|
||||
+ endif()
|
||||
+ target_include_directories(${test_name}_${CPU_CAPABILITY} PRIVATE $<INSTALL_INTERFACE:include>)
|
||||
+ target_include_directories(${test_name}_${CPU_CAPABILITY} PRIVATE $<BUILD_INTERFACE:${CMAKE_BINARY_DIR}/include>)
|
||||
+diff --git a/third_party/sleef.BUILD b/third_party/sleef.BUILD
|
||||
+index 573f9c5b54a..f22a6e905e2 100644
|
||||
+--- a/third_party/sleef.BUILD
|
||||
++++ b/third_party/sleef.BUILD
|
||||
+@@ -38,6 +38,7 @@ SLEEF_PUBLIC_HEADERS = [
|
||||
+ SLEEF_PRIVATE_INCLUDES = [
|
||||
+ "-Iexternal/sleef/src/arch",
|
||||
+ "-Iexternal/sleef/src/common",
|
||||
++ "-Iexternal/sleef/src/libm",
|
||||
+ ]
|
||||
+
|
||||
+ SLEEF_PUBLIC_INCLUDES = [
|
||||
+@@ -201,8 +202,6 @@ cc_library(
|
||||
+ srcs = [
|
||||
+ "src/libm/rempitab.c",
|
||||
+ "src/libm/sleefdp.c",
|
||||
+- "src/libm/sleefld.c",
|
||||
+- "src/libm/sleefqp.c",
|
||||
+ "src/libm/sleefsp.c",
|
||||
+ ],
|
||||
+ hdrs = SLEEF_PUBLIC_HEADERS,
|
||||
+--
|
||||
+2.45.1
|
||||
+
|
||||
diff --git a/python-torch.spec b/python-torch.spec
|
||||
index d50687a5174a..63600c2e8c39 100644
|
||||
--- a/python-torch.spec
|
||||
+++ b/python-torch.spec
|
||||
@@ -176,6 +176,17 @@ Patch7: 0001-Reenable-dim-for-python-3.12.patch
|
||||
Patch8: 0001-dynamo-3.12-enable-dynamo-on-3.12-enable-most-dynamo.patch
|
||||
%endif
|
||||
|
||||
+# Enable x86 CPU vectorization on windows [submodule sleef] (#118980)
|
||||
+# https://github.com/pytorch/pytorch/commit/56451cd49d9cf94b49197e09dec13426bb1a5370
|
||||
+#
|
||||
+# Despite the title, this patch fixes compatibility with sleef 3.6 by including
|
||||
+# a backwards-compatible version of the fix from
|
||||
+# https://github.com/pytorch/pytorch/pull/122723.
|
||||
+# Cherry-picked on v2.3.0, without the commit to update the third_party/sleef
|
||||
+# git submodule (because the release archive contains an actual sleef source
|
||||
+# tree instead, so this would not apply.)
|
||||
+Patch9: 0001-Enable-x86-CPU-vectorization-on-windows-submodule-sl.patch
|
||||
+
|
||||
%if %{with rocm}
|
||||
# ROCm patches
|
||||
# https://github.com/pytorch/pytorch/pull/120551
|
||||
--
|
||||
2.45.1
|
||||
|
||||
|
|
@ -1,115 +0,0 @@
|
|||
From ee3fb343a376cdba6f4ce188cac90023f13e2aea Mon Sep 17 00:00:00 2001
|
||||
From: Tom Rix <trix@redhat.com>
|
||||
Date: Thu, 4 Apr 2024 14:21:38 -0600
|
||||
Subject: [PATCH] Reenable dim for python 3.12
|
||||
|
||||
In 3.12:
|
||||
|
||||
_PyArg_Parser added an element to the start of the structure.
|
||||
So existing positional initialization is off. Switch to element
|
||||
initialization.
|
||||
|
||||
_Py_CODEUNIT changed to from an int to a union, but relevant_op
|
||||
is passed an int for the return of decoder.opcode, so the parameter
|
||||
type is wrong, switch it to int.
|
||||
|
||||
The opcode PRECALL was removed, so reduce its handling to 3.11
|
||||
|
||||
Signed-off-by: Tom Rix <trix@redhat.com>
|
||||
---
|
||||
functorch/csrc/dim/dim.cpp | 24 +++++-------------------
|
||||
functorch/csrc/dim/minpybind.h | 4 ++--
|
||||
2 files changed, 7 insertions(+), 21 deletions(-)
|
||||
|
||||
diff --git a/functorch/csrc/dim/dim.cpp b/functorch/csrc/dim/dim.cpp
|
||||
index 4cc027504c77..e48b0d58081f 100644
|
||||
--- a/functorch/csrc/dim/dim.cpp
|
||||
+++ b/functorch/csrc/dim/dim.cpp
|
||||
@@ -6,20 +6,6 @@
|
||||
|
||||
#include <torch/csrc/utils/python_compat.h>
|
||||
|
||||
-
|
||||
-// Many APIs have changed/don't exist anymore
|
||||
-#if IS_PYTHON_3_12_PLUS
|
||||
-
|
||||
-#include "dim.h"
|
||||
-
|
||||
-// Re-enable this some day
|
||||
-PyObject* Dim_init() {
|
||||
- PyErr_SetString(PyExc_RuntimeError, "First class dim doesn't work with python 3.12");
|
||||
- return nullptr;
|
||||
-}
|
||||
-
|
||||
-#else
|
||||
-
|
||||
#include "minpybind.h"
|
||||
#include <frameobject.h>
|
||||
#include <opcode.h>
|
||||
@@ -441,7 +427,7 @@ static PyObject* DimList_bind(DimList *self,
|
||||
PY_BEGIN
|
||||
mpy::handle sizes;
|
||||
static const char * const _keywords[] = {"sizes", nullptr};
|
||||
- static _PyArg_Parser parser = {"O", _keywords, 0};
|
||||
+ static _PyArg_Parser parser = { .format = "O", .keywords = _keywords};
|
||||
if (!_PyArg_ParseStackAndKeywords(args, nargs, kwnames, &parser, &sizes)) {
|
||||
return nullptr;
|
||||
}
|
||||
@@ -465,7 +451,7 @@ static PyObject* DimList_bind_len(DimList *self,
|
||||
PY_BEGIN
|
||||
int size;
|
||||
static const char * const _keywords[] = {"N", nullptr};
|
||||
- static _PyArg_Parser parser = {"i", _keywords, 0};
|
||||
+ static _PyArg_Parser parser = { .format = "i", .keywords = _keywords};
|
||||
if (!_PyArg_ParseStackAndKeywords(args, nargs, kwnames, &parser, &size)) {
|
||||
return nullptr;
|
||||
}
|
||||
@@ -1468,7 +1454,7 @@ PyTypeObject Tensor::Type = {
|
||||
|
||||
// dim() --------------------
|
||||
|
||||
-static bool relevant_op(_Py_CODEUNIT c) {
|
||||
+static bool relevant_op(int c) {
|
||||
switch(c) {
|
||||
case STORE_NAME:
|
||||
case STORE_GLOBAL:
|
||||
@@ -1587,7 +1573,7 @@ static PyObject* _dims(PyObject *self,
|
||||
auto c = mpy::obj<PyCodeObject>::steal(PyFrame_GetCode(f.ptr()));
|
||||
auto lasti = PyFrame_GetLasti(f.ptr());
|
||||
auto decoder = PyInstDecoder(c.ptr(), lasti);
|
||||
- #if IS_PYTHON_3_11_PLUS
|
||||
+ #if IS_PYTHON_3_11
|
||||
// When py3.11 adapts bytecode lasti points to the precall
|
||||
// rather than the call instruction after it
|
||||
if (decoder.opcode() == PRECALL) {
|
||||
@@ -3268,4 +3254,4 @@ PyObject* Dim_init() {
|
||||
}
|
||||
}
|
||||
|
||||
-#endif
|
||||
+
|
||||
diff --git a/functorch/csrc/dim/minpybind.h b/functorch/csrc/dim/minpybind.h
|
||||
index de82b5af95a4..d76d4828bf80 100644
|
||||
--- a/functorch/csrc/dim/minpybind.h
|
||||
+++ b/functorch/csrc/dim/minpybind.h
|
||||
@@ -621,7 +621,7 @@ struct vector_args {
|
||||
PyObject *dummy = NULL;
|
||||
_PyArg_ParseStackAndKeywords((PyObject*const*)args, nargs, kwnames.ptr(), _parser, &dummy, &dummy, &dummy, &dummy, &dummy);
|
||||
#else
|
||||
- _PyArg_Parser* _parser = new _PyArg_Parser{NULL, &names_buf[0], fname_cstr, 0};
|
||||
+ _PyArg_Parser* _parser = new _PyArg_Parser{ .keywords = &names_buf[0], .fname = fname_cstr};
|
||||
std::unique_ptr<PyObject*[]> buf(new PyObject*[names.size()]);
|
||||
_PyArg_UnpackKeywords((PyObject*const*)args, nargs, NULL, kwnames.ptr(), _parser, required, (Py_ssize_t)values.size() - kwonly, 0, &buf[0]);
|
||||
#endif
|
||||
@@ -706,7 +706,7 @@ inline object handle::call_vector(vector_args args) {
|
||||
#define MPY_PARSE_ARGS_KWNAMES(fmt, FORALL_ARGS) \
|
||||
static const char * const kwlist[] = { FORALL_ARGS(MPY_ARGS_NAME) nullptr}; \
|
||||
FORALL_ARGS(MPY_ARGS_DECLARE) \
|
||||
- static _PyArg_Parser parser = {fmt, kwlist, 0}; \
|
||||
+ static _PyArg_Parser parser = { .format = fmt, .keywords = kwlist}; \
|
||||
if (!_PyArg_ParseStackAndKeywords(args, nargs, kwnames, &parser, FORALL_ARGS(MPY_ARGS_POINTER) nullptr)) { \
|
||||
throw mpy::exception_set(); \
|
||||
}
|
||||
--
|
||||
2.44.0
|
||||
|
||||
|
|
@ -1,39 +0,0 @@
|
|||
From 5b8e51b24513fa851eeff42f23d942bde301e321 Mon Sep 17 00:00:00 2001
|
||||
From: Tom Rix <trix@redhat.com>
|
||||
Date: Fri, 29 Sep 2023 06:19:29 -0700
|
||||
Subject: [PATCH] Regenerate flatbuffer header
|
||||
|
||||
For this error
|
||||
torch/csrc/jit/serialization/mobile_bytecode_generated.h:12:41:
|
||||
error: static assertion failed: Non-compatible flatbuffers version included
|
||||
12 | FLATBUFFERS_VERSION_MINOR == 3 &&
|
||||
|
||||
PyTorch is expecting 23.3.3, what f38 has
|
||||
Rawhide is at 23.5.26
|
||||
|
||||
Regenerate with
|
||||
flatc --cpp --gen-mutable --no-prefix --scoped-enums mobile_bytecode.fbs
|
||||
|
||||
Signed-off-by: Tom Rix <trix@redhat.com>
|
||||
---
|
||||
torch/csrc/jit/serialization/mobile_bytecode_generated.h | 4 ++--
|
||||
1 file changed, 2 insertions(+), 2 deletions(-)
|
||||
|
||||
diff --git a/torch/csrc/jit/serialization/mobile_bytecode_generated.h b/torch/csrc/jit/serialization/mobile_bytecode_generated.h
|
||||
index cffe8bc7a6..83575e4c19 100644
|
||||
--- a/torch/csrc/jit/serialization/mobile_bytecode_generated.h
|
||||
+++ b/torch/csrc/jit/serialization/mobile_bytecode_generated.h
|
||||
@@ -9,8 +9,8 @@
|
||||
// Ensure the included flatbuffers.h is the same version as when this file was
|
||||
// generated, otherwise it may not be compatible.
|
||||
static_assert(FLATBUFFERS_VERSION_MAJOR == 23 &&
|
||||
- FLATBUFFERS_VERSION_MINOR == 3 &&
|
||||
- FLATBUFFERS_VERSION_REVISION == 3,
|
||||
+ FLATBUFFERS_VERSION_MINOR == 5 &&
|
||||
+ FLATBUFFERS_VERSION_REVISION == 26,
|
||||
"Non-compatible flatbuffers version included");
|
||||
|
||||
namespace torch {
|
||||
--
|
||||
2.43.0
|
||||
|
||||
|
|
@ -1,73 +0,0 @@
|
|||
From 3ef82b814179da571b2478f61d4279717ab0b23a Mon Sep 17 00:00:00 2001
|
||||
From: Tom Rix <trix@redhat.com>
|
||||
Date: Fri, 29 Sep 2023 06:25:23 -0700
|
||||
Subject: [PATCH] Stub in kineto ActivityType
|
||||
|
||||
There is an error with kineto is not used, the shim still
|
||||
requires the ActivityTYpe.h header to get the enum Activity type.
|
||||
So cut-n-paste just enough of the header in to do this.
|
||||
|
||||
Signed-off-by: Tom Rix <trix@redhat.com>
|
||||
---
|
||||
torch/csrc/profiler/kineto_shim.h | 44 +++++++++++++++++++++++++++++++
|
||||
1 file changed, 44 insertions(+)
|
||||
|
||||
diff --git a/torch/csrc/profiler/kineto_shim.h b/torch/csrc/profiler/kineto_shim.h
|
||||
index e92cbf003d..68985ab7d0 100644
|
||||
--- a/torch/csrc/profiler/kineto_shim.h
|
||||
+++ b/torch/csrc/profiler/kineto_shim.h
|
||||
@@ -12,7 +12,51 @@
|
||||
#undef USE_KINETO
|
||||
#endif
|
||||
|
||||
+#ifdef USE_KINETO
|
||||
#include <ActivityType.h>
|
||||
+#else
|
||||
+namespace libkineto {
|
||||
+// copied from header
|
||||
+/*
|
||||
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
+ * All rights reserved.
|
||||
+ *
|
||||
+ * This source code is licensed under the BSD-style license found in the
|
||||
+ * LICENSE file in the root directory of this source tree.
|
||||
+ */
|
||||
+
|
||||
+// Note : All activity types are not enabled by default. Please add them
|
||||
+// at correct position in the enum
|
||||
+enum class ActivityType {
|
||||
+ // Activity types enabled by default
|
||||
+ CPU_OP = 0, // cpu side ops
|
||||
+ USER_ANNOTATION,
|
||||
+ GPU_USER_ANNOTATION,
|
||||
+ GPU_MEMCPY,
|
||||
+ GPU_MEMSET,
|
||||
+ CONCURRENT_KERNEL, // on-device kernels
|
||||
+ EXTERNAL_CORRELATION,
|
||||
+ CUDA_RUNTIME, // host side cuda runtime events
|
||||
+ CUDA_DRIVER, // host side cuda driver events
|
||||
+ CPU_INSTANT_EVENT, // host side point-like events
|
||||
+ PYTHON_FUNCTION,
|
||||
+ OVERHEAD, // CUPTI induced overhead events sampled from its overhead API.
|
||||
+
|
||||
+ // Optional Activity types
|
||||
+ CUDA_SYNC, // synchronization events between runtime and kernels
|
||||
+ GLOW_RUNTIME, // host side glow runtime events
|
||||
+ MTIA_RUNTIME, // host side MTIA runtime events
|
||||
+ CUDA_PROFILER_RANGE, // CUPTI Profiler range for performance metrics
|
||||
+ MTIA_CCP_EVENTS, // MTIA ondevice CCP events
|
||||
+ HPU_OP, // HPU host side runtime event
|
||||
+ XPU_RUNTIME, // host side xpu runtime events
|
||||
+
|
||||
+ ENUM_COUNT, // This is to add buffer and not used for any profiling logic. Add your new type before it.
|
||||
+ OPTIONAL_ACTIVITY_TYPE_START = CUDA_SYNC,
|
||||
+};
|
||||
+}
|
||||
+
|
||||
+#endif
|
||||
|
||||
#include <torch/csrc/Export.h>
|
||||
#include <torch/csrc/profiler/api.h>
|
||||
--
|
||||
2.43.0
|
||||
|
||||
|
|
@ -1,25 +0,0 @@
|
|||
From a5dff521691a17701b5a02ec75e84cfe1bf605f7 Mon Sep 17 00:00:00 2001
|
||||
From: Tom Rix <trix@redhat.com>
|
||||
Date: Sat, 3 Feb 2024 06:41:49 -0500
|
||||
Subject: [PATCH] can not use with c files
|
||||
|
||||
---
|
||||
cmake/Dependencies.cmake | 2 +-
|
||||
1 file changed, 1 insertion(+), 1 deletion(-)
|
||||
|
||||
diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake
|
||||
index 4dd8042058..5f91f3ffab 100644
|
||||
--- a/cmake/Dependencies.cmake
|
||||
+++ b/cmake/Dependencies.cmake
|
||||
@@ -1269,7 +1269,7 @@ if(USE_ROCM)
|
||||
list(APPEND HIP_CXX_FLAGS -Wno-duplicate-decl-specifier)
|
||||
list(APPEND HIP_CXX_FLAGS -DCAFFE2_USE_MIOPEN)
|
||||
list(APPEND HIP_CXX_FLAGS -DTHRUST_DEVICE_SYSTEM=THRUST_DEVICE_SYSTEM_HIP)
|
||||
- list(APPEND HIP_CXX_FLAGS -std=c++17)
|
||||
+# list(APPEND HIP_CXX_FLAGS -std=c++17)
|
||||
if(ROCM_VERSION_DEV VERSION_GREATER_EQUAL "6.0.0")
|
||||
list(APPEND HIP_CXX_FLAGS -DHIPBLAS_V2)
|
||||
endif()
|
||||
--
|
||||
2.43.0
|
||||
|
||||
|
|
@ -1,42 +0,0 @@
|
|||
From 214dc959acc809e1959643272c344ee5335d5a69 Mon Sep 17 00:00:00 2001
|
||||
From: Tom Rix <trix@redhat.com>
|
||||
Date: Thu, 1 Feb 2024 11:29:47 -0500
|
||||
Subject: [PATCH] cuda - hip signatures
|
||||
|
||||
---
|
||||
aten/src/ATen/cuda/detail/LazyNVRTC.cpp | 9 +++++++++
|
||||
1 file changed, 9 insertions(+)
|
||||
|
||||
diff --git a/aten/src/ATen/cuda/detail/LazyNVRTC.cpp b/aten/src/ATen/cuda/detail/LazyNVRTC.cpp
|
||||
index 1b85e7776e..bb6f88783a 100644
|
||||
--- a/aten/src/ATen/cuda/detail/LazyNVRTC.cpp
|
||||
+++ b/aten/src/ATen/cuda/detail/LazyNVRTC.cpp
|
||||
@@ -134,8 +134,13 @@ nvrtcResult nvrtcCreateProgram(nvrtcProgram *prog,
|
||||
const char *src,
|
||||
const char *name,
|
||||
int numHeaders,
|
||||
+#if !defined(USE_ROCM)
|
||||
const char * const *headers,
|
||||
const char * const *includeNames) {
|
||||
+#else
|
||||
+ const char **headers,
|
||||
+ const char **includeNames) {
|
||||
+#endif
|
||||
auto fn = reinterpret_cast<decltype(&nvrtcCreateProgram)>(getNVRTCLibrary().sym(__func__));
|
||||
if (!fn)
|
||||
throw std::runtime_error("Can't get nvrtcCreateProgram");
|
||||
@@ -150,7 +155,11 @@ NVRTC_STUB2(nvrtcGetPTX, nvrtcProgram, char *);
|
||||
NVRTC_STUB2(nvrtcGetCUBINSize, nvrtcProgram, size_t *);
|
||||
NVRTC_STUB2(nvrtcGetCUBIN, nvrtcProgram, char *);
|
||||
#endif
|
||||
+#if !defined(USE_ROCM)
|
||||
NVRTC_STUB3(nvrtcCompileProgram, nvrtcProgram, int, const char * const *);
|
||||
+#else
|
||||
+NVRTC_STUB3(nvrtcCompileProgram, nvrtcProgram, int, const char **);
|
||||
+#endif
|
||||
_STUB_1(NVRTC, nvrtcGetErrorString, const char *, nvrtcResult);
|
||||
NVRTC_STUB2(nvrtcGetProgramLogSize,nvrtcProgram, size_t*);
|
||||
NVRTC_STUB2(nvrtcGetProgramLog, nvrtcProgram, char *);
|
||||
--
|
||||
2.43.0
|
||||
|
||||
|
|
@ -1,94 +0,0 @@
|
|||
From 038ce9e44776e23f21c1816daa259bc0ea335088 Mon Sep 17 00:00:00 2001
|
||||
From: Tom Rix <trix@redhat.com>
|
||||
Date: Sat, 29 Jun 2024 07:06:09 -0700
|
||||
Subject: [PATCH] disable use of aotriton
|
||||
|
||||
---
|
||||
.../ATen/native/transformers/cuda/sdp_utils.cpp | 17 +++++++++++++++--
|
||||
1 file changed, 15 insertions(+), 2 deletions(-)
|
||||
|
||||
diff --git a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp
|
||||
index 214b02d8262e..7b3eb9dcd8cd 100644
|
||||
--- a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp
|
||||
+++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp
|
||||
@@ -19,9 +19,12 @@
|
||||
#include <c10/core/SymInt.h>
|
||||
#include <c10/util/string_view.h>
|
||||
|
||||
+#ifdef USE_FLASH_ATTENTION
|
||||
#if USE_ROCM
|
||||
#include <aotriton/flash.h>
|
||||
#endif
|
||||
+#endif
|
||||
+
|
||||
|
||||
/**
|
||||
* Note [SDPA Runtime Dispatch]
|
||||
@@ -182,6 +185,9 @@ bool check_sm_version(cudaDeviceProp * dprops) {
|
||||
|
||||
bool check_flash_attention_hardware_support(sdp_params const& params, bool debug) {
|
||||
// Check that the gpu is capable of running flash attention
|
||||
+#ifndef USE_FLASH_ATTENTION
|
||||
+ return false;
|
||||
+#else
|
||||
using sm80 = SMVersion<8, 0>;
|
||||
using sm90 = SMVersion<9, 0>;
|
||||
#if USE_ROCM
|
||||
@@ -209,9 +215,13 @@ bool check_flash_attention_hardware_support(sdp_params const& params, bool debug
|
||||
}
|
||||
#endif
|
||||
return true;
|
||||
+#endif
|
||||
}
|
||||
|
||||
bool check_mem_efficient_hardware_support(sdp_params const& params, bool debug) {
|
||||
+#ifndef USE_FLASH_ATTENTION
|
||||
+ return false;
|
||||
+#else
|
||||
// Mem Efficient attention supports hardware in the range [sm_50, sm_90]
|
||||
using sm50 = SMVersion<5, 0>;
|
||||
using sm90 = SMVersion<9, 0>;
|
||||
@@ -240,6 +250,7 @@ bool check_mem_efficient_hardware_support(sdp_params const& params, bool debug)
|
||||
}
|
||||
#endif
|
||||
return true;
|
||||
+#endif
|
||||
}
|
||||
|
||||
bool check_requires_grad_and_head_dim_gt192_constraints_on_sm86_89(
|
||||
@@ -554,7 +565,7 @@ bool can_use_flash_attention(sdp_params const& params, bool debug) {
|
||||
#ifndef USE_FLASH_ATTENTION
|
||||
TORCH_WARN_ONCE(!debug, "Torch was not compiled with flash attention.");
|
||||
return false;
|
||||
-#endif
|
||||
+#else
|
||||
|
||||
// Define gate functions that determine if a flash kernel can be ran
|
||||
// Replace with std::to_array when we migrate to c++20
|
||||
@@ -597,13 +608,14 @@ bool can_use_flash_attention(sdp_params const& params, bool debug) {
|
||||
}
|
||||
}
|
||||
return true;
|
||||
+#endif
|
||||
}
|
||||
|
||||
bool can_use_mem_efficient_attention(sdp_params const& params, bool debug) {
|
||||
#ifndef USE_MEM_EFF_ATTENTION
|
||||
TORCH_WARN_ONCE(!debug, "Torch was not compiled with memory efficient attention.");
|
||||
return false;
|
||||
-#endif
|
||||
+#else
|
||||
// Constraints specific to mem efficient attention
|
||||
constexpr auto greater_than_or_equal_sm80_mem_efficient_dtypes =
|
||||
array_of<at::ScalarType>(at::kHalf, at::kFloat, at::kBFloat16);
|
||||
@@ -663,6 +675,7 @@ bool can_use_mem_efficient_attention(sdp_params const& params, bool debug) {
|
||||
}
|
||||
#endif
|
||||
return check_tensor_dtype(params, less_than_sm80_mem_efficient_dtypes, debug);
|
||||
+#endif
|
||||
}
|
||||
|
||||
SDPBackend select_sdp_backend(sdp_params const& kernel_params) {
|
||||
--
|
||||
2.45.2
|
||||
|
||||
|
|
@ -1,226 +0,0 @@
|
|||
From b9d45eb1cc90696a4de76676221219e24423c709 Mon Sep 17 00:00:00 2001
|
||||
From: William Wen <williamwen@meta.com>
|
||||
Date: Wed, 3 Apr 2024 17:58:46 -0700
|
||||
Subject: [PATCH] [dynamo, 3.12] enable dynamo on 3.12, enable most dynamo
|
||||
unittests on 3.12 (#123216)
|
||||
|
||||
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123216
|
||||
Approved by: https://github.com/jansel, https://github.com/malfet
|
||||
---
|
||||
test/dynamo/test_autograd_function.py | 3 ++
|
||||
test/dynamo/test_misc.py | 63 +++++++++++++++++++++++++
|
||||
test/functorch/test_eager_transforms.py | 7 ++-
|
||||
test/run_test.py | 3 --
|
||||
torch/__init__.py | 5 +-
|
||||
torch/_dynamo/eval_frame.py | 4 +-
|
||||
torch/_dynamo/test_case.py | 8 +---
|
||||
7 files changed, 74 insertions(+), 19 deletions(-)
|
||||
|
||||
diff --git a/test/dynamo/test_autograd_function.py b/test/dynamo/test_autograd_function.py
|
||||
index d23fec607afa..bc5ebc767038 100644
|
||||
--- a/test/dynamo/test_autograd_function.py
|
||||
+++ b/test/dynamo/test_autograd_function.py
|
||||
@@ -2,6 +2,8 @@
|
||||
|
||||
import copy
|
||||
import math
|
||||
+import sys
|
||||
+import unittest
|
||||
|
||||
import torch
|
||||
|
||||
@@ -528,6 +530,7 @@ class AutogradFunctionTests(torch._dynamo.test_case.TestCase):
|
||||
# I pulled all of these test cases from test_autograd.py
|
||||
# In the future, we should make the Dynamo test suite actually
|
||||
# run on test_autograd.py (it's disabled right now) and delete these.
|
||||
+ @unittest.skipIf(sys.version_info >= (3, 12), "invalid free in 3.12+")
|
||||
def test_smoke_from_test_autograd(self):
|
||||
class Func(torch.autograd.Function):
|
||||
@staticmethod
|
||||
diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py
|
||||
index a73de8b1c7e9..8f54e0564e6b 100644
|
||||
--- a/test/dynamo/test_misc.py
|
||||
+++ b/test/dynamo/test_misc.py
|
||||
@@ -9760,6 +9760,69 @@ fn
|
||||
lambda mod: mod,
|
||||
)
|
||||
|
||||
+ @xfailIfPy311
|
||||
+ def test_outside_linear_module_free(self):
|
||||
+ # Compared to test_linear_module_free, the linear
|
||||
+ # layer is not the code object that is directly compiled.
|
||||
+ def model_inp_ctr():
|
||||
+ fc = torch.nn.Linear(100, 100)
|
||||
+
|
||||
+ class Mod(torch.nn.Module):
|
||||
+ def __init__(self):
|
||||
+ super().__init__()
|
||||
+ self.fc_ref = fc
|
||||
+
|
||||
+ def forward(self, x):
|
||||
+ return fc(x[0])
|
||||
+
|
||||
+ # return fc to keep it alive in _test_compile_model_free
|
||||
+ return Mod(), (torch.randn(100, 100), fc)
|
||||
+
|
||||
+ self._test_compile_model_free(model_inp_ctr, lambda mod: mod.fc_ref)
|
||||
+
|
||||
+ @unittest.skipIf(sys.version_info >= (3, 12), "leaks in 3.12+")
|
||||
+ def test_parameter_free(self):
|
||||
+ def model_inp_ctr():
|
||||
+ param = torch.nn.Parameter(torch.randn(100, 100))
|
||||
+
|
||||
+ class Mod(torch.nn.Module):
|
||||
+ def __init__(self):
|
||||
+ super().__init__()
|
||||
+ self.param = param
|
||||
+
|
||||
+ def forward(self, x):
|
||||
+ return self.param * x[0]
|
||||
+
|
||||
+ # return param to keep it alive in _test_compile_model_free
|
||||
+ return Mod(), (torch.randn(100, 100), param)
|
||||
+
|
||||
+ self._test_compile_model_free(model_inp_ctr, lambda mod: mod.param)
|
||||
+
|
||||
+ def test_raises_importerror1(self):
|
||||
+ @torch.compile(backend="eager")
|
||||
+ def fn(x):
|
||||
+ try:
|
||||
+ import some_module_that_surely_does_not_exist
|
||||
+
|
||||
+ return
|
||||
+ except ImportError:
|
||||
+ pass
|
||||
+ return x.sin()
|
||||
+
|
||||
+ x = torch.randn(8)
|
||||
+ self.assertEqual(fn(x), x.sin())
|
||||
+
|
||||
+ def test_raises_importerror2(self):
|
||||
+ @torch.compile(backend="eager")
|
||||
+ def fn(x):
|
||||
+ import some_module_that_surely_does_not_exist
|
||||
+
|
||||
+ return x + 1
|
||||
+
|
||||
+ x = torch.randn(8)
|
||||
+ with self.assertRaises(ImportError):
|
||||
+ fn(x)
|
||||
+
|
||||
def test_dynamo_cache_move_to_front(self):
|
||||
class Mod(torch.nn.Module):
|
||||
def __init__(self):
|
||||
diff --git a/test/functorch/test_eager_transforms.py b/test/functorch/test_eager_transforms.py
|
||||
index 09415cf8f48e..60790ec06059 100644
|
||||
--- a/test/functorch/test_eager_transforms.py
|
||||
+++ b/test/functorch/test_eager_transforms.py
|
||||
@@ -4762,8 +4762,7 @@ class TestCompileTransforms(TestCase):
|
||||
# Triton only supports GPU with SM70 or later.
|
||||
@expectedFailureIf((IS_ARM64 and not IS_MACOS) or
|
||||
IS_WINDOWS or
|
||||
- (TEST_CUDA and not SM70OrLater) or
|
||||
- (sys.version_info >= (3, 12)))
|
||||
+ (TEST_CUDA and not SM70OrLater))
|
||||
def test_compile_vmap_hessian(self, device):
|
||||
# The model and inputs are a smaller version
|
||||
# of code at benchmark repo:
|
||||
@@ -4792,8 +4791,8 @@ class TestCompileTransforms(TestCase):
|
||||
actual = opt_fn(params_and_buffers, x)
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
- # torch.compile is not supported on Windows or on Python 3.12+
|
||||
- @expectedFailureIf(IS_WINDOWS or (sys.version_info >= (3, 12)))
|
||||
+ # torch.compile is not supported on Windows
|
||||
+ @expectedFailureIf(IS_WINDOWS)
|
||||
@torch._dynamo.config.patch(suppress_errors=False)
|
||||
@torch._dynamo.config.patch(capture_func_transforms=True)
|
||||
@skipIfTorchDynamo("Do not test torch.compile on top of torch.compile")
|
||||
diff --git a/test/run_test.py b/test/run_test.py
|
||||
index e86af9623042..ebb14df4167d 100755
|
||||
--- a/test/run_test.py
|
||||
+++ b/test/run_test.py
|
||||
@@ -74,7 +74,6 @@ sys.path.remove(str(REPO_ROOT))
|
||||
RERUN_DISABLED_TESTS = os.getenv("PYTORCH_TEST_RERUN_DISABLED_TESTS", "0") == "1"
|
||||
DISTRIBUTED_TEST_PREFIX = "distributed"
|
||||
INDUCTOR_TEST_PREFIX = "inductor"
|
||||
-DYNAMO_TEST_PREFIX = "dynamo"
|
||||
|
||||
|
||||
# Note [ROCm parallel CI testing]
|
||||
@@ -324,7 +323,6 @@ JIT_EXECUTOR_TESTS = [
|
||||
]
|
||||
|
||||
INDUCTOR_TESTS = [test for test in TESTS if test.startswith(INDUCTOR_TEST_PREFIX)]
|
||||
-DYNAMO_TESTS = [test for test in TESTS if test.startswith(DYNAMO_TEST_PREFIX)]
|
||||
DISTRIBUTED_TESTS = [test for test in TESTS if test.startswith(DISTRIBUTED_TEST_PREFIX)]
|
||||
TORCH_EXPORT_TESTS = [test for test in TESTS if test.startswith("export")]
|
||||
FUNCTORCH_TESTS = [test for test in TESTS if test.startswith("functorch")]
|
||||
@@ -1361,7 +1359,6 @@ def get_selected_tests(options) -> List[str]:
|
||||
# these tests failing in Python 3.12 temporarily disabling
|
||||
if sys.version_info >= (3, 12):
|
||||
options.exclude.extend(INDUCTOR_TESTS)
|
||||
- options.exclude.extend(DYNAMO_TESTS)
|
||||
options.exclude.extend(
|
||||
[
|
||||
"functorch/test_dims",
|
||||
diff --git a/torch/__init__.py b/torch/__init__.py
|
||||
index d381712b4a35..26cdffe81d29 100644
|
||||
--- a/torch/__init__.py
|
||||
+++ b/torch/__init__.py
|
||||
@@ -1861,9 +1861,8 @@ def compile(model: Optional[Callable] = None, *,
|
||||
|
||||
"""
|
||||
_C._log_api_usage_once("torch.compile")
|
||||
- # Temporary until we get proper support for python 3.12
|
||||
- if sys.version_info >= (3, 12):
|
||||
- raise RuntimeError("Dynamo is not supported on Python 3.12+")
|
||||
+ if sys.version_info >= (3, 13):
|
||||
+ raise RuntimeError("Dynamo is not supported on Python 3.13+")
|
||||
|
||||
# Decorator mode
|
||||
if model is None:
|
||||
diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py
|
||||
index 53ab0df3a947..0a80eeea99ed 100644
|
||||
--- a/torch/_dynamo/eval_frame.py
|
||||
+++ b/torch/_dynamo/eval_frame.py
|
||||
@@ -589,8 +589,8 @@ class _NullDecorator(contextlib.nullcontext): # type: ignore[type-arg]
|
||||
|
||||
|
||||
def check_if_dynamo_supported():
|
||||
- if sys.version_info >= (3, 12):
|
||||
- raise RuntimeError("Python 3.12+ not yet supported for torch.compile")
|
||||
+ if sys.version_info >= (3, 13):
|
||||
+ raise RuntimeError("Python 3.13+ not yet supported for torch.compile")
|
||||
|
||||
|
||||
def is_dynamo_supported():
|
||||
diff --git a/torch/_dynamo/test_case.py b/torch/_dynamo/test_case.py
|
||||
index e3cbef09eaae..297ea6e2bc2a 100644
|
||||
--- a/torch/_dynamo/test_case.py
|
||||
+++ b/torch/_dynamo/test_case.py
|
||||
@@ -1,7 +1,6 @@
|
||||
import contextlib
|
||||
import importlib
|
||||
import logging
|
||||
-import sys
|
||||
|
||||
import torch
|
||||
import torch.testing
|
||||
@@ -20,12 +19,7 @@ log = logging.getLogger(__name__)
|
||||
def run_tests(needs=()):
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
|
||||
- if (
|
||||
- TEST_WITH_TORCHDYNAMO
|
||||
- or IS_WINDOWS
|
||||
- or TEST_WITH_CROSSREF
|
||||
- or sys.version_info >= (3, 12)
|
||||
- ):
|
||||
+ if TEST_WITH_TORCHDYNAMO or IS_WINDOWS or TEST_WITH_CROSSREF:
|
||||
return # skip testing
|
||||
|
||||
if isinstance(needs, str):
|
||||
--
|
||||
2.44.0
|
||||
|
||||
|
|
@ -1,54 +0,0 @@
|
|||
From ba2cf11d1bf1dd5086c8e793198a697d4179cca7 Mon Sep 17 00:00:00 2001
|
||||
From: Kefu Chai <tchaikov@gmail.com>
|
||||
Date: Tue, 16 Jul 2024 08:00:22 +0800
|
||||
Subject: [PATCH] include fmt/ranges.h for using fmt::join()
|
||||
|
||||
fmt::join() was moved into fmt/ranges.h in fmt 11, so include this
|
||||
header for using it.
|
||||
|
||||
Signed-off-by: Kefu Chai <tchaikov@gmail.com>
|
||||
---
|
||||
torch/csrc/distributed/c10d/socket.cpp | 1 +
|
||||
torch/csrc/profiler/standalone/execution_trace_observer.cpp | 1 +
|
||||
torch/csrc/profiler/util.cpp | 1 +
|
||||
3 files changed, 3 insertions(+)
|
||||
|
||||
diff --git a/torch/csrc/distributed/c10d/socket.cpp b/torch/csrc/distributed/c10d/socket.cpp
|
||||
index 5013f2540..cbcd33a19 100644
|
||||
--- a/torch/csrc/distributed/c10d/socket.cpp
|
||||
+++ b/torch/csrc/distributed/c10d/socket.cpp
|
||||
@@ -31,6 +31,7 @@ C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wdeprecated")
|
||||
#include <fmt/chrono.h>
|
||||
C10_DIAGNOSTIC_POP()
|
||||
#include <fmt/format.h>
|
||||
+#include <fmt/ranges.h>
|
||||
|
||||
#include <torch/csrc/distributed/c10d/error.h>
|
||||
#include <torch/csrc/distributed/c10d/exception.h>
|
||||
diff --git a/torch/csrc/profiler/standalone/execution_trace_observer.cpp b/torch/csrc/profiler/standalone/execution_trace_observer.cpp
|
||||
index 2ef2e5423..fb053e916 100644
|
||||
--- a/torch/csrc/profiler/standalone/execution_trace_observer.cpp
|
||||
+++ b/torch/csrc/profiler/standalone/execution_trace_observer.cpp
|
||||
@@ -10,6 +10,7 @@
|
||||
#endif // _WIN32
|
||||
|
||||
#include <fmt/format.h>
|
||||
+#include <fmt/ranges.h>
|
||||
#include <chrono>
|
||||
#include <cmath>
|
||||
#include <fstream>
|
||||
diff --git a/torch/csrc/profiler/util.cpp b/torch/csrc/profiler/util.cpp
|
||||
index 896bf606c..c229ce130 100644
|
||||
--- a/torch/csrc/profiler/util.cpp
|
||||
+++ b/torch/csrc/profiler/util.cpp
|
||||
@@ -5,6 +5,7 @@
|
||||
#include <c10/util/ArrayRef.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <fmt/format.h>
|
||||
+#include <fmt/ranges.h>
|
||||
|
||||
#ifdef USE_KINETO
|
||||
#include <libkineto.h>
|
||||
--
|
||||
2.45.2
|
||||
|
||||
|
|
@ -1,54 +0,0 @@
|
|||
From b3b307add5724ee5730f161e16594fa702f34a19 Mon Sep 17 00:00:00 2001
|
||||
From: Tom Rix <trix@redhat.com>
|
||||
Date: Sat, 3 Feb 2024 08:20:28 -0500
|
||||
Subject: [PATCH] no third_party FXdiv
|
||||
|
||||
---
|
||||
caffe2/CMakeLists.txt | 24 ++++++++++++------------
|
||||
1 file changed, 12 insertions(+), 12 deletions(-)
|
||||
|
||||
diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt
|
||||
index b2f3adbfae..80a5625c8d 100644
|
||||
--- a/caffe2/CMakeLists.txt
|
||||
+++ b/caffe2/CMakeLists.txt
|
||||
@@ -110,15 +110,15 @@ endif()
|
||||
# Note: the folders that are being commented out have not been properly
|
||||
# addressed yet.
|
||||
|
||||
-if(NOT MSVC AND USE_XNNPACK)
|
||||
- if(NOT TARGET fxdiv)
|
||||
- set(FXDIV_BUILD_TESTS OFF CACHE BOOL "")
|
||||
- set(FXDIV_BUILD_BENCHMARKS OFF CACHE BOOL "")
|
||||
- add_subdirectory(
|
||||
- "${FXDIV_SOURCE_DIR}"
|
||||
- "${CMAKE_BINARY_DIR}/FXdiv")
|
||||
- endif()
|
||||
-endif()
|
||||
+#if(NOT MSVC AND USE_XNNPACK)
|
||||
+# if(NOT TARGET fxdiv)
|
||||
+# set(FXDIV_BUILD_TESTS OFF CACHE BOOL "")
|
||||
+# set(FXDIV_BUILD_BENCHMARKS OFF CACHE BOOL "")
|
||||
+# add_subdirectory(
|
||||
+# "${FXDIV_SOURCE_DIR}"
|
||||
+# "${CMAKE_BINARY_DIR}/FXdiv")
|
||||
+# endif()
|
||||
+#endif()
|
||||
|
||||
add_subdirectory(core)
|
||||
add_subdirectory(serialize)
|
||||
@@ -1081,9 +1081,9 @@ if(USE_XPU)
|
||||
target_compile_definitions(torch_xpu PRIVATE USE_XPU)
|
||||
endif()
|
||||
|
||||
-if(NOT MSVC AND USE_XNNPACK)
|
||||
- TARGET_LINK_LIBRARIES(torch_cpu PRIVATE fxdiv)
|
||||
-endif()
|
||||
+#if(NOT MSVC AND USE_XNNPACK)
|
||||
+# TARGET_LINK_LIBRARIES(torch_cpu PRIVATE fxdiv)
|
||||
+#endif()
|
||||
|
||||
# ==========================================================
|
||||
# formerly-libtorch flags
|
||||
--
|
||||
2.43.0
|
||||
|
||||
|
|
@ -1,65 +0,0 @@
|
|||
From 2ce255b75760a0a513fb1706629b416f76a5c822 Mon Sep 17 00:00:00 2001
|
||||
From: Tom Rix <trix@redhat.com>
|
||||
Date: Sat, 3 Feb 2024 08:16:04 -0500
|
||||
Subject: [PATCH] no third_party fmt
|
||||
|
||||
---
|
||||
c10/CMakeLists.txt | 2 +-
|
||||
cmake/Dependencies.cmake | 6 +++---
|
||||
torch/CMakeLists.txt | 2 +-
|
||||
3 files changed, 5 insertions(+), 5 deletions(-)
|
||||
|
||||
diff --git a/c10/CMakeLists.txt b/c10/CMakeLists.txt
|
||||
index 1f742f4c176..4fa08913bdd 100644
|
||||
--- a/c10/CMakeLists.txt
|
||||
+++ b/c10/CMakeLists.txt
|
||||
@@ -87,7 +87,7 @@ endif()
|
||||
if(C10_USE_GLOG)
|
||||
target_link_libraries(c10 PUBLIC glog::glog)
|
||||
endif()
|
||||
-target_link_libraries(c10 PRIVATE fmt::fmt-header-only)
|
||||
+target_link_libraries(c10 PRIVATE fmt)
|
||||
|
||||
if(C10_USE_NUMA)
|
||||
message(STATUS "NUMA paths:")
|
||||
diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake
|
||||
index 6f5a2d5feff..42fbf80f6e8 100644
|
||||
--- a/cmake/Dependencies.cmake
|
||||
+++ b/cmake/Dependencies.cmake
|
||||
@@ -1837,7 +1837,7 @@ endif()
|
||||
#
|
||||
set(TEMP_BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS})
|
||||
set(BUILD_SHARED_LIBS OFF CACHE BOOL "Build shared libs" FORCE)
|
||||
-add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/fmt)
|
||||
+# add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/fmt)
|
||||
|
||||
# Disable compiler feature checks for `fmt`.
|
||||
#
|
||||
@@ -1846,9 +1846,9 @@ add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/fmt)
|
||||
# CMAKE_CXX_FLAGS in ways that break feature checks. Since we already know
|
||||
# `fmt` is compatible with a superset of the compilers that PyTorch is, it
|
||||
# shouldn't be too bad to just disable the checks.
|
||||
-set_target_properties(fmt-header-only PROPERTIES INTERFACE_COMPILE_FEATURES "")
|
||||
+# set_target_properties(fmt-header-only PROPERTIES INTERFACE_COMPILE_FEATURES "")
|
||||
|
||||
-list(APPEND Caffe2_DEPENDENCY_LIBS fmt::fmt-header-only)
|
||||
+# list(APPEND Caffe2_DEPENDENCY_LIBS fmt::fmt-header-only)
|
||||
set(BUILD_SHARED_LIBS ${TEMP_BUILD_SHARED_LIBS} CACHE BOOL "Build shared libs" FORCE)
|
||||
|
||||
# ---[ Kineto
|
||||
diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt
|
||||
index 97a72eed55b..9e5014d1980 100644
|
||||
--- a/torch/CMakeLists.txt
|
||||
+++ b/torch/CMakeLists.txt
|
||||
@@ -80,7 +80,7 @@ set(TORCH_PYTHON_LINK_LIBRARIES
|
||||
python::python
|
||||
pybind::pybind11
|
||||
shm
|
||||
- fmt::fmt-header-only
|
||||
+ fmt
|
||||
ATEN_CPU_FILES_GEN_LIB)
|
||||
|
||||
if(USE_ASAN AND TARGET Sanitizer::address)
|
||||
--
|
||||
2.43.2
|
||||
|
||||
|
|
@ -1,36 +0,0 @@
|
|||
From 8cb61cf9282102ac225645fcc9fb4a1bb7cb15a2 Mon Sep 17 00:00:00 2001
|
||||
From: Tom Rix <trix@redhat.com>
|
||||
Date: Sat, 3 Feb 2024 08:11:55 -0500
|
||||
Subject: [PATCH] no third_party foxi
|
||||
|
||||
---
|
||||
cmake/Dependencies.cmake | 6 +++---
|
||||
1 file changed, 3 insertions(+), 3 deletions(-)
|
||||
|
||||
diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake
|
||||
index 5f91f3ffab..8e1461af81 100644
|
||||
--- a/cmake/Dependencies.cmake
|
||||
+++ b/cmake/Dependencies.cmake
|
||||
@@ -1567,7 +1567,7 @@ if(CAFFE2_CMAKE_BUILDING_WITH_MAIN_REPO AND NOT INTERN_DISABLE_ONNX)
|
||||
set_target_properties(onnx_proto PROPERTIES CXX_STANDARD 17)
|
||||
endif()
|
||||
endif()
|
||||
- add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/../third_party/foxi EXCLUDE_FROM_ALL)
|
||||
+ # add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/../third_party/foxi EXCLUDE_FROM_ALL)
|
||||
|
||||
add_definitions(-DONNX_NAMESPACE=${ONNX_NAMESPACE})
|
||||
if(NOT USE_SYSTEM_ONNX)
|
||||
@@ -1600,8 +1600,8 @@ if(CAFFE2_CMAKE_BUILDING_WITH_MAIN_REPO AND NOT INTERN_DISABLE_ONNX)
|
||||
message("-- Found onnx: ${ONNX_LIBRARY} ${ONNX_PROTO_LIBRARY}")
|
||||
list(APPEND Caffe2_DEPENDENCY_LIBS onnx_proto onnx)
|
||||
endif()
|
||||
- include_directories(${FOXI_INCLUDE_DIRS})
|
||||
- list(APPEND Caffe2_DEPENDENCY_LIBS foxi_loader)
|
||||
+# include_directories(${FOXI_INCLUDE_DIRS})
|
||||
+# list(APPEND Caffe2_DEPENDENCY_LIBS foxi_loader)
|
||||
# Recover the build shared libs option.
|
||||
set(BUILD_SHARED_LIBS ${TEMP_BUILD_SHARED_LIBS})
|
||||
endif()
|
||||
--
|
||||
2.43.0
|
||||
|
||||
|
|
@ -1,112 +0,0 @@
|
|||
From 027dad1eaed51c1172e2497da611e3267d42d2f0 Mon Sep 17 00:00:00 2001
|
||||
From: Tom Rix <Tom.Rix@amd.com>
|
||||
Date: Fri, 28 Mar 2025 09:16:03 -0700
|
||||
Subject: [PATCH] python-torch: disable ck
|
||||
|
||||
---
|
||||
aten/src/ATen/CMakeLists.txt | 7 +++----
|
||||
aten/src/ATen/Context.cpp | 1 +
|
||||
aten/src/ATen/cuda/CUDABlas.cpp | 10 +++++-----
|
||||
3 files changed, 9 insertions(+), 9 deletions(-)
|
||||
|
||||
diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt
|
||||
index 085af373ec22..84808880e51c 100644
|
||||
--- a/aten/src/ATen/CMakeLists.txt
|
||||
+++ b/aten/src/ATen/CMakeLists.txt
|
||||
@@ -134,7 +134,7 @@ file(GLOB native_cuda_cu "native/cuda/*.cu")
|
||||
file(GLOB native_cuda_cpp "native/cuda/*.cpp")
|
||||
file(GLOB native_cuda_h "native/cuda/*.h" "native/cuda/*.cuh")
|
||||
file(GLOB native_cuda_linalg_cpp "native/cuda/linalg/*.cpp")
|
||||
-file(GLOB native_hip_h "native/hip/*.h" "native/hip/*.cuh" "native/hip/bgemm_kernels/*.h")
|
||||
+file(GLOB native_hip_h "native/hip/*.h" "native/hip/*.cuh" )
|
||||
file(GLOB native_cudnn_cpp "native/cudnn/*.cpp")
|
||||
file(GLOB native_sparse_cuda_cu "native/sparse/cuda/*.cu")
|
||||
file(GLOB native_sparse_cuda_cpp "native/sparse/cuda/*.cpp")
|
||||
@@ -145,7 +145,7 @@ file(GLOB native_nested_h "native/nested/*.h")
|
||||
file(GLOB native_nested_cuda_cu "native/nested/cuda/*.cu")
|
||||
file(GLOB native_nested_cuda_cpp "native/nested/cuda/*.cpp")
|
||||
|
||||
-file(GLOB native_hip_hip "native/hip/*.hip" "native/hip/bgemm_kernels/*.hip")
|
||||
+file(GLOB native_hip_hip "native/hip/*.hip" )
|
||||
file(GLOB native_hip_cpp "native/hip/*.cpp")
|
||||
file(GLOB native_hip_linalg_cpp "native/hip/linalg/*.cpp")
|
||||
file(GLOB native_miopen_cpp "native/miopen/*.cpp")
|
||||
@@ -361,13 +361,12 @@ endif()
|
||||
${native_quantized_hip_hip}
|
||||
${native_transformers_hip_hip} ${native_transformers_src_hip_hip}
|
||||
)
|
||||
- if(WIN32) # Windows doesn't support Composable Kernels and Triton
|
||||
file(GLOB native_hip_bgemm "native/hip/bgemm_kernels/*.hip")
|
||||
file(GLOB native_hip_ck "native/hip/ck*.hip")
|
||||
exclude(ATen_HIP_SRCS "${ATen_HIP_SRCS}"
|
||||
${native_hip_bgemm} ${native_hip_ck}
|
||||
${native_transformers_hip_hip} ${native_transformers_hip_cpp})
|
||||
- endif()
|
||||
+
|
||||
# TODO: Codegen separate files for HIP and use those (s/cuda_generated_sources/hip_generated_sources)
|
||||
list(APPEND all_hip_cpp
|
||||
${native_nested_hip_cpp}
|
||||
diff --git a/aten/src/ATen/Context.cpp b/aten/src/ATen/Context.cpp
|
||||
index f598fc3a39d3..03dab6ff38fe 100644
|
||||
--- a/aten/src/ATen/Context.cpp
|
||||
+++ b/aten/src/ATen/Context.cpp
|
||||
@@ -355,6 +355,7 @@ at::BlasBackend Context::blasPreferredBackend() {
|
||||
}
|
||||
|
||||
void Context::setBlasPreferredBackend(at::BlasBackend b) {
|
||||
+ return;
|
||||
#ifdef _MSC_VER
|
||||
TORCH_WARN_ONCE(
|
||||
"torch.backends.cuda.preferred_blas_library is an experimental feature. "
|
||||
diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp
|
||||
index a62b028fd4ff..cba38426ea1f 100644
|
||||
--- a/aten/src/ATen/cuda/CUDABlas.cpp
|
||||
+++ b/aten/src/ATen/cuda/CUDABlas.cpp
|
||||
@@ -708,7 +708,7 @@ void bgemm_internal<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16))
|
||||
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
|
||||
bgemm_internal_cublaslt<at::BFloat16>(CUDABLAS_BGEMM_ARGS(at::BFloat16));
|
||||
}
|
||||
-#ifdef USE_ROCM
|
||||
+#ifdef USE_ROCM_NO_CK
|
||||
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
|
||||
at::native::bgemm_internal_ck<at::BFloat16>(CUDABLAS_BGEMM_ARGS(at::BFloat16));
|
||||
}
|
||||
@@ -1061,7 +1061,7 @@ void gemm_internal<double>(CUDABLAS_GEMM_ARGTYPES(double))
|
||||
gemm_internal_cublaslt<double>(CUDABLAS_GEMM_ARGS(double));
|
||||
#endif
|
||||
}
|
||||
-#ifdef USE_ROCM
|
||||
+#ifdef USE_ROCM_NO_CK
|
||||
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
|
||||
at::native::gemm_internal_ck<double>(CUDABLAS_GEMM_ARGS(double));
|
||||
}
|
||||
@@ -1077,7 +1077,7 @@ void gemm_internal<float>(CUDABLAS_GEMM_ARGTYPES(float))
|
||||
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
|
||||
gemm_internal_cublaslt<float>(CUDABLAS_GEMM_ARGS(float));
|
||||
}
|
||||
-#ifdef USE_ROCM
|
||||
+#ifdef USE_ROCM_NO_CK
|
||||
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
|
||||
at::native::gemm_internal_ck<float>(CUDABLAS_GEMM_ARGS(float));
|
||||
}
|
||||
@@ -1125,7 +1125,7 @@ void gemm_internal<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half))
|
||||
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
|
||||
gemm_internal_cublaslt<at::Half>(CUDABLAS_GEMM_ARGS(at::Half));
|
||||
}
|
||||
-#ifdef USE_ROCM
|
||||
+#ifdef USE_ROCM_NO_CK
|
||||
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
|
||||
at::native::gemm_internal_ck<at::Half>(CUDABLAS_GEMM_ARGS(at::Half));
|
||||
}
|
||||
@@ -1141,7 +1141,7 @@ void gemm_internal<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16))
|
||||
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
|
||||
gemm_internal_cublaslt<at::BFloat16>(CUDABLAS_GEMM_ARGS(at::BFloat16));
|
||||
}
|
||||
-#ifdef USE_ROCM
|
||||
+#ifdef USE_ROCM_NO_CK
|
||||
else if (at::globalContext().blasPreferredBackend() == BlasBackend::Ck) {
|
||||
at::native::gemm_internal_ck<at::BFloat16>(CUDABLAS_GEMM_ARGS(at::BFloat16));
|
||||
}
|
||||
--
|
||||
2.48.1
|
||||
|
||||
|
|
@ -1,25 +0,0 @@
|
|||
From 58ccda271e8f51c3fa5b7518cf6ee52ce204fd37 Mon Sep 17 00:00:00 2001
|
||||
From: Tom Rix <trix@redhat.com>
|
||||
Date: Thu, 22 Feb 2024 09:28:11 -0500
|
||||
Subject: [PATCH] reenable foxi linking
|
||||
|
||||
---
|
||||
cmake/Dependencies.cmake | 2 +-
|
||||
1 file changed, 1 insertion(+), 1 deletion(-)
|
||||
|
||||
diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake
|
||||
index 42fbf80f6e8..bc3a2dc6fee 100644
|
||||
--- a/cmake/Dependencies.cmake
|
||||
+++ b/cmake/Dependencies.cmake
|
||||
@@ -1604,7 +1604,7 @@ if(CAFFE2_CMAKE_BUILDING_WITH_MAIN_REPO AND NOT INTERN_DISABLE_ONNX)
|
||||
list(APPEND Caffe2_DEPENDENCY_LIBS onnx_proto onnx)
|
||||
endif()
|
||||
# include_directories(${FOXI_INCLUDE_DIRS})
|
||||
-# list(APPEND Caffe2_DEPENDENCY_LIBS foxi_loader)
|
||||
+ list(APPEND Caffe2_DEPENDENCY_LIBS foxi_loader)
|
||||
# Recover the build shared libs option.
|
||||
set(BUILD_SHARED_LIBS ${TEMP_BUILD_SHARED_LIBS})
|
||||
endif()
|
||||
--
|
||||
2.43.2
|
||||
|
||||
|
|
@ -1,25 +0,0 @@
|
|||
From 04dd33db93b852fdfd7ea408813080b2e2026650 Mon Sep 17 00:00:00 2001
|
||||
From: Tom Rix <trix@redhat.com>
|
||||
Date: Sat, 3 Feb 2024 06:41:20 -0500
|
||||
Subject: [PATCH] silence an assert
|
||||
|
||||
---
|
||||
aten/src/ATen/native/cuda/IndexKernel.cu | 2 +-
|
||||
1 file changed, 1 insertion(+), 1 deletion(-)
|
||||
|
||||
diff --git a/aten/src/ATen/native/cuda/IndexKernel.cu b/aten/src/ATen/native/cuda/IndexKernel.cu
|
||||
index 657c0c77b3..b406aa6687 100644
|
||||
--- a/aten/src/ATen/native/cuda/IndexKernel.cu
|
||||
+++ b/aten/src/ATen/native/cuda/IndexKernel.cu
|
||||
@@ -249,7 +249,7 @@ void index_put_kernel_quantized_cuda(TensorIterator& iter, const IntArrayRef ind
|
||||
|
||||
gpu_index_kernel(iter, index_size, index_stride, [inv_scale, zero_point, qmin, qmax]C10_DEVICE(char* const out_data, const char* const in_data, const int64_t offset) {
|
||||
int64_t qvalue = static_cast<int64_t>(zero_point + nearbyintf(*(float*)in_data * inv_scale));
|
||||
- qvalue = std::clamp(qvalue, qmin, qmax);
|
||||
+ //qvalue = std::clamp(qvalue, qmin, qmax);
|
||||
*(scalar_t*)(out_data + offset) = static_cast<scalar_t>(qvalue);
|
||||
});
|
||||
});
|
||||
--
|
||||
2.43.0
|
||||
|
||||
|
|
@ -1,88 +0,0 @@
|
|||
From f646e0f04ae591c8f2d8a0cd24b035725c57659b Mon Sep 17 00:00:00 2001
|
||||
From: Tom Rix <Tom.Rix@amd.com>
|
||||
Date: Thu, 23 Jan 2025 08:24:22 -0800
|
||||
Subject: [PATCH] torch: paper over c++ assert
|
||||
|
||||
---
|
||||
aten/src/ATen/native/sparse/FlattenIndicesCommon.h | 2 ++
|
||||
.../ATen/native/sparse/SparseBinaryOpIntersectionCommon.h | 5 +++++
|
||||
.../src/ATen/native/sparse/ValidateCompressedIndicesCommon.h | 2 ++
|
||||
3 files changed, 9 insertions(+)
|
||||
|
||||
diff --git a/aten/src/ATen/native/sparse/FlattenIndicesCommon.h b/aten/src/ATen/native/sparse/FlattenIndicesCommon.h
|
||||
index 0e79ed809ae6..a3cec8aaf78b 100644
|
||||
--- a/aten/src/ATen/native/sparse/FlattenIndicesCommon.h
|
||||
+++ b/aten/src/ATen/native/sparse/FlattenIndicesCommon.h
|
||||
@@ -69,11 +69,13 @@ Tensor _flatten_indices_impl(const Tensor& indices, IntArrayRef size) {
|
||||
[=] FUNCAPI (int64_t nnz_idx) -> int64_t {
|
||||
const auto* RESTRICT ptr_indices_dim = ptr_indices + nnz_idx * indices_nnz_stride;
|
||||
auto hash = static_cast<int64_t>(0);
|
||||
+#if 0
|
||||
for (int64_t dim = 0; dim < sparse_dim; ++dim) {
|
||||
const auto dim_hash_coeff = hash_coeffs[dim];
|
||||
const auto dim_index = ptr_indices_dim[dim * indices_dim_stride];
|
||||
hash += dim_index * dim_hash_coeff;
|
||||
}
|
||||
+#endif
|
||||
return hash;
|
||||
});
|
||||
}
|
||||
diff --git a/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionCommon.h b/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionCommon.h
|
||||
index c0b94bf39d54..8de4900b7a01 100644
|
||||
--- a/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionCommon.h
|
||||
+++ b/aten/src/ATen/native/sparse/SparseBinaryOpIntersectionCommon.h
|
||||
@@ -279,12 +279,15 @@ void _sparse_binary_op_intersection_kernel_impl(
|
||||
if (!ptr_indices) {
|
||||
return hash;
|
||||
}
|
||||
+#if 0
|
||||
+// /usr/lib/gcc/x86_64-redhat-linux/15/../../../../include/c++/15/array:219:2: error: reference to __host__ function '__glibcxx_assert_fail' in __host__ __device__ function
|
||||
const auto* RESTRICT ptr_indices_dim = ptr_indices + nnz_idx * indices_nnz_stride;
|
||||
for (int64_t dim = 0; dim < sparse_dim; ++dim) {
|
||||
const auto dim_hash_coeff = hash_coeffs[dim];
|
||||
const auto dim_index = ptr_indices_dim[dim * indices_dim_stride];
|
||||
hash += dim_index * dim_hash_coeff;
|
||||
}
|
||||
+#endif
|
||||
return hash;
|
||||
});
|
||||
}
|
||||
@@ -364,6 +367,7 @@ void _sparse_binary_op_intersection_kernel_impl(
|
||||
if (hash_ptr) {
|
||||
hash = hash_ptr[nnz_idx];
|
||||
} else if (sparse_dim) {
|
||||
+#if 0
|
||||
// Compute hash value
|
||||
const auto* RESTRICT ptr_indices_dim = ptr_indices + nnz_idx * indices_nnz_stride;
|
||||
for (int64_t dim = 0; dim < sparse_dim; ++dim) {
|
||||
@@ -371,6 +375,7 @@ void _sparse_binary_op_intersection_kernel_impl(
|
||||
const auto dim_index = ptr_indices_dim[dim * indices_dim_stride];
|
||||
hash += dim_index * dim_hash_coeff;
|
||||
}
|
||||
+#endif
|
||||
}
|
||||
|
||||
// Perform hash values intersection
|
||||
diff --git a/aten/src/ATen/native/sparse/ValidateCompressedIndicesCommon.h b/aten/src/ATen/native/sparse/ValidateCompressedIndicesCommon.h
|
||||
index ec4c084a39cc..9bc9655b0afa 100644
|
||||
--- a/aten/src/ATen/native/sparse/ValidateCompressedIndicesCommon.h
|
||||
+++ b/aten/src/ATen/native/sparse/ValidateCompressedIndicesCommon.h
|
||||
@@ -341,6 +341,7 @@ void _validate_compressed_sparse_indices_kernel(
|
||||
// assuming idx contiguity per batch:
|
||||
int64_t tmp = batch_idx * nnz;
|
||||
// `nnz == idx_sizes[idx_ndims - 1]` is checked above as `nnz == idx.size(-1)`
|
||||
+#if 0
|
||||
for (int i = idx_ndims - 1;
|
||||
i >= 0 && nnz > 0; // break early when nnz==0
|
||||
i--) {
|
||||
@@ -348,6 +349,7 @@ void _validate_compressed_sparse_indices_kernel(
|
||||
idx_offset += (tmp - div * idx_sizes[i]) * idx_strides[i];
|
||||
tmp = div;
|
||||
}
|
||||
+#endif
|
||||
const auto* RESTRICT ptr_idx_batch = ptr_idx + idx_offset;
|
||||
_check_idx_sorted_distinct_vals_slices_with_cidx<
|
||||
cdim_name,
|
||||
--
|
||||
2.48.1
|
||||
|
||||
|
|
@ -1,34 +0,0 @@
|
|||
From 4248211ce9a9de81bb3ade5d421ba709b19ead08 Mon Sep 17 00:00:00 2001
|
||||
From: Tom Rix <trix@redhat.com>
|
||||
Date: Sat, 3 Feb 2024 15:01:28 -0500
|
||||
Subject: [PATCH] use any hip
|
||||
|
||||
---
|
||||
cmake/public/LoadHIP.cmake | 4 ++--
|
||||
1 file changed, 2 insertions(+), 2 deletions(-)
|
||||
|
||||
diff --git a/cmake/public/LoadHIP.cmake b/cmake/public/LoadHIP.cmake
|
||||
index 1abeb06228..28458c4146 100644
|
||||
--- a/cmake/public/LoadHIP.cmake
|
||||
+++ b/cmake/public/LoadHIP.cmake
|
||||
@@ -30,7 +30,7 @@ endif()
|
||||
message("Building PyTorch for GPU arch: ${PYTORCH_ROCM_ARCH}")
|
||||
|
||||
# Add HIP to the CMAKE Module Path
|
||||
-set(CMAKE_MODULE_PATH ${ROCM_PATH}/lib/cmake/hip ${CMAKE_MODULE_PATH})
|
||||
+set(CMAKE_MODULE_PATH ${ROCM_PATH}/lib64/cmake/hip ${CMAKE_MODULE_PATH})
|
||||
|
||||
macro(find_package_and_print_version PACKAGE_NAME)
|
||||
find_package("${PACKAGE_NAME}" ${ARGN})
|
||||
@@ -38,7 +38,7 @@ macro(find_package_and_print_version PACKAGE_NAME)
|
||||
endmacro()
|
||||
|
||||
# Find the HIP Package
|
||||
-find_package_and_print_version(HIP 1.0)
|
||||
+find_package_and_print_version(HIP MODULE)
|
||||
|
||||
if(HIP_FOUND)
|
||||
set(PYTORCH_FOUND_HIP TRUE)
|
||||
--
|
||||
2.43.0
|
||||
|
||||
|
|
@ -1,47 +0,0 @@
|
|||
From 091b7fe1ccbb5e4ff4ac6017d42bacb869f61a27 Mon Sep 17 00:00:00 2001
|
||||
From: Tom Rix <trix@redhat.com>
|
||||
Date: Sat, 20 Jul 2024 05:37:15 -0600
|
||||
Subject: [PATCH] Add cmake option USE_SYSTEM_FBGEMM
|
||||
|
||||
Signed-off-by: Tom Rix <trix@redhat.com>
|
||||
---
|
||||
CMakeLists.txt | 1 +
|
||||
cmake/Dependencies.cmake | 3 ++-
|
||||
2 files changed, 3 insertions(+), 1 deletion(-)
|
||||
|
||||
diff --git a/CMakeLists.txt b/CMakeLists.txt
|
||||
index c4cd4b2c2a98..2068f7c6c4f2 100644
|
||||
--- a/CMakeLists.txt
|
||||
+++ b/CMakeLists.txt
|
||||
@@ -253,6 +253,7 @@ cmake_dependent_option(USE_STATIC_CUDNN "Use cuDNN static libraries" OFF
|
||||
"USE_CUDNN" OFF)
|
||||
cmake_dependent_option(USE_CUSPARSELT "Use cuSPARSELt" ON "USE_CUDA" OFF)
|
||||
option(USE_FBGEMM "Use FBGEMM (quantized 8-bit server operators)" ON)
|
||||
+option(USE_SYSTEM_FBGEMM "Use system-wide FBGEMM" OFF)
|
||||
option(USE_KINETO "Use Kineto profiling library" ON)
|
||||
option(USE_CUPTI_SO "Use CUPTI as a shared library" ON)
|
||||
option(USE_FAKELOWP "Use FakeLowp operators" OFF)
|
||||
diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake
|
||||
index f1f2eb7cec31..192dac46f13b 100644
|
||||
--- a/cmake/Dependencies.cmake
|
||||
+++ b/cmake/Dependencies.cmake
|
||||
@@ -706,6 +706,7 @@ endif()
|
||||
|
||||
# ---[ FBGEMM
|
||||
if(USE_FBGEMM)
|
||||
+ if (NOT USE_SYSTEM_FBGEMM)
|
||||
set(CAFFE2_THIRD_PARTY_ROOT "${PROJECT_SOURCE_DIR}/third_party")
|
||||
if(NOT DEFINED FBGEMM_SOURCE_DIR)
|
||||
set(FBGEMM_SOURCE_DIR "${CAFFE2_THIRD_PARTY_ROOT}/fbgemm" CACHE STRING "FBGEMM source directory")
|
||||
@@ -746,7 +747,7 @@ if(USE_FBGEMM)
|
||||
target_compile_options_if_supported(asmjit -Wno-unused-but-set-variable)
|
||||
endif()
|
||||
endif()
|
||||
-
|
||||
+ endif()
|
||||
if(USE_FBGEMM)
|
||||
list(APPEND Caffe2_DEPENDENCY_LIBS fbgemm)
|
||||
endif()
|
||||
--
|
||||
2.45.1
|
||||
|
||||
149
next/0001-Add-cmake-variable-USE_ROCM_CK.patch
Normal file
149
next/0001-Add-cmake-variable-USE_ROCM_CK.patch
Normal file
|
|
@ -0,0 +1,149 @@
|
|||
From 4cc5d88dfe7a45ab245648dc874645d32a24b98b Mon Sep 17 00:00:00 2001
|
||||
From: Tom Rix <Tom.Rix@amd.com>
|
||||
Date: Fri, 27 Jun 2025 13:52:51 -0700
|
||||
Subject: [PATCH] Add cmake variable USE_ROCM_CK
|
||||
|
||||
---
|
||||
CMakeLists.txt | 1 +
|
||||
aten/src/ATen/CMakeLists.txt | 40 ++++++++++++++++-----------------
|
||||
aten/src/ATen/cuda/CUDABlas.cpp | 10 ++++-----
|
||||
cmake/Dependencies.cmake | 3 +++
|
||||
4 files changed, 29 insertions(+), 25 deletions(-)
|
||||
|
||||
diff --git a/CMakeLists.txt b/CMakeLists.txt
|
||||
index 99c0b9e0ea0c..4c632e42f531 100644
|
||||
--- a/CMakeLists.txt
|
||||
+++ b/CMakeLists.txt
|
||||
@@ -240,6 +240,7 @@ cmake_dependent_option(
|
||||
BUILD_LAZY_CUDA_LINALG "Build cuda linalg ops as separate library" ON
|
||||
"USE_CUDA AND LINUX AND BUILD_PYTHON" OFF)
|
||||
cmake_dependent_option(USE_ROCM "Use ROCm" ON "LINUX" OFF)
|
||||
+cmake_dependent_option(USE_ROCM_CK "Use ROCm Composable Kernel" ON "USE_ROCM" ON)
|
||||
option(CAFFE2_STATIC_LINK_CUDA "Statically link CUDA libraries" OFF)
|
||||
cmake_dependent_option(USE_CUDNN "Use cuDNN" ON "USE_CUDA" OFF)
|
||||
cmake_dependent_option(USE_STATIC_CUDNN "Use cuDNN static libraries" OFF
|
||||
diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt
|
||||
index c9cfd74b501e..59f6178218ee 100644
|
||||
--- a/aten/src/ATen/CMakeLists.txt
|
||||
+++ b/aten/src/ATen/CMakeLists.txt
|
||||
@@ -373,26 +373,26 @@ if(USE_ROCM)
|
||||
# is header only, so this should be ok, except that the CMake build generates
|
||||
# a ck/config.h. We just do that part here. Without this, the ck.h from the
|
||||
# ROCM SDK may get accidentally used instead.
|
||||
- function(_pytorch_rocm_generate_ck_conf)
|
||||
- set(CK_ENABLE_INT8 "ON")
|
||||
- set(CK_ENABLE_FP16 "ON")
|
||||
- set(CK_ENABLE_FP32 "ON")
|
||||
- set(CK_ENABLE_FP64 "ON")
|
||||
- set(CK_ENABLE_BF16 "ON")
|
||||
- set(CK_ENABLE_FP8 "ON")
|
||||
- set(CK_ENABLE_BF8 "ON")
|
||||
- set(CK_USE_XDL "ON")
|
||||
- set(CK_USE_WMMA "ON")
|
||||
- configure_file(
|
||||
- "${Torch_SOURCE_DIR}/third_party/composable_kernel/include/ck/config.h.in"
|
||||
- "${CMAKE_CURRENT_BINARY_DIR}/composable_kernel/ck/config.h"
|
||||
- )
|
||||
- endfunction()
|
||||
+# function(_pytorch_rocm_generate_ck_conf)
|
||||
+# set(CK_ENABLE_INT8 "ON")
|
||||
+# set(CK_ENABLE_FP16 "ON")
|
||||
+# set(CK_ENABLE_FP32 "ON")
|
||||
+# set(CK_ENABLE_FP64 "ON")
|
||||
+# set(CK_ENABLE_BF16 "ON")
|
||||
+# set(CK_ENABLE_FP8 "ON")
|
||||
+# set(CK_ENABLE_BF8 "ON")
|
||||
+# set(CK_USE_XDL "ON")
|
||||
+# set(CK_USE_WMMA "ON")
|
||||
+# configure_file(
|
||||
+# "${Torch_SOURCE_DIR}/third_party/composable_kernel/include/ck/config.h.in"
|
||||
+# "${CMAKE_CURRENT_BINARY_DIR}/composable_kernel/ck/config.h"
|
||||
+# )
|
||||
+# endfunction()
|
||||
list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/hip)
|
||||
- list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/include)
|
||||
- list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/library/include)
|
||||
- list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_BINARY_DIR}/composable_kernel)
|
||||
- _pytorch_rocm_generate_ck_conf()
|
||||
+# list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/include)
|
||||
+# list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/library/include)
|
||||
+# list(APPEND ATen_HIP_INCLUDE ${CMAKE_CURRENT_BINARY_DIR}/composable_kernel)
|
||||
+# _pytorch_rocm_generate_ck_conf()
|
||||
|
||||
# Next two lines are needed because TunableOp uses third-party/fmt
|
||||
list(APPEND ATen_HIP_INCLUDE $<TARGET_PROPERTY:fmt::fmt-header-only,INTERFACE_INCLUDE_DIRECTORIES>)
|
||||
@@ -409,7 +409,7 @@ endif()
|
||||
${native_quantized_hip_hip}
|
||||
${native_transformers_hip_hip} ${native_transformers_src_hip_hip}
|
||||
)
|
||||
- if(WIN32) # Windows doesn't support Composable Kernels
|
||||
+ if(NOT USE_ROCM_CK) # Windows doesn't support Composable Kernels
|
||||
file(GLOB native_hip_bgemm "native/hip/bgemm_kernels/*.hip")
|
||||
file(GLOB native_hip_ck "native/hip/ck*.hip")
|
||||
exclude(ATen_HIP_SRCS "${ATen_HIP_SRCS}"
|
||||
diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp
|
||||
index 89350a11bea7..33e5f2808057 100644
|
||||
--- a/aten/src/ATen/cuda/CUDABlas.cpp
|
||||
+++ b/aten/src/ATen/cuda/CUDABlas.cpp
|
||||
@@ -752,7 +752,7 @@ template <>
|
||||
void bgemm_internal<double>(CUDABLAS_BGEMM_ARGTYPES(double))
|
||||
{
|
||||
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
|
||||
-#ifdef USE_ROCM
|
||||
+#ifdef USE_ROCM_CK
|
||||
// hipblaslt does not support double gemm yet
|
||||
bgemm_internal_cublas<double>(CUDABLAS_BGEMM_ARGS(double));
|
||||
#else
|
||||
@@ -1103,7 +1103,7 @@ inline void gemm_internal_cublas_half_helper(CUDABLAS_GEMM_ARGTYPES_AND_C_DTYPE(
|
||||
void * beta_ptr = &fbeta;
|
||||
_cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
|
||||
GEMM_CHECK_ARGVALUES(at::Half);
|
||||
-#ifdef USE_ROCM
|
||||
+#ifdef USE_ROCM_CK
|
||||
int flag = 0;
|
||||
#if USE_GEMM_FLAGS_FP16_ALT_IMPL
|
||||
flag = at::ROCmBackwardPassGuard::is_backward_pass() ? rocblas_gemm_flags_fp16_alt_impl : 0;
|
||||
@@ -1270,7 +1270,7 @@ template <>
|
||||
void gemm_internal<double>(CUDABLAS_GEMM_ARGTYPES(double))
|
||||
{
|
||||
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
|
||||
-#ifdef USE_ROCM
|
||||
+#ifdef USE_ROCM_CK
|
||||
// hipblaslt does not support double gemm yet
|
||||
gemm_internal_cublas<double>(CUDABLAS_GEMM_ARGS(double));
|
||||
#else
|
||||
@@ -1311,7 +1311,7 @@ template <>
|
||||
void gemm_internal<c10::complex<double>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<double>))
|
||||
{
|
||||
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
|
||||
-#ifdef USE_ROCM
|
||||
+#ifdef USE_ROCM_CK
|
||||
// hipblaslt does not support complex gemm yet
|
||||
gemm_internal_cublas<c10::complex<double>>(CUDABLAS_GEMM_ARGS(c10::complex<double>));
|
||||
#else
|
||||
@@ -1327,7 +1327,7 @@ template <>
|
||||
void gemm_internal<c10::complex<float>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<float>))
|
||||
{
|
||||
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
|
||||
-#ifdef USE_ROCM
|
||||
+#ifdef USE_ROCM_CK
|
||||
// hipblaslt does not support complex gemm yet
|
||||
gemm_internal_cublas<c10::complex<float>>(CUDABLAS_GEMM_ARGS(c10::complex<float>));
|
||||
#else
|
||||
diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake
|
||||
index a93386c27f8d..be1368999d38 100644
|
||||
--- a/cmake/Dependencies.cmake
|
||||
+++ b/cmake/Dependencies.cmake
|
||||
@@ -1031,6 +1031,9 @@ if(USE_ROCM)
|
||||
if(HIPBLASLT_VEC_EXT)
|
||||
list(APPEND HIP_CXX_FLAGS -DHIPBLASLT_VEC_EXT)
|
||||
endif()
|
||||
+ if(USE_ROCM_CK)
|
||||
+ list(APPEND HIP_CXX_FLAGS -DUSE_ROCM_CK)
|
||||
+ endif()
|
||||
list(APPEND HIP_HIPCC_FLAGS --offload-compress)
|
||||
if(WIN32)
|
||||
add_definitions(-DROCM_ON_WINDOWS)
|
||||
--
|
||||
2.49.0
|
||||
|
||||
|
|
@ -1,506 +0,0 @@
|
|||
From f1d65e958afa65882dbfea8b392ab847a84d41ed Mon Sep 17 00:00:00 2001
|
||||
From: Tom Rix <trix@redhat.com>
|
||||
Date: Sat, 29 Jun 2024 04:18:34 -0700
|
||||
Subject: [PATCH] Optionally use hipblaslt
|
||||
|
||||
---
|
||||
aten/src/ATen/cuda/CUDABlas.cpp | 46 ++++++++++++++++++------
|
||||
aten/src/ATen/cuda/CUDAContextLight.h | 4 +++
|
||||
aten/src/ATen/cuda/CublasHandlePool.cpp | 10 ++++--
|
||||
aten/src/ATen/cuda/tunable/TunableGemm.h | 18 +++++++---
|
||||
aten/src/ATen/native/cuda/Blas.cpp | 18 +++++++++-
|
||||
cmake/Dependencies.cmake | 3 ++
|
||||
cmake/public/LoadHIP.cmake | 2 +-
|
||||
7 files changed, 82 insertions(+), 19 deletions(-)
|
||||
|
||||
diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp
|
||||
index ce991a9bcad4..3f0d17b52778 100644
|
||||
--- a/aten/src/ATen/cuda/CUDABlas.cpp
|
||||
+++ b/aten/src/ATen/cuda/CUDABlas.cpp
|
||||
@@ -14,7 +14,9 @@
|
||||
#include <c10/util/irange.h>
|
||||
|
||||
#ifdef USE_ROCM
|
||||
+#ifdef USE_HIPBLASLT
|
||||
#include <hipblaslt/hipblaslt-ext.hpp>
|
||||
+#endif
|
||||
// until hipblas has an API to accept flags, we must use rocblas here
|
||||
#include <hipblas/hipblas.h>
|
||||
#include <rocblas/rocblas.h>
|
||||
@@ -182,6 +184,9 @@ uint32_t _getAlignment(uintptr_t address) {
|
||||
static size_t _parseChosenWorkspaceSize() {
|
||||
const char * val = getenv("CUBLASLT_WORKSPACE_SIZE");
|
||||
#ifdef USE_ROCM
|
||||
+#ifndef USE_HIPBLASLT
|
||||
+ return 0;
|
||||
+#endif
|
||||
if (!val) {
|
||||
// accept either env var
|
||||
val = getenv("HIPBLASLT_WORKSPACE_SIZE");
|
||||
@@ -235,6 +240,7 @@ namespace at::cuda::blas {
|
||||
} while (0)
|
||||
|
||||
|
||||
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
|
||||
namespace {
|
||||
// Following the pattern of CuSparseDescriptor
|
||||
// Defined here for now because this is the only place cublas_lt interface is
|
||||
@@ -318,7 +324,6 @@ class CuBlasLtMatmulPreference : public CuBlasLtDescriptor<
|
||||
};
|
||||
} // namespace
|
||||
|
||||
-
|
||||
template <typename Dtype>
|
||||
inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
|
||||
cudaDataType_t abcType = CUDA_R_32F;
|
||||
@@ -452,7 +457,7 @@ inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
|
||||
" scaleType ",
|
||||
scaleType);
|
||||
}
|
||||
-
|
||||
+#endif
|
||||
|
||||
template <typename Dtype>
|
||||
inline void bgemm_internal_cublas(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
|
||||
@@ -608,10 +613,13 @@ void bgemm_internal<double>(CUDABLAS_BGEMM_ARGTYPES(double))
|
||||
template <>
|
||||
void bgemm_internal<float>(CUDABLAS_BGEMM_ARGTYPES(float))
|
||||
{
|
||||
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
|
||||
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
|
||||
bgemm_internal_cublaslt<float>(CUDABLAS_BGEMM_ARGS(float));
|
||||
}
|
||||
- else {
|
||||
+ else
|
||||
+#endif
|
||||
+ {
|
||||
bgemm_internal_cublas<float>(CUDABLAS_BGEMM_ARGS(float));
|
||||
}
|
||||
}
|
||||
@@ -651,10 +659,13 @@ void bgemm_internal<c10::complex<float>>(CUDABLAS_BGEMM_ARGTYPES(c10::complex<fl
|
||||
template <>
|
||||
void bgemm_internal<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half))
|
||||
{
|
||||
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
|
||||
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
|
||||
bgemm_internal_cublaslt<at::Half>(CUDABLAS_BGEMM_ARGS(at::Half));
|
||||
}
|
||||
- else {
|
||||
+ else
|
||||
+#endif
|
||||
+ {
|
||||
bgemm_internal_cublas<at::Half>(CUDABLAS_BGEMM_ARGS(at::Half));
|
||||
}
|
||||
}
|
||||
@@ -662,10 +673,13 @@ void bgemm_internal<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half))
|
||||
template <>
|
||||
void bgemm_internal<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16))
|
||||
{
|
||||
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
|
||||
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
|
||||
bgemm_internal_cublaslt<at::BFloat16>(CUDABLAS_BGEMM_ARGS(at::BFloat16));
|
||||
}
|
||||
- else {
|
||||
+ else
|
||||
+#endif
|
||||
+ {
|
||||
bgemm_internal_cublas<at::BFloat16>(CUDABLAS_BGEMM_ARGS(at::BFloat16));
|
||||
}
|
||||
}
|
||||
@@ -781,11 +795,13 @@ void bgemm<at::BFloat16>(CUDABLAS_BGEMM_ARGTYPES(at::BFloat16)) {
|
||||
}
|
||||
}
|
||||
|
||||
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
|
||||
template <typename Dtype>
|
||||
inline void gemm_internal_cublaslt(CUDABLAS_GEMM_ARGTYPES(Dtype)) {
|
||||
// forward to bgemm implementation but set strides and batches to 0
|
||||
bgemm_internal_cublaslt(transa, transb, m, n, k, alpha, a, lda, 0, b, ldb, 0, beta, c, ldc, 0, 0);
|
||||
}
|
||||
+#endif
|
||||
|
||||
template <typename Dtype>
|
||||
inline void gemm_internal_cublas(CUDABLAS_GEMM_ARGTYPES(Dtype)) {
|
||||
@@ -1008,10 +1024,13 @@ void gemm_internal<double>(CUDABLAS_GEMM_ARGTYPES(double))
|
||||
template <>
|
||||
void gemm_internal<float>(CUDABLAS_GEMM_ARGTYPES(float))
|
||||
{
|
||||
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
|
||||
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
|
||||
gemm_internal_cublaslt<float>(CUDABLAS_GEMM_ARGS(float));
|
||||
}
|
||||
- else {
|
||||
+ else
|
||||
+#endif
|
||||
+ {
|
||||
gemm_internal_cublas<float>(CUDABLAS_GEMM_ARGS(float));
|
||||
}
|
||||
}
|
||||
@@ -1051,10 +1070,13 @@ void gemm_internal<c10::complex<float>>(CUDABLAS_GEMM_ARGTYPES(c10::complex<floa
|
||||
template <>
|
||||
void gemm_internal<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half))
|
||||
{
|
||||
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
|
||||
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
|
||||
gemm_internal_cublaslt<at::Half>(CUDABLAS_GEMM_ARGS(at::Half));
|
||||
}
|
||||
- else {
|
||||
+ else
|
||||
+#endif
|
||||
+ {
|
||||
gemm_internal_cublas<at::Half>(CUDABLAS_GEMM_ARGS(at::Half));
|
||||
}
|
||||
}
|
||||
@@ -1062,10 +1084,13 @@ void gemm_internal<at::Half>(CUDABLAS_GEMM_ARGTYPES(at::Half))
|
||||
template <>
|
||||
void gemm_internal<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16))
|
||||
{
|
||||
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
|
||||
if (at::globalContext().blasPreferredBackend() == BlasBackend::Cublaslt) {
|
||||
gemm_internal_cublaslt<at::BFloat16>(CUDABLAS_GEMM_ARGS(at::BFloat16));
|
||||
}
|
||||
- else {
|
||||
+ else
|
||||
+#endif
|
||||
+ {
|
||||
gemm_internal_cublas<at::BFloat16>(CUDABLAS_GEMM_ARGS(at::BFloat16));
|
||||
}
|
||||
}
|
||||
@@ -1177,7 +1202,7 @@ void gemm<at::BFloat16>(CUDABLAS_GEMM_ARGTYPES(at::BFloat16)) {
|
||||
}
|
||||
}
|
||||
|
||||
-
|
||||
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
|
||||
template <typename Dtype>
|
||||
void gemm_and_bias(
|
||||
bool transpose_mat1,
|
||||
@@ -1410,7 +1435,7 @@ void scaled_gemm(
|
||||
ScalarType result_dtype,
|
||||
void* amax_ptr,
|
||||
bool use_fast_accum) {
|
||||
-#if CUDA_VERSION >= 11080 || defined(USE_ROCM)
|
||||
+#if CUDA_VERSION >= 11080 || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
|
||||
const auto computeType = CUBLAS_COMPUTE_32F;
|
||||
const auto scaleType = CUDA_R_32F;
|
||||
const int8_t fastAccuMode = use_fast_accum ? 1 : 0;
|
||||
@@ -1681,6 +1706,7 @@ void int8_gemm(
|
||||
" scaleType ",
|
||||
scaleType);
|
||||
}
|
||||
+#endif
|
||||
|
||||
template <>
|
||||
void trsm<float>(CUDABLAS_TRSM_ARGTYPES(float)) {
|
||||
diff --git a/aten/src/ATen/cuda/CUDAContextLight.h b/aten/src/ATen/cuda/CUDAContextLight.h
|
||||
index f2b657ced51b..f0ee613c4208 100644
|
||||
--- a/aten/src/ATen/cuda/CUDAContextLight.h
|
||||
+++ b/aten/src/ATen/cuda/CUDAContextLight.h
|
||||
@@ -9,7 +9,9 @@
|
||||
|
||||
// cublasLT was introduced in CUDA 10.1 but we enable only for 11.1 that also
|
||||
// added bf16 support
|
||||
+#if (!defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT)))
|
||||
#include <cublasLt.h>
|
||||
+#endif
|
||||
|
||||
#ifdef CUDART_VERSION
|
||||
#include <cusolverDn.h>
|
||||
@@ -80,7 +82,9 @@ TORCH_CUDA_CPP_API c10::Allocator* getCUDADeviceAllocator();
|
||||
/* Handles */
|
||||
TORCH_CUDA_CPP_API cusparseHandle_t getCurrentCUDASparseHandle();
|
||||
TORCH_CUDA_CPP_API cublasHandle_t getCurrentCUDABlasHandle();
|
||||
+#if (!defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT)))
|
||||
TORCH_CUDA_CPP_API cublasLtHandle_t getCurrentCUDABlasLtHandle();
|
||||
+#endif
|
||||
|
||||
TORCH_CUDA_CPP_API void clearCublasWorkspaces();
|
||||
|
||||
diff --git a/aten/src/ATen/cuda/CublasHandlePool.cpp b/aten/src/ATen/cuda/CublasHandlePool.cpp
|
||||
index 8eac525b3695..abfdf7a23847 100644
|
||||
--- a/aten/src/ATen/cuda/CublasHandlePool.cpp
|
||||
+++ b/aten/src/ATen/cuda/CublasHandlePool.cpp
|
||||
@@ -29,7 +29,7 @@ namespace at::cuda {
|
||||
|
||||
namespace {
|
||||
|
||||
-#if defined(USE_ROCM)
|
||||
+#if defined(USE_ROCM) && defined(USE_HIPBLASLT)
|
||||
void createCublasLtHandle(cublasLtHandle_t *handle) {
|
||||
TORCH_CUDABLAS_CHECK(cublasLtCreate(handle));
|
||||
}
|
||||
@@ -191,8 +191,9 @@ cublasHandle_t getCurrentCUDABlasHandle() {
|
||||
return handle;
|
||||
}
|
||||
|
||||
-cublasLtHandle_t getCurrentCUDABlasLtHandle() {
|
||||
#ifdef USE_ROCM
|
||||
+#if defined(USE_HIPBLASLT)
|
||||
+cublasLtHandle_t getCurrentCUDABlasLtHandle() {
|
||||
c10::DeviceIndex device = 0;
|
||||
AT_CUDA_CHECK(c10::cuda::GetDevice(&device));
|
||||
|
||||
@@ -213,9 +214,12 @@ cublasLtHandle_t getCurrentCUDABlasLtHandle() {
|
||||
|
||||
auto handle = myPoolWindow->reserve(device);
|
||||
return handle;
|
||||
+}
|
||||
+#endif
|
||||
#else
|
||||
+cublasLtHandle_t getCurrentCUDABlasLtHandle() {
|
||||
return reinterpret_cast<cublasLtHandle_t>(getCurrentCUDABlasHandle());
|
||||
-#endif
|
||||
}
|
||||
+#endif
|
||||
|
||||
} // namespace at::cuda
|
||||
diff --git a/aten/src/ATen/cuda/tunable/TunableGemm.h b/aten/src/ATen/cuda/tunable/TunableGemm.h
|
||||
index 53e6154120c9..fa1d664696db 100644
|
||||
--- a/aten/src/ATen/cuda/tunable/TunableGemm.h
|
||||
+++ b/aten/src/ATen/cuda/tunable/TunableGemm.h
|
||||
@@ -11,7 +11,9 @@
|
||||
|
||||
#include <ATen/cuda/tunable/GemmCommon.h>
|
||||
#ifdef USE_ROCM
|
||||
+#ifdef USE_HIPBLASLT
|
||||
#include <ATen/cuda/tunable/GemmHipblaslt.h>
|
||||
+#endif
|
||||
#include <ATen/cuda/tunable/GemmRocblas.h>
|
||||
#endif
|
||||
#include <ATen/cuda/tunable/StreamTimer.h>
|
||||
@@ -65,6 +67,7 @@ class DefaultGemmStridedBatchedOp : public Callable<GemmStridedBatchedParams<T>>
|
||||
}
|
||||
};
|
||||
|
||||
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
|
||||
template <typename T>
|
||||
class DefaultScaledGemmOp : public Callable<ScaledGemmParams<T>> {
|
||||
public:
|
||||
@@ -94,6 +97,7 @@ class DefaultScaledGemmOp : public Callable<ScaledGemmParams<T>> {
|
||||
return OK;
|
||||
}
|
||||
};
|
||||
+#endif
|
||||
|
||||
template <typename T>
|
||||
inline bool IsZero(T v) {
|
||||
@@ -191,6 +195,7 @@ static void AddRocblasValidator() {
|
||||
}
|
||||
}
|
||||
|
||||
+#ifdef USE_HIPBLASLT
|
||||
static void AddHipblasltValidator() {
|
||||
auto validators = getTuningContext()->GetTuningResultsValidator().GetAllValidators();
|
||||
if (validators.find("HIPBLASLT_VERSION") == validators.end()) {
|
||||
@@ -205,6 +210,7 @@ static void AddHipblasltValidator() {
|
||||
[hipblaslt_version](auto&& k) { return hipblaslt_version == k ? OK : FAIL; });
|
||||
}
|
||||
}
|
||||
+#endif
|
||||
|
||||
static void AddRocmValidator() {
|
||||
auto validators = getTuningContext()->GetTuningResultsValidator().GetAllValidators();
|
||||
@@ -243,7 +249,7 @@ class GemmTunableOp : public TunableOp<GemmParams<T>, StreamTimer> {
|
||||
}
|
||||
AddRocblasValidator();
|
||||
}
|
||||
-
|
||||
+#ifdef USE_HIPBLASLT
|
||||
static const char *env_hipblaslt = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED");
|
||||
if (env_hipblaslt == nullptr || strcmp(env_hipblaslt, "1") == 0) {
|
||||
rocm_validators = true;
|
||||
@@ -257,7 +263,7 @@ class GemmTunableOp : public TunableOp<GemmParams<T>, StreamTimer> {
|
||||
}
|
||||
AddHipblasltValidator();
|
||||
}
|
||||
-
|
||||
+#endif
|
||||
if (rocm_validators) {
|
||||
AddRocmValidator();
|
||||
}
|
||||
@@ -286,7 +292,7 @@ class GemmStridedBatchedTunableOp : public TunableOp<GemmStridedBatchedParams<T>
|
||||
}
|
||||
AddRocblasValidator();
|
||||
}
|
||||
-
|
||||
+#ifdef USE_HIPBLASLT
|
||||
static const char *env_hipblaslt = std::getenv("PYTORCH_TUNABLEOP_HIPBLASLT_ENABLED");
|
||||
if (env_hipblaslt == nullptr || strcmp(env_hipblaslt, "1") == 0) {
|
||||
rocm_validators = true;
|
||||
@@ -300,7 +306,7 @@ class GemmStridedBatchedTunableOp : public TunableOp<GemmStridedBatchedParams<T>
|
||||
}
|
||||
AddHipblasltValidator();
|
||||
}
|
||||
-
|
||||
+#endif
|
||||
if (rocm_validators) {
|
||||
AddRocmValidator();
|
||||
}
|
||||
@@ -312,6 +318,7 @@ class GemmStridedBatchedTunableOp : public TunableOp<GemmStridedBatchedParams<T>
|
||||
}
|
||||
};
|
||||
|
||||
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
|
||||
template <typename AT, typename BT, typename CT, BlasOp ALayout, BlasOp BLayout>
|
||||
class ScaledGemmTunableOp : public TunableOp<ScaledGemmParams<CT>, StreamTimer> {
|
||||
public:
|
||||
@@ -321,10 +328,12 @@ class ScaledGemmTunableOp : public TunableOp<ScaledGemmParams<CT>, StreamTimer>
|
||||
auto validators = getTuningContext()->GetTuningResultsValidator().GetAllValidators();
|
||||
|
||||
#if defined(USE_ROCM)
|
||||
+#ifdef USE_HIPBLASLT
|
||||
for (auto&& [name, op] : GetHipBlasLtScaledGemmTypeStringAndOps<AT, BT, CT, ALayout, BLayout>()) {
|
||||
this->RegisterOp(std::move(name), std::move(op));
|
||||
}
|
||||
AddHipblasltValidator();
|
||||
+#endif
|
||||
AddRocmValidator();
|
||||
#endif
|
||||
}
|
||||
@@ -337,6 +346,7 @@ class ScaledGemmTunableOp : public TunableOp<ScaledGemmParams<CT>, StreamTimer>
|
||||
"_", BlasOpToString(ALayout), BlasOpToString(BLayout));
|
||||
}
|
||||
};
|
||||
+#endif
|
||||
|
||||
#undef XSTRINGIFY
|
||||
#undef STRINGIFY
|
||||
diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp
|
||||
index 84c59a4fd0d7..56ad5de3bf2d 100644
|
||||
--- a/aten/src/ATen/native/cuda/Blas.cpp
|
||||
+++ b/aten/src/ATen/native/cuda/Blas.cpp
|
||||
@@ -173,6 +173,7 @@ cuda::blas::GEMMAndBiasActivationEpilogue activation_to_gemm_and_blas_arg(Activa
|
||||
}
|
||||
|
||||
static bool getDisableAddmmCudaLt() {
|
||||
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
|
||||
static const char* env_value = std::getenv("DISABLE_ADDMM_CUDA_LT");
|
||||
#ifdef USE_ROCM
|
||||
// if we enable tunable op, it'll take priority over just hipblaslt (heuristics)
|
||||
@@ -196,10 +197,14 @@ static bool getDisableAddmmCudaLt() {
|
||||
}
|
||||
return false;
|
||||
#endif
|
||||
+#else
|
||||
+ return true;
|
||||
+#endif
|
||||
}
|
||||
|
||||
#ifdef USE_ROCM
|
||||
static bool isSupportedHipLtROCmArch(int index) {
|
||||
+#ifdef USE_HIPBLASLT
|
||||
hipDeviceProp_t* prop = at::cuda::getDeviceProperties(index);
|
||||
std::string device_arch = prop->gcnArchName;
|
||||
static const std::vector<std::string> archs = {"gfx90a", "gfx940", "gfx941", "gfx942"};
|
||||
@@ -210,6 +215,7 @@ static bool isSupportedHipLtROCmArch(int index) {
|
||||
}
|
||||
}
|
||||
TORCH_CHECK(false, "Attempting to use hipBLASLt on a unsupported architecture!");
|
||||
+#endif
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
@@ -235,6 +241,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
|
||||
at::ScalarType scalar_type = self.scalar_type();
|
||||
c10::MaybeOwned<Tensor> self_;
|
||||
if (&result != &self) {
|
||||
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
|
||||
#if (defined(CUDA_VERSION) && (CUDA_VERSION >= 11040)) || defined(USE_ROCM)
|
||||
// Strangely, if mat2 has only 1 row or column, we get
|
||||
// CUBLAS_STATUS_INVALID_VALUE error from cublasLtMatmulAlgoGetHeuristic.
|
||||
@@ -276,13 +283,14 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
|
||||
scalar_type != at::ScalarType::BFloat16));
|
||||
#endif
|
||||
}
|
||||
+#endif
|
||||
#endif
|
||||
if (!useLtInterface) {
|
||||
self_ = expand_size(self, {mat1_sizes[0], mat2_sizes[1]}, "addmm");
|
||||
}
|
||||
self__sizes = self_->sizes();
|
||||
} else {
|
||||
-#if defined(USE_ROCM)
|
||||
+#if defined(USE_ROCM) && defined(USE_HIPBLASLT)
|
||||
useLtInterface = !disable_addmm_cuda_lt &&
|
||||
result.dim() == 2 && result.is_contiguous() &&
|
||||
isSupportedHipLtROCmArch(self.device().index()) &&
|
||||
@@ -334,6 +342,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!args.result->is_conj());
|
||||
|
||||
if (useLtInterface) {
|
||||
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
|
||||
#if defined(USE_ROCM)
|
||||
AT_DISPATCH_FLOATING_TYPES_AND2(
|
||||
at::ScalarType::Half,
|
||||
@@ -394,6 +403,7 @@ Tensor& addmm_out_cuda_impl(Tensor& result, const Tensor& self, const Tensor& ma
|
||||
activation_epilogue
|
||||
);
|
||||
});
|
||||
+#endif
|
||||
#endif
|
||||
} else
|
||||
{
|
||||
@@ -803,6 +813,7 @@ Tensor _int_mm_cuda(const Tensor& self, const Tensor& mat2) {
|
||||
}
|
||||
|
||||
static bool _scaled_mm_allowed_device() {
|
||||
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
|
||||
auto dprops = at::cuda::getCurrentDeviceProperties();
|
||||
#ifdef USE_ROCM
|
||||
std::string device_arch = dprops->gcnArchName;
|
||||
@@ -817,6 +828,9 @@ static bool _scaled_mm_allowed_device() {
|
||||
#else
|
||||
return dprops->major >= 9 || (dprops->major == 8 && dprops->minor == 9);
|
||||
#endif
|
||||
+#else
|
||||
+ return false;
|
||||
+#endif
|
||||
}
|
||||
|
||||
// Computes matrix multiply + bias while applying scaling to input and output matrices and computes amax
|
||||
@@ -850,6 +864,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
|
||||
// Check sizes
|
||||
bool allowed_device = _scaled_mm_allowed_device();
|
||||
TORCH_CHECK(allowed_device, "torch._scaled_mm is only supported on CUDA devices with compute capability >= 9.0 or 8.9, or ROCm MI300+");
|
||||
+#if !defined(USE_ROCM) || (defined(USE_ROCM) && defined(USE_HIPBLASLT))
|
||||
TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix");
|
||||
TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix");
|
||||
TORCH_CHECK(
|
||||
@@ -1025,6 +1040,7 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
|
||||
#if defined(USE_ROCM) && ROCM_VERSION >= 60000 && ROCM_VERSION < 60200
|
||||
// ROCm's hipBLASLt does not support amax before 6.2, so calculate separately
|
||||
amax = at::max(at::abs(out.to(kFloat)));
|
||||
+#endif
|
||||
#endif
|
||||
|
||||
return {out, amax};
|
||||
diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake
|
||||
index f1f2eb7cec31..8d05e834bbc5 100644
|
||||
--- a/cmake/Dependencies.cmake
|
||||
+++ b/cmake/Dependencies.cmake
|
||||
@@ -1052,6 +1052,9 @@ if(USE_ROCM)
|
||||
list(APPEND HIP_CXX_FLAGS -DTHRUST_DEVICE_SYSTEM=THRUST_DEVICE_SYSTEM_HIP)
|
||||
list(APPEND HIP_CXX_FLAGS -std=c++17)
|
||||
list(APPEND HIP_CXX_FLAGS -DHIPBLAS_V2)
|
||||
+ if(hipblast_FOUND)
|
||||
+ list(APPEND HIP_CXX_FLAGS -DUSE_HIPBLASLT)
|
||||
+ endif()
|
||||
if(HIP_NEW_TYPE_ENUMS)
|
||||
list(APPEND HIP_CXX_FLAGS -DHIP_NEW_TYPE_ENUMS)
|
||||
endif()
|
||||
diff --git a/cmake/public/LoadHIP.cmake b/cmake/public/LoadHIP.cmake
|
||||
index fa39156031ff..df4836847fdf 100644
|
||||
--- a/cmake/public/LoadHIP.cmake
|
||||
+++ b/cmake/public/LoadHIP.cmake
|
||||
@@ -155,7 +155,7 @@ if(HIP_FOUND)
|
||||
find_package_and_print_version(hiprand REQUIRED)
|
||||
find_package_and_print_version(rocblas REQUIRED)
|
||||
find_package_and_print_version(hipblas REQUIRED)
|
||||
- find_package_and_print_version(hipblaslt REQUIRED)
|
||||
+ find_package_and_print_version(hipblaslt)
|
||||
find_package_and_print_version(miopen REQUIRED)
|
||||
find_package_and_print_version(hipfft REQUIRED)
|
||||
find_package_and_print_version(hipsparse REQUIRED)
|
||||
--
|
||||
2.45.2
|
||||
|
||||
|
|
@ -1,94 +0,0 @@
|
|||
From 038ce9e44776e23f21c1816daa259bc0ea335088 Mon Sep 17 00:00:00 2001
|
||||
From: Tom Rix <trix@redhat.com>
|
||||
Date: Sat, 29 Jun 2024 07:06:09 -0700
|
||||
Subject: [PATCH] disable use of aotriton
|
||||
|
||||
---
|
||||
.../ATen/native/transformers/cuda/sdp_utils.cpp | 17 +++++++++++++++--
|
||||
1 file changed, 15 insertions(+), 2 deletions(-)
|
||||
|
||||
diff --git a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp
|
||||
index 214b02d8262e..7b3eb9dcd8cd 100644
|
||||
--- a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp
|
||||
+++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp
|
||||
@@ -19,9 +19,12 @@
|
||||
#include <c10/core/SymInt.h>
|
||||
#include <c10/util/string_view.h>
|
||||
|
||||
+#ifdef USE_FLASH_ATTENTION
|
||||
#if USE_ROCM
|
||||
#include <aotriton/flash.h>
|
||||
#endif
|
||||
+#endif
|
||||
+
|
||||
|
||||
/**
|
||||
* Note [SDPA Runtime Dispatch]
|
||||
@@ -182,6 +185,9 @@ bool check_sm_version(cudaDeviceProp * dprops) {
|
||||
|
||||
bool check_flash_attention_hardware_support(sdp_params const& params, bool debug) {
|
||||
// Check that the gpu is capable of running flash attention
|
||||
+#ifndef USE_FLASH_ATTENTION
|
||||
+ return false;
|
||||
+#else
|
||||
using sm80 = SMVersion<8, 0>;
|
||||
using sm90 = SMVersion<9, 0>;
|
||||
#if USE_ROCM
|
||||
@@ -209,9 +215,13 @@ bool check_flash_attention_hardware_support(sdp_params const& params, bool debug
|
||||
}
|
||||
#endif
|
||||
return true;
|
||||
+#endif
|
||||
}
|
||||
|
||||
bool check_mem_efficient_hardware_support(sdp_params const& params, bool debug) {
|
||||
+#ifndef USE_FLASH_ATTENTION
|
||||
+ return false;
|
||||
+#else
|
||||
// Mem Efficient attention supports hardware in the range [sm_50, sm_90]
|
||||
using sm50 = SMVersion<5, 0>;
|
||||
using sm90 = SMVersion<9, 0>;
|
||||
@@ -240,6 +250,7 @@ bool check_mem_efficient_hardware_support(sdp_params const& params, bool debug)
|
||||
}
|
||||
#endif
|
||||
return true;
|
||||
+#endif
|
||||
}
|
||||
|
||||
bool check_requires_grad_and_head_dim_gt192_constraints_on_sm86_89(
|
||||
@@ -554,7 +565,7 @@ bool can_use_flash_attention(sdp_params const& params, bool debug) {
|
||||
#ifndef USE_FLASH_ATTENTION
|
||||
TORCH_WARN_ONCE(!debug, "Torch was not compiled with flash attention.");
|
||||
return false;
|
||||
-#endif
|
||||
+#else
|
||||
|
||||
// Define gate functions that determine if a flash kernel can be ran
|
||||
// Replace with std::to_array when we migrate to c++20
|
||||
@@ -597,13 +608,14 @@ bool can_use_flash_attention(sdp_params const& params, bool debug) {
|
||||
}
|
||||
}
|
||||
return true;
|
||||
+#endif
|
||||
}
|
||||
|
||||
bool can_use_mem_efficient_attention(sdp_params const& params, bool debug) {
|
||||
#ifndef USE_MEM_EFF_ATTENTION
|
||||
TORCH_WARN_ONCE(!debug, "Torch was not compiled with memory efficient attention.");
|
||||
return false;
|
||||
-#endif
|
||||
+#else
|
||||
// Constraints specific to mem efficient attention
|
||||
constexpr auto greater_than_or_equal_sm80_mem_efficient_dtypes =
|
||||
array_of<at::ScalarType>(at::kHalf, at::kFloat, at::kBFloat16);
|
||||
@@ -663,6 +675,7 @@ bool can_use_mem_efficient_attention(sdp_params const& params, bool debug) {
|
||||
}
|
||||
#endif
|
||||
return check_tensor_dtype(params, less_than_sm80_mem_efficient_dtypes, debug);
|
||||
+#endif
|
||||
}
|
||||
|
||||
SDPBackend select_sdp_backend(sdp_params const& kernel_params) {
|
||||
--
|
||||
2.45.2
|
||||
|
||||
|
|
@ -6,10 +6,10 @@
|
|||
# So pre releases can be tried
|
||||
%bcond_with gitcommit
|
||||
%if %{with gitcommit}
|
||||
# v2.8.0-rc3
|
||||
%global commit0 3d53a53e504089a52a149791fd33d7fc898bd055
|
||||
# v2.8.0-rc6
|
||||
%global commit0 f2b69a083d15e3d0083bb304302a3fd0b5fb8705
|
||||
%global shortcommit0 %(c=%{commit0}; echo ${c:0:7})
|
||||
%global date0 20250625
|
||||
%global date0 20250718
|
||||
%global pypi_version 2.8.0
|
||||
%global flatbuffers_version 24.12.23
|
||||
%global miniz_version 3.0.2
|
||||
|
|
@ -357,11 +357,9 @@ sed -i -e 's@list(APPEND Caffe2_DEPENDENCY_LIBS fmt::fmt-header-only)@#list(APPE
|
|||
sed -i -e 's@if(NOT TARGET fxdiv)@if(MSVC AND USE_XNNPACK)@' caffe2/CMakeLists.txt
|
||||
sed -i -e 's@TARGET_LINK_LIBRARIES(torch_cpu PRIVATE fxdiv)@#TARGET_LINK_LIBRARIES(torch_cpu PRIVATE fxdiv)@' caffe2/CMakeLists.txt
|
||||
|
||||
%if %{without gitcommit}
|
||||
# https://github.com/pytorch/pytorch/issues/149803
|
||||
# Tries to checkout nccl
|
||||
sed -i -e 's@ checkout_nccl()@# checkout_nccl()@' tools/build_pytorch_libs.py
|
||||
%endif
|
||||
sed -i -e 's@ checkout_nccl()@ True@' tools/build_pytorch_libs.py
|
||||
|
||||
# Disable the use of check_submodule's in the setup.py, we are a tarball, not a git repo
|
||||
sed -i -e 's@check_submodules()$@#check_submodules()@' setup.py
|
||||
|
|
@ -541,6 +539,7 @@ export USE_SYSTEM_EIGEN_INSTALL=ON
|
|||
export USE_SYSTEM_ONNX=ON
|
||||
export USE_SYSTEM_PYBIND11=OFF
|
||||
export USE_SYSTEM_LIBS=OFF
|
||||
export USE_SYSTEM_NCCL=OFF
|
||||
export USE_TENSORPIPE=OFF
|
||||
export USE_XNNPACK=OFF
|
||||
export USE_XPU=OFF
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue