python-torch/0001-Fix-compilation-and-import-torch-issues-for-cpython-.patch
Tom Rix cec8b79644 Update to 2.8.0-rc8
Signed-off-by: Tom Rix <Tom.Rix@amd.com>
2025-07-31 05:52:50 -07:00

359 lines
12 KiB
Diff

From f2a544b2e3a5bdc04985f6e06223c0c1700120a0 Mon Sep 17 00:00:00 2001
From: albanD <desmaison.alban@gmail.com>
Date: Sat, 12 Jul 2025 03:42:33 -0400
Subject: [PATCH] Fix compilation and "import torch" issues for cpython 3.14
Imported from
https://github.com/albanD/pytorch/tree/cpython314_build
commit 88bb9cdb72449f4277829e20d94ad8aec1894216
Signed-off-by: Tom Rix <Tom.Rix@amd.com>
---
torch/_dynamo/bytecode_analysis.py | 2 +-
torch/ao/quantization/__init__.py | 5 +++-
torch/ao/quantization/qconfig.py | 4 ++-
torch/ao/quantization/utils.py | 7 +++--
torch/csrc/dynamo/cpython_defs.c | 16 +++++++++++
torch/csrc/dynamo/cpython_includes.h | 17 ++++++++++++
torch/csrc/dynamo/eval_frame.c | 34 +++++++++++++++--------
torch/csrc/dynamo/framelocals_mapping.cpp | 14 ++++++++++
torch/csrc/utils/python_compat.h | 1 +
torch/onnx/__init__.py | 1 -
torch/utils/weak.py | 29 +++++++++++++++++--
11 files changed, 111 insertions(+), 19 deletions(-)
diff --git a/torch/_dynamo/bytecode_analysis.py b/torch/_dynamo/bytecode_analysis.py
index 3252ea91409f..2de74ee5bf8d 100644
--- a/torch/_dynamo/bytecode_analysis.py
+++ b/torch/_dynamo/bytecode_analysis.py
@@ -33,7 +33,7 @@ if sys.version_info >= (3, 11):
TERMINAL_OPCODES.add(dis.opmap["JUMP_FORWARD"])
else:
TERMINAL_OPCODES.add(dis.opmap["JUMP_ABSOLUTE"])
-if sys.version_info >= (3, 12):
+if (3, 12) <= sys.version_info < (3, 14):
TERMINAL_OPCODES.add(dis.opmap["RETURN_CONST"])
if sys.version_info >= (3, 13):
TERMINAL_OPCODES.add(dis.opmap["JUMP_BACKWARD_NO_INTERRUPT"])
diff --git a/torch/ao/quantization/__init__.py b/torch/ao/quantization/__init__.py
index ffc1792fd23f..cf5a8b99a894 100644
--- a/torch/ao/quantization/__init__.py
+++ b/torch/ao/quantization/__init__.py
@@ -1,5 +1,6 @@
# mypy: allow-untyped-defs
+import sys
from typing import Callable, Optional, Union
import torch
@@ -33,7 +34,9 @@ from .stubs import * # noqa: F403
# ensure __module__ is set correctly for public APIs
ObserverOrFakeQuantize = Union[ObserverBase, FakeQuantizeBase]
-ObserverOrFakeQuantize.__module__ = "torch.ao.quantization"
+if sys.version_info < (3, 14):
+ ObserverOrFakeQuantize.__module__ = "torch.ao.quantization"
+
for _f in [
compare_results,
extract_results_from_loggers,
diff --git a/torch/ao/quantization/qconfig.py b/torch/ao/quantization/qconfig.py
index efee5302ad42..d9a8fc78bab4 100644
--- a/torch/ao/quantization/qconfig.py
+++ b/torch/ao/quantization/qconfig.py
@@ -1,5 +1,6 @@
# mypy: allow-untyped-defs
import copy
+import sys
import warnings
from collections import namedtuple
from typing import Any, Optional, Union
@@ -568,7 +569,8 @@ def _assert_valid_qconfig(qconfig: Optional[QConfig], mod: torch.nn.Module) -> N
QConfigAny = Optional[QConfig]
-QConfigAny.__module__ = "torch.ao.quantization.qconfig"
+if sys.version_info < (3, 14):
+ QConfigAny.__module__ = "torch.ao.quantization.qconfig"
def _add_module_to_qconfig_obs_ctr(
diff --git a/torch/ao/quantization/utils.py b/torch/ao/quantization/utils.py
index 4ac3112ec072..3b1503e01701 100644
--- a/torch/ao/quantization/utils.py
+++ b/torch/ao/quantization/utils.py
@@ -4,6 +4,7 @@ Utils shared by different modes of quantization (eager/graph)
"""
import functools
+import sys
import warnings
from collections import OrderedDict
from inspect import getfullargspec, signature
@@ -16,7 +17,8 @@ from torch.nn.utils.parametrize import is_parametrized
NodePattern = Union[tuple[Node, Node], tuple[Node, tuple[Node, Node]], Any]
-NodePattern.__module__ = "torch.ao.quantization.utils"
+if sys.version_info < (3, 14):
+ NodePattern.__module__ = "torch.ao.quantization.utils"
# This is the Quantizer class instance from torch/quantization/fx/quantize.py.
# Define separately to prevent circular imports.
@@ -31,7 +33,8 @@ QuantizerCls = Any
Pattern = Union[
Callable, tuple[Callable, Callable], tuple[Callable, tuple[Callable, Callable]], Any
]
-Pattern.__module__ = "torch.ao.quantization.utils"
+if sys.version_info < (3, 14):
+ Pattern.__module__ = "torch.ao.quantization.utils"
# TODO: maybe rename this to MatchInputNode
diff --git a/torch/csrc/dynamo/cpython_defs.c b/torch/csrc/dynamo/cpython_defs.c
index b68ef894aeaa..244d4165d5e8 100644
--- a/torch/csrc/dynamo/cpython_defs.c
+++ b/torch/csrc/dynamo/cpython_defs.c
@@ -2,6 +2,20 @@
#include <torch/csrc/dynamo/cpython_includes.h>
#include <torch/csrc/dynamo/debug_macros.h>
+#if IS_PYTHON_3_14_PLUS
+
+const uint8_t* THP_PyOpcode_Caches = NULL;
+const int THP_PyOpcode_Caches_size = 0;
+
+void
+THP_PyThreadState_PopFrame(PyThreadState *tstate, _PyInterpreterFrame * frame)
+{}
+void
+THP_PyFrame_Clear(_PyInterpreterFrame *frame)
+{}
+
+#else
+
#if IS_PYTHON_3_11_PLUS
#define Py_BUILD_CORE
@@ -360,3 +374,5 @@ const uint8_t* THP_PyOpcode_Caches = NULL;
const int THP_PyOpcode_Caches_size = 0;
#endif
+
+#endif // IS_PYTHON_3_14_PLUS
\ No newline at end of file
diff --git a/torch/csrc/dynamo/cpython_includes.h b/torch/csrc/dynamo/cpython_includes.h
index 6b99c1d5aec8..616be16563cf 100644
--- a/torch/csrc/dynamo/cpython_includes.h
+++ b/torch/csrc/dynamo/cpython_includes.h
@@ -21,6 +21,14 @@
#if IS_PYTHON_3_11_PLUS
#include <internal/pycore_frame.h>
+#if IS_PYTHON_3_14_PLUS
+#include <internal/pycore_interpframe_structs.h>
+#include <internal/pycore_stackref.h>
+#endif
+#endif
+
+#if IS_PYTHON_3_14_PLUS
+#include <internal/pycore_code.h>
#endif
#undef Py_BUILD_CORE
@@ -30,6 +38,13 @@
extern "C" {
#endif
+#if IS_PYTHON_3_14_PLUS
+
+#define F_CODE(x) (PyCodeObject*)PyStackRef_AsPyObjectBorrow(x->f_executable)
+#define PREV_INSTR(x) (x)->instr_ptr
+
+#else
+
#if IS_PYTHON_3_13_PLUS
#define F_CODE(x) ((PyCodeObject*)(x)->f_executable)
#define PREV_INSTR(x) (x)->instr_ptr
@@ -38,6 +53,8 @@ extern "C" {
#define PREV_INSTR(x) (x)->prev_instr
#endif
+#endif // IS_PYTHON_3_14_PLUS
+
#if IS_PYTHON_3_12_PLUS
#define FUNC(x) ((x)->f_funcobj)
#else
diff --git a/torch/csrc/dynamo/eval_frame.c b/torch/csrc/dynamo/eval_frame.c
index f413782b2d30..72bb8839bac3 100644
--- a/torch/csrc/dynamo/eval_frame.c
+++ b/torch/csrc/dynamo/eval_frame.c
@@ -224,17 +224,6 @@ const char* get_frame_name(THP_EVAL_API_FRAME_OBJECT* frame) {
return PyUnicode_AsUTF8(F_CODE(frame)->co_name);
}
-void clear_old_frame_if_python_312_plus(
- PyThreadState* tstate,
- THP_EVAL_API_FRAME_OBJECT* frame) {
-#if IS_PYTHON_3_12_PLUS
-
- THP_PyFrame_Clear(frame);
- THP_PyThreadState_PopFrame(tstate, frame);
-
-#endif
-}
-
static PyObject* dynamo_eval_custom_code_impl(
PyThreadState* tstate,
THP_EVAL_API_FRAME_OBJECT* frame,
@@ -485,6 +474,18 @@ static PyObject* dynamo__custom_eval_frame_shim(
static void enable_eval_frame_shim(PyThreadState* tstate) {}
static void enable_eval_frame_default(PyThreadState* tstate) {}
+PyObject* dynamo_eval_custom_code(
+ PyThreadState* tstate,
+ THP_EVAL_API_FRAME_OBJECT* frame,
+ PyCodeObject* code,
+ const char* trace_annotation,
+ int throw_flag) {}
+THPPyInterpreterFrame* THPPyInterpreterFrame_New(
+ THP_EVAL_API_FRAME_OBJECT* frame) {}
+PyObject* dynamo_eval_frame_default(
+ PyThreadState* tstate,
+ THP_EVAL_API_FRAME_OBJECT* frame,
+ int throw_flag) {}
static struct PyGetSetDef THPPyInterpreterFrame_properties[] = {NULL};
@@ -498,6 +499,17 @@ static PyTypeObject THPPyInterpreterFrameType = {
#endif // !(IS_PYTHON_3_14_PLUS)
+void clear_old_frame_if_python_312_plus(
+ PyThreadState* tstate,
+ THP_EVAL_API_FRAME_OBJECT* frame) {
+#if IS_PYTHON_3_12_PLUS
+
+ THP_PyFrame_Clear(frame);
+ THP_PyThreadState_PopFrame(tstate, frame);
+
+#endif
+}
+
static PyObject* increment_working_threads(
PyThreadState* tstate,
PyObject* module) {
diff --git a/torch/csrc/dynamo/framelocals_mapping.cpp b/torch/csrc/dynamo/framelocals_mapping.cpp
index b839fb26fc91..c4ee36d87767 100644
--- a/torch/csrc/dynamo/framelocals_mapping.cpp
+++ b/torch/csrc/dynamo/framelocals_mapping.cpp
@@ -26,9 +26,13 @@ FrameLocalsMapping::FrameLocalsMapping(FrameLocalsFrameType* frame)
PyCodeObject* co = F_CODE(frame);
_framelocals.resize(co->co_nlocalsplus, nullptr);
+#if IS_PYTHON_3_14_PLUS
+ TORCH_CHECK(false, "Python 3.14+ not supported");
+#else
if (!frame->stacktop) {
return;
}
+#endif
auto update_framelocals = [&](int i, PyObject* value) {
_PyLocals_Kind kind = _PyLocals_GetKind(co->co_localspluskinds, i);
@@ -53,11 +57,21 @@ FrameLocalsMapping::FrameLocalsMapping(FrameLocalsFrameType* frame)
};
auto offset = co->co_nlocalsplus - co->co_nfreevars;
+#if IS_PYTHON_3_14_PLUS
+ TORCH_CHECK(false, "Python 3.14+ not supported");
+#else
for (int i = 0; i < offset; i++) {
update_framelocals(i, frame->localsplus[i]);
}
+#endif
+
// Get references to closure variables
+#if IS_PYTHON_3_14_PLUS
+ PyObject* closure;
+ TORCH_CHECK(false, "Python 3.14+ not supported");
+#else
PyObject* closure = ((PyFunctionObject*)FUNC(frame))->func_closure;
+#endif
for (int i = 0; i < co->co_nfreevars; i++) {
update_framelocals(offset + i, PyTuple_GET_ITEM(closure, i));
}
diff --git a/torch/csrc/utils/python_compat.h b/torch/csrc/utils/python_compat.h
index a1537611cc47..16292e4fd030 100644
--- a/torch/csrc/utils/python_compat.h
+++ b/torch/csrc/utils/python_compat.h
@@ -13,6 +13,7 @@ extern "C" {
#define IS_PYTHON_3_12_PLUS PY_VERSION_HEX >= 0x030C0000
#define IS_PYTHON_3_13_PLUS PY_VERSION_HEX >= 0x030D0000
#define IS_PYTHON_3_14_PLUS PY_VERSION_HEX >= 0x030E0000
+#define IS_PYTHON_3_15_PLUS PY_VERSION_HEX >= 0x030F0000
static inline int PyCode_GetNCellvars(PyCodeObject* code) {
// gh-26364 added co_ncellvars to Python 3.11.0rc1
diff --git a/torch/onnx/__init__.py b/torch/onnx/__init__.py
index 345ffd2a065b..ceeadde5365b 100644
--- a/torch/onnx/__init__.py
+++ b/torch/onnx/__init__.py
@@ -104,7 +104,6 @@ ONNXProgram.__module__ = "torch.onnx"
OnnxExporterError.__module__ = "torch.onnx"
_OrtBackend.__module__ = "torch.onnx"
_OrtBackendOptions.__module__ = "torch.onnx"
-_OrtExecutionProvider.__module__ = "torch.onnx"
enable_fake_mode.__module__ = "torch.onnx"
is_onnxrt_backend_supported.__module__ = "torch.onnx"
diff --git a/torch/utils/weak.py b/torch/utils/weak.py
index 8bf2ba5ed02b..9c7218cb2ad3 100644
--- a/torch/utils/weak.py
+++ b/torch/utils/weak.py
@@ -3,8 +3,6 @@ from __future__ import annotations
import collections.abc as _collections_abc
import weakref
-
-from _weakrefset import _IterationGuard # type: ignore[attr-defined]
from collections.abc import Mapping, MutableMapping
from weakref import ref
@@ -22,6 +20,33 @@ __all__ = [
]
+# TODO: make weakref properly thread safe following
+# https://github.com/python/cpython/pull/125325
+class _IterationGuard:
+ # This context manager registers itself in the current iterators of the
+ # weak container, such as to delay all removals until the context manager
+ # exits.
+ # This technique should be relatively thread-safe (since sets are).
+
+ def __init__(self, weakcontainer):
+ # Don't create cycles
+ self.weakcontainer = ref(weakcontainer)
+
+ def __enter__(self):
+ w = self.weakcontainer()
+ if w is not None:
+ w._iterating.add(self)
+ return self
+
+ def __exit__(self, e, t, b):
+ w = self.weakcontainer()
+ if w is not None:
+ s = w._iterating
+ s.remove(self)
+ if not s:
+ w._commit_removals()
+
+
# This file defines a variant of WeakKeyDictionary that overrides the hashing
# behavior of the key to use object identity, rather than the builtin
# __eq__/__hash__ functions. This is useful for Tensor weak keys, as their
--
2.49.0