Patch the dim issue needed for torchrl Stage tensorpipe thirdparty needed for distributed. Signed-off-by: Tom Rix <trix@redhat.com>
115 lines
4.2 KiB
Diff
115 lines
4.2 KiB
Diff
From ee3fb343a376cdba6f4ce188cac90023f13e2aea Mon Sep 17 00:00:00 2001
|
|
From: Tom Rix <trix@redhat.com>
|
|
Date: Thu, 4 Apr 2024 14:21:38 -0600
|
|
Subject: [PATCH] Reenable dim for python 3.12
|
|
|
|
In 3.12:
|
|
|
|
_PyArg_Parser added an element to the start of the structure.
|
|
So existing positional initialization is off. Switch to element
|
|
initialization.
|
|
|
|
_Py_CODEUNIT changed to from an int to a union, but relevant_op
|
|
is passed an int for the return of decoder.opcode, so the parameter
|
|
type is wrong, switch it to int.
|
|
|
|
The opcode PRECALL was removed, so reduce its handling to 3.11
|
|
|
|
Signed-off-by: Tom Rix <trix@redhat.com>
|
|
---
|
|
functorch/csrc/dim/dim.cpp | 24 +++++-------------------
|
|
functorch/csrc/dim/minpybind.h | 4 ++--
|
|
2 files changed, 7 insertions(+), 21 deletions(-)
|
|
|
|
diff --git a/functorch/csrc/dim/dim.cpp b/functorch/csrc/dim/dim.cpp
|
|
index 4cc027504c77..e48b0d58081f 100644
|
|
--- a/functorch/csrc/dim/dim.cpp
|
|
+++ b/functorch/csrc/dim/dim.cpp
|
|
@@ -6,20 +6,6 @@
|
|
|
|
#include <torch/csrc/utils/python_compat.h>
|
|
|
|
-
|
|
-// Many APIs have changed/don't exist anymore
|
|
-#if IS_PYTHON_3_12_PLUS
|
|
-
|
|
-#include "dim.h"
|
|
-
|
|
-// Re-enable this some day
|
|
-PyObject* Dim_init() {
|
|
- PyErr_SetString(PyExc_RuntimeError, "First class dim doesn't work with python 3.12");
|
|
- return nullptr;
|
|
-}
|
|
-
|
|
-#else
|
|
-
|
|
#include "minpybind.h"
|
|
#include <frameobject.h>
|
|
#include <opcode.h>
|
|
@@ -441,7 +427,7 @@ static PyObject* DimList_bind(DimList *self,
|
|
PY_BEGIN
|
|
mpy::handle sizes;
|
|
static const char * const _keywords[] = {"sizes", nullptr};
|
|
- static _PyArg_Parser parser = {"O", _keywords, 0};
|
|
+ static _PyArg_Parser parser = { .format = "O", .keywords = _keywords};
|
|
if (!_PyArg_ParseStackAndKeywords(args, nargs, kwnames, &parser, &sizes)) {
|
|
return nullptr;
|
|
}
|
|
@@ -465,7 +451,7 @@ static PyObject* DimList_bind_len(DimList *self,
|
|
PY_BEGIN
|
|
int size;
|
|
static const char * const _keywords[] = {"N", nullptr};
|
|
- static _PyArg_Parser parser = {"i", _keywords, 0};
|
|
+ static _PyArg_Parser parser = { .format = "i", .keywords = _keywords};
|
|
if (!_PyArg_ParseStackAndKeywords(args, nargs, kwnames, &parser, &size)) {
|
|
return nullptr;
|
|
}
|
|
@@ -1468,7 +1454,7 @@ PyTypeObject Tensor::Type = {
|
|
|
|
// dim() --------------------
|
|
|
|
-static bool relevant_op(_Py_CODEUNIT c) {
|
|
+static bool relevant_op(int c) {
|
|
switch(c) {
|
|
case STORE_NAME:
|
|
case STORE_GLOBAL:
|
|
@@ -1587,7 +1573,7 @@ static PyObject* _dims(PyObject *self,
|
|
auto c = mpy::obj<PyCodeObject>::steal(PyFrame_GetCode(f.ptr()));
|
|
auto lasti = PyFrame_GetLasti(f.ptr());
|
|
auto decoder = PyInstDecoder(c.ptr(), lasti);
|
|
- #if IS_PYTHON_3_11_PLUS
|
|
+ #if IS_PYTHON_3_11
|
|
// When py3.11 adapts bytecode lasti points to the precall
|
|
// rather than the call instruction after it
|
|
if (decoder.opcode() == PRECALL) {
|
|
@@ -3268,4 +3254,4 @@ PyObject* Dim_init() {
|
|
}
|
|
}
|
|
|
|
-#endif
|
|
+
|
|
diff --git a/functorch/csrc/dim/minpybind.h b/functorch/csrc/dim/minpybind.h
|
|
index de82b5af95a4..d76d4828bf80 100644
|
|
--- a/functorch/csrc/dim/minpybind.h
|
|
+++ b/functorch/csrc/dim/minpybind.h
|
|
@@ -621,7 +621,7 @@ struct vector_args {
|
|
PyObject *dummy = NULL;
|
|
_PyArg_ParseStackAndKeywords((PyObject*const*)args, nargs, kwnames.ptr(), _parser, &dummy, &dummy, &dummy, &dummy, &dummy);
|
|
#else
|
|
- _PyArg_Parser* _parser = new _PyArg_Parser{NULL, &names_buf[0], fname_cstr, 0};
|
|
+ _PyArg_Parser* _parser = new _PyArg_Parser{ .keywords = &names_buf[0], .fname = fname_cstr};
|
|
std::unique_ptr<PyObject*[]> buf(new PyObject*[names.size()]);
|
|
_PyArg_UnpackKeywords((PyObject*const*)args, nargs, NULL, kwnames.ptr(), _parser, required, (Py_ssize_t)values.size() - kwonly, 0, &buf[0]);
|
|
#endif
|
|
@@ -706,7 +706,7 @@ inline object handle::call_vector(vector_args args) {
|
|
#define MPY_PARSE_ARGS_KWNAMES(fmt, FORALL_ARGS) \
|
|
static const char * const kwlist[] = { FORALL_ARGS(MPY_ARGS_NAME) nullptr}; \
|
|
FORALL_ARGS(MPY_ARGS_DECLARE) \
|
|
- static _PyArg_Parser parser = {fmt, kwlist, 0}; \
|
|
+ static _PyArg_Parser parser = { .format = fmt, .keywords = kwlist}; \
|
|
if (!_PyArg_ParseStackAndKeywords(args, nargs, kwnames, &parser, FORALL_ARGS(MPY_ARGS_POINTER) nullptr)) { \
|
|
throw mpy::exception_set(); \
|
|
}
|
|
--
|
|
2.44.0
|
|
|