226 lines
8.2 KiB
Diff
226 lines
8.2 KiB
Diff
From b9d45eb1cc90696a4de76676221219e24423c709 Mon Sep 17 00:00:00 2001
|
|
From: William Wen <williamwen@meta.com>
|
|
Date: Wed, 3 Apr 2024 17:58:46 -0700
|
|
Subject: [PATCH] [dynamo, 3.12] enable dynamo on 3.12, enable most dynamo
|
|
unittests on 3.12 (#123216)
|
|
|
|
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123216
|
|
Approved by: https://github.com/jansel, https://github.com/malfet
|
|
---
|
|
test/dynamo/test_autograd_function.py | 3 ++
|
|
test/dynamo/test_misc.py | 63 +++++++++++++++++++++++++
|
|
test/functorch/test_eager_transforms.py | 7 ++-
|
|
test/run_test.py | 3 --
|
|
torch/__init__.py | 5 +-
|
|
torch/_dynamo/eval_frame.py | 4 +-
|
|
torch/_dynamo/test_case.py | 8 +---
|
|
7 files changed, 74 insertions(+), 19 deletions(-)
|
|
|
|
diff --git a/test/dynamo/test_autograd_function.py b/test/dynamo/test_autograd_function.py
|
|
index d23fec607afa..bc5ebc767038 100644
|
|
--- a/test/dynamo/test_autograd_function.py
|
|
+++ b/test/dynamo/test_autograd_function.py
|
|
@@ -2,6 +2,8 @@
|
|
|
|
import copy
|
|
import math
|
|
+import sys
|
|
+import unittest
|
|
|
|
import torch
|
|
|
|
@@ -528,6 +530,7 @@ class AutogradFunctionTests(torch._dynamo.test_case.TestCase):
|
|
# I pulled all of these test cases from test_autograd.py
|
|
# In the future, we should make the Dynamo test suite actually
|
|
# run on test_autograd.py (it's disabled right now) and delete these.
|
|
+ @unittest.skipIf(sys.version_info >= (3, 12), "invalid free in 3.12+")
|
|
def test_smoke_from_test_autograd(self):
|
|
class Func(torch.autograd.Function):
|
|
@staticmethod
|
|
diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py
|
|
index a73de8b1c7e9..8f54e0564e6b 100644
|
|
--- a/test/dynamo/test_misc.py
|
|
+++ b/test/dynamo/test_misc.py
|
|
@@ -9760,6 +9760,69 @@ fn
|
|
lambda mod: mod,
|
|
)
|
|
|
|
+ @xfailIfPy311
|
|
+ def test_outside_linear_module_free(self):
|
|
+ # Compared to test_linear_module_free, the linear
|
|
+ # layer is not the code object that is directly compiled.
|
|
+ def model_inp_ctr():
|
|
+ fc = torch.nn.Linear(100, 100)
|
|
+
|
|
+ class Mod(torch.nn.Module):
|
|
+ def __init__(self):
|
|
+ super().__init__()
|
|
+ self.fc_ref = fc
|
|
+
|
|
+ def forward(self, x):
|
|
+ return fc(x[0])
|
|
+
|
|
+ # return fc to keep it alive in _test_compile_model_free
|
|
+ return Mod(), (torch.randn(100, 100), fc)
|
|
+
|
|
+ self._test_compile_model_free(model_inp_ctr, lambda mod: mod.fc_ref)
|
|
+
|
|
+ @unittest.skipIf(sys.version_info >= (3, 12), "leaks in 3.12+")
|
|
+ def test_parameter_free(self):
|
|
+ def model_inp_ctr():
|
|
+ param = torch.nn.Parameter(torch.randn(100, 100))
|
|
+
|
|
+ class Mod(torch.nn.Module):
|
|
+ def __init__(self):
|
|
+ super().__init__()
|
|
+ self.param = param
|
|
+
|
|
+ def forward(self, x):
|
|
+ return self.param * x[0]
|
|
+
|
|
+ # return param to keep it alive in _test_compile_model_free
|
|
+ return Mod(), (torch.randn(100, 100), param)
|
|
+
|
|
+ self._test_compile_model_free(model_inp_ctr, lambda mod: mod.param)
|
|
+
|
|
+ def test_raises_importerror1(self):
|
|
+ @torch.compile(backend="eager")
|
|
+ def fn(x):
|
|
+ try:
|
|
+ import some_module_that_surely_does_not_exist
|
|
+
|
|
+ return
|
|
+ except ImportError:
|
|
+ pass
|
|
+ return x.sin()
|
|
+
|
|
+ x = torch.randn(8)
|
|
+ self.assertEqual(fn(x), x.sin())
|
|
+
|
|
+ def test_raises_importerror2(self):
|
|
+ @torch.compile(backend="eager")
|
|
+ def fn(x):
|
|
+ import some_module_that_surely_does_not_exist
|
|
+
|
|
+ return x + 1
|
|
+
|
|
+ x = torch.randn(8)
|
|
+ with self.assertRaises(ImportError):
|
|
+ fn(x)
|
|
+
|
|
def test_dynamo_cache_move_to_front(self):
|
|
class Mod(torch.nn.Module):
|
|
def __init__(self):
|
|
diff --git a/test/functorch/test_eager_transforms.py b/test/functorch/test_eager_transforms.py
|
|
index 09415cf8f48e..60790ec06059 100644
|
|
--- a/test/functorch/test_eager_transforms.py
|
|
+++ b/test/functorch/test_eager_transforms.py
|
|
@@ -4762,8 +4762,7 @@ class TestCompileTransforms(TestCase):
|
|
# Triton only supports GPU with SM70 or later.
|
|
@expectedFailureIf((IS_ARM64 and not IS_MACOS) or
|
|
IS_WINDOWS or
|
|
- (TEST_CUDA and not SM70OrLater) or
|
|
- (sys.version_info >= (3, 12)))
|
|
+ (TEST_CUDA and not SM70OrLater))
|
|
def test_compile_vmap_hessian(self, device):
|
|
# The model and inputs are a smaller version
|
|
# of code at benchmark repo:
|
|
@@ -4792,8 +4791,8 @@ class TestCompileTransforms(TestCase):
|
|
actual = opt_fn(params_and_buffers, x)
|
|
self.assertEqual(actual, expected)
|
|
|
|
- # torch.compile is not supported on Windows or on Python 3.12+
|
|
- @expectedFailureIf(IS_WINDOWS or (sys.version_info >= (3, 12)))
|
|
+ # torch.compile is not supported on Windows
|
|
+ @expectedFailureIf(IS_WINDOWS)
|
|
@torch._dynamo.config.patch(suppress_errors=False)
|
|
@torch._dynamo.config.patch(capture_func_transforms=True)
|
|
@skipIfTorchDynamo("Do not test torch.compile on top of torch.compile")
|
|
diff --git a/test/run_test.py b/test/run_test.py
|
|
index e86af9623042..ebb14df4167d 100755
|
|
--- a/test/run_test.py
|
|
+++ b/test/run_test.py
|
|
@@ -74,7 +74,6 @@ sys.path.remove(str(REPO_ROOT))
|
|
RERUN_DISABLED_TESTS = os.getenv("PYTORCH_TEST_RERUN_DISABLED_TESTS", "0") == "1"
|
|
DISTRIBUTED_TEST_PREFIX = "distributed"
|
|
INDUCTOR_TEST_PREFIX = "inductor"
|
|
-DYNAMO_TEST_PREFIX = "dynamo"
|
|
|
|
|
|
# Note [ROCm parallel CI testing]
|
|
@@ -324,7 +323,6 @@ JIT_EXECUTOR_TESTS = [
|
|
]
|
|
|
|
INDUCTOR_TESTS = [test for test in TESTS if test.startswith(INDUCTOR_TEST_PREFIX)]
|
|
-DYNAMO_TESTS = [test for test in TESTS if test.startswith(DYNAMO_TEST_PREFIX)]
|
|
DISTRIBUTED_TESTS = [test for test in TESTS if test.startswith(DISTRIBUTED_TEST_PREFIX)]
|
|
TORCH_EXPORT_TESTS = [test for test in TESTS if test.startswith("export")]
|
|
FUNCTORCH_TESTS = [test for test in TESTS if test.startswith("functorch")]
|
|
@@ -1361,7 +1359,6 @@ def get_selected_tests(options) -> List[str]:
|
|
# these tests failing in Python 3.12 temporarily disabling
|
|
if sys.version_info >= (3, 12):
|
|
options.exclude.extend(INDUCTOR_TESTS)
|
|
- options.exclude.extend(DYNAMO_TESTS)
|
|
options.exclude.extend(
|
|
[
|
|
"functorch/test_dims",
|
|
diff --git a/torch/__init__.py b/torch/__init__.py
|
|
index d381712b4a35..26cdffe81d29 100644
|
|
--- a/torch/__init__.py
|
|
+++ b/torch/__init__.py
|
|
@@ -1861,9 +1861,8 @@ def compile(model: Optional[Callable] = None, *,
|
|
|
|
"""
|
|
_C._log_api_usage_once("torch.compile")
|
|
- # Temporary until we get proper support for python 3.12
|
|
- if sys.version_info >= (3, 12):
|
|
- raise RuntimeError("Dynamo is not supported on Python 3.12+")
|
|
+ if sys.version_info >= (3, 13):
|
|
+ raise RuntimeError("Dynamo is not supported on Python 3.13+")
|
|
|
|
# Decorator mode
|
|
if model is None:
|
|
diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py
|
|
index 53ab0df3a947..0a80eeea99ed 100644
|
|
--- a/torch/_dynamo/eval_frame.py
|
|
+++ b/torch/_dynamo/eval_frame.py
|
|
@@ -589,8 +589,8 @@ class _NullDecorator(contextlib.nullcontext): # type: ignore[type-arg]
|
|
|
|
|
|
def check_if_dynamo_supported():
|
|
- if sys.version_info >= (3, 12):
|
|
- raise RuntimeError("Python 3.12+ not yet supported for torch.compile")
|
|
+ if sys.version_info >= (3, 13):
|
|
+ raise RuntimeError("Python 3.13+ not yet supported for torch.compile")
|
|
|
|
|
|
def is_dynamo_supported():
|
|
diff --git a/torch/_dynamo/test_case.py b/torch/_dynamo/test_case.py
|
|
index e3cbef09eaae..297ea6e2bc2a 100644
|
|
--- a/torch/_dynamo/test_case.py
|
|
+++ b/torch/_dynamo/test_case.py
|
|
@@ -1,7 +1,6 @@
|
|
import contextlib
|
|
import importlib
|
|
import logging
|
|
-import sys
|
|
|
|
import torch
|
|
import torch.testing
|
|
@@ -20,12 +19,7 @@ log = logging.getLogger(__name__)
|
|
def run_tests(needs=()):
|
|
from torch.testing._internal.common_utils import run_tests
|
|
|
|
- if (
|
|
- TEST_WITH_TORCHDYNAMO
|
|
- or IS_WINDOWS
|
|
- or TEST_WITH_CROSSREF
|
|
- or sys.version_info >= (3, 12)
|
|
- ):
|
|
+ if TEST_WITH_TORCHDYNAMO or IS_WINDOWS or TEST_WITH_CROSSREF:
|
|
return # skip testing
|
|
|
|
if isinstance(needs, str):
|
|
--
|
|
2.44.0
|
|
|