From 56dfd4c74bfa2fb03fbfeb28f7774870248910c4 Mon Sep 17 00:00:00 2001 From: Simon Layton Date: Mon, 3 Nov 2025 22:12:22 +0000 Subject: [PATCH 001/130] Add CUDA MXFP4 scaled mm support via. FBGEMM (#166526) Summary: * Pull in `f4f4bf16` from FBGemm to provide MXFP4 support for CUDA * Add testing Test Plan: Reviewers: Subscribers: Tasks: Tags: Signed-off-by: Simon Layton Pull Request resolved: https://github.com/pytorch/pytorch/pull/166526 Approved by: https://github.com/drisspg, https://github.com/ngimel --- aten/src/ATen/CMakeLists.txt | 2 +- aten/src/ATen/native/cuda/ScaledBlas.cpp | 81 +++++++++++++++++++++--- test/test_scaled_matmul_cuda.py | 40 ++++++------ 3 files changed, 91 insertions(+), 32 deletions(-) diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index 8b283c417b74b..ae762e1def3ec 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -260,7 +260,7 @@ IF(USE_FBGEMM_GENAI) if(USE_CUDA) # To avoid increasing the build time/binary size unnecessarily, use an allow-list of kernels to build. # If you want to integrate a kernel from FBGEMM into torch, you have to add it here. - set(FBGEMM_CUTLASS_KERNELS_REGEX ".*(mx8mx8bf16_grouped|f4f4bf16_grouped).*") + set(FBGEMM_CUTLASS_KERNELS_REGEX ".*(mx8mx8bf16_grouped|f4f4bf16_grouped|f4f4bf16).*") file(GLOB_RECURSE fbgemm_genai_native_cuda_cu "${FBGEMM_GENAI_SRCS}/cutlass_extensions/*.cu" "${FBGEMM_GENAI_SRCS}/cutlass_extensions/**/*.cu") diff --git a/aten/src/ATen/native/cuda/ScaledBlas.cpp b/aten/src/ATen/native/cuda/ScaledBlas.cpp index 0d2963874abbd..9065d79929360 100644 --- a/aten/src/ATen/native/cuda/ScaledBlas.cpp +++ b/aten/src/ATen/native/cuda/ScaledBlas.cpp @@ -59,6 +59,24 @@ // forward declare class cublasCommonArgs; +#ifndef _WIN32 +namespace fbgemm_gpu { + +// NOTE(slayton58): FBGemm_GPU kernels come from within the FBGemm repo. +// To update supported ops means a submodule bump, which is.. painful. Instead, we +// can simply forward-declare the methods we want to use.. Works at least as a short-term +// thing, but should still be fixed somewhere/somehow. +at::Tensor f4f4bf16( + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + std::optional, + bool use_mx); + +} // namespace fbgemm_gpu +#endif + using at::blas::ScalingType; using at::blas::SwizzleType; @@ -1087,26 +1105,47 @@ _scaled_mxfp4_mxfp4( const std::optional& bias, const c10::ScalarType out_dtype, Tensor& out) { -#ifndef USE_ROCM - TORCH_CHECK_NOT_IMPLEMENTED(false, "MXFP4 scaling supported on ROCM only"); -#endif +#if defined(_WIN32) || (!defined(USE_ROCM) && !defined(USE_FBGEMM_GENAI)) + TORCH_CHECK_NOT_IMPLEMENTED(false, "MXFP4 scaling supported on ROCM and CUDA+FBGEMM_GENAI only"); +#else // Restrictions: // A, B are FP4, scales are e8m0, A: shape K//32, B: K, N//32 TORCH_CHECK_VALUE(mat_a.scalar_type() == at::kFloat4_e2m1fn_x2 && mat_b.scalar_type() == at::kFloat4_e2m1fn_x2, "mat_a and mat_b must be fp4 types, got: ", mat_a.scalar_type(), mat_b.scalar_type()); - auto scale_a_elems = ceil_div(2 * mat_a.size(0), 32) * mat_a.size(1); - auto scale_b_elems = ceil_div(2 * mat_b.size(1), 32) * mat_b.size(0); + // Packed FP4 format means actual-K = 2 * reported-K -- adjust + auto K_multiplier = 2; +#ifdef USE_ROCM + // AMD + auto scale_a_elems = ceil_div(K_multiplier * mat_a.size(0), 32) * mat_a.size(1); + auto scale_b_elems = ceil_div(K_multiplier * mat_b.size(1), 32) * mat_b.size(0); +#else + // NVIDIA + auto scale_a_elems = round_up(mat_a.size(0), 128) * round_up(ceil_div(K_multiplier * mat_a.size(1), 32), 4); + auto scale_b_elems = round_up(mat_b.size(1), 128) * round_up(ceil_div(K_multiplier * mat_b.size(0), 32), 4); +#endif TORCH_CHECK_VALUE(scale_a_elems == scale_a.numel(), "For Blockwise scaling scale_a should have ", scale_a_elems, " elements, got: ", scale_a.numel()); TORCH_CHECK_VALUE(scale_b_elems == scale_b.numel(), "For Blockwise scaling scale_b should have ", scale_b_elems, " elements, got: ", scale_b.numel()); +#ifdef USE_ROCM + // AMD + TORCH_CHECK_VALUE(swizzle_a == SwizzleType::NO_SWIZZLE, "scale_a must not be swizzled (NO_SWIZZLE format)"); + TORCH_CHECK_VALUE(swizzle_b == SwizzleType::NO_SWIZZLE, "scale_b must not be swizzled (NO_SWIZZLE format)"); +#else + // NVIDIA + TORCH_CHECK_VALUE(swizzle_a == SwizzleType::SWIZZLE_32_4_4, "scale_a must be swizzled to SWIZZLE_32_4_4 format"); + TORCH_CHECK_VALUE(swizzle_b == SwizzleType::SWIZZLE_32_4_4, "scale_b must be swizzled to SWIZZLE_32_4_4 format"); +#endif + TORCH_CHECK_VALUE(scale_a.is_contiguous() && scale_b.is_contiguous(), "For Blockwise scaling both scales should be contiguous"); TORCH_CHECK_VALUE(out.scalar_type() == out_dtype, "expected out.scalar_type() to be ", out_dtype, ", but got ", out_dtype); +#ifdef USE_ROCM + // AMD auto scaling_choice_a = ScalingType::BlockWise1x32; auto scaling_choice_b = ScalingType::BlockWise1x32; @@ -1121,11 +1160,30 @@ _scaled_mxfp4_mxfp4( TORCH_CHECK_VALUE(out.scalar_type() == ScalarType::BFloat16 || out.scalar_type() == ScalarType::Half, "Block-wise scaling only supports BFloat16 or Half output types"); -#else - TORCH_CHECK_NOT_IMPLEMENTED(false, "Block-wise scaling for Float8_e8m0fnu requires ROCm 7.0 or later"); #endif return _scaled_gemm(mat_a, mat_b, scale_a, scale_b, scaling_choice_a, scaling_choice_b, bias, false /* use_fast_accum */, out); +#else + // NVIDIA + // NOTE(slayton58): fbgemm_gpu::f4f4bf16 does *not* allow passing an output tensor, + // but we have one we need to use. Two clear options are to copy into + // our output (slow), or use a move-assignment-operator (faster). + // However, the compiler can complain about the explicit move preventing + // copy elision because the return from f4f4bf16 is a temporary object. + // So we don't explicitly move, and trust the compiler here... + // In the longer term this should be fixed on the FBGemm side. + out = fbgemm_gpu::f4f4bf16( + mat_a, + mat_b.transpose(-2, -1), + scale_a, + scale_b, + std::nullopt, /* global_scale */ + true /* use_mx */ + ); + + return out; +#endif +#endif } Tensor& @@ -1250,17 +1308,20 @@ _scaled_mm_cuda_v2_out( mat_a.size(0), "x", mat_a.size(1), " and ", mat_b.size(0), "x", mat_b.size(1), ")"); } + // Handle fp4 packed-K dimension + int K_multiplier = (mat_a.scalar_type() == ScalarType::Float4_e2m1fn_x2) ? 2 : 1; + TORCH_CHECK_VALUE(!bias || bias->numel() == mat_b.sizes()[1], "Bias must be size ", mat_b.sizes()[1], " but got ", bias->numel()); TORCH_CHECK_VALUE( - mat_a.sizes()[1] % 16 == 0, + K_multiplier * mat_a.sizes()[1] % 16 == 0, "Expected trailing dimension of mat1 to be divisible by 16 ", "but got mat1 shape: (", mat_a.sizes()[0], "x", - mat_a.sizes()[1], + K_multiplier * mat_a.sizes()[1], ")."); - TORCH_CHECK_VALUE(mat_b.sizes()[0] % 16 == 0 && mat_b.sizes()[1] % 16 == 0, "mat2 shape (", mat_b.sizes()[0], "x", + TORCH_CHECK_VALUE(K_multiplier * mat_b.sizes()[0] % 16 == 0 && mat_b.sizes()[1] % 16 == 0, "mat2 shape (", mat_b.sizes()[0], "x", mat_b.sizes()[1], ") must be divisible by 16"); // TODO(slayton): Existing checks, not sure if they should really be here. diff --git a/test/test_scaled_matmul_cuda.py b/test/test_scaled_matmul_cuda.py index 4d88ccd9cc7dd..9738ac4ac6fbf 100644 --- a/test/test_scaled_matmul_cuda.py +++ b/test/test_scaled_matmul_cuda.py @@ -209,42 +209,36 @@ def infer_scale_swizzle(mat, scale): ] == math.ceil(mat.shape[1] // 128): return ScalingType.BlockWise128x128, SwizzleType.NO_SWIZZLE + # if we're checking for nvfp4, need to adjust for packed-K + K_multiplier = 2 if mat.dtype == torch.float4_e2m1fn_x2 else 1 # NVFP4 if ( (scale.numel() - == round_up(mat.shape[0], 128) * round_up(math.ceil(2 * mat.shape[1] // 16), 4) + == round_up(mat.shape[0], 128) * round_up(math.ceil(K_multiplier * mat.shape[1] // 16), 4) or scale.numel() - == round_up(mat.shape[1], 128) * round_up(math.ceil(2 * mat.shape[0] // 16), 4)) + == round_up(mat.shape[1], 128) * round_up(math.ceil(K_multiplier * mat.shape[0] // 16), 4)) and mat.dtype == torch.float4_e2m1fn_x2 and scale.dtype == torch.float8_e4m3fn ): return ScalingType.BlockWise1x16, SwizzleType.SWIZZLE_32_4_4 - # MXFP4 w/o swizzle - if ( - (scale.numel() == 2 * math.ceil(mat.shape[0] // 32) * mat.shape[1] - or scale.numel() == 2 * math.ceil(mat.shape[1] // 32) * mat.shape[0]) - and mat.dtype == torch.float4_e2m1fn_x2 - and scale.dtype == torch.float8_e8m0fnu - ): - return ScalingType.BlockWise1x32, SwizzleType.NO_SWIZZLE - + # MX formats if not torch.version.hip: - # MXFP8 w/ swizzle + # MX w/swizzle (NVIDIA) if ( (scale.numel() - == round_up(mat.shape[0], 128) * round_up(math.ceil(mat.shape[1] // 32), 4) + == round_up(mat.shape[0], 128) * round_up(math.ceil(K_multiplier * mat.shape[1] // 32), 4) or scale.numel() - == round_up(mat.shape[1], 128) * round_up(math.ceil(mat.shape[0] // 32), 4)) + == round_up(mat.shape[1], 128) * round_up(math.ceil(K_multiplier * mat.shape[0] // 32), 4)) and scale.dtype == torch.float8_e8m0fnu ): return ScalingType.BlockWise1x32, SwizzleType.SWIZZLE_32_4_4 else: - # MXFP8 w/o swizzle + # MX w/o swizzle (AMD) if ( - (scale.numel() == math.ceil(mat.shape[0] // 32) * mat.shape[1] - or scale.numel() == math.ceil(mat.shape[1] // 32) * mat.shape[0]) + (scale.numel() == math.ceil(mat.shape[0] // 32) * K_multiplier * mat.shape[1] + or scale.numel() == math.ceil(K_multiplier * mat.shape[1] // 32) * mat.shape[0]) and scale.dtype == torch.float8_e8m0fnu ): return ScalingType.BlockWise1x32, SwizzleType.NO_SWIZZLE @@ -1868,7 +1862,7 @@ def test_blockwise_nvfp4_with_global_scale(self, mkn) -> None: (127, 96, 1024), (1025, 128, 96) ], name_fn=lambda mkn: f"{mkn[0]}_{mkn[1]}_{mkn[2]}") - @parametrize("recipe", ["mxfp8", "mxfp4" if torch.version.hip else "nvfp4"]) + @parametrize("recipe", ["mxfp8", "mxfp4", "nvfp4"]) def test_blockwise_mxfp8_nvfp4_mxfp4_numerics(self, test_case_name, fast_accum, mkn, recipe) -> None: if (recipe == "nvfp4" or recipe == "mxfp4") and fast_accum: raise unittest.SkipTest("fast_accum not supported in nvfp4/mxfp4 cublas gemm, skipping") @@ -1882,8 +1876,12 @@ def test_blockwise_mxfp8_nvfp4_mxfp4_numerics(self, test_case_name, fast_accum, if not (M % 16 == 0 and K % 128 == 0 and N % 16 == 0): raise unittest.SkipTest("M and N must be multiples of 16 and K must be multiple of 128 on ROCm, skipping") - fp4_scaling_dtype = torch.float8_e8m0fnu if torch.version.hip else torch.float8_e4m3fn - BLOCK_SIZE = 32 if torch.version.hip else (16 if recipe == "nvfp4" else 32) + fp4_scaling_dtype = torch.float8_e8m0fnu if recipe == "mxfp4" else torch.float8_e4m3fn + BLOCK_SIZE = 16 if recipe == "nvfp4" else 32 + + if K % BLOCK_SIZE != 0: + raise unittest.SkipTest(f"K ({K}) must be divisible by BLOCK_SIZE ({BLOCK_SIZE}), skipping") + require_exact_match = True approx_match_sqnr_target = 22.0 @@ -2061,7 +2059,7 @@ def test_blockwise_mxfp8_nvfp4_mxfp4_numerics(self, test_case_name, fast_accum, B = B.clamp(min=min_val, max=max_val) B = _bfloat16_to_float4_e2m1fn_x2(B) - approx_match_sqnr_target = 15 if torch.version.hip else 15.8 + approx_match_sqnr_target = 15 if recipe == "mxfp4" else 15.8 C_ref = A_ref @ B_ref.t() From afd50bdd290d1ff8976d8477efb9ad9256705d88 Mon Sep 17 00:00:00 2001 From: Catherine Lee Date: Tue, 4 Nov 2025 16:43:06 +0000 Subject: [PATCH 002/130] [CI] Use smaller amx + avx2 runners for inductor test? (#164989) Results from CI: No failures but generally takes longer, maybe ~20% increase in time? But the smaller runner is ~25% of the cost of the current runner, so in terms of cost this is a decrease If the 20% is too much, we can try the 4x larger runners, which are about half the cost of the current runner, so it would probably still result in cost savings with hopefully less impact to time Pull Request resolved: https://github.com/pytorch/pytorch/pull/164989 Approved by: https://github.com/BoyuanFeng, https://github.com/huydhn --- .github/workflows/inductor-unittest.yml | 8 ++++---- .github/workflows/inductor.yml | 14 +++++++------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/.github/workflows/inductor-unittest.yml b/.github/workflows/inductor-unittest.yml index 6ab276a57fc4d..3ce917567aec2 100644 --- a/.github/workflows/inductor-unittest.yml +++ b/.github/workflows/inductor-unittest.yml @@ -115,10 +115,10 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" test-matrix: | { include: [ - { config: "inductor_amx", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" }, - { config: "inductor_amx", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" }, - { config: "inductor_avx2", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.10xlarge.avx2" }, - { config: "inductor_avx2", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.10xlarge.avx2" }, + { config: "inductor_amx", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" }, + { config: "inductor_amx", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" }, + { config: "inductor_avx2", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.avx2" }, + { config: "inductor_avx2", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.avx2" }, ]} secrets: inherit diff --git a/.github/workflows/inductor.yml b/.github/workflows/inductor.yml index 2616141c0dc2a..8a913c3b36a11 100644 --- a/.github/workflows/inductor.yml +++ b/.github/workflows/inductor.yml @@ -84,13 +84,13 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" test-matrix: | { include: [ - { config: "cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" }, - { config: "cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" }, - { config: "dynamic_cpu_inductor_huggingface", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" }, - { config: "dynamic_cpu_inductor_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" }, - { config: "dynamic_cpu_inductor_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" }, - { config: "dynamic_cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" }, - { config: "dynamic_cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.8xlarge.amx" }, + { config: "cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" }, + { config: "cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" }, + { config: "dynamic_cpu_inductor_huggingface", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" }, + { config: "dynamic_cpu_inductor_timm", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" }, + { config: "dynamic_cpu_inductor_timm", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" }, + { config: "dynamic_cpu_inductor_torchbench", shard: 1, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" }, + { config: "dynamic_cpu_inductor_torchbench", shard: 2, num_shards: 2, runner: "${{ needs.get-label-type.outputs.label-type }}linux.2xlarge.amx" }, { config: "inductor_torchbench_cpu_smoketest_perf", shard: 1, num_shards: 1, runner: "${{ needs.get-label-type.outputs.label-type }}linux.24xl.spr-metal" }, ]} build-additional-packages: "vision audio torchao" From 8d4b8ab43033667f66a1180974d8faf9b1b8b93d Mon Sep 17 00:00:00 2001 From: Catherine Lee Date: Tue, 4 Nov 2025 16:45:22 +0000 Subject: [PATCH 003/130] [ez] Print some more test timing info in the logs (#166447) You can just subtract timestamps, but this makes it easier Pull Request resolved: https://github.com/pytorch/pytorch/pull/166447 Approved by: https://github.com/Skylion007 --- test/run_test.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/test/run_test.py b/test/run_test.py index 4b7030d461529..448fbc28751f3 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -1826,9 +1826,14 @@ def run_test_module( test_name = test.name # Printing the date here can help diagnose which tests are slow - print_to_stderr(f"Running {str(test)} ... [{datetime.now()}]") + start = time.perf_counter() + print_to_stderr(f"Running {str(test)} ... [{datetime.now()}][{start}]") handler = CUSTOM_HANDLERS.get(test_name, run_test) return_code = handler(test, test_directory, options) + end = time.perf_counter() + print_to_stderr( + f"Finished {str(test)} ... [{datetime.now()}][{end}], took {(end - start) / 60:.2f}min" + ) assert isinstance(return_code, int) and not isinstance(return_code, bool), ( f"While running {str(test)} got non integer return code {return_code}" ) From 68eb55c4b23babd005267dfd322dc4b070041f58 Mon Sep 17 00:00:00 2001 From: Shangdi Yu Date: Mon, 3 Nov 2025 17:40:23 -0800 Subject: [PATCH 004/130] Add model code stack trace to cuda.memory._snapshot (#166676) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit We store a mapping between generated fx graph code and original model code stack trace in `fx.traceback._FX_METADATA_REGISTRY`. And we do a post-processing on the memory snapshot to append the original model stack trace information. To achieve this, the biggest change we had to do in `aot_eager` mode is to give each generated fx graph a unique stack trace, i.e. it cannot just be ``. We set co_filename to **pretend** that the code is from `co_filename` file. Now instead of `` in stack trace, we get something like `fx_generated_3a4b5c6d7e8f9a0.py`. `augment_with_fx_traces` arg is added to `torch.cuda.memory._snapshot` and `_dump_snapshot`. When the arg is set to True, a post-processing will run to populate the original model stack trace to the snapshot frames. The new behavior of GraphModule can be controlled by `TORCH_ENRICH_RPOFILER_STACK_TRACE` or `_dynamo.config.enrich_profiler_metadata=True`. Alternative: Instead of setting co_filename, we can also do it like below: Note that if we do it this way, we will need to dump the file to make the graph module torch-scriptable. TorchScript requires source access in order to carry out compilation, so we need to make sure original .py files are available. ``` key = filename globals_copy = globals.copy() globals_copy["__file__"] = key globals_copy["__name__"] = key linecache.lazycache(key, globals_copy) exec(compile(src, key, "exec"), globals) ```` Other changes: - Update `MemoryViz.js` to display fx node information and original model code if exist ``` python test/test_fx.py -k test_lineno_map python test/test_fx.py -k test_custom_traceback_raised python test/test_public_bindings.py python test/test_cuda.py -k test_fx_memory python test/test_fx.py -k test_informative_co_filename python test/test_fx.py -k test_autowrap_functions python test/dynamo/test_utils.py -k test_inductor_provenance ``` ```python # Profile with memory snapshot torch.cuda.memory._record_memory_history() with torch._dynamo.config.patch("enrich_profiler_stack_trace", True): compiled = torch.compile(mod, backend="aot_eager", fullgraph=True) result = compiled(torch.randn(10, 10, device="cuda:0")) torch.cuda.memory._dump_snapshot("memory_snapshot.pickle", augment_with_fx_traces=True) torch.cuda.memory._record_memory_history(enabled=None) ``` Screenshot 2025-10-30 at 10 40 44 AM Pull Request resolved: https://github.com/pytorch/pytorch/pull/166676 Approved by: https://github.com/albanD, https://github.com/ezyang --- test/test_cuda.py | 134 +++++++++++++++++++++++ test/test_fx.py | 2 + torch/_dynamo/config.py | 6 ++ torch/cuda/memory.py | 202 +++++++++++++++++++++++++++++++++-- torch/fx/graph.py | 13 ++- torch/fx/graph_module.py | 79 +++++++++++++- torch/fx/traceback.py | 22 ++++ torch/utils/viz/MemoryViz.js | 24 ++++- 8 files changed, 472 insertions(+), 10 deletions(-) diff --git a/test/test_cuda.py b/test/test_cuda.py index a7e373da63824..00c3b00d6049c 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -7413,6 +7413,140 @@ def test_graph_external_wait_and_record(self): ) +class TestFXMemoryProfiler(TestCase): + """Tests for memory profiler augmentation with original stack traces.""" + + def collect_frames( + self, augmented_snapshot, collect_device_traces=True, collect_segments=True + ): + """Collects all frames that has node metadata from a memory snapshot.""" + # Collect all frames with FX metadata + fx_frames = [] + + # Check device traces for FX debug fields + if collect_device_traces and "device_traces" in augmented_snapshot: + for trace_list in augmented_snapshot["device_traces"]: + for trace_entry in trace_list: + if isinstance(trace_entry, dict) and "frames" in trace_entry: + for frame in trace_entry["frames"]: + if isinstance(frame, dict): + # Check for FX debug fields + if "fx_node_op" in frame or "fx_node_name" in frame: + fx_frames.append(frame) + + # Check segments/blocks for FX debug fields + if collect_segments and "segments" in augmented_snapshot: + for segment in augmented_snapshot["segments"]: + if "blocks" in segment: + for block in segment["blocks"]: + if "frames" in block: + for frame in block["frames"]: + if isinstance(frame, dict): + if "fx_node_op" in frame or "fx_node_name" in frame: + fx_frames.append(frame) + return fx_frames + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + @torch._dynamo.config.patch("enrich_profiler_metadata", True) + def test_fx_memory_profiler_augmentation(self): + """Test that memory snapshots are augmented with FX debug information.""" + + # Create a simple model + class MLPModule(nn.Module): + def __init__(self, device): + super().__init__() + torch.manual_seed(5) + self.net1 = nn.Linear(10, 16, bias=True, device=device) + self.relu = nn.ReLU() + self.net2 = nn.Linear(16, 10, bias=True, device=device) + + def forward(self, x): + a = self.net1(x) + b = self.relu(a) + c = self.net2(b) + return c + + device = "cuda" + mod = MLPModule(device) + with tempfile.TemporaryDirectory() as tmpdir: + torch.cuda.memory._record_memory_history() + compiled = torch.compile(mod, backend="aot_eager", fullgraph=True) + result = compiled(torch.randn(10, 10, device=device)) + augmented_snapshot = torch.cuda.memory._snapshot( + augment_with_fx_traces=True + ) + torch.cuda.memory._record_memory_history(enabled=None, clear_history=True) + torch.cuda.empty_cache() + + fx_frames = self.collect_frames(augmented_snapshot) + if TEST_WITH_ROCM: + self.assertGreater(len(fx_frames), 0) + else: + self.assertEqual(len(fx_frames), 12) + + for frame in fx_frames: + # Every FX frame should have both node_op and node_name + self.assertIn("fx_node_op", frame) + self.assertIn("fx_node_name", frame) + self.assertIn("fx_node_target", frame) + self.assertIn("fx_original_trace", frame) + + self.assertIn(frame["fx_node_name"], ["addmm", "relu", "addmm_1"]) + fx_node_name = frame["fx_node_name"] + if fx_node_name == "addmm": + self.assertIn("a = self.net1(x)", frame["fx_original_trace"]) + elif fx_node_name == "addmm_1": + self.assertIn("c = self.net2(b)", frame["fx_original_trace"]) + elif fx_node_name == "relu": + self.assertIn("b = self.relu(a)", frame["fx_original_trace"]) + + # Test that when we have two graphs with the same src_code, they're not hashed + # to the same metadata + class MLPModule2(nn.Module): + def __init__(self, device): + super().__init__() + torch.manual_seed(5) + self.net1 = nn.Linear(10, 16, bias=True, device=device) + self.relu = nn.ReLU() + self.net2 = nn.Linear(16, 10, bias=True, device=device) + + def forward(self, x): + d = self.net1(x) + e = self.relu(d) + f = self.net2(e) + return f + + mod = MLPModule2(device) + with tempfile.TemporaryDirectory() as tmpdir: + torch.cuda.memory._record_memory_history() + compiled = torch.compile(mod, backend="aot_eager", fullgraph=True) + result = compiled(torch.randn(10, 10, device=device)) + augmented_snapshot = torch.cuda.memory._snapshot( + augment_with_fx_traces=True + ) + torch.cuda.memory._record_memory_history(enabled=None, clear_history=True) + + # avoid collecting segments from previous run for unit test purpose + fx_frames = self.collect_frames(augmented_snapshot, collect_segments=False) + self.assertGreater(len(fx_frames), 0) + + for frame in fx_frames: + # Every FX frame should have both node_op and node_name + self.assertIn("fx_node_op", frame) + self.assertIn("fx_node_name", frame) + self.assertIn("fx_node_target", frame) + self.assertIn("fx_original_trace", frame) + + self.assertIn(frame["fx_node_name"], ["addmm", "relu", "addmm_1"]) + fx_node_name = frame["fx_node_name"] + if fx_node_name == "addmm": + self.assertIn("d = self.net1(x)", frame["fx_original_trace"]) + elif fx_node_name == "addmm_1": + self.assertIn("f = self.net2(e)", frame["fx_original_trace"]) + elif fx_node_name == "relu": + self.assertIn("e = self.relu(d)", frame["fx_original_trace"]) + + instantiate_parametrized_tests(TestCuda) instantiate_parametrized_tests(TestCudaMallocAsync) instantiate_parametrized_tests(TestCompileKernel) diff --git a/test/test_fx.py b/test/test_fx.py index 880cc91edc067..d6f33d426aee7 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -771,6 +771,7 @@ def forward(self, a, b): gm = GraphModule(tracer.root, graph) expected = {1: 2, 2: 3, 3: 4, 4: 5} self.assertTrue(set(expected.items()).issubset(set(gm._lineno_map.items()))) + self.assertEqual(gm._prologue_start, 4) # test custom codegen def transform_code(code): @@ -780,6 +781,7 @@ def transform_code(code): gm.recompile() expected = {2: 2, 3: 3, 4: 4, 5: 5} self.assertTrue(set(expected.items()).issubset(set(gm._lineno_map.items()))) + self.assertEqual(gm._prologue_start, 4) def test_graph_unique_names_manual(self): graph: torch.fx.Graph = torch.fx.Graph() diff --git a/torch/_dynamo/config.py b/torch/_dynamo/config.py index 5858a4584b3dd..0c95408401c79 100644 --- a/torch/_dynamo/config.py +++ b/torch/_dynamo/config.py @@ -739,6 +739,12 @@ def default_debug_dir_root() -> str: # HACK: this is for testing custom ops profiling only _custom_ops_profile: Optional[Any] = None +# Experimental: If True, graph module will register fx metadata during recompile() +enrich_profiler_metadata: bool = Config( # type: ignore[var-annotated] + default=False, + env_name_default="TORCH_ENRICH_RPOFILER_STACK_TRACE", +) + if TYPE_CHECKING: from torch.utils._config_typing import * # noqa: F401, F403 diff --git a/torch/cuda/memory.py b/torch/cuda/memory.py index 2dfd5f9479499..6834ffb5706a0 100644 --- a/torch/cuda/memory.py +++ b/torch/cuda/memory.py @@ -4,12 +4,14 @@ import collections import contextlib import ctypes +import os import pickle +import re import sys import warnings from inspect import signature -from typing import Any, Literal, Optional, TYPE_CHECKING -from typing_extensions import deprecated +from typing import Any, cast, Literal, Optional, TYPE_CHECKING, TypedDict +from typing_extensions import deprecated, NotRequired import torch from torch import _C @@ -29,6 +31,60 @@ from torch.types import Device +# Type definitions for memory profiler +class _Frame(TypedDict): + """Frame information from memory profiler snapshots.""" + + filename: str + line: int + name: str + # Fields added by FX augmentation (optional) + fx_node_op: NotRequired[str] + fx_node_name: NotRequired[str] + fx_node_target: NotRequired[str] + fx_original_trace: NotRequired[str] + + +class _Block(TypedDict): + """Memory block information.""" + + size: int + requested_size: int + address: int + state: str + frames: list[_Frame] + + +class _Segment(TypedDict): + """Memory segment information.""" + + address: int + total_size: int + stream: int + segment_type: str + allocated_size: int + active_size: int + blocks: list[_Block] + + +class _TraceEntry(TypedDict): + """Memory trace entry information.""" + + action: str + addr: NotRequired[int] + frames: list[_Frame] + size: int + stream: int + device_free: NotRequired[int] + + +class _Snapshot(TypedDict): + """Memory snapshot structure.""" + + segments: list[_Segment] + device_traces: NotRequired[list[list[_TraceEntry]]] + + __all__ = [ "caching_allocator_alloc", "caching_allocator_delete", @@ -964,7 +1020,120 @@ def _record_memory_history_impl( _record_memory_history.__signature__ = signature(_record_memory_history_impl) # type: ignore[attr-defined] -def _snapshot(device: "Device" = None): +def _augment_frames(frames: list[_Frame]) -> int: + """ + Augment a list of frames with FX debug information. + + Args: + frames: List of frame dictionaries to augment + + Returns: + The count of frames that were augmented. + """ + from torch.fx.graph_module import FX_GRAPH_MODULE_FILE_PREFIX + + # Regex pattern to match FX generated files + _FX_GENERATED_PATTERN = re.compile( + rf"{re.escape(FX_GRAPH_MODULE_FILE_PREFIX)}.*\.py$" + ) + + count = 0 + if not frames: + return count + + for frame in frames: + if "filename" in frame and "line" in frame: + filename = frame["filename"] + lineno = frame["line"] + + # Check if this looks like an FX generated file + if not _FX_GENERATED_PATTERN.search(os.path.basename(filename)): + continue + + # Look up metadata from the global registry + from torch.fx.traceback import _FX_METADATA_REGISTRY + + metadata = _FX_METADATA_REGISTRY.get(filename) + if metadata is None: + continue + + lineno_map = metadata.get("lineno_map", {}) + node_metadata = metadata.get("node_metadata", {}) + prologue_start = metadata.get("prologue_start", 0) + + # Get the node index for this line + node_idx = lineno_map.get(lineno - prologue_start) + + if node_idx is not None and node_idx in node_metadata: + node_info = node_metadata[node_idx] + original_trace = node_info.get("stack_trace") + node_op = node_info.get("op") + node_name = node_info.get("name") + node_target = node_info.get("target") + + # Always add node metadata + frame["fx_node_op"] = node_op + frame["fx_node_name"] = node_name + frame["fx_node_target"] = str(node_target) + + # Add original trace if available + if original_trace: + frame["fx_original_trace"] = original_trace + + count += 1 + + return count + + +def _augment_memory_snapshot_stack_traces( + snapshot: str | _Snapshot, +) -> _Snapshot: + """ + Augment a memory snapshot with original source stack traces from FX metadata. + + IMPORTANT: This function reads from a global in-memory registry (_FX_METADATA_REGISTRY) + that is populated during graph module compilation. It must be called in the same + Python process where the FX graphs were compiled. It cannot be used to augment + snapshots loaded from disk in a different process. + + Args: + snapshot: Either a memory snapshot dict or path to a snapshot pickle file + + Returns: + The augmented snapshot dictionary with fx_node_op, fx_node_name, + fx_original_trace, and fx_node_info fields added to frames + """ + + snapshot_dict: _Snapshot + if isinstance(snapshot, str): + # Load the memory snapshot + with open(snapshot, "rb") as f: + snapshot_dict = cast(_Snapshot, pickle.load(f)) + else: + snapshot_dict = snapshot + + # Process stack traces in the snapshot + augmented_count = 0 + + # Process blocks in segments (for regular allocations) + if "segments" in snapshot_dict: + for segment in snapshot_dict["segments"]: + if "blocks" in segment: + for block in segment["blocks"]: + if "frames" in block: + augmented_count += _augment_frames(block["frames"]) + + # Process device traces (for memory history) + if "device_traces" in snapshot_dict: + for trace_list in snapshot_dict["device_traces"]: + for trace_entry in trace_list: + if isinstance(trace_entry, dict) and "frames" in trace_entry: + augmented_count += _augment_frames(trace_entry["frames"]) + + return snapshot_dict + + +def _snapshot(device: "Device" = None, augment_with_fx_traces=False): """Save a snapshot of CUDA memory state at the time it was called. The state is represented as a dictionary with the following structure. @@ -1012,6 +1181,11 @@ class Frame(TypedDict): filename: str line: int name: str + # Optional FX debug fields (present when augment_with_fx_traces=True + # and the frame corresponds to FX-generated code) + fx_node_op: str # FX node operation type (e.g., 'call_function', 'output') + fx_node_name: str # FX node name (e.g., 'linear', 'relu_1') + fx_original_trace: str # Original model source code stack trace class TraceEntry(TypedDict): @@ -1041,13 +1215,23 @@ class TraceEntry(TypedDict): device_free: int # only present for OOM, the amount of # memory cuda still reports to be free + Args: + device: Device to capture snapshot for. If None, captures for current device. + augment_with_fx_traces: If True, augment stack trace frames with FX debug information + that maps generated FX code back to original model source code. + This adds fx_node_op, fx_node_name, fx_original_trace, and + fx_node_info fields to Frame objects. Default: False. + Returns: The Snapshot dictionary object """ - return _C._cuda_memorySnapshot(None) + s = _C._cuda_memorySnapshot(None) + if augment_with_fx_traces: + s = _augment_memory_snapshot_stack_traces(s) # type: ignore[assignment, arg-type] + return s -def _dump_snapshot(filename="dump_snapshot.pickle"): +def _dump_snapshot(filename="dump_snapshot.pickle", augment_with_fx_traces=False): """ Save a pickled version of the `torch.memory._snapshot()` dictionary to a file. @@ -1059,8 +1243,14 @@ def _dump_snapshot(filename="dump_snapshot.pickle"): Args: filename (str, optional): Name of the file to create. Defaults to "dump_snapshot.pickle". + augment_with_fx_traces (bool, optional): If True, augment the snapshot with FX debug information + before dumping. This maps generated FX code stack traces + back to original model source code. Defaults to False. + verbose (bool, optional): If True and augment_with_fx_traces is True, print verbose debug output + during augmentation. Defaults to False. """ - s = _snapshot() + s = _snapshot(augment_with_fx_traces=augment_with_fx_traces) + with open(filename, "wb") as f: pickle.dump(s, f) diff --git a/torch/fx/graph.py b/torch/fx/graph.py index fc6f4c5b27021..697b2f4084ca5 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -226,8 +226,10 @@ class PythonCode: # Values in global scope during execution of `src_def`. globals: dict[str, Any] # Optional mapping from the forward function's line number to - # node index. + # node index. Line number starts at the prologue (i.e. forward()). _lineno_map: Optional[dict[int, Optional[int]]] + # The line number of prologue in fn_code + _prologue_start: int = 0 def _format_target(base: str, target: str) -> str: @@ -854,7 +856,14 @@ def _tensor_annotation(t: torch.Tensor) -> str: {prologue} {code}""" - return PythonCode(fn_code, globals_, _lineno_map=lineno_map) + # The +4 accounts for the empty lines before prologue in fn_code + prologue_start = wrap_stmts.count("\n") + 4 + return PythonCode( + fn_code, + globals_, + _lineno_map=lineno_map, + _prologue_start=prologue_start, + ) # Ideally, we'd like to refactor all of the pytree logic into this codegen diff --git a/torch/fx/graph_module.py b/torch/fx/graph_module.py index 159926bc8ba49..297f76732584f 100644 --- a/torch/fx/graph_module.py +++ b/torch/fx/graph_module.py @@ -1,6 +1,8 @@ # mypy: allow-untyped-defs +import base64 import contextlib import copy +import hashlib import itertools import linecache import os @@ -36,6 +38,7 @@ ] _USER_PRESERVED_ATTRIBUTES_KEY = "_user_preserved_attributes" +FX_GRAPH_MODULE_FILE_PREFIX = "fx_generated_" # Normal exec loses the source code, however we can work with @@ -61,7 +64,13 @@ def cache(self, src: str, globals: dict[str, Any], co_fields=None): key = self._get_key() if co_fields: - key += f" from {co_fields['co_filename']}:{co_fields['co_firstlineno']} in {co_fields['co_name']}" + if "co_filename" in co_fields: + # If only co_filename is provided, use it directly as the key + if "co_firstlineno" not in co_fields or "co_name" not in co_fields: + key = co_fields["co_filename"] + else: + # Full co_fields with all three components + key += f" from {co_fields['co_filename']}:{co_fields['co_firstlineno']} in {co_fields['co_name']}" self.eval_cache[key] = src # Don't mutate globals so that this loader is only used @@ -353,6 +362,36 @@ def _print_readable( return output +def _metadata_hash(code: str, node_metadata: dict) -> str: + """ + Create a content-addressed hash from code and metadata. + + Args: + code: The source code string + lineno_map: Mapping from line numbers to node indices + node_metadata: Metadata for each node + + Returns: + A 51-character base32-encoded hash + """ + import json + + # Create a deterministic string representation of all components + # We use JSON to ensure consistent serialization + hash_data = { + "code": code, + "node_metadata": node_metadata, + } + hashing_str = json.dumps(hash_data).encode("utf-8") + + # [:51] to strip off the "Q====" suffix common to every hash value. + return ( + base64.b32encode(hashlib.sha256(hashing_str).digest())[:51] + .decode("utf-8") + .lower() + ) + + class _WrappedCall: def __init__(self, cls, cls_call): self.cls = cls @@ -825,9 +864,47 @@ def recompile(self) -> PythonCode: python_code = self._graph.python_code(root_module="self") self._code = python_code.src self._lineno_map = python_code._lineno_map + self._prologue_start = python_code._prologue_start cls = type(self) co_fields = self._graph._co_fields if hasattr(self._graph, "_co_fields") else {} + from torch._dynamo import config as dynamo_config + + if dynamo_config.enrich_profiler_metadata: + # Generate metadata and register for profiler augmentation + node_metadata: dict[int, dict[str, Any]] = {} + for i, node in enumerate(self._graph.nodes): + node_metadata[i] = { + "name": node.name, + "op": node.op, + "target": str(node.target), + "stack_trace": node.meta.get("stack_trace", None), + } + + # Generate a content-addressed filename based on hash of code and metadata + # This ensures the same code+metadata always generates the same filename + hash_value = _metadata_hash(self._code, node_metadata) + file_stem = f"{FX_GRAPH_MODULE_FILE_PREFIX}_{hash_value}" + + filename = f"{file_stem}.py" + + # Only include co_filename to use it directly as the cache key + co_fields = { + "co_filename": filename, + } + + # Store metadata in global in-memory registry + metadata = { + "lineno_map": python_code._lineno_map, + "prologue_start": python_code._prologue_start, + "node_metadata": node_metadata, + } + + # Register metadata in the global registry + from torch.fx.traceback import _register_fx_metadata + + _register_fx_metadata(filename, metadata) + cls.forward = _forward_from_src(self._code, python_code.globals, co_fields) # Determine whether this class explicitly defines a __call__ implementation diff --git a/torch/fx/traceback.py b/torch/fx/traceback.py index a143119cd78b0..25fb81a5aa016 100644 --- a/torch/fx/traceback.py +++ b/torch/fx/traceback.py @@ -38,6 +38,28 @@ current_replay_node: Optional[Node] = None should_preserve_node_meta = False +# ============================================================================= +# FX Metadata Registry for Memory Profiler +# ============================================================================= +# Global in-memory registry for FX metadata +# Maps module_name -> metadata dict containing lineno_map and node_metadata +_FX_METADATA_REGISTRY: dict[str, dict[str, Any]] = {} + + +def _register_fx_metadata(module_name: str, metadata: dict[str, Any]) -> None: + """ + Register FX metadata in the global in-memory registry. + + This is called automatically during graph module compilation to store metadata + for later use by memory profiler augmentation. + + Args: + module_name: The module identifier (content-addressed filename) + metadata: Metadata dict containing lineno_map, node_metadata, and source_code + """ + # TODO: add logging to tlparse + _FX_METADATA_REGISTRY[module_name] = metadata + @compatibility(is_backward_compatible=False) class NodeSourceAction(Enum): diff --git a/torch/utils/viz/MemoryViz.js b/torch/utils/viz/MemoryViz.js index 09f8c444f600c..dfeae36cebab7 100644 --- a/torch/utils/viz/MemoryViz.js +++ b/torch/utils/viz/MemoryViz.js @@ -806,7 +806,29 @@ function format_frames(frames) { } const frame_strings = frames .filter(frameFilter) - .map(f => `${f.filename}:${f.line}:${f.name}`); + .map(f => { + let frame_str = `${f.filename}:${f.line}:${f.name}`; + + // Add FX debug information if available + if (f.fx_node_op || f.fx_node_name || f.fx_node_target) { + const fx_parts = []; + if (f.fx_node_name) fx_parts.push(`node=${f.fx_node_name}`); + if (f.fx_node_op) fx_parts.push(`op=${f.fx_node_op}`); + if (f.fx_node_target) fx_parts.push(`target=${f.fx_node_target}`); + frame_str += `\n >> FX: ${fx_parts.join(', ')}`; + } + + if (f.fx_original_trace) { + frame_str += `\n >> Original Model Code:`; + const original_lines = f.fx_original_trace.trim().split('\n'); + // Show all lines of the original trace + for (const line of original_lines) { + frame_str += `\n ${line}`; + } + } + + return frame_str; + }); return elideRepeats(frame_strings).join('\n'); } From d02f68f4840af4ff2431a3015ff8d64aea43e720 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Mon, 3 Nov 2025 10:27:22 -0800 Subject: [PATCH 005/130] [BE] Use `[[maybe_unused]]` (#166865) Instead of `(void) foo; // Unused parameter` trick, as this is a C++17 standard feature Will replace further repetitions of the same pattern soon after Pull Request resolved: https://github.com/pytorch/pytorch/pull/166865 Approved by: https://github.com/mikaylagawarecki, https://github.com/Skylion007, https://github.com/janeyx99 --- torch/csrc/stable/stableivalue_conversions.h | 48 +++++++------------- 1 file changed, 16 insertions(+), 32 deletions(-) diff --git a/torch/csrc/stable/stableivalue_conversions.h b/torch/csrc/stable/stableivalue_conversions.h index f35ed50d99be4..8004e91b77f8e 100644 --- a/torch/csrc/stable/stableivalue_conversions.h +++ b/torch/csrc/stable/stableivalue_conversions.h @@ -31,10 +31,8 @@ template struct FromImpl { static StableIValue call( T val, - uint64_t extension_build_version, - bool is_internal) { - (void)extension_build_version; // Unused parameter - (void)is_internal; // Unused parameter + [[maybe_unused]] uint64_t extension_build_version, + [[maybe_unused]] bool is_internal) { static_assert( sizeof(T) <= sizeof(StableIValue), "StableLibrary stack does not support parameter types larger than 64 bits."); @@ -75,10 +73,8 @@ template <> struct FromImpl { static StableIValue call( ScalarType val, - uint64_t extension_build_version, - bool is_internal) { - (void)extension_build_version; // Unused parameter - (void)is_internal; // Unused parameter + [[maybe_unused]] uint64_t extension_build_version, + [[maybe_unused]] bool is_internal) { switch (val) { case ScalarType::Byte: return from(aoti_torch_dtype_uint8()); @@ -133,10 +129,8 @@ template <> struct FromImpl { static StableIValue call( std::nullopt_t val, - uint64_t extension_build_version, - bool is_internal) { - (void)extension_build_version; // Unused parameter - (void)is_internal; // Unused parameter + [[maybe_unused]] uint64_t extension_build_version, + [[maybe_unused]] bool is_internal) { return from(nullptr); } }; @@ -190,10 +184,8 @@ template <> struct FromImpl { static StableIValue call( const torch::stable::Tensor& val, - uint64_t extension_build_version, - bool is_internal) { - (void)extension_build_version; // Unused parameter - (void)is_internal; // Unused parameter + [[maybe_unused]] uint64_t extension_build_version, + [[maybe_unused]] bool is_internal) { AtenTensorHandle new_ath; TORCH_ERROR_CODE_CHECK(aoti_torch_new_tensor_handle(val.get(), &new_ath)); return from(new_ath); @@ -209,10 +201,8 @@ template struct ToImpl { static T call( StableIValue val, - uint64_t extension_build_version, - bool is_internal) { - (void)extension_build_version; // Unused parameter - (void)is_internal; // Unused parameter + [[maybe_unused]] uint64_t extension_build_version, + [[maybe_unused]] bool is_internal) { static_assert(std::is_trivially_copyable_v); // T may not have a default constructor. (For example, it might be // c10::Device.) However, std::memcpy implicitly creates a T at the @@ -249,10 +239,8 @@ template <> struct ToImpl { static ScalarType call( StableIValue val, - uint64_t extension_build_version, - bool is_internal) { - (void)extension_build_version; // Unused parameter - (void)is_internal; // Unused parameter + [[maybe_unused]] uint64_t extension_build_version, + [[maybe_unused]] bool is_internal) { int32_t shim_scalartype = to(val); if (shim_scalartype == aoti_torch_dtype_uint8()) { return ScalarType::Byte; @@ -309,10 +297,8 @@ template <> struct ToImpl { static std::nullopt_t call( StableIValue val, - uint64_t extension_build_version, - bool is_internal) { - (void)extension_build_version; // Unused parameter - (void)is_internal; // Unused parameter + [[maybe_unused]] uint64_t extension_build_version, + [[maybe_unused]] bool is_internal) { // val should be equivalent to from(nullptr) return std::nullopt; } @@ -350,10 +336,8 @@ template <> struct ToImpl { static torch::stable::Tensor call( StableIValue val, - uint64_t extension_build_version, - bool is_internal) { - (void)extension_build_version; // Unused parameter - (void)is_internal; // Unused parameter + [[maybe_unused]] uint64_t extension_build_version, + [[maybe_unused]] bool is_internal) { return torch::stable::Tensor(to(val)); } }; From eefa16342c9f322b56c7c0cd6d309c3ed8f0b882 Mon Sep 17 00:00:00 2001 From: Nikita Vedeneev Date: Tue, 4 Nov 2025 12:59:31 +0000 Subject: [PATCH 006/130] [Inductor] addmm with bias -> unfuse bias if there is a pointwise/reduction consumer (#166165) Prefer unfused addmm when there is at least a single elemwise/reduction consumer.. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166165 Approved by: https://github.com/eellison --- test/inductor/test_padding.py | 7 ++- test/inductor/test_torchinductor.py | 4 +- torch/_inductor/fx_passes/post_grad.py | 8 ++-- torch/_inductor/utils.py | 64 ++++++++++++++++++++++++++ 4 files changed, 77 insertions(+), 6 deletions(-) diff --git a/test/inductor/test_padding.py b/test/inductor/test_padding.py index c67bde87a369b..5e599110d29d6 100644 --- a/test/inductor/test_padding.py +++ b/test/inductor/test_padding.py @@ -500,8 +500,13 @@ def test_LinearAndSoftmax_codegen(self, bias=True): forward_wrapper = wrapper_codes[0] # make sure the load for softmax is aligned + if bias: + # addmm -> mm + bias and bias is fused with softmax + softmax_load_str = "tl.load(in_out_ptr0 + (r0_1 + 30528*x0)" + else: + softmax_load_str = "tl.load(in_ptr0 + (r0_1 + 30528*x0)" self.assertTrue( - "tl.load(in_ptr0 + (r0_1 + 30528*x0)" in forward_wrapper, + softmax_load_str in forward_wrapper, f"forward_wrapper: {forward_wrapper}", ) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 675d912c0c01f..dad2de9bde327 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -15280,7 +15280,7 @@ def fn3(x): ), ( fn3, - "triton_poi_fused_native_layer_norm_relu", + "triton_poi_fused_addmm_native_layer_norm", (torch.randn(4, 4, device=GPU_TYPE),), ), ] @@ -15293,7 +15293,7 @@ def fn3(x): ), ( fn3, - "triton_poi_fused_LayerNorm_ReLU", + "triton_poi_fused_LayerNorm_Linear_ReLU", (torch.randn(4, 4, device=GPU_TYPE),), ), ] diff --git a/torch/_inductor/fx_passes/post_grad.py b/torch/_inductor/fx_passes/post_grad.py index f11817e1d4c51..7d995adec04ef 100644 --- a/torch/_inductor/fx_passes/post_grad.py +++ b/torch/_inductor/fx_passes/post_grad.py @@ -51,8 +51,8 @@ decode_device, get_all_devices, get_gpu_type, + has_uses_tagged_as, is_gpu, - is_pointwise_use, OPTIMUS_EXCLUDE_POST_GRAD, ) from ..virtualized import V @@ -1510,8 +1510,10 @@ def should_prefer_unfused_addmm(match): if not is_gpu(inp.meta["val"].device.type): return False - output = match.output_node() - return all(is_pointwise_use(use) for use in output.users) + return has_uses_tagged_as( + match.output_node(), + (torch.Tag.pointwise, torch.Tag.reduction), + ) @register_graph_pattern( diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 13938f6ec1e55..6b34ef28b2c10 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -549,6 +549,70 @@ def is_pointwise_use( return torch.Tag.pointwise in target.tags or is_pointwise_fn(target) +class LogicalConnective(enum.Enum): + OR = enum.auto() + AND = enum.auto() + + +def has_uses( + target: Node, + use_selector_fn: Callable[[torch._ops.OpOverload], bool] = lambda _: False, + use_aggregate_type: LogicalConnective = LogicalConnective.OR, +) -> bool: + """ + Given a target, explore the uses of `target` by applying `use_selector_fn` + on them, and then aggregate these booleans with the `use_aggregate_type` + logical connective. + + Uses in view ops will follow the views uses. + """ + + def get_use_aggregate_fn( + use_aggregate_type: LogicalConnective, + ) -> Callable[[Iterator[Any]], bool]: + match use_aggregate_type: + case LogicalConnective.AND: + return all + case LogicalConnective.OR: + return any + case _: + return any + + use_aggregate_fn = get_use_aggregate_fn(use_aggregate_type) + + def has_uses_impl(use: Node) -> bool: + if use.op != "call_function": + return False + if not ( + isinstance(use.target, torch._ops.OpOverload) + or use.target is operator.getitem + ): + return False + + target = cast(torch._ops.OpOverload, use.target) + # Process getitem and view + if target is operator.getitem or is_view(target): + return use_aggregate_fn(has_uses_impl(user) for user in use.users) + + return use_selector_fn(target) + + return use_aggregate_fn(has_uses_impl(user) for user in target.users) + + +def has_uses_tagged_as( + target: Node, + use_tags: Collection[torch.Tag], + use_aggregate_type: LogicalConnective = LogicalConnective.OR, +) -> bool: + """ + Is there a use with given tags? + """ + + return has_uses( + target, lambda use: any(tag in use_tags for tag in use.tags), use_aggregate_type + ) + + def gen_gm_and_inputs( target: Any, args: list[Any], kwargs: dict[str, Any] ) -> tuple[GraphModule, list[torch.Tensor]]: From 3144713325de01b478e9b469f546d61903cb570a Mon Sep 17 00:00:00 2001 From: clr Date: Mon, 3 Nov 2025 15:25:04 -0800 Subject: [PATCH 007/130] subproc_pool: Add support for enabling quiesce via a timer (#166467) This adds the capability to subproc pool to enable quiesce via a timer Pull Request resolved: https://github.com/pytorch/pytorch/pull/166467 Approved by: https://github.com/masnesral --- test/inductor/test_compile_worker.py | 20 ++++++++++++++----- .../_inductor/compile_worker/subproc_pool.py | 13 ++++++++++++ torch/_inductor/compile_worker/timer.py | 2 +- torch/_inductor/config.py | 5 +++++ 4 files changed, 34 insertions(+), 6 deletions(-) diff --git a/test/inductor/test_compile_worker.py b/test/inductor/test_compile_worker.py index 50a389e8663f9..7237d5a01c6b2 100644 --- a/test/inductor/test_compile_worker.py +++ b/test/inductor/test_compile_worker.py @@ -4,6 +4,7 @@ import tempfile from threading import Event +import torch._inductor.config as config from torch._inductor.compile_worker.subproc_pool import ( raise_testexc, SubprocException, @@ -16,9 +17,12 @@ class TestCompileWorker(TestCase): + def make_pool(self, size): + return SubprocPool(size) + @skipIfWindows(msg="pass_fds not supported on Windows.") def test_basic_jobs(self): - pool = SubprocPool(2) + pool = self.make_pool(2) try: a = pool.submit(operator.add, 100, 1) b = pool.submit(operator.sub, 100, 1) @@ -29,7 +33,7 @@ def test_basic_jobs(self): @skipIfWindows(msg="pass_fds not supported on Windows.") def test_exception(self): - pool = SubprocPool(2) + pool = self.make_pool(2) try: a = pool.submit(raise_testexc) with self.assertRaisesRegex( @@ -42,7 +46,7 @@ def test_exception(self): @skipIfWindows(msg="pass_fds not supported on Windows.") def test_crash(self): - pool = SubprocPool(2) + pool = self.make_pool(2) try: with self.assertRaises(Exception): a = pool.submit(os._exit, 1) @@ -58,7 +62,7 @@ def test_crash(self): @skipIfWindows(msg="pass_fds not supported on Windows.") def test_quiesce(self): - pool = SubprocPool(2) + pool = self.make_pool(2) try: a = pool.submit(operator.add, 100, 1) pool.quiesce() @@ -75,7 +79,7 @@ def test_logging(self): os.environ["ROLE_RANK"] = "0" with tempfile.NamedTemporaryFile(delete=True) as temp_log: os.environ["TORCHINDUCTOR_WORKER_LOGPATH"] = temp_log.name - pool = SubprocPool(2) + pool = self.make_pool(2) try: pool.submit(operator.add, 100, 1) self.assertEqual(os.path.exists(temp_log.name), True) @@ -83,6 +87,12 @@ def test_logging(self): pool.shutdown() +@config.patch("quiesce_async_compile_time", 0.1) +class TestCompileWorkerWithTimer(TestCompileWorker): + def make_pool(self, size): + return SubprocPool(size, quiesce=True) + + class TestTimer(TestCase): def test_basics(self): done = Event() diff --git a/torch/_inductor/compile_worker/subproc_pool.py b/torch/_inductor/compile_worker/subproc_pool.py index 037b0e438adaa..a4114644026ca 100644 --- a/torch/_inductor/compile_worker/subproc_pool.py +++ b/torch/_inductor/compile_worker/subproc_pool.py @@ -24,6 +24,7 @@ import torch._thread_safe_fork # noqa: F401 from torch._inductor import config from torch._inductor.codecache import torch_key +from torch._inductor.compile_worker.timer import Timer from torch._inductor.compile_worker.tracked_process_pool import ( TrackedProcessPoolExecutor, ) @@ -132,6 +133,7 @@ def __init__( nprocs: int, pickler: Optional[SubprocPickler] = None, kind: SubprocKind = SubprocKind.FORK, + quiesce: bool = False, ) -> None: entry = os.path.join(os.path.dirname(__file__), "__main__.py") self.pickler = pickler or SubprocPickler() @@ -216,6 +218,13 @@ def __init__( "pytorch.wait_counter.subproc_pool.first_job" ).guard() + if quiesce: + self.timer: Optional[Timer] = Timer( + config.quiesce_async_compile_time, self.quiesce + ) + else: + self.timer = None + # Start thread last to ensure all member variables are initialized # before any access. self.read_thread.start() @@ -288,6 +297,8 @@ def _read_thread(self) -> None: with self.futures_lock: if not self.running: return + if self.timer: + self.timer.record_call() if isinstance(result, _SubprocExceptionInfo): # An exception occurred in the submitted job self.pending_futures[job_id].set_exception( @@ -322,6 +333,8 @@ def shutdown(self) -> None: with self.write_lock: if not self.running: return + if self.timer: + self.timer.quit() self.running = False self.running_waitcounter.__exit__() _send_msg(self.write_pipe, MsgHeader.SHUTDOWN) diff --git a/torch/_inductor/compile_worker/timer.py b/torch/_inductor/compile_worker/timer.py index d4b0c0dc9e281..7cfeb4217e26b 100644 --- a/torch/_inductor/compile_worker/timer.py +++ b/torch/_inductor/compile_worker/timer.py @@ -17,7 +17,7 @@ def __init__( self.background_thread: Optional[Thread] = None self.last_called: Optional[float] = None self.duration = duration - self.sleep_time = 60 + self.sleep_time = duration / 2 self.call = call self.exit = False diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index b78ade758f80b..08cc2b2bd861a 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -964,6 +964,11 @@ def decide_compile_threads() -> int: default=False, ) +# Time in seconds to wait before quiescing +quiesce_async_compile_time: int = Config( + default=60, +) + # Whether or not to enable statically launching CUDA kernels # compiled by triton (instead of using triton's own launcher) use_static_cuda_launcher: bool = static_cuda_launcher_default() From 527b1109a8a8d8ae9e1c76c057468aacb302ed84 Mon Sep 17 00:00:00 2001 From: Richard Zou Date: Tue, 4 Nov 2025 07:36:39 -0800 Subject: [PATCH 008/130] Delete deprecated fp32 precision warnings (#166956) The deprecation warning led to warning spamming in PyTorch APIs, like torch.compile. This is not how a deprecation warning should go: if we add a deprecation warning, we'd better update our built-in APIs to prevent warning spam. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166956 Approved by: https://github.com/albanD --- aten/src/ATen/Context.cpp | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/aten/src/ATen/Context.cpp b/aten/src/ATen/Context.cpp index a354b41912406..6bc321887502d 100644 --- a/aten/src/ATen/Context.cpp +++ b/aten/src/ATen/Context.cpp @@ -23,8 +23,6 @@ C10_DIAGNOSTIC_POP() #endif namespace at { -namespace { - /* These const variables defined the fp32 precisions for different backend We have "generic", "cuda", "mkldnn" backend now and we can choose fp32 @@ -41,16 +39,6 @@ namespace { ->rnn */ - C10_ALWAYS_INLINE void warn_deprecated_fp32_precision_api(){ - TORCH_WARN_ONCE( - "Please use the new API settings to control TF32 behavior, such as torch.backends.cudnn.conv.fp32_precision = 'tf32' " - "or torch.backends.cuda.matmul.fp32_precision = 'ieee'. Old settings, e.g, torch.backends.cuda.matmul.allow_tf32 = True, " - "torch.backends.cudnn.allow_tf32 = True, allowTF32CuDNN() and allowTF32CuBLAS() will be deprecated after Pytorch 2.9. Please see " - "https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices" - ); - } -} // namespace - Float32Backend str2backend(const std::string& name) { if (name == "generic") return Float32Backend::GENERIC; @@ -206,7 +194,6 @@ bool Context::allowTF32CuDNN(std::optional op) const { } else { return float32Precision(Float32Backend::CUDA, op.value()) == Float32Precision::TF32; } - warn_deprecated_fp32_precision_api(); return allow_tf32_cudnn; } @@ -214,7 +201,6 @@ void Context::setAllowTF32CuDNN(bool b) { setFloat32Precision(Float32Backend::CUDA, Float32Op::RNN, b ? Float32Precision::TF32 : Float32Precision::NONE); setFloat32Precision(Float32Backend::CUDA, Float32Op::CONV, b ? Float32Precision::TF32 : Float32Precision::NONE); allow_tf32_cudnn = b; - warn_deprecated_fp32_precision_api(); } void Context::setSDPPriorityOrder(const std::vector& order) { @@ -325,7 +311,6 @@ bool Context::allowTF32CuBLAS() const { "Current status indicate that you have used mix of the legacy and new APIs to set the TF32 status for cublas matmul. ", "We suggest only using the new API to set the TF32 flag. See also: ", "https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices"); - warn_deprecated_fp32_precision_api(); return allow_tf32_new; } @@ -349,7 +334,6 @@ Float32MatmulPrecision Context::float32MatmulPrecision() const { "Current status indicate that you have used mix of the legacy and new APIs to set the matmul precision. ", "We suggest only using the new API for matmul precision. See also: ", "https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices"); - warn_deprecated_fp32_precision_api(); return float32_matmul_precision; } @@ -377,7 +361,6 @@ Float32Precision Context::float32Precision(Float32Backend backend, Float32Op op) void Context::setFloat32MatmulPrecision(const std::string &s) { auto match = [this](const std::string & s_) { - warn_deprecated_fp32_precision_api(); // TODO: consider if CuDNN field needs to also be set for potential future CuDNN ops like multi-headed attention if (s_ == "highest") { float32_matmul_precision = at::Float32MatmulPrecision::HIGHEST; From 53f75cd5ba933148b21e4b1763a1a0790b0f3744 Mon Sep 17 00:00:00 2001 From: Wenlin Chong Date: Tue, 4 Nov 2025 18:18:34 +0000 Subject: [PATCH 009/130] Fixed some syntax errors in SECURITY.md file. (#166718) Fixed some syntax errors in SECURITY.md file including PyTorch's capitalization problems, some grammatical inconsistencies, etc Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/166718 Approved by: https://github.com/mikaylagawarecki --- SECURITY.md | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/SECURITY.md b/SECURITY.md index ed8228af36724..375f94547941f 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -1,7 +1,7 @@ # Security Policy - [**Reporting a Vulnerability**](#reporting-a-vulnerability) - - [**Using Pytorch Securely**](#using-pytorch-securely) + - [**Using PyTorch Securely**](#using-pytorch-securely) - [Untrusted models](#untrusted-models) - [TorchScript models](#torchscript-models) - [Untrusted inputs](#untrusted-inputs) @@ -10,28 +10,28 @@ - [**CI/CD security principles**](#cicd-security-principles) ## Reporting Security Issues -Beware that none of the topics under [Using Pytorch Securely](#using-pytorch-securely) are considered vulnerabilities of Pytorch. +Beware that none of the topics under [Using PyTorch Securely](#using-pytorch-securely) are considered vulnerabilities of PyTorch. However, if you believe you have found a security vulnerability in PyTorch, we encourage you to let us know right away. We will investigate all legitimate reports and do our best to quickly fix the problem. Please report security issues using https://github.com/pytorch/pytorch/security/advisories/new -All reports submitted thru the security advisories mechanism would **either be made public or dismissed by the team within 90 days of the submission**. If advisory has been closed on the grounds that it is not a security issue, please do not hesitate to create an [new issue](https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml) as it is still likely a valid issue within the framework. +All reports submitted through the security advisories mechanism would **either be made public or dismissed by the team within 90 days of the submission**. If advisory has been closed on the grounds that it is not a security issue, please do not hesitate to create an [new issue](https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml) as it is still likely a valid issue within the framework. Please refer to the following page for our responsible disclosure policy, reward guidelines, and those things that should not be reported: https://www.facebook.com/whitehat -## Using Pytorch Securely -**Pytorch models are programs**, so treat its security seriously -- running untrusted models is equivalent to running untrusted code. In general we recommend that model weights and the python code for the model are distributed independently. That said, be careful about where you get the python code from and who wrote it (preferentially check for a provenance or checksums, do not run any pip installed package). +## Using PyTorch Securely +**PyTorch models are programs**, so treat its security seriously -- running untrusted models is equivalent to running untrusted code. In general we recommend that model weights and the python code for the model are distributed independently. That said, be careful about where you get the python code from and who wrote it (preferentially check for a provenance or checksums, do not run any pip installed package). ### Untrusted models Be careful when running untrusted models. This classification includes models created by unknown developers or utilizing data obtained from unknown sources[^data-poisoning-sources]. **Prefer to execute untrusted models within a secure, isolated environment such as a sandbox** (e.g., containers, virtual machines). This helps protect your system from potentially malicious code. You can find further details and instructions in [this page](https://developers.google.com/code-sandboxing). -**Be mindful of risky model formats**. Give preference to share and load weights with the appropriate format for your use case. [safetensors](https://huggingface.co/docs/safetensors/en/index) gives the most safety but is the most restricted in what it supports. [`torch.load`](https://pytorch.org/docs/stable/generated/torch.load.html#torch.load) has a significantly larger surface of attack but is more flexible in what it can serialize. See the documentation for more details. +**Be mindful of risky model formats**. Give preference to share and load weights with the appropriate format for your use case. [Safetensors](https://huggingface.co/docs/safetensors/en/index) gives the most safety but is the most restricted in what it supports. [`torch.load`](https://pytorch.org/docs/stable/generated/torch.load.html#torch.load) has a significantly larger surface of attack but is more flexible in what it can serialize. See the documentation for more details. Even for more secure serialization formats, unexpected inputs to the downstream system can cause diverse security threats (e.g. denial of service, out of bound reads/writes) and thus we recommend extensive validation of any untrusted inputs. @@ -43,7 +43,7 @@ Important Note: The trustworthiness of a model is not binary. You must always de ### TorchScript models -TorchScript models should treated the same way as locally executable code from an unknown source. Only run TorchScript models if you trust the provider. Please note, that tools for introspecting TorchScript models (such as `torch.utils.model_dump`) may also execute partial or full code stored in those models, therefore they should be used only if you trust the provider of the binary you are about to load. +TorchScript models should be treated the same way as locally executable code from an unknown source. Only run TorchScript models if you trust the provider. Please note, that tools for introspecting TorchScript models (such as `torch.utils.model_dump`) may also execute partial or full code stored in those models, therefore they should be used only if you trust the provider of the binary you are about to load. ### Untrusted inputs during training and prediction @@ -59,9 +59,9 @@ If applicable, prepare your model against bad inputs and prompt injections. Some ### Data privacy -**Take special security measures if your model if you train models with sensitive data**. Prioritize [sandboxing](https://developers.google.com/code-sandboxing) your models and: -- Do not feed sensitive data to untrusted model (even if runs in a sandboxed environment) -- If you consider publishing a model that was partially trained with sensitive data, be aware that data can potentially be recovered from the trained weights (especially if model overfits). +**Take special security measures if you train your models with sensitive data**. Prioritize [sandboxing](https://developers.google.com/code-sandboxing) your models and: +- Do not feed sensitive data to an untrusted model (even if runs in a sandboxed environment) +- If you consider publishing a model that was partially trained with sensitive data, be aware that data can potentially be recovered from the trained weights (especially if the model overfits). ### Using distributed features From 496277a8ffcb29c9976fe93f91ab8232e29764b9 Mon Sep 17 00:00:00 2001 From: amdfaa <107946068+amdfaa@users.noreply.github.com> Date: Tue, 4 Nov 2025 18:44:21 +0000 Subject: [PATCH 010/130] [ROCm][CI] Lower runner check gpu count for distributed jobs (#166961) This is a PR to temporarily relieve the queueing that is caused by an mi250 node outage. See this ticket for more information: https://github.com/pytorch/pytorch/issues/166866 It relaxes the GPU count check to allow distributed jobs to run on 2-GPU runners Pull Request resolved: https://github.com/pytorch/pytorch/pull/166961 Approved by: https://github.com/jeffdaily --- .github/workflows/_rocm-test.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/_rocm-test.yml b/.github/workflows/_rocm-test.yml index 43ed76a63cc67..608aeba53e6d8 100644 --- a/.github/workflows/_rocm-test.yml +++ b/.github/workflows/_rocm-test.yml @@ -97,8 +97,8 @@ jobs: shell: bash run: | ngpu=$(rocminfo | grep -c -E 'Name:.*\sgfx') - if [[ $ngpu -lt 4 ]]; then - echo "Error: only $ngpu GPU(s) detected, at least 4 GPUs are needed for distributed jobs" + if [[ $ngpu -lt 2 ]]; then #We are temporarily reducing this down to 2 from 4 so that we can run tests on nodes with less gpus. + echo "Error: only $ngpu GPU(s) detected, at least 2 GPUs are needed for distributed jobs" exit 1 fi From 1d3f5e19da068ec1340db041b7105b287a513578 Mon Sep 17 00:00:00 2001 From: Eddie Yan Date: Tue, 4 Nov 2025 18:46:43 +0000 Subject: [PATCH 011/130] [cuDNN] Smoke-test runtime cuDNN version matches compile time version in CI (#165922) Fix and regression test for https://github.com/pytorch/pytorch/issues/165801 Pull Request resolved: https://github.com/pytorch/pytorch/pull/165922 Approved by: https://github.com/malfet, https://github.com/atalman, https://github.com/Skylion007, https://github.com/drisspg Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com> Co-authored-by: Andrey Talman --- .ci/docker/common/install_cuda.sh | 2 +- .ci/pytorch/smoke_test/smoke_test.py | 12 ++++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/.ci/docker/common/install_cuda.sh b/.ci/docker/common/install_cuda.sh index fe2f9ae3185a3..fe0cb8cc79c4f 100644 --- a/.ci/docker/common/install_cuda.sh +++ b/.ci/docker/common/install_cuda.sh @@ -129,7 +129,7 @@ function install_129 { } function install_128 { - CUDNN_VERSION=9.8.0.87 + CUDNN_VERSION=9.10.2.21 echo "Installing CUDA 12.8.1 and cuDNN ${CUDNN_VERSION} and NVSHMEM and NCCL and cuSparseLt-0.7.1" # install CUDA 12.8.1 in the same container install_cuda 12.8.1 cuda_12.8.1_570.124.06_linux diff --git a/.ci/pytorch/smoke_test/smoke_test.py b/.ci/pytorch/smoke_test/smoke_test.py index 675d58a3e283d..3642f29684cf0 100644 --- a/.ci/pytorch/smoke_test/smoke_test.py +++ b/.ci/pytorch/smoke_test/smoke_test.py @@ -272,6 +272,18 @@ def smoke_test_cuda( torch_cudnn_version = cudnn_to_version_str(torch.backends.cudnn.version()) print(f"Torch cuDNN version: {torch_cudnn_version}") + torch_cudnn_compile_version = torch._C._cudnn.getCompileVersion() + print(f"Torch cuDNN compile-time version: {torch_cudnn_compile_version}") + torch_cudnn_runtime_version = tuple( + [int(x) for x in torch_cudnn_version.split(".")] + ) + if torch_cudnn_runtime_version != torch_cudnn_compile_version: + raise RuntimeError( + "cuDNN runtime version doesn't match comple version. " + f"Loaded: {torch_cudnn_runtime_version} " + f"Expected: {torch_cudnn_compile_version}" + ) + if sys.platform in ["linux", "linux2"]: torch_nccl_version = ".".join(str(v) for v in torch.cuda.nccl.version()) print(f"Torch nccl; version: {torch_nccl_version}") From a5f3035aafd5113dd7641a95a3e919d4a4c8781f Mon Sep 17 00:00:00 2001 From: Oguz Ulgen Date: Tue, 4 Nov 2025 10:46:38 -0800 Subject: [PATCH 012/130] More pyrefly local errors (#166976) Pull Request resolved: https://github.com/pytorch/pytorch/pull/166976 Approved by: https://github.com/maggiemoss, https://github.com/Skylion007 --- torch/_higher_order_ops/triton_kernel_wrap.py | 1 + torch/cuda/__init__.py | 2 +- torch/cuda/memory.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/torch/_higher_order_ops/triton_kernel_wrap.py b/torch/_higher_order_ops/triton_kernel_wrap.py index 8ffab37699422..0e398897a7eab 100644 --- a/torch/_higher_order_ops/triton_kernel_wrap.py +++ b/torch/_higher_order_ops/triton_kernel_wrap.py @@ -498,6 +498,7 @@ def get_signature_value(idx: int, arg: Any) -> str: # pyrefly: ignore # missing-attribute codegen_fns = backend.get_codegen_implementation(*codegen_args) module_map = backend.get_module_map() + # pyrefly: ignore[missing-argument,bad-argument-type] ttir_module = src.make_ir(options, codegen_fns, module_map, context) else: codegen_args = [options] if get_codegen_implementation_sig_params == 1 else [] diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py index dff869742df56..23d297b6d95e0 100644 --- a/torch/cuda/__init__.py +++ b/torch/cuda/__init__.py @@ -1228,7 +1228,7 @@ def _get_pynvml_handler(device: "Device" = None): "nvidia-ml-py does not seem to be installed or it can't be imported." # pyrefly: ignore [invalid-inheritance] ) from _PYNVML_ERR - # pyrefly: ignore [import-error] + # pyrefly: ignore [import-error,missing-module-attribute] from pynvml import NVMLError_DriverNotLoaded try: diff --git a/torch/cuda/memory.py b/torch/cuda/memory.py index 6834ffb5706a0..a1decc20cc9a8 100644 --- a/torch/cuda/memory.py +++ b/torch/cuda/memory.py @@ -828,7 +828,7 @@ def list_gpu_processes(device: "Device" = None) -> str: import pynvml # type: ignore[import] except ModuleNotFoundError: return "pynvml module not found, please install nvidia-ml-py" - # pyrefly: ignore [import-error] + # pyrefly: ignore [import-error,missing-module-attribute] from pynvml import NVMLError_DriverNotLoaded try: From 52ea135f77f2469a8c15f2051260584ddd7c3bb8 Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Mon, 3 Nov 2025 07:19:37 -0800 Subject: [PATCH 013/130] [BE] Delete Python-3.9 stdlib definitions from torch.package (#166768) And simplify the entire function to just assert and return Pull Request resolved: https://github.com/pytorch/pytorch/pull/166768 Approved by: https://github.com/cyyever, https://github.com/atalman --- torch/package/_stdlib.py | 229 +-------------------------------------- 1 file changed, 2 insertions(+), 227 deletions(-) diff --git a/torch/package/_stdlib.py b/torch/package/_stdlib.py index 57a51ac41cfd9..e07b20a83cc6d 100644 --- a/torch/package/_stdlib.py +++ b/torch/package/_stdlib.py @@ -17,230 +17,5 @@ def is_stdlib_module(module: str) -> bool: def _get_stdlib_modules(): - if sys.version_info.major == 3: # noqa: UP036 - if sys.version_info.minor == 9: - return stdlib3_9 - if sys.version_info.minor >= 10: # noqa: YTT204 - return sys.stdlib_module_names # type: ignore[attr-defined] - elif sys.version_info.major > 3: # noqa: UP036 - return sys.stdlib_module_names # type: ignore[attr-defined] - - raise RuntimeError(f"Unsupported Python version: {sys.version_info}") - - -stdlib3_9 = { - "_thread", - "abc", - "aifc", - "argparse", - "array", - "ast", - "asynchat", - "asyncio", - "asyncore", - "atexit", - "audioop", - "base64", - "bdb", - "binascii", - "binhex", - "bisect", - "builtins", - "bz2", - "cProfile", - "calendar", - "cgi", - "cgitb", - "chunk", - "cmath", - "cmd", - "code", - "codecs", - "codeop", - "collections", - "colorsys", - "compileall", - "concurrent", - "configparser", - "contextlib", - "contextvars", - "copy", - "copyreg", - "crypt", - "csv", - "ctypes", - "curses", - "dataclasses", - "datetime", - "dbm", - "decimal", - "difflib", - "dis", - "distutils", - "doctest", - "email", - "encodings", - "ensurepip", - "enum", - "errno", - "faulthandler", - "fcntl", - "filecmp", - "fileinput", - "fnmatch", - "formatter", - "fractions", - "ftplib", - "functools", - "gc", - "getopt", - "getpass", - "gettext", - "glob", - "graphlib", - "grp", - "gzip", - "hashlib", - "heapq", - "hmac", - "html", - "http", - "imaplib", - "imghdr", - "imp", - "importlib", - "inspect", - "io", - "ipaddress", - "itertools", - "json", - "keyword", - "lib2to3", - "linecache", - "locale", - "logging", - "lzma", - "mailbox", - "mailcap", - "marshal", - "math", - "mimetypes", - "mmap", - "modulefinder", - "msilib", - "msvcrt", - "multiprocessing", - "netrc", - "nis", - "nntplib", - "ntpath", - "numbers", - "operator", - "optparse", - "os", - "ossaudiodev", - "parser", - "pathlib", - "pdb", - "pickle", - "pickletools", - "pipes", - "pkgutil", - "platform", - "plistlib", - "poplib", - "posix", - "posixpath", - "pprint", - "profile", - "pstats", - "pty", - "pwd", - "py_compile", - "pyclbr", - "pydoc", - "queue", - "quopri", - "random", - "re", - "readline", - "reprlib", - "resource", - "rlcompleter", - "runpy", - "sched", - "secrets", - "select", - "selectors", - "shelve", - "shlex", - "shutil", - "signal", - "site", - "smtpd", - "smtplib", - "sndhdr", - "socket", - "socketserver", - "spwd", - "sqlite3", - "sre", - "sre_compile", - "sre_constants", - "sre_parse", - "ssl", - "stat", - "statistics", - "string", - "stringprep", - "struct", - "subprocess", - "sunau", - "symbol", - "symtable", - "sys", - "sysconfig", - "syslog", - "tabnanny", - "tarfile", - "telnetlib", - "tempfile", - "termios", - "test", - "textwrap", - "threading", - "time", - "timeit", - "tkinter", - "token", - "tokenize", - "trace", - "traceback", - "tracemalloc", - "tty", - "turtle", - "turtledemo", - "types", - "typing", - "unicodedata", - "unittest", - "urllib", - "uu", - "uuid", - "venv", - "warnings", - "wave", - "weakref", - "webbrowser", - "winreg", - "winsound", - "wsgiref", - "xdrlib", - "xml", - "xmlrpc", - "zipapp", - "zipfile", - "zipimport", - "zlib", - "zoneinfo", -} + assert sys.version_info >= (3, 10) + return sys.stdlib_module_names From cef98ae5cbb484483b8cfe1b720e74fa10c7e720 Mon Sep 17 00:00:00 2001 From: IvanKobzarev Date: Tue, 4 Nov 2025 07:41:25 -0800 Subject: [PATCH 014/130] [aotd] Compiled saved tensor hooks context (#166887) Draft to expose compiled saved tensor hook context to selectively apply them. Exposing node, fw_graph, bw_graph. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166887 Approved by: https://github.com/bdhirsh --- test/functorch/test_aotdispatch.py | 15 ++++ .../_functorch/_aot_autograd/graph_compile.py | 86 ++++++++++++++----- 2 files changed, 81 insertions(+), 20 deletions(-) diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index fba7a96288caf..b0dd1ff8fa75d 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -167,6 +167,14 @@ def _pack_fp8_wrap(x): if not x.dtype.is_floating_point: return x + if type(x) is not torch.Tensor: + # Check only during compilation + # Test calls hooks to get reference output + ctx = torch._functorch._aot_autograd.graph_compile._get_saved_tensor_hook_context() + assert ctx["_fw_graph"] is not None + assert ctx["_bw_graph"] is not None + assert ctx["_node"] is not None + return (x.dtype, x.to(torch.float8_e5m2)) @@ -176,6 +184,13 @@ def _unpack_fp8_wrap(x): return x dtype, tensor = x + if type(tensor) is not torch.Tensor: + # Check only during compilation + # Test calls hooks to get reference output + ctx = torch._functorch._aot_autograd.graph_compile._get_saved_tensor_hook_context() + assert ctx["_fw_graph"] is not None + assert ctx["_bw_graph"] is not None + assert ctx["_node"] is not None return tensor.to(dtype) diff --git a/torch/_functorch/_aot_autograd/graph_compile.py b/torch/_functorch/_aot_autograd/graph_compile.py index 60ee3bc2973b1..b11eb87dc1720 100644 --- a/torch/_functorch/_aot_autograd/graph_compile.py +++ b/torch/_functorch/_aot_autograd/graph_compile.py @@ -25,6 +25,9 @@ if TYPE_CHECKING: from collections.abc import Sequence +import threading +from contextlib import contextmanager + import torch import torch.utils._pytree as pytree import torch.utils.dlpack @@ -97,6 +100,43 @@ ) +_thread_local = threading.local() + + +# Saved tensor hooks context +# Compiled saved tensor hooks are convenient way to inline some logic in the graphs +# for saved nodes from forward to backward. (E.g. activations quantization) +# In base implementation user does not have any additional information about saved value +# in the hook, except FakeTensor shape, dtype, device etc. +# _get_saved_tensor_hook_context gives additional graph information about that saved value, +# that can be used to make a decisions which pack/unpack to apply for particular saved value. +# This allows user to reuse saved tensors hooks api to apply selective pack/unpack in +# graph aware way. +# Alternative to this will be making user to write a custom pass that mucks with forward outputs, +# backward input metadata, which requires significantly more effort. +# +# As for now in context we expose forward graph, backward graph and current saved node, +# which contains node.meta with additional information about that fx.Node. +# Warning: This API may change without backward compatibility. +@contextmanager +def _saved_tensor_hook_context(state: dict[str, Any]): + previous_state = getattr(_thread_local, "state", None) + try: + _thread_local.state = state + yield + finally: + # Clean up: restore previous state or remove attribute + if previous_state is not None: + _thread_local.state = previous_state + else: + if hasattr(_thread_local, "state"): + delattr(_thread_local, "state") + + +def _get_saved_tensor_hook_context() -> dict[str, Any] | None: + return getattr(_thread_local, "state", None) + + zip = strict_zip log = logging.getLogger(__name__) @@ -1097,7 +1137,11 @@ def _gen_unused_name(candidate: str): if not isinstance(val, torch.Tensor): continue - pack_out_val = pack_hook_gm(val) + def _get_extra_info() -> dict[str, Any]: + return {"_fw_graph": fw_g, "_bw_graph": bw_g, "_node": saved} + + with _saved_tensor_hook_context(_get_extra_info()): + pack_out_val = pack_hook_gm(val) requires_sc_handling = any( is_traceable_wrapper_subclass(x) for x in pytree.tree_leaves(pack_out_val) @@ -1109,16 +1153,17 @@ def _gen_unused_name(candidate: str): " in the pack hook, and reconstructing the subclass in the unpack hook" ) - pack_gm = prepare_hook_gm(aot_config, pack_hook_gm, (val,)) - pack_g = pack_gm.graph - maybe_log_graph( - pack_gm, - f"saved_tensors_pack_hook {saved.name}", - aot_config, - lambda: f"aot_saved_tensors_hooks_pack {saved.name}", - structured_logs, - ) - pack_out_val = pack_gm(val) + with _saved_tensor_hook_context(_get_extra_info()): + pack_gm = prepare_hook_gm(aot_config, pack_hook_gm, (val,)) + pack_g = pack_gm.graph + maybe_log_graph( + pack_gm, + f"saved_tensors_pack_hook {saved.name}", + aot_config, + lambda: f"aot_saved_tensors_hooks_pack {saved.name}", + structured_logs, + ) + pack_out_val = pack_gm(val) # Install pack hook graph as eiplogue of fw_module. # Saved tensor output becomes input of pack hook graph. @@ -1188,15 +1233,16 @@ def _gen_unused_name(candidate: str): # Install unpack hook graph as a prologue of backward graph # Saved tensors inputs are replaced with packed tensors and packed sym scalars. # The saved tensors inputs usages in the graph are replaced with unpack hook graph outputs. - unpack_gm = prepare_hook_gm(aot_config, unpack_hook_gm, (pack_out_val,)) - unpack_g = unpack_gm.graph - maybe_log_graph( - unpack_gm, - f"saved_tensors_unpack_hook {saved.name}", - aot_config, - lambda: f"aot_saved_tensors_hooks_unpack {saved.name}", - structured_logs, - ) + with _saved_tensor_hook_context(_get_extra_info()): + unpack_gm = prepare_hook_gm(aot_config, unpack_hook_gm, (pack_out_val,)) + unpack_g = unpack_gm.graph + maybe_log_graph( + unpack_gm, + f"saved_tensors_unpack_hook {saved.name}", + aot_config, + lambda: f"aot_saved_tensors_hooks_unpack {saved.name}", + structured_logs, + ) def find_saved_in_bw_inputs(bw_inputs): for n in bw_inputs: From d77c24caac4b42c56b4fa6a156ce85fb4907643e Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 4 Nov 2025 20:13:33 +0000 Subject: [PATCH 015/130] Revert "[Inductor][Grouped Gemm] Add Blackwell CuTeDSL Kernel (#165036)" This reverts commit 0e1a88904f4a5e30634b196678b56e1d6ec074f5. Reverted https://github.com/pytorch/pytorch/pull/165036 on behalf of https://github.com/atalman due to regressed vllm signal: [GH job link](https://github.com/pytorch/pytorch/actions/runs/19059329909/job/54439919668) [HUD commit link](https://hud.pytorch.org/pytorch/pytorch/commit/0e1a88904f4a5e30634b196678b56e1d6ec074f5) ([comment](https://github.com/pytorch/pytorch/pull/165036#issuecomment-3487846555)) --- .ci/pytorch/test.sh | 2 +- .gitignore | 1 - setup.py | 34 -- test/inductor/test_cutedsl_grouped_mm.py | 154 -------- torch/_inductor/config.py | 4 - torch/_inductor/kernel/mm_common.py | 7 - torch/_inductor/kernel/mm_grouped.py | 93 ++--- .../templates/cutedsl_mm_grouped.py.jinja | 333 ------------------ .../_inductor/template_heuristics/cutedsl.py | 141 -------- torch/_inductor/utils.py | 71 ---- 10 files changed, 33 insertions(+), 807 deletions(-) delete mode 100644 test/inductor/test_cutedsl_grouped_mm.py delete mode 100644 torch/_inductor/kernel/templates/cutedsl_mm_grouped.py.jinja delete mode 100644 torch/_inductor/template_heuristics/cutedsl.py diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index 9ae2578758939..26996b5a32d56 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -337,7 +337,7 @@ test_python() { test_python_smoke() { # Smoke tests for H100/B200 - time python test/run_test.py --include test_matmul_cuda test_scaled_matmul_cuda inductor/test_fp8 inductor/test_max_autotune inductor/test_cutedsl_grouped_mm $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running + time python test/run_test.py --include test_matmul_cuda test_scaled_matmul_cuda inductor/test_fp8 inductor/test_max_autotune $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running assert_git_not_dirty } diff --git a/.gitignore b/.gitignore index 3b4323051073a..d1b3b17445dac 100644 --- a/.gitignore +++ b/.gitignore @@ -127,7 +127,6 @@ torch/test/ torch/utils/benchmark/utils/valgrind_wrapper/callgrind.h torch/utils/benchmark/utils/valgrind_wrapper/valgrind.h torch/version.py -torch/_inductor/kernel/vendored_templates/* minifier_launcher.py aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_fwd_d* aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_bwd_d* diff --git a/setup.py b/setup.py index dd8a52cbeb7c7..31e78d0245d93 100644 --- a/setup.py +++ b/setup.py @@ -630,37 +630,6 @@ def mirror_files_into_torchgen() -> None: raise RuntimeError("Check the file paths in `mirror_files_into_torchgen()`") -def mirror_inductor_external_kernels() -> None: - """ - Copy external kernels into Inductor so they are importable. - """ - paths = [ - ( - CWD / "torch/_inductor/kernel/vendored_templates/cutedsl_grouped_gemm.py", - CWD - / "third_party/cutlass/examples/python/CuTeDSL/blackwell/grouped_gemm.py", - ), - ] - for new_path, orig_path in paths: - # Create the dirs involved in new_path if they don't exist - if not new_path.exists(): - new_path.parent.mkdir(parents=True, exist_ok=True) - - # Copy the files from the orig location to the new location - if orig_path.is_file(): - shutil.copyfile(orig_path, new_path) - continue - if orig_path.is_dir(): - if new_path.exists(): - # copytree fails if the tree exists already, so remove it. - shutil.rmtree(new_path) - shutil.copytree(orig_path, new_path) - continue - raise RuntimeError( - "Check the file paths in `mirror_inductor_external_kernels()`" - ) - - # ATTENTION: THIS IS AI SLOP def extract_variant_from_version(version: str) -> str: """Extract variant from version string, defaulting to 'cpu'.""" @@ -1647,8 +1616,6 @@ def main() -> None: if RUN_BUILD_DEPS: build_deps() - mirror_inductor_external_kernels() - ( ext_modules, cmdclass, @@ -1682,7 +1649,6 @@ def main() -> None: "_inductor/codegen/aoti_runtime/*.cpp", "_inductor/script.ld", "_inductor/kernel/flex/templates/*.jinja", - "_inductor/kernel/templates/*.jinja", "_export/serde/*.yaml", "_export/serde/*.thrift", "share/cmake/ATen/*.cmake", diff --git a/test/inductor/test_cutedsl_grouped_mm.py b/test/inductor/test_cutedsl_grouped_mm.py deleted file mode 100644 index c26def3a54099..0000000000000 --- a/test/inductor/test_cutedsl_grouped_mm.py +++ /dev/null @@ -1,154 +0,0 @@ -# Owner(s): ["module: inductor"] - - -import unittest - -import torch -from torch import Tensor -from torch._inductor import config -from torch._inductor.codegen.cuda.cuda_env import is_datacenter_blackwell_arch -from torch._inductor.test_case import run_tests, TestCase as InductorTestCase -from torch._inductor.utils import ensure_cute_available -from torch.testing._internal.common_utils import ( - instantiate_parametrized_tests, - parametrize, -) - - -@unittest.skipIf( - not (ensure_cute_available() and is_datacenter_blackwell_arch()), - "CuTeDSL library or Blackwell device not available", -) -@instantiate_parametrized_tests -class TestCuTeDSLGroupedGemm(InductorTestCase): - def _get_inputs( - self, - group_size: int, - M_hint: int, - K: int, - N: int, - device: str, - dtype: torch.dtype, - alignment: int = 16, - ) -> tuple[Tensor, Tensor, Tensor]: - # --- Random, tile-aligned M sizes --- - M_sizes = ( - torch.randint(1, (M_hint // alignment) + 1, (group_size,), dtype=torch.int) - * alignment - ) - - M_total = torch.sum(M_sizes).item() - - # --- Construct input tensors --- - A = torch.randn(int(M_total), K, dtype=dtype, device=device) * 0.1 - B = torch.randn((group_size, K, N), dtype=dtype, device=device) * 0.01 - - # --- Build offsets (no leading zero, strictly increasing) --- - offsets = torch.cumsum(M_sizes, dim=0).to(dtype=torch.int32, device=device) - - return (A, B, offsets) - - @parametrize("group_size", (2, 8)) - @parametrize("M_hint", (256, 1024)) - @parametrize("K", (64, 128)) - @parametrize("N", (128, 256)) - def test_grouped_gemm_basic(self, group_size: int, M_hint: int, K: int, N: int): - device = "cuda" - dtype = torch.bfloat16 - - A, B, offsets = self._get_inputs(group_size, M_hint, K, N, device, dtype) - - def grouped_gemm_fn(A_packed, B_batched, offs): - return torch._grouped_mm(A_packed, B_batched, offs=offs) - - # Eager execution - c_eager = grouped_gemm_fn(A, B, offsets) - - # Test with Cute backend - with config.patch( - { - "max_autotune": True, - "max_autotune_gemm_backends": "CUTEDSL", - "test_configs.autotune_choice_name_regex": "cutedsl", - "autotune_fallback_to_aten": False, - } - ): - grouped_gemm_compiled = torch.compile( - grouped_gemm_fn, backend="inductor", dynamic=False - ) - c_compiled = grouped_gemm_compiled(A, B, offsets) - - self.assertEqual(c_eager.dtype, dtype) - self.assertEqual(c_compiled.dtype, dtype) - torch.testing.assert_close(c_eager, c_compiled) - - @parametrize("layout_A", ("contiguous", "offset", "padded", "view")) - @parametrize("layout_B", ("contiguous", "broadcasted")) - def test_grouped_gemm_assorted_layouts( - self, - layout_A: str, - layout_B: str, - ): - device = "cuda" - dtype = torch.bfloat16 - - G, K, N = 8, 64, 128 - M_sizes = [128] * G - sum_M = sum(M_sizes) - offsets = torch.tensor( - [sum(M_sizes[: i + 1]) for i in range(G)], dtype=torch.int32, device=device - ) - - A_base = torch.randn(sum_M, K, device=device, dtype=dtype) - A = A_base - - if layout_A == "offset": - # allocate bigger buffer than needed, use nonzero storage offset - storage = torch.randn(sum_M * K + 512, device=device, dtype=dtype) - offset = 128 # skip first 128 elements - A = torch.as_strided(storage[offset:], (sum_M, K), (K, 1)) - elif layout_A == "padded": - # simulate row pitch > K (row_stride = K + pad) - row_pitch = K + 8 - storage = torch.randn(sum_M * row_pitch, device=device, dtype=dtype) - A = torch.as_strided(storage, (sum_M, K), (row_pitch, 1)) - elif layout_A == "view": - A_storage = torch.randn(sum_M * K, device=device, dtype=dtype) - A = A_storage.view(sum_M, K) - assert A._base is not None - assert A.shape == (sum_M, K) - - B = torch.randn((G, K, N), dtype=dtype, device=device) * 0.01 - - if layout_B == "broadcasted": - # Broadcast B across groups (zero stride along G) - B = B[0].expand(G, K, N) - assert B.stride(0) == 0 - - def grouped_gemm_fn(A_packed, B_batched, offs): - return torch._grouped_mm(A_packed, B_batched, offs=offs) - - # --- eager --- - c_eager = grouped_gemm_fn(A, B, offsets) - - # --- compiled (CUTE backend) --- - with config.patch( - { - "max_autotune": True, - "max_autotune_gemm_backends": "CUTEDSL", - "test_configs.autotune_choice_name_regex": "cutedsl", - "autotune_fallback_to_aten": False, - } - ): - grouped_gemm_compiled = torch.compile( - grouped_gemm_fn, backend="inductor", dynamic=False - ) - c_compiled = grouped_gemm_compiled(A, B, offsets) - - self.assertEqual(c_eager.dtype, dtype) - self.assertEqual(c_compiled.dtype, dtype) - torch.testing.assert_close(c_eager, c_compiled) - - -if __name__ == "__main__": - run_tests() diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 08cc2b2bd861a..457f86fe7a77e 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -546,10 +546,6 @@ def prologue_fusion_enabled() -> bool: "TORCHINDUCTOR_MAX_AUTOTUNE_FLEX_SEARCH_SPACE", "DEFAULT" ).upper() # type: ignore[assignment] -cutedsl_enable_autotuning: bool = ( - os.environ.get("CUTEDSL_ENABLE_AUTOTUNING", "0") == "1" -) - # DEPRECATED. This setting is ignored. autotune_fallback_to_aten = False diff --git a/torch/_inductor/kernel/mm_common.py b/torch/_inductor/kernel/mm_common.py index eb22b95af2afc..b95073e769f31 100644 --- a/torch/_inductor/kernel/mm_common.py +++ b/torch/_inductor/kernel/mm_common.py @@ -1,8 +1,6 @@ # mypy: allow-untyped-defs import logging from collections.abc import Sequence -from functools import partial -from pathlib import Path from typing import Any import torch @@ -14,7 +12,6 @@ from .. import config from ..codegen.wrapper import PythonWrapperCodegen from ..ir import _IntLike, Layout, TensorBox -from ..utils import load_template log = logging.getLogger(__name__) @@ -257,7 +254,3 @@ def is_batch_stride_largest_or_zero(mat1, mat2, layout) -> bool: return False return True - - -_KERNEL_TEMPLATE_DIR = Path(__file__).parent / "templates" -load_kernel_template = partial(load_template, template_dir=_KERNEL_TEMPLATE_DIR) diff --git a/torch/_inductor/kernel/mm_grouped.py b/torch/_inductor/kernel/mm_grouped.py index 0a44b728a5a93..881c14fd43d0d 100644 --- a/torch/_inductor/kernel/mm_grouped.py +++ b/torch/_inductor/kernel/mm_grouped.py @@ -1,11 +1,10 @@ # mypy: allow-untyped-defs import logging -from dataclasses import asdict, dataclass +from dataclasses import dataclass from typing import Any, Optional import torch from torch._dynamo.utils import counters -from torch._inductor.codegen.cutedsl.cutedsl_template import CuteDSLTemplate from torch._inductor.runtime.triton_compat import tl from torch._inductor.virtualized import V from torch.utils._triton import has_triton @@ -19,25 +18,19 @@ TritonTemplate, ) from ..utils import ( - ensure_cute_available, get_gpu_shared_memory, get_num_sms, has_free_symbols, use_aten_gemm_kernels, - use_blackwell_cutedsl_grouped_mm, use_triton_template, ) from .mm_common import ( _is_static_problem, check_supported_striding, - load_kernel_template, persistent_grouped_mm_grid, ) -if ensure_cute_available(): - from torch._inductor.template_heuristics.cutedsl import get_groupgemm_configs - log = logging.getLogger(__name__) aten = torch.ops.aten @@ -520,11 +513,6 @@ def do_mma(a, b, accumulator): source=triton_grouped_mm_source, ) -cutedsl_grouped_mm_template = CuteDSLTemplate( - name="grouped_gemm_cutedsl", - source=load_kernel_template("cutedsl_mm_grouped"), -) - def grouped_mm_args( mat1: TensorBox, @@ -726,44 +714,43 @@ def _tuned_grouped_mm_common( # Checking only for the equality of corresponding dims of # multiplicands here, relying on meta function checks for # everything else. - if len(m1_size) == 2: - if len(m2_size) == 2: - m, k1 = m1_size - k2, _ = m2_size - # pyrefly: ignore [missing-attribute] - g = offs.get_size()[0] - V.graph.sizevars.check_equals(k1, k2) - a_is_2d, b_is_2d = True, True - else: - # pyrefly: ignore [missing-attribute] - g1 = offs.layout.size[0] - m, k1 = m1_size - g2, k2, _ = m2_size - g = V.graph.sizevars.check_equals_and_simplify(g1, g2) - V.graph.sizevars.check_equals(k1, k2) - a_is_2d, b_is_2d = True, False - else: - if len(m2_size) == 2: - # pyrefly: ignore [missing-attribute] - g1 = offs.layout.size[0] - g2, m, k1 = m1_size - k2, _ = m2_size - g = V.graph.sizevars.check_equals_and_simplify(g1, g2) - V.graph.sizevars.check_equals(k1, k2) - a_is_2d, b_is_2d = False, True - else: - g1, m, k1 = m1_size - g2, k2, _ = m2_size - g = V.graph.sizevars.check_equals_and_simplify(g1, g2) - V.graph.sizevars.check_equals(k1, k2) - a_is_2d, b_is_2d = False, False - if ( is_nonzero and use_triton_template(layout) and can_use_triton_kernel(mat_a, mat_b, offs, bias, scale_result) ): scaled = scale_a is not None + if len(m1_size) == 2: + if len(m2_size) == 2: + m, k1 = m1_size + k2, _ = m2_size + # pyrefly: ignore [missing-attribute] + g = offs.get_size()[0] + V.graph.sizevars.check_equals(k1, k2) + a_is_2d, b_is_2d = True, True + else: + # pyrefly: ignore [missing-attribute] + g1 = offs.layout.size[0] + m, k1 = m1_size + g2, k2, _ = m2_size + g = V.graph.sizevars.check_equals_and_simplify(g1, g2) + V.graph.sizevars.check_equals(k1, k2) + a_is_2d, b_is_2d = True, False + else: + if len(m2_size) == 2: + # pyrefly: ignore [missing-attribute] + g1 = offs.layout.size[0] + g2, m, k1 = m1_size + k2, _ = m2_size + g = V.graph.sizevars.check_equals_and_simplify(g1, g2) + V.graph.sizevars.check_equals(k1, k2) + a_is_2d, b_is_2d = False, True + else: + g1, m, k1 = m1_size + g2, k2, _ = m2_size + g = V.graph.sizevars.check_equals_and_simplify(g1, g2) + V.graph.sizevars.check_equals(k1, k2) + a_is_2d, b_is_2d = False, False a_is_k_major = mat_a.get_stride()[-1] == 1 b_is_k_major = mat_b.get_stride()[-2] == 1 @@ -801,22 +788,6 @@ def _tuned_grouped_mm_common( **config.kwargs, ) - if use_blackwell_cutedsl_grouped_mm( - mat_a, mat_b, layout, a_is_2d, b_is_2d, offs, bias, scale_result - ): - for config in get_groupgemm_configs(): - kwargs = dict( - ACC_DTYPE="cutlass.Float32", - ) - - cutedsl_grouped_mm_template.maybe_append_choice( - choices, - input_nodes=input_nodes, - layout=layout, - **kwargs, - **asdict(config), - ) - input_gen_fns = { 4: lambda x: create_offsets( x, m1_size, m2_size, offs.get_size() if offs is not None else None diff --git a/torch/_inductor/kernel/templates/cutedsl_mm_grouped.py.jinja b/torch/_inductor/kernel/templates/cutedsl_mm_grouped.py.jinja deleted file mode 100644 index 989f297c5f80f..0000000000000 --- a/torch/_inductor/kernel/templates/cutedsl_mm_grouped.py.jinja +++ /dev/null @@ -1,333 +0,0 @@ -import functools -from torch._inductor.runtime.runtime_utils import ceildiv -from cutlass.utils import TensorMapUpdateMode -{{gen_defines()}} -# ---- Import GroupedGemm implementation, copied on PyTorch build from Cutlass repository: cutlass/examples/python/CuTeDSL/blackwell/grouped_gemm.py ---- -from torch._inductor.kernel.vendored_templates.cutedsl_grouped_gemm import ( - GroupedGemmKernel, -) - - -# Note about caching: -# Each instantiated CuTeDSL grouped GEMM kernel file generated by Inductor -# maintains its own local caching system. At this stage, all compile-time -# constexprs (e.g., TILE_M, TILE_N, CLUSTER_M/N, USE_2_CTA) and the kernel -# name itself ({{kernel_name}}) are permanently baked into the file, so they -# do not need to be included in any cache key. -# -# The caching mechanism is split into two levels: -# -# 1. prep_cache -# Caches the compiled executor for build_group_ptrs_from_bases(). This -# kernel depends only on the tensor shapes, strides, and dtypes of A/B/C, -# and can therefore be safely reused across runs with different group -# partitioning (`offs`). -# -# 2. gemm_cache -# Caches the compiled Grouped GEMM executor. Its key extends the prep -# cache key with hardware- and grid-specific parameters: -# (prep_cache_key, max_active_clusters, total_num_clusters). -# This is necessary because different `offs` tensors can change the -# per-group problem sizes and thus alter `total_num_clusters`, which in -# turn changes the grid shape and persistent scheduler configuration. -# Kernels compiled for one grid cannot be safely reused for another. -# -# -# Additionally, note the @lru_cache decorator on get_hardware_info(). Empirically, -# hw.get_max_active_clusters() triggers significant MLIR recompilation overhead, -# despite depending only on the GPU type. We cache this function to mitigate -# redundant recompiles even when shape/stride/dtype cache misses force kernel -# regeneration. A follow-up study will investigate the root cause. - -prep_cache = {} -gemm_cache = {} - - -@functools.lru_cache -def get_hardware_info(): - hw = cutlass.utils.HardwareInfo() - sm_count = hw.get_max_active_clusters(1) - max_active_clusters = hw.get_max_active_clusters(CLUSTER_M * CLUSTER_N) - - return (sm_count, max_active_clusters) - - -def get_prep_cache_key(input_a, input_b, output): - """ - Returns a tuple key for caching the preprocessing kernel executor based on kernel name, - shapes, strides, and dtypes of input/output tensors. - """ - return ( - tuple(input_a.shape), - tuple(input_a.stride()), - input_a.dtype, - tuple(input_b.shape), - tuple(input_b.stride()), - input_b.dtype, - tuple(output.shape), - tuple(output.stride()), - output.dtype, - ) - - -def get_gemm_cache_key(prep_cache_key, max_active_clusters, total_num_clusters): - """ - Returns a tuple key for caching the gemm kernel executor by extending the - prep cache key with hardware- and grid-specific parameters. - """ - return ( - prep_cache_key, - max_active_clusters, - total_num_clusters, - ) - - -@cute.kernel -def build_group_ptrs_from_bases_kernel( - base_A_u64: cutlass.Int64, # device addr of input_a (bytes) - base_B_u64: cutlass.Int64, # device addr of input_b (bytes) - base_C_u64: cutlass.Int64, # device addr of Output (bytes) - offs: cute.Tensor, # [G], cutlass.Int32/64 cumulative - K: cutlass.Constexpr, - N: cutlass.Constexpr, - sizeof_element: cutlass.Int32, # bytes - # -------- STRIDES (in ELEMENTS) -------- - stride_A_m_elems: cutlass.Constexpr, # A.stride(0) - stride_A_k_elems: cutlass.Constexpr, # A.stride(1) - stride_B0_elems: cutlass.Constexpr, # B.stride(0) - stride_Bk_elems: cutlass.Constexpr, # B.stride(1) - stride_Bn_elems: cutlass.Constexpr, # B.stride(2) - stride_C_m_elems: cutlass.Constexpr, # C.stride(0) - stride_C_n_elems: cutlass.Constexpr, # C.stride(1) - # -------- OUTPUTS -------- - out_ptrs: cute.Tensor, # [G,3] cutlass.Int64: (A_ptr, B_ptr, C_ptr) - out_problem: cute.Tensor, # [G,4] cutlass.Int32: (m_g, n, k, 1) - out_strides_abc: cute.Tensor, # [G,3,2] cutlass.Int32 [[A_m,A_k],[B_n,B_k],[C_m,C_n]] -): - tidx, _, _ = cute.arch.thread_idx() - g = tidx - - m_beg_i32 = 0 - if g > 0: - m_beg_i32 = offs[g - 1] - m_end_i32 = offs[g] - m_g_i32 = m_end_i32 - m_beg_i32 - - a_byte_off = ( - cutlass.Int64(m_beg_i32) * stride_A_m_elems * cutlass.Int64(sizeof_element) - ) - c_byte_off = ( - cutlass.Int64(m_beg_i32) * stride_C_m_elems * cutlass.Int64(sizeof_element) - ) - b_byte_off = cutlass.Int64(g) * stride_B0_elems * cutlass.Int64(sizeof_element) - - # ---- pointers ---- - out_ptrs[g, 0] = base_A_u64 + a_byte_off - out_ptrs[g, 1] = base_B_u64 + b_byte_off - out_ptrs[g, 2] = base_C_u64 + c_byte_off - - # ---- (m, n, k, 1) ---- - out_problem[g, 0] = m_g_i32 - out_problem[g, 1] = N - out_problem[g, 2] = K - out_problem[g, 3] = cutlass.Int32(1) - - # ---- strides ---- - out_strides_abc[g, 0, 0] = cutlass.Int32(stride_A_m_elems) - out_strides_abc[g, 0, 1] = cutlass.Int32(stride_A_k_elems) - out_strides_abc[g, 1, 0] = cutlass.Int32(stride_Bn_elems) - out_strides_abc[g, 1, 1] = cutlass.Int32(stride_Bk_elems) - out_strides_abc[g, 2, 0] = cutlass.Int32(stride_C_m_elems) - out_strides_abc[g, 2, 1] = cutlass.Int32(stride_C_n_elems) - - -@cute.jit -def launch_build_group_ptrs_from_bases( - base_A_u64: cutlass.Int64, - base_B_u64: cutlass.Int64, - base_C_u64: cutlass.Int64, - offs: cute.Tensor, - G: cutlass.Constexpr, - K: cutlass.Constexpr, - N: cutlass.Constexpr, - sizeof_element: cutlass.Constexpr, - stride_A_m_elems: cutlass.Constexpr, - stride_A_k_elems: cutlass.Constexpr, - stride_B0_elems: cutlass.Constexpr, - stride_Bk_elems: cutlass.Constexpr, - stride_Bn_elems: cutlass.Constexpr, - stride_C_m_elems: cutlass.Constexpr, - stride_C_n_elems: cutlass.Constexpr, - out_ptrs: cute.Tensor, # [G,3] cutlass.Int64 - out_problem: cute.Tensor, # [G,4] cutlass.Int32 - out_strides_abc: cute.Tensor, # [3,2] cutlass.Int32 - stream: cuda.CUstream, -): - build_group_ptrs_from_bases_kernel( - base_A_u64, - base_B_u64, - base_C_u64, - offs, - K, - N, - sizeof_element, - stride_A_m_elems, - stride_A_k_elems, - stride_B0_elems, - stride_Bk_elems, - stride_Bn_elems, - stride_C_m_elems, - stride_C_n_elems, - out_ptrs, - out_problem, - out_strides_abc, - ).launch(grid=(1, 1, 1), block=(G, 1, 1), stream=stream) - - -{{def_kernel("input_a", "input_b", "input_a_offs")}} - stream = cuda.CUstream(stream) - - input_b = input_b.transpose(1, 2) - - sumM, K = input_a.shape - G, N, Kb = input_b.shape - - dev = input_a.device - - base_A_u64 = int(input_a.data_ptr()) - base_B_u64 = int(input_b.data_ptr()) - base_C_u64 = int({{get_output()}}.data_ptr()) - - ptrs_t = torch.empty((G, 3), device=dev, dtype=torch.int64) - probs_t = torch.empty((G, 4), device=dev, dtype=torch.int32) - strides_t = torch.empty((G, 3, 2), device=dev, dtype=torch.int32) - ptrs = from_dlpack(ptrs_t) - probs = from_dlpack(probs_t) - strides = from_dlpack(strides_t) - - prep_cache_key = get_prep_cache_key(input_a, input_b, {{get_output()}}) - prep_executor = prep_cache.get(prep_cache_key) - - if prep_executor is None: - sizeof_element = int(input_a.element_size()) - sA_m, sA_k = map(int, input_a.stride()) - sB_0, sB_n, sB_k = map(int, input_b.stride()) - sC_m, sC_n = map(int, {{get_output()}}.stride()) - - prep_executor = cute.compile( - launch_build_group_ptrs_from_bases, - base_A_u64=base_A_u64, - base_B_u64=base_B_u64, - base_C_u64=base_C_u64, - offs=from_dlpack(input_a_offs), - G=int(G), - K=int(K), - N=int(N), - sizeof_element=sizeof_element, - stride_A_m_elems=sA_m, - stride_A_k_elems=sA_k, - stride_B0_elems=sB_0, - stride_Bk_elems=sB_k, - stride_Bn_elems=sB_n, - stride_C_m_elems=sC_m, - stride_C_n_elems=sC_n, - out_ptrs=ptrs, - out_problem=probs, - out_strides_abc=strides, - stream=stream, - ) - - prep_cache[prep_cache_key] = prep_executor - - prep_executor( - base_A_u64=base_A_u64, - base_B_u64=base_B_u64, - base_C_u64=base_C_u64, - offs=from_dlpack(input_a_offs), - out_ptrs=ptrs, - out_problem=probs, - out_strides_abc=strides, - stream=stream, - ) - - # --- Tensormap workspace per SM --- - num_tensormap_buffers, max_active_clusters = get_hardware_info() - tensormap_shape = ( - num_tensormap_buffers, - GroupedGemmKernel.num_tensormaps, - GroupedGemmKernel.bytes_per_tensormap // 8, - ) - tensormap_workspace_t = torch.empty(tensormap_shape, device=dev, dtype=torch.int64) - tensormap_workspace = from_dlpack(tensormap_workspace_t) - - # --- Total clusters --- - def compute_total_num_clusters( - problem_sizes_mnkl, - cluster_tile_shape_mn, - ): - total_num_clusters = 0 - for m, n, _, _ in problem_sizes_mnkl: - num_clusters_mn = tuple( - ceildiv(x, y) for x, y in zip((m, n), cluster_tile_shape_mn) - ) - total_num_clusters += functools.reduce(lambda x, y: x * y, num_clusters_mn) - return total_num_clusters - - # Compute cluster tile shape - def compute_cluster_tile_shape( - mma_tiler_mn, - cluster_shape_mn, - use_2cta_instrs, - ): - cta_tile_shape_mn = list(mma_tiler_mn) - if use_2cta_instrs: - cta_tile_shape_mn[0] = cta_tile_shape_mn[0] // 2 - return tuple(x * y for x, y in zip(cta_tile_shape_mn, cluster_shape_mn)) - - cluster_tile_shape_mn = compute_cluster_tile_shape( - (TILE_M, TILE_N), (CLUSTER_M, CLUSTER_N), bool(USE_2_CTA) - ) - - total_num_clusters = int(compute_total_num_clusters(probs_t, cluster_tile_shape_mn)) - - gemm_cache_key = get_gemm_cache_key( - prep_cache_key, max_active_clusters, total_num_clusters - ) - gemm_executor = gemm_cache.get(gemm_cache_key) - - if gemm_executor is None: - grouped_gemm = GroupedGemmKernel( - acc_dtype=ACC_DTYPE, - use_2cta_instrs=USE_2_CTA, - mma_tiler_mn=(TILE_M, TILE_N), - cluster_shape_mn=(CLUSTER_M, CLUSTER_N), - tensormap_update_mode=TENSORMAP_UPDATE_MODE, - ) - - gemm_executor = cute.compile( - grouped_gemm, - from_dlpack(input_a.unsqueeze(-1), assumed_align=16), - from_dlpack(input_b[0].unsqueeze(-1), assumed_align=16), - from_dlpack({{get_output()}}.unsqueeze(-1), assumed_align=16), - G, - probs, - strides, - ptrs, - total_num_clusters, - tensormap_workspace, - max_active_clusters, - stream, - ) - - gemm_cache[gemm_cache_key] = gemm_executor - - gemm_executor( - from_dlpack(input_a.unsqueeze(-1), assumed_align=16), - from_dlpack(input_b[0].unsqueeze(-1), assumed_align=16), - from_dlpack({{get_output()}}.unsqueeze(-1), assumed_align=16), - probs, - strides, - ptrs, - tensormap_workspace, - stream, - ) diff --git a/torch/_inductor/template_heuristics/cutedsl.py b/torch/_inductor/template_heuristics/cutedsl.py deleted file mode 100644 index db337b9d8a271..0000000000000 --- a/torch/_inductor/template_heuristics/cutedsl.py +++ /dev/null @@ -1,141 +0,0 @@ -from dataclasses import dataclass -from enum import auto, Enum -from itertools import product - -import torch._inductor.config as config - - -class TensorMapUpdateMode(Enum): - """Enum mirroring cutlass.utils.TensorMapUpdateMode to decouple this file from a cutlass dependency.""" - - SMEM = auto() - GMEM = auto() - - -@dataclass(frozen=True) -class CuTeGemmConfig: - TILE_M: int = 128 - TILE_N: int = 192 - CLUSTER_M: int = 2 - CLUSTER_N: int = 1 - USE_2_CTA: bool = False - TENSORMAP_UPDATE_MODE: TensorMapUpdateMode = TensorMapUpdateMode.SMEM - - -def get_exhaustive_groupgemm_configs() -> list[CuTeGemmConfig]: - """ - Returns the exhaustive configuration set for the Blackwell CuTeDSL Grouped GEMM kernel. - For information regarding valid config sets, see: - https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/blackwell/grouped_gemm.py - """ - - # Tile_n is always the same regardless of 2cta - tile_n_vals = [32, 64, 96, 128, 160, 192, 224, 256] - - # Valid clusters - clusters_no_2cta = [ - (1, 1), - (1, 2), - (1, 4), - (1, 8), - (1, 16), - (2, 1), - (2, 2), - (2, 4), - (2, 8), - (4, 1), - (4, 2), - (4, 4), - (8, 1), - (8, 2), - (16, 1), - ] - clusters_2cta = [ - (2, 1), - (2, 2), - (2, 4), - (2, 8), - (4, 1), - (4, 2), - (4, 4), - (8, 1), - (8, 2), - (16, 1), - ] - - configs: list[CuTeGemmConfig] = [] - - for use_2cta, cluster_set, tile_m_range in [ - (False, clusters_no_2cta, [64, 128]), - (True, clusters_2cta, [128, 256]), - ]: - for tensormap_update_mode, tile_m, tile_n, (cluster_m, cluster_n) in product( - [TensorMapUpdateMode.SMEM, TensorMapUpdateMode.GMEM], - tile_m_range, - tile_n_vals, - cluster_set, - ): - configs.append( - CuTeGemmConfig( - tile_m, - tile_n, - cluster_m, - cluster_n, - USE_2_CTA=use_2cta, - TENSORMAP_UPDATE_MODE=tensormap_update_mode, - ) - ) - - return configs - - -def get_default_groupgemm_configs() -> list[CuTeGemmConfig]: - """ - Returns the default configuration set for the Blackwell CuTeDSL Grouped GEMM kernel. - """ - - config_tuples = [ - (128, 256, 2, 1, False, TensorMapUpdateMode.SMEM), - (256, 160, 2, 1, True, TensorMapUpdateMode.GMEM), - (256, 256, 2, 1, True, TensorMapUpdateMode.GMEM), - (64, 32, 1, 1, False, TensorMapUpdateMode.GMEM), - (64, 256, 1, 2, False, TensorMapUpdateMode.SMEM), - (128, 256, 1, 2, False, TensorMapUpdateMode.SMEM), - (256, 256, 2, 2, True, TensorMapUpdateMode.GMEM), - (128, 256, 1, 2, False, TensorMapUpdateMode.GMEM), - (64, 32, 1, 1, False, TensorMapUpdateMode.SMEM), - (256, 256, 2, 1, True, TensorMapUpdateMode.SMEM), - (128, 256, 1, 1, False, TensorMapUpdateMode.GMEM), - (256, 256, 8, 1, True, TensorMapUpdateMode.GMEM), - (64, 32, 1, 2, False, TensorMapUpdateMode.SMEM), - (256, 192, 2, 1, True, TensorMapUpdateMode.GMEM), - (256, 256, 2, 2, True, TensorMapUpdateMode.SMEM), - (128, 96, 1, 2, False, TensorMapUpdateMode.SMEM), - (64, 192, 1, 1, False, TensorMapUpdateMode.SMEM), - (64, 64, 1, 1, False, TensorMapUpdateMode.GMEM), - (64, 192, 1, 1, False, TensorMapUpdateMode.GMEM), - (128, 64, 1, 1, False, TensorMapUpdateMode.GMEM), - (64, 160, 1, 1, False, TensorMapUpdateMode.GMEM), - (64, 256, 1, 1, False, TensorMapUpdateMode.GMEM), - ] - - return [CuTeGemmConfig(*args) for args in config_tuples] - - -def get_groupgemm_configs() -> list[CuTeGemmConfig]: - """ - Returns the configuration set for the Blackwell CuTeDSL Grouped GEMM kernel. - - Note: CuTeDSL autotuning is still experimental — enabling it may trigger kernel launch failures - or unstable results. By default, autotuning is disabled and we return only - a single baseline config. - """ - if ( - config.cutedsl_enable_autotuning - and config.max_autotune_gemm_search_space == "EXHAUSTIVE" - ): - return get_exhaustive_groupgemm_configs() - elif config.cutedsl_enable_autotuning: - return get_default_groupgemm_configs() - else: - return [get_default_groupgemm_configs()[0]] diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 6b34ef28b2c10..2cf915d9e61de 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -1975,77 +1975,6 @@ def use_triton_blackwell_tma_template( return has_triton_tensor_descriptor_host_tma() and is_datacenter_blackwell_arch() -@functools.lru_cache(maxsize=1) -def ensure_cute_available() -> bool: - """Check if CuTeDSL is importable; cache the result for reuse. - - Call ensure_cute_available.cache_clear() after installing CuTeDSL - in the same interpreter to retry the import. - """ - try: - return importlib.util.find_spec("cutlass.cute") is not None - except ImportError: - return False - - -def use_blackwell_cutedsl_grouped_mm( - mat_a: Any, - mat_b: Any, - layout: Layout, - a_is_2d: bool, - b_is_2d: bool, - offs: Optional[Any], - bias: Optional[Any], - scale_result: Optional[Any], -) -> bool: - """ - Returns True if we can use the blackwell kernel for grouped mm. - Required conditions: - 1. CuTeDSL is available - 2. We are on a blackwell arch - 3. The dtype is bf16 - 4. Max autotune or max autotune gemm is enabled - 6. A, B, and the output are 16B aligned - 7. We are not using dynamic shapes - 8. A is 2d - 9. B is 3d - 10. Offsets are provided - 11. Bias and Scale are not provided - """ - if not ensure_cute_available(): - return False - - from .codegen.cuda.cuda_env import is_datacenter_blackwell_arch - - if not is_gpu(layout.device.type) and is_datacenter_blackwell_arch(): - return False - - layout_dtypes = [torch.bfloat16] - if not _use_template_for_gpu(layout, layout_dtypes): - return False - - if not (config.max_autotune or config.max_autotune_gemm): - return False - - # Checks for 16B ptr and stride alignment - if not can_use_tma(mat_a, mat_b, output_layout=layout): - return False - - if any(is_dynamic(x) for x in [mat_a, mat_b]): - return False - - if not a_is_2d or b_is_2d: - return False - - if offs is None: - return False - - if bias is not None or scale_result is not None: - return False - - return True - - def use_cutlass_template(layout: Layout, m: int, n: int, k: int) -> bool: from .virtualized import V From 397d9fe2aea0dc60ea19ffddf6ac750420362867 Mon Sep 17 00:00:00 2001 From: Shunting Zhang Date: Tue, 4 Nov 2025 00:19:13 -0800 Subject: [PATCH 016/130] [inductor] coordesc not tune XBLOCK for mix-order-reduction (#166669) For mix-order reduction, we current force XBLOCK to be 1 to simplify codegen. Don't tune it in CDT. Differential Revision: [](https://our.internmc.facebook.com/intern/diff/) Differential Revision: [D86224689](https://our.internmc.facebook.com/intern/diff/D86224689) Pull Request resolved: https://github.com/pytorch/pytorch/pull/166669 Approved by: https://github.com/jansel, https://github.com/mlazos, https://github.com/eellison, https://github.com/v0i0 --- test/inductor/test_mix_order_reduction.py | 16 ++++++++++++++++ .../runtime/coordinate_descent_tuner.py | 8 +++++++- torch/_inductor/runtime/triton_heuristics.py | 8 ++++++++ 3 files changed, 31 insertions(+), 1 deletion(-) diff --git a/test/inductor/test_mix_order_reduction.py b/test/inductor/test_mix_order_reduction.py index 230a2514b9171..0dcc37ee359d8 100644 --- a/test/inductor/test_mix_order_reduction.py +++ b/test/inductor/test_mix_order_reduction.py @@ -117,6 +117,22 @@ def outer_red(): metrics.codegen_mix_order_reduction, ) + @inductor_config.patch(coordinate_descent_tuning=True) + def test_XBLOCK_coordest_tuning(self): + """ + We should skip XBLOCK coordinate descent tuning for + mix order reduction. + """ + if not inductor_config.triton.mix_order_reduction: + self.skipTest("Mix order reduction not enabled") + + def f(x): + return x.sum(dim=-1), x.sum(dim=0) + + x = torch.randn(32768, 256, dtype=torch.float, device=GPU_TYPE) + self.check_numeric(f, (x,)) + self.assertEqual(metrics.codegen_mix_order_reduction, 1) + @inductor_config.patch(unroll_reductions_threshold=1) def test_3layer_split_reduction(self): """ diff --git a/torch/_inductor/runtime/coordinate_descent_tuner.py b/torch/_inductor/runtime/coordinate_descent_tuner.py index 341475ef1d6fb..7ea22bdcddf0b 100644 --- a/torch/_inductor/runtime/coordinate_descent_tuner.py +++ b/torch/_inductor/runtime/coordinate_descent_tuner.py @@ -5,6 +5,8 @@ from collections.abc import Callable from typing import TYPE_CHECKING +from torch.utils._ordered_set import OrderedSet + from .hints import TRITON_MAX_BLOCK from .runtime_utils import red_text, triton_config_to_hashable @@ -54,6 +56,7 @@ def __init__( name="unknown", size_hints=None, inductor_meta=None, + frozen_fields=None, ): self.is_mm = is_mm # we will tune num_stages for mm @@ -66,6 +69,9 @@ def __init__( self.name = name self.size_hints = size_hints self.inductor_meta = inductor_meta or {} + self.frozen_fields: OrderedSet[str] = ( + OrderedSet(frozen_fields) if frozen_fields is not None else OrderedSet() + ) def get_config_max(self, prefix: str) -> int: max_block = TRITON_MAX_BLOCK[prefix.upper()] @@ -117,7 +123,7 @@ def tunable_fields(self): out.append("num_stages") out.remove("ZBLOCK") # ZBLOCK=1 always in native matmul - return out + return [f for f in out if f not in self.frozen_fields] def value_too_large(self, name: str, val: int) -> bool: block_suffix = "BLOCK" diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index fe6788fb21e91..cb43d55bc86b3 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -336,6 +336,7 @@ def __init__( name=self.fn.__name__, size_hints=size_hints, inductor_meta=self.inductor_meta, + frozen_fields=self.get_coordesc_frozen_fields(), ) self.filename = filename @@ -365,6 +366,13 @@ def __init__( # Mode for launch grid calculation self.grid_mode: Literal["python", "cpp"] = "python" + def get_coordesc_frozen_fields(self) -> OrderedSet[str]: + out: OrderedSet[str] = OrderedSet() + if self.inductor_meta.get("RSPLIT_SIZE"): + # We fix XBLOCK for mix order reduction + out.add("XBLOCK") + return out + def is_statically_launchable(self): """ Checks if every compiled kernel is statically launchable, which From 3283eaa5ba901b518fe971e3a35434982034e061 Mon Sep 17 00:00:00 2001 From: Ivan Zaitsev Date: Tue, 4 Nov 2025 20:33:56 +0000 Subject: [PATCH 017/130] Upload test stats for trunk/sha tag (#166916) Noticed that workflow runs for `trunk/{sha}` tags (issued by autorevert) don't populate test_run_s3 Clickhouse table. This PR is addressing this by changing the gate condition to upload tests stats. see https://github.com/pytorch/pytorch/actions/runs/19054297956/job/54421254448#step:8:23 as an evidence that HEAD_BRANCH is correctly populated for trunk tags. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166916 Approved by: https://github.com/huydhn, https://github.com/clee2000 --- tools/stats/upload_test_stats.py | 17 ++++++++++++++++- tools/test/test_upload_gate.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+), 1 deletion(-) create mode 100644 tools/test/test_upload_gate.py diff --git a/tools/stats/upload_test_stats.py b/tools/stats/upload_test_stats.py index b5802e8032419..6c0232c5e5a17 100644 --- a/tools/stats/upload_test_stats.py +++ b/tools/stats/upload_test_stats.py @@ -2,6 +2,7 @@ import argparse import os +import re import sys import xml.etree.ElementTree as ET from multiprocessing import cpu_count, Pool @@ -19,6 +20,19 @@ ) +def should_upload_full_test_run(head_branch: str | None, head_repository: str) -> bool: + """Return True if we should upload the full test_run dataset. + + Rules: + - Only for the main repository (pytorch/pytorch) + - If head_branch is 'main', or a tag of form 'trunk/{40-hex-sha}' + """ + is_trunk_tag = bool(re.fullmatch(r"trunk/[0-9a-fA-F]{40}", (head_branch or ""))) + return head_repository == "pytorch/pytorch" and ( + head_branch == "main" or is_trunk_tag + ) + + def parse_xml_report( tag: str, report: Path, @@ -287,7 +301,8 @@ def init_value(test_case: dict[str, Any]) -> dict[str, Any]: remove_nan_inf(failed_tests_cases), ) - if args.head_branch == "main" and args.head_repository == "pytorch/pytorch": + # Upload full test_run only for trusted refs (main or trunk/{sha} tags) + if should_upload_full_test_run(args.head_branch, args.head_repository): # For jobs on main branch, upload everything. upload_workflow_stats_to_s3( args.workflow_run_id, diff --git a/tools/test/test_upload_gate.py b/tools/test/test_upload_gate.py new file mode 100644 index 0000000000000..7d9a2e5fe3b0b --- /dev/null +++ b/tools/test/test_upload_gate.py @@ -0,0 +1,28 @@ +import unittest + +from tools.stats.upload_test_stats import should_upload_full_test_run + + +class TestUploadGate(unittest.TestCase): + def test_main_branch_on_pytorch_repo(self) -> None: + self.assertTrue(should_upload_full_test_run("main", "pytorch/pytorch")) + + def test_trunk_tag_valid_sha_on_pytorch_repo(self) -> None: + sha = "a" * 40 + self.assertTrue(should_upload_full_test_run(f"trunk/{sha}", "pytorch/pytorch")) + + def test_trunk_tag_invalid_sha_on_pytorch_repo(self) -> None: + # Not 40 hex chars + self.assertFalse(should_upload_full_test_run("trunk/12345", "pytorch/pytorch")) + + def test_non_main_branch_on_pytorch_repo(self) -> None: + self.assertFalse( + should_upload_full_test_run("feature-branch", "pytorch/pytorch") + ) + + def test_main_branch_on_fork_repo(self) -> None: + self.assertFalse(should_upload_full_test_run("main", "someone/fork")) + + +if __name__ == "__main__": + unittest.main() From b4e4ee81d386db922d8f63359f9870eff1f44052 Mon Sep 17 00:00:00 2001 From: Andrey Talman Date: Tue, 4 Nov 2025 20:34:11 +0000 Subject: [PATCH 018/130] Update triton to 3.5.1 release (#166968) This includes sm103 https://github.com/triton-lang/triton/pull/8485 fix Pull Request resolved: https://github.com/pytorch/pytorch/pull/166968 Approved by: https://github.com/Lucaskabela, https://github.com/njriasan --- .ci/docker/ci_commit_pins/triton.txt | 2 +- .ci/docker/triton_version.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.ci/docker/ci_commit_pins/triton.txt b/.ci/docker/ci_commit_pins/triton.txt index 10f1207e60e6c..7aab8bed1c108 100644 --- a/.ci/docker/ci_commit_pins/triton.txt +++ b/.ci/docker/ci_commit_pins/triton.txt @@ -1 +1 @@ -7416ffcb92cdbe98d9f97e4e6f95247e46dfc9fd +bfeb066872bc1e8b2d2bc0a3b295b99dd77206e7 diff --git a/.ci/docker/triton_version.txt b/.ci/docker/triton_version.txt index 1545d966571dc..d5c0c99142898 100644 --- a/.ci/docker/triton_version.txt +++ b/.ci/docker/triton_version.txt @@ -1 +1 @@ -3.5.0 +3.5.1 From 2bba37309bc8996fc6a190592e5ad9aac53761c9 Mon Sep 17 00:00:00 2001 From: IvanKobzarev Date: Tue, 4 Nov 2025 09:52:20 -0800 Subject: [PATCH 019/130] [inductor] runtime estimations disable use_nccl_estimator by default (#166973) Pull Request resolved: https://github.com/pytorch/pytorch/pull/166973 Approved by: https://github.com/eellison, https://github.com/jathu --- torch/_inductor/comm_analysis.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch/_inductor/comm_analysis.py b/torch/_inductor/comm_analysis.py index afa569ff97da2..61af576772c16 100644 --- a/torch/_inductor/comm_analysis.py +++ b/torch/_inductor/comm_analysis.py @@ -359,7 +359,8 @@ def estimate_fx_collective_size(fx_node: torch.fx.Node) -> int: def estimate_nccl_collective_runtime_from_fx_node( fx_node: torch.fx.Node, override_size: Optional[int] = None, - use_nccl_estimator: bool = True, + # TODO(ivankobzarev): NCCL estimator sometimes fail unexpectedly, enable back after fix. + use_nccl_estimator: bool = False, ) -> float: """ Returns estimated NCCL collective runtime in nanoseconds (ns). From 871d0cd19651ce569fbc1b5dbc28195f8ae78315 Mon Sep 17 00:00:00 2001 From: Oguz Ulgen Date: Tue, 4 Nov 2025 10:46:42 -0800 Subject: [PATCH 020/130] If USE_CUDA=1 is set, do not fallback to no CUDA (#166982) So many times i build pytorch only to notice chef nuked my nvcc and i wasted 30m building a cpu version, lets hard error fast Pull Request resolved: https://github.com/pytorch/pytorch/pull/166982 Approved by: https://github.com/malfet ghstack dependencies: #166976 --- CMakeLists.txt | 10 ++++++++++ cmake/public/cuda.cmake | 9 +++++++++ 2 files changed, 19 insertions(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index 2bbb8797b78cd..86f43f58817ba 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -234,7 +234,17 @@ option(USE_COLORIZE_OUTPUT "Colorize output during compilation" ON) option(USE_ASAN "Use Address+Undefined Sanitizers" OFF) option(USE_LSAN "Use Leak Sanitizer" OFF) option(USE_TSAN "Use Thread Sanitizer" OFF) + +# Track whether USE_CUDA was explicitly set by the user (before option() is called) +# If USE_CUDA is already defined in cache, it means user explicitly set it +if(DEFINED CACHE{USE_CUDA}) + set(_USE_CUDA_EXPLICITLY_SET TRUE) +else() + set(_USE_CUDA_EXPLICITLY_SET FALSE) +endif() + option(USE_CUDA "Use CUDA" ON) + option(USE_XPU "Use XPU" ON) cmake_dependent_option( BUILD_LAZY_CUDA_LINALG "Build cuda linalg ops as separate library" ON diff --git a/cmake/public/cuda.cmake b/cmake/public/cuda.cmake index 218c50a69c6fb..bc8855d23e61f 100644 --- a/cmake/public/cuda.cmake +++ b/cmake/public/cuda.cmake @@ -28,6 +28,15 @@ endif() # Find CUDA. find_package(CUDA) if(NOT CUDA_FOUND) + # If user explicitly set USE_CUDA=1, error out instead of falling back + if(_USE_CUDA_EXPLICITLY_SET AND USE_CUDA) + message(FATAL_ERROR + "PyTorch: CUDA was explicitly requested (USE_CUDA=1) but cannot be found. " + "Please check your CUDA installation, ensure CUDA toolkit is installed, " + "and that CUDA_HOME or CMAKE_CUDA_COMPILER is set correctly. " + "If you want to build without CUDA, please set USE_CUDA=0.") + endif() + message(WARNING "PyTorch: CUDA cannot be found. Depending on whether you are building " "PyTorch or a PyTorch dependent library, the next warning / error will " From 4e1bd1673855356402ffcee4d254129f7848f402 Mon Sep 17 00:00:00 2001 From: "Colin L. Rice" Date: Tue, 4 Nov 2025 10:05:48 -0800 Subject: [PATCH 021/130] inductor: Switch quiesce to use timer based implementation. (#166581) Major change is to switch to a timer based implementation. Additionally, we get rid of the context manager for turning of the compile pool. We still have the warmup calls. Note that this only modifies the async_compile methods, the fx pool is left running. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166581 Approved by: https://github.com/masnesral ghstack dependencies: #166467 --- torch/_dynamo/convert_frame.py | 2 -- .../_aot_autograd/runtime_wrappers.py | 3 -- torch/_inductor/async_compile.py | 31 ++----------------- 3 files changed, 3 insertions(+), 33 deletions(-) diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index 875f640194e42..4439c7dc09efe 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -1285,7 +1285,6 @@ def _compile( # in the case of normal and exception code paths convert_frame_box: Optional[ConvertFrameBox] = None, ) -> ConvertFrameReturn: - from torch._inductor.async_compile import async_compile_pool_manager from torch.fx.experimental.validator import ( BisectValidationException, ValidationException, @@ -1479,7 +1478,6 @@ def count_args(code: CodeType) -> int: with ( _use_lazy_graph_module(config.use_lazy_graph_module), compile_context(CompileContext(compile_id)), - async_compile_pool_manager(), chromium_event_timed( "dynamo", reset_event_log_on_exit=True, log_pt2_compile_event=True ), diff --git a/torch/_functorch/_aot_autograd/runtime_wrappers.py b/torch/_functorch/_aot_autograd/runtime_wrappers.py index 4846f1ca74edb..86202e2cd319d 100644 --- a/torch/_functorch/_aot_autograd/runtime_wrappers.py +++ b/torch/_functorch/_aot_autograd/runtime_wrappers.py @@ -2365,8 +2365,6 @@ def backward(double_ctx, *args): @staticmethod def _backward_impl(ctx, all_args): - from torch._inductor.async_compile import async_compile_pool_manager - # compiled autograd reimplements this function at proxy_call_aot_backward assert not backward_state_indices, ( "BackwardState requires CompiledAutograd" @@ -2446,7 +2444,6 @@ def _backward_impl(ctx, all_args): with ( tracing(saved_context), compile_context(saved_compile_context), - async_compile_pool_manager(), context(), track_graph_compiling(aot_config, "backward"), metrics_context, diff --git a/torch/_inductor/async_compile.py b/torch/_inductor/async_compile.py index ac0d60bdebd71..a2c80002eb928 100644 --- a/torch/_inductor/async_compile.py +++ b/torch/_inductor/async_compile.py @@ -2,7 +2,6 @@ from __future__ import annotations import atexit -import contextlib import functools import json import logging @@ -230,18 +229,6 @@ def remove_future(kernel_src: str) -> None: del CompiledTritonKernels._cache[key] -@contextlib.contextmanager -def async_compile_pool_manager(): - """ - Context manager to quiesce the subproc pool at the end of compilation, i.e., - when dynamo is done. - """ - try: - yield - finally: - AsyncCompile.quiesce() - - class AsyncCompile: """ Utilities to compile in thread pools or subprocess pools (in the case of Triton). @@ -277,7 +264,9 @@ def process_pool() -> AnyPool: pool: AnyPool if config.worker_start_method == "subprocess": # Wrapper around ProcessPoolExecutor forks in a new process we control - pool = SubprocPool(get_compile_threads()) + pool = SubprocPool( + get_compile_threads(), quiesce=config.quiesce_async_compile_pool + ) else: if config.worker_start_method == "spawn": # Avoid creating pools in the spawned subprocs themselves: @@ -333,20 +322,6 @@ def use_process_pool(cls): cls._ready_future = cls.process_pool().submit(cls._get_ready) return cls._ready_future.done() - @classmethod - def quiesce(cls) -> None: - """ - If using a SubprocPool, signal the sidecar process to shut down its - ProcessPoolExecutor. - """ - # Don't inadvertently create a process pool if it doesn't already exist: - if not cls.process_pool.cache_info().currsize: - return - if config.quiesce_async_compile_pool: - pool = cls.process_pool() - if isinstance(pool, SubprocPool): - pool.quiesce() - @classmethod def wakeup(cls) -> None: """ From 2673f8b00705d9dd537f2bfcce6a5a1dbf4b2a31 Mon Sep 17 00:00:00 2001 From: Parshant Sharma Date: Tue, 4 Nov 2025 21:06:55 +0000 Subject: [PATCH 022/130] Fix torch.linalg.eig inductor stride mismatch (#162484) Fixes #159445 ### Summary - Fixed a stride layout issue in the `torch.linalg.eig` meta kernel that prevented successful compilation with the inductor backend. The meta kernel was producing incorrect row-major strides. - LAPACK/BLAS libraries (underlying implementation) expect column-major layout Pull Request resolved: https://github.com/pytorch/pytorch/pull/162484 Approved by: https://github.com/isuruf --- test/inductor/test_torchinductor.py | 16 ++++++++++++++++ .../test_torchinductor_codegen_dynamic_shapes.py | 3 +++ torch/_meta_registrations.py | 4 ++++ 3 files changed, 23 insertions(+) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index dad2de9bde327..ed8993a1c9a39 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -5876,6 +5876,22 @@ def fn(x, y): reference_in_float=False, ) + @skipIfMPS + def test_linalg_eig_stride_consistency(self): + def fn(x): + eigenvals, eigenvecs = torch.linalg.eig(x) + return eigenvecs + + x = torch.randn(5, 5, device=self.device, dtype=torch.float32) + + self.common( + fn, + [x], + exact_stride=True, + exact_dtype=True, + check_lowp=False, + ) + def test_view_as_complex(self): class Repro(torch.nn.Module): def __init__(self) -> None: diff --git a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py index 2244af38f635a..e73f82ab64911 100644 --- a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py +++ b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py @@ -159,6 +159,9 @@ def run(*ex, **kwargs): # "test_complex_fallback_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), "test_adaptive_avg_pool2d2_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), + "test_linalg_eig_stride_consistency_dynamic_shapes": TestFailure( + ("cpu", "cuda", "xpu") + ), "test_adaptive_max_pool2d2_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), "test_argmax_to_float_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), "test_avg_pool2d7_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index f84b77e630bf3..fe0492ff19c1c 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -1021,6 +1021,10 @@ def meta_linalg_eig(input: Tensor): ) values = input.new_empty(input.shape[:-1], dtype=complex_dtype) vectors = input.new_empty(input.shape, dtype=complex_dtype) + is_cuda = device_hint(input) == "cuda" + vectors.as_strided_( + input.shape, make_contiguous_strides_for(input.shape, row_major=is_cuda) + ) return values, vectors From 7f0e9321360cb13563a11bf9c720464c3dbf1ece Mon Sep 17 00:00:00 2001 From: William Wen Date: Tue, 4 Nov 2025 10:42:01 -0800 Subject: [PATCH 023/130] [dynamo] don't use LocalSource for temp variables created by side_effects (#166917) Fixes https://github.com/pytorch/pytorch/issues/166900 Implementation notes: - I tried to disallow guard generation before side effect application in order to futureproof improper guard generation. However, this was not feasible since it is possible to realize lazy VTs while generating side effects (e.g. realizing a constant variable that is used in a deque update). - `codegen_save_tempvars` now generates `TempLocalSource` for create temporary variables now, so that they won't get confused with `LocalSource` - we should error out when we attempt to create guards for `TempLocalSource`. I considered using `SyntheticLocalSource`, but that has additional `subguards_allowed` behavior that we may not want to have for temp variables. - We moved the guard installation for constant user-defined pytree objects from `as_python_constant` to `__init__`. Objects created outside the compile-region will be guarded, while objects created inside the compile-region will not be guarded. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166917 Approved by: https://github.com/anijain2305 --- test/dynamo/test_misc.py | 21 +++++++++++++++++++++ torch/_dynamo/side_effects.py | 8 ++++---- torch/_dynamo/source.py | 17 +++++++++++++++++ torch/_dynamo/variables/user_defined.py | 16 ++++++++++------ torch/_guards.py | 1 + 5 files changed, 53 insertions(+), 10 deletions(-) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 169f43ce0a077..b8727208a5bfa 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -13219,6 +13219,27 @@ def mapper(x): self.assertEqual(counter.frame_count, 1) self.assertEqual(counter.op_count, 9) + def test_pytree_register_constant_with_side_effect(self): + class Foo: + pass + + class Bar: + def __eq__(self, other): + return super().__eq__(other) + + def __hash__(self): + return 0 + + python_pytree.register_constant(Bar) + + @torch.compile(backend="eager", fullgraph=True) + def fn(x, obj): + obj.attr = {3: Bar()} + return x + 1 + + inp = torch.ones(3) + self.assertEqual(fn(inp, Foo()), inp + 1) + class TestTracer(JitTestCase): def test_jit_save(self): diff --git a/torch/_dynamo/side_effects.py b/torch/_dynamo/side_effects.py index bd38e9295a05a..688a05f26ae64 100644 --- a/torch/_dynamo/side_effects.py +++ b/torch/_dynamo/side_effects.py @@ -42,7 +42,7 @@ ) from .codegen import PyCodegen from .exc import SideEffectsError, unimplemented_v2 -from .source import GlobalSource, LocalCellSource, LocalSource, Source +from .source import GlobalSource, LocalCellSource, Source, TempLocalSource from .utils import is_frozen_dataclass, nn_module_new, object_new from .variables.base import ( AttributeMutation, @@ -704,7 +704,7 @@ def codegen_save_tempvars(self, cg: PyCodegen) -> None: ) cg.extend_output(create_call_function(0, False)) cg.add_cache(var) - var.source = LocalSource(cg.tempvars[var]) # type: ignore[attr-defined] + var.source = TempLocalSource(cg.tempvars[var]) # type: ignore[attr-defined] elif var.source is None: # pyrefly: ignore [bad-assignment] var.source = LocalCellSource(var.local_name) @@ -729,7 +729,7 @@ def codegen_save_tempvars(self, cg: PyCodegen) -> None: # `add_cache` generates STORE and consumes TOS, but we never # cleared it. TODO move this call into `add_cache` cg.clear_tos() - var.source = LocalSource(cg.tempvars[var]) + var.source = TempLocalSource(cg.tempvars[var]) elif isinstance(var, variables.AutogradFunctionContextVariable): unimplemented_v2( gb_type="AutogradFunctionContextVariable escaped Dynamo-traced region", @@ -764,7 +764,7 @@ def load_new_method() -> None: cg.extend_output(create_call_function(1 + len(var.init_args), False)) # type: ignore[attr-defined] cg.add_cache(var) - var.source = LocalSource(cg.tempvars[var]) + var.source = TempLocalSource(cg.tempvars[var]) for ctx, args in self.save_for_backward: cg(ctx.source) diff --git a/torch/_dynamo/source.py b/torch/_dynamo/source.py index 8edd8f7540e31..5be6b8ccbf41d 100644 --- a/torch/_dynamo/source.py +++ b/torch/_dynamo/source.py @@ -151,6 +151,23 @@ def name(self) -> str: return f"L[{repr(self.local_name)}]" +@dataclasses.dataclass(frozen=True) +class TempLocalSource(Source): + # like LocalSource, but cannot be guarded on + local_name: str + + def reconstruct(self, codegen: "PyCodegen") -> None: + codegen.append_output(codegen.create_load(self.local_name)) + + def guard_source(self) -> GuardSource: + return GuardSource.TEMP_LOCAL + + def name(self) -> str: + raise NotImplementedError( + "Cannot create guard on TempLocalSource - this is an internal Dynamo bug. Please file an issue on GitHub." + ) + + @dataclasses.dataclass(frozen=True) class SyntheticLocalSource(Source): local_name: str diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index 707ad7b3d9d18..085b5e0c648c5 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -968,6 +968,12 @@ def __init__( # rid of these workarounds here and in `GetAttrVariable`. self.attrs_directly_modifed_on_dict = set() + import torch.utils._pytree as pytree + + self.is_pytree_constant_class = pytree.is_constant_class(self.value_type) + if pytree.is_constant_class(self.value_type) and self.source: + install_guard(self.source.make_guard(GuardBuilder.EQUALS_MATCH)) + def __str__(self) -> str: inner = self.value_type.__name__ if inner in [ @@ -989,12 +995,10 @@ def python_type(self): return self.value_type def as_python_constant(self): - import torch.utils._pytree as pytree - - if pytree.is_constant_class(self.value_type): - if self.source is not None: - install_guard(self.source.make_guard(GuardBuilder.EQUALS_MATCH)) - return self.value + if self.is_pytree_constant_class and self.source: + # NOTE pytree constants created in the torch.compile region will + # NOT be guarded (even though they have a source set) + return self.value # TODO else try reconstructing the object by, e.g., leveraging side # effects and `as_python_constant`. return super().as_python_constant() diff --git a/torch/_guards.py b/torch/_guards.py index bac59965a3aef..b321c5f968b16 100644 --- a/torch/_guards.py +++ b/torch/_guards.py @@ -145,6 +145,7 @@ class GuardSource(enum.Enum): GLOBAL_UNSPECIALIZED_NN_MODULE = 13 LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE = 14 GLOBAL_UNSPECIALIZED_BUILTIN_NN_MODULE = 15 + TEMP_LOCAL = 16 def is_fsdp_module(self) -> bool: return self in (GuardSource.GLOBAL_FSDP_MODULE, GuardSource.LOCAL_FSDP_MODULE) From ed45c5f38df6aa419c67d139d932c2c94404223a Mon Sep 17 00:00:00 2001 From: Laith Sakka Date: Tue, 4 Nov 2025 09:14:26 -0800 Subject: [PATCH 024/130] Avoid DDE in narrow with unbacked start (#166361) Slice knows how to handle unbacked start, we do not need to offset start before calling slice, we can leave it for slice. The only edge case is when start<0 and start+length ==0 in that case slice and narrow would deviate, for that case we shall pass dim_size instead of start+length Pull Request resolved: https://github.com/pytorch/pytorch/pull/166361 Approved by: https://github.com/aorenste --- aten/src/ATen/native/TensorShape.cpp | 38 +++++++++++++++--- c10/core/SymBool.cpp | 14 +++++++ c10/core/SymBool.h | 6 +++ test/export/test_export.py | 31 +++++++++----- test/test_dynamic_shapes.py | 51 ++++++++++++++++++++++++ test/test_torchfuzz_repros.py | 5 ++- torch/_inductor/codegen/wrapper.py | 3 +- torch/fx/experimental/symbolic_shapes.py | 19 ++++++++- torch/utils/_sympy/printers.py | 36 +++++++++++++++++ 9 files changed, 184 insertions(+), 19 deletions(-) diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index 6df7761d822db..6136a6aa8c520 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -1,5 +1,6 @@ #include #include +#include #define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include #include @@ -1710,11 +1711,14 @@ Tensor narrow_symint( "], but got ", start, ")") - if (start < 0) { - start = start + cur_size; - } + // Bounds check without converting start: + // - If start < 0: need (start + cur_size) + length <= cur_size, i.e., start + + // length <= 0 + // - If start >= 0: need start + length <= cur_size + auto end = start + length; TORCH_SYM_CHECK( - start.sym_le(cur_size - length), + (start.sym_lt(0).sym_and((end).sym_le(0))) + .sym_or(start.sym_ge(0).sym_and((end).sym_le(cur_size))), "start (", start, ") + length (", @@ -1722,7 +1726,31 @@ Tensor narrow_symint( ") exceeds dimension size (", cur_size, ")."); - return at::slice_symint(self, dim, start, start + length, 1); + + if (TORCH_GUARD_OR_FALSE(start.sym_ge(0).sym_or(end.sym_ne(0)))) { + return at::slice_symint(self, dim, start, end, 1); + } else if (TORCH_GUARD_OR_FALSE(start.sym_lt(0))) { + // Avoid the complex symbolic expressions path for non-unbacked. + return at::slice_symint(self, dim, start + cur_size, end + cur_size, 1); + } else { + // Cannot statically determine the condition due to unbacked. + // This is an interesting situation; when start is negative and + // start + length == 0, slice and narrow do different things. + // i.e., x.narrow(0, -2, 2) != x[-2:0]; in that case, we want to + // pass curr_size instead of 0. Otherwise, they would do the same thing. + // This says at runtime: if start < 0 and end == 0, then pass curr_size + // instead of 0. + + auto use_different = start.sym_lt(0).sym_and(end.sym_eq(0)).toSymInt(); + auto result = + at::slice_symint(self, dim, start, end + use_different * cur_size, 1); + + // Ensure slice allocated unbacked size is specialized to length. + SymInt new_size = result.sym_size(dim); + TORCH_SYM_CHECK(new_size.sym_eq(length), "") + + return result; + } } // This overload exists purely for XLA, because they wanted to pass in diff --git a/c10/core/SymBool.cpp b/c10/core/SymBool.cpp index d804eb9d27409..48c407b8b069c 100644 --- a/c10/core/SymBool.cpp +++ b/c10/core/SymBool.cpp @@ -1,4 +1,5 @@ #include +#include #include namespace c10 { @@ -111,4 +112,17 @@ bool SymBool::has_hint() const { return toSymNodeImpl()->has_hint(); } +SymInt SymBool::toSymInt() const { + // If concrete bool, return concrete SymInt + if (auto ma = maybe_as_bool()) { + return SymInt(*ma ? 1 : 0); + } + + // Symbolic case: use sym_ite to convert bool to int (0 or 1) + auto node = toSymNodeImpl(); + auto one_node = node->wrap_int(1); + auto zero_node = node->wrap_int(0); + return SymInt(node->sym_ite(one_node, zero_node)); +} + } // namespace c10 diff --git a/c10/core/SymBool.h b/c10/core/SymBool.h index d5d509e239b1d..a27a28a5bf8a3 100644 --- a/c10/core/SymBool.h +++ b/c10/core/SymBool.h @@ -12,6 +12,8 @@ namespace c10 { +class SymInt; + class C10_API SymBool { public: /*implicit*/ SymBool(bool b) : data_(b) {} @@ -80,6 +82,10 @@ class C10_API SymBool { return toSymNodeImplUnowned()->constant_bool(); } + // Convert SymBool to SymInt (0 or 1) + // This is the C++ equivalent of Python's cast_symbool_to_symint_guardless + SymInt toSymInt() const; + bool is_heap_allocated() const { return ptr_; } diff --git a/test/export/test_export.py b/test/export/test_export.py index 3908f03b11e55..cdc18b1d4c564 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -6093,26 +6093,19 @@ def forward(self, x, y, fixes): retry_export( cf_implicitsize(), (torch.tensor(2), torch.randn(10)), - fixes=[ - # Could not guard on data-dependent expression u0 < 0 - "torch._check(i >= 0)", - ], + fixes=[], ) class cf_stacklist(torch.nn.Module): def forward(self, xs, y, fixes): i = y.item() eval(fixes) - # instead of xs[i] return torch.stack(xs, 0).narrow(0, i, 1).squeeze() retry_export( cf_stacklist(), ([torch.ones(5) * i for i in range(10)], torch.tensor(2)), - fixes=[ - # Could not guard on data-dependent expression u0 < 0 - "torch._check(i >= 0)", - ], + fixes=[], ) class cf_tensorsplit(torch.nn.Module): @@ -6166,7 +6159,12 @@ def test_no_suggested_fixes_for_data_dependent_errors(self): class cf_stacklist(torch.nn.Module): def forward(self, xs, y): # y.item() is not a local, so we can't suggest a fix - return torch.stack(xs, 0).narrow(0, y.item(), 1).squeeze() + if y.item() < 0: + return ( + torch.stack(xs, 0).narrow(0, y.item() + xs.size(), 1).squeeze() + ) + else: + return torch.stack(xs, 0).narrow(0, y.item(), 1).squeeze() with self.assertRaisesRegex( error_type, @@ -6196,7 +6194,18 @@ class cf_stacklist_udd(torch.nn.Module): def forward(self, xs, y): box = Box(y.item()) # box.content is not a local, so we can't suggest a fix - return torch.stack(xs, 0).narrow(0, box.content, 1).squeeze() + if box.content < 0: + return ( + torch.stack(xs, 0) + .narrow(0, box.content + xs.size(), 1) + .squeeze() + ) + else: + return ( + torch.stack(xs, 0) + .narrow(0, box.content + xs.size(), 1) + .squeeze() + ) with self.assertRaisesRegex( error_type, diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index fb1d22805d50a..b63e0427c26c3 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -4401,6 +4401,57 @@ def func(x, y): self.assertEqual(compiled(a, b), func(a, b)) + @fresh_cache() + @torch._dynamo.config.patch("capture_scalar_outputs", True) + def test_narrow_unbacked_start(self): + def func(x, start, length): + # unbacked start + u0 = start.item() + return torch.narrow(x, 0, u0, length) + + compiled_func = torch.compile(func, fullgraph=True, backend="inductor") + + x = torch.tensor([1, 2, 3, 4, 5, 6]) + + # Test cases: (start, length) + test_cases = [ + # Negative starts + (-2, 2), # Start from second-to-last element + (-1, 1), # Start from last element + (-3, 3), # Start from third-to-last element + (-6, 2), # Start from beginning (negative) + (-4, 1), # Start from fourth-to-last element + # Positive starts + (0, 2), # Start from beginning + (1, 3), # Start from second element + (2, 2), # Start from third element + (4, 2), # Start near end + # Edge cases + (0, 6), # Full tensor + (0, 1), # Single element from start + (5, 1), # Single element from end + ] + + for start_val, length in test_cases: + with self.subTest(start=start_val, length=length): + start = torch.tensor([start_val]) + + # Test with compiled function + result_compiled = compiled_func(x, start, length) + + # Test with eager function (expected behavior) + result_eager = func(x, start, length) + + # Compare results + self.assertEqual(result_compiled, result_eager) + + @fresh_cache() + @torch._dynamo.config.patch("capture_scalar_outputs", True) + @torch._inductor.config.patch("cpp_wrapper", True) + def test_narrow_unbacked_start_cpp_wrapper(self): + """Test narrow with unbacked start with cpp_wrapper""" + self.test_narrow_unbacked_start() + instantiate_parametrized_tests(TestUnbacked) diff --git a/test/test_torchfuzz_repros.py b/test/test_torchfuzz_repros.py index 3b864aae4f477..84a00430420cf 100644 --- a/test/test_torchfuzz_repros.py +++ b/test/test_torchfuzz_repros.py @@ -16,6 +16,10 @@ from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON +# Skip all tests in this file if CUDA is not available +pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") + + class TestFuzzerCompileIssues(TestCase): """Test cases for fuzzer-discovered eager/compile divergence issues.""" @@ -257,7 +261,6 @@ def foo(arg0, arg1): out_compiled.sum().backward() print("Compile Success! ✅") - @pytest.mark.xfail(reason="Issue #163971") def test_fuzzer_issue_163971(self): torch.manual_seed(0) diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index e629d9c7bdebd..947166cf216cd 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -2063,7 +2063,8 @@ def clamp_index(x): neg = self.codegen_sizevar( sympy.Max(0, sympy.Min(x + node.size, node.size)) ) - return f"{pos} if {x} >= 0 else {neg}" + x_cond = self.codegen_sizevar(x) + return f"{pos} if {x_cond} >= 0 else {neg}" def codegen_with_step(start_var, end_var, step): if step == 1: diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index aeccdfbe000db..693d25aea6130 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -547,6 +547,7 @@ def rebind_unbacked( assert shape_env is not None for raw_u0, path in bindings.items(): u1 = pytree.key_get(result, path) + # Sometimes, things were previously unbacked bindings become constants. # There are two situations this can happen. # @@ -602,7 +603,23 @@ def rebind_unbacked( if u1.node.hint is not None: continue - raw_u1 = u1.node.expr + # unbacked symbols bindings might be replaced to other backed or + # unbacked replacements. + # + # Example: + # u = x.item() + # torch._check(u == 5) + # + # The safest approach is to retrieve raw_u1 from u1.node._expr + # and perform the rebinding on the original unbacked symbol, + # even if it’s no longer directly referenced. + # + # In other words, we should always rebind the original symbol + # before any replacements are applied. + # u0 -> u0 == s1 + raw_u1 = u1.node._expr + + # TODO Do we still need this logic below? # Simplify SymBool binding if ( isinstance(raw_u1, sympy.Piecewise) diff --git a/torch/utils/_sympy/printers.py b/torch/utils/_sympy/printers.py index 526443577b3f8..915d0e5461f1e 100644 --- a/torch/utils/_sympy/printers.py +++ b/torch/utils/_sympy/printers.py @@ -306,6 +306,24 @@ def _print_RoundDecimal(self, expr: sympy.Expr) -> str: raise TypeError("ndigits must be an instance of sympy.Integer") return f"round({self._print(number)}, {ndigits})" + def _print_Piecewise(self, expr: sympy.Expr) -> str: + # Convert Piecewise(expr_cond_pairs) to nested ternary expressions + # Piecewise((e1, c1), (e2, c2), ..., (eN, cN)) + # becomes: e1 if c1 else (e2 if c2 else (... else eN)) + result: Optional[str] = None + for expr_i, cond_i in reversed(expr.args): + expr_str = self._print(expr_i) + if cond_i == True: # noqa: E712 + # This is the default case + result = expr_str + else: + cond_str = self._print(cond_i) + if result is None: + result = expr_str + else: + result = f"({expr_str} if {cond_str} else {result})" + return result if result else "0" + class CppPrinter(ExprPrinter): def _print_Integer(self, expr: sympy.Expr) -> str: @@ -327,6 +345,24 @@ def _print_Where(self, expr: sympy.Expr) -> str: ) return f"{c} ? {p} : {q}" + def _print_Piecewise(self, expr: sympy.Expr) -> str: + # Convert Piecewise(expr_cond_pairs) to nested ternary operators + # Piecewise((e1, c1), (e2, c2), ..., (eN, cN)) + # becomes: c1 ? e1 : (c2 ? e2 : (... : eN)) + result: Optional[str] = None + for expr_i, cond_i in reversed(expr.args): + expr_str = self.parenthesize(expr_i, PRECEDENCE["Atom"] - 0.5) + if cond_i == True: # noqa: E712 + # This is the default case + result = expr_str + else: + cond_str = self.parenthesize(cond_i, PRECEDENCE["Atom"] - 0.5) + if result is None: + result = expr_str + else: + result = f"{cond_str} ? {expr_str} : {result}" + return f"({result})" if result else "0" + def _print_ModularIndexing(self, expr: sympy.Expr) -> str: x, div, mod = expr.args x = self.doprint(x) From cdca63db8c0f30a6fcc181784411bbd8913aa1db Mon Sep 17 00:00:00 2001 From: Alexander Grund Date: Tue, 4 Nov 2025 21:28:14 +0000 Subject: [PATCH 025/130] Fix quoting in pytest_cache.py invocations (#166955) Especially the job identifier can contain spaces so needs to be quoted Fixes e.g. https://github.com/pytorch/pytorch/actions/runs/19063797853/job/54449422160#step:15:52 Pull Request resolved: https://github.com/pytorch/pytorch/pull/166955 Approved by: https://github.com/Skylion007 --- .github/actions/pytest-cache-download/action.yml | 12 ++++++------ .github/actions/pytest-cache-upload/action.yml | 16 ++++++++-------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/.github/actions/pytest-cache-download/action.yml b/.github/actions/pytest-cache-download/action.yml index 1406f962c4ca8..3f51f6a5525bc 100644 --- a/.github/actions/pytest-cache-download/action.yml +++ b/.github/actions/pytest-cache-download/action.yml @@ -38,9 +38,9 @@ runs: run: | python3 .github/scripts/pytest_cache.py \ --download \ - --cache_dir $GITHUB_WORKSPACE/$CACHE_DIR \ - --pr_identifier $GITHUB_REF \ - --job_identifier $JOB_IDENTIFIER \ - --temp_dir $RUNNER_TEMP \ - --repo $REPO \ - --bucket $BUCKET \ + --cache_dir "$GITHUB_WORKSPACE/$CACHE_DIR" \ + --pr_identifier "$GITHUB_REF" \ + --job_identifier "$JOB_IDENTIFIER" \ + --temp_dir "$RUNNER_TEMP" \ + --repo "$REPO" \ + --bucket "$BUCKET" \ diff --git a/.github/actions/pytest-cache-upload/action.yml b/.github/actions/pytest-cache-upload/action.yml index 2652d019075f7..9fbb63a760f27 100644 --- a/.github/actions/pytest-cache-upload/action.yml +++ b/.github/actions/pytest-cache-upload/action.yml @@ -47,11 +47,11 @@ runs: run: | python3 .github/scripts/pytest_cache.py \ --upload \ - --cache_dir $GITHUB_WORKSPACE/$CACHE_DIR \ - --pr_identifier $GITHUB_REF \ - --job_identifier $JOB_IDENTIFIER \ - --sha $SHA \ - --test_config $TEST_CONFIG \ - --shard $SHARD \ - --repo $REPO \ - --temp_dir $RUNNER_TEMP \ + --cache_dir "$GITHUB_WORKSPACE/$CACHE_DIR" \ + --pr_identifier "$GITHUB_REF" \ + --job_identifier "$JOB_IDENTIFIER" \ + --sha "$SHA" \ + --test_config "$TEST_CONFIG" \ + --shard "$SHARD" \ + --repo "$REPO" \ + --temp_dir "$RUNNER_TEMP" \ From a64c7d740428010d700b4bcd395af8a7b2d5c21f Mon Sep 17 00:00:00 2001 From: Pian Pawakapan Date: Tue, 4 Nov 2025 21:30:43 +0000 Subject: [PATCH 026/130] [DebugMode] output, tensor id annotations for DebugMode (#165076) Adds optional "node" id for tensors, output info annotations to DebugMode, with `DebugMode(record_output=True, record_ids=True)` Example output for `test_debug_mode_mm`, with both enabled: ``` torch.mm(dt$0: f32[8, 8]| S(0), dt$1: f32[8, 32]| S(0)) -> dt$12: f32[8, 32]| S(0) aten::mm(dt$2: f32[8, 8]| S(0), dt$3: f32[8, 32]| S(0)) redistribute_input(1, S(0) -> R) redistribute_input(t$4: f32[1, 32], trace: S(0)->R) _c10d_functional::all_gather_into_tensor(t$5: f32[1, 32], 8, 0) -> t$6: f32[8, 32] _c10d_functional::wait_tensor(t$7: f32[8, 32]) -> t$8: f32[8, 32] aten::mm(t$9: f32[1, 8], t$10: f32[8, 32]) -> t$11: f32[1, 32] (dt$13: f32[8, 32]| S(0)) -> dt$17: f32[]| P aten::sum(dt$14: f32[8, 32]| S(0)) aten::sum(t$15: f32[1, 32]) -> t$16: f32[]""" ``` Sadly the only way to get DTensor op outputs is to set `record_torchfunction=True`, as dispatch calls just defer to DTensor's dispatch logic. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165076 Approved by: https://github.com/zpcore --- .../tensor/debug/test_debug_mode.py | 22 +-- torch/utils/_debug_mode.py | 126 +++++++++++++++--- 2 files changed, 117 insertions(+), 31 deletions(-) diff --git a/test/distributed/tensor/debug/test_debug_mode.py b/test/distributed/tensor/debug/test_debug_mode.py index 07442f34c8946..9acfcb15804e5 100644 --- a/test/distributed/tensor/debug/test_debug_mode.py +++ b/test/distributed/tensor/debug/test_debug_mode.py @@ -42,22 +42,24 @@ def test_debug_mode_mm(self): x_dtensor = DTensor.from_local(x, mesh, [Shard(0)], run_check=False) y_dtensor = DTensor.from_local(y, mesh, [Shard(0)], run_check=False) - with DebugMode(record_torchfunction=True) as debug_mode: + with DebugMode( + record_torchfunction=True, record_ids=True, record_output=True + ) as debug_mode: torch.mm(x_dtensor, y_dtensor).sum() self.assertExpectedInline( debug_mode.debug_string(), """\ - torch.mm(dt: f32[8, 8]| S(0), dt: f32[8, 32]| S(0)) - aten::mm(dt: f32[8, 8]| S(0), dt: f32[8, 32]| S(0)) + torch.mm(dt$0: f32[8, 8]| S(0), dt$1: f32[8, 32]| S(0)) -> dt$6: f32[8, 32]| S(0) + aten::mm(dt$0: f32[8, 8]| S(0), dt$1: f32[8, 32]| S(0)) redistribute_input(1, S(0) -> R) - redistribute_input(t: f32[1, 32], trace: S(0)->R) - _c10d_functional::all_gather_into_tensor(t: f32[1, 32], 8, 0) - _c10d_functional::wait_tensor(t: f32[8, 32]) - aten::mm(t: f32[1, 8], t: f32[8, 32]) - (dt: f32[8, 32]| S(0)) - aten::sum(dt: f32[8, 32]| S(0)) - aten::sum(t: f32[1, 32])""", + redistribute_input(t$2: f32[1, 32], trace: S(0)->R) + _c10d_functional::all_gather_into_tensor(t$2: f32[1, 32], 8, 0) -> t$3: f32[8, 32] + _c10d_functional::wait_tensor(t$3: f32[8, 32]) -> t$3: f32[8, 32] + aten::mm(t$4: f32[1, 8], t$3: f32[8, 32]) -> t$5: f32[1, 32] + (dt$6: f32[8, 32]| S(0)) -> dt$8: f32[]| P + aten::sum(dt$6: f32[8, 32]| S(0)) + aten::sum(t$5: f32[1, 32]) -> t$7: f32[]""", ) self.assertTrue(isinstance(debug_mode.operators[0], _OpCall)) diff --git a/torch/utils/_debug_mode.py b/torch/utils/_debug_mode.py index 09435aa07e68b..5e24ce086e1aa 100644 --- a/torch/utils/_debug_mode.py +++ b/torch/utils/_debug_mode.py @@ -2,6 +2,7 @@ import contextlib import functools import traceback +import weakref from typing import Any, Callable, Optional, TYPE_CHECKING import torch @@ -14,6 +15,7 @@ ) from torch.utils._pytree import tree_all, tree_map from torch.utils._traceback import CapturedTraceback +from torch.utils.weak import WeakIdRef if TYPE_CHECKING: @@ -56,29 +58,48 @@ def _stringify_dtensor_spec(spec) -> str: return DTensorSpec.format_shard_order_str(spec.placements, spec.shard_order) -def _tensor_debug_string(tensor, attributes) -> str: +class TensorIdTracker: + def __init__(self): + self.tensor_memo: dict[WeakIdRef, int] = {} + self.next_tensor_id = 0 + + def _id(self, tensor) -> int: + with torch._C._DisablePythonDispatcher(): + o = WeakIdRef(tensor) + + def del_memo(): + self.tensor_memo.pop(o, None) + + weakref.finalize(tensor, del_memo) + if o not in self.tensor_memo: + self.tensor_memo[o] = self.next_tensor_id + self.next_tensor_id += 1 + return self.tensor_memo[o] + + +def _tensor_debug_string(tensor, attributes, tensor_memo=None) -> str: """Convert tensor to debug string representation.""" if isinstance(tensor, torch.Tensor): tensor_debug_str = f"{dtype_abbrs[tensor.dtype]}{_stringify_shape(tensor.shape)}{_stringify_attributes(tensor, attributes)}" - + id_str = f"${tensor_memo._id(tensor)}" if tensor_memo is not None else "" if isinstance(tensor, torch.distributed.tensor.DTensor): # omitted device mesh - return f"dt: {tensor_debug_str}| {_stringify_dtensor_spec(tensor._spec)}" + return f"dt{id_str}: {tensor_debug_str}| {_stringify_dtensor_spec(tensor._spec)}" elif isinstance(tensor, FakeTensor): - return f"ft: {tensor_debug_str}" + return f"ft{id_str}: {tensor_debug_str}" else: - return f"t: {tensor_debug_str}" + return f"t{id_str}: {tensor_debug_str}" else: raise RuntimeError(f"Unsupported tensor type: {type(tensor)}") -def _arg_to_str(arg, attributes) -> str: +def _arg_to_str(arg, attributes, tensor_memo=None) -> str: from torch.distributed.tensor._dtensor_spec import DTensorSpec def to_str(x): if isinstance(x, torch.Tensor): - return _tensor_debug_string(x, attributes) + return _tensor_debug_string(x, attributes, tensor_memo) elif isinstance(x, DTensorSpec): return _stringify_dtensor_spec(x) return x @@ -144,8 +165,11 @@ def __init__( # results from dispatch hooks self.record = record self.log = log + self.output_str: Optional[str] = None - def stringify_args(self, attributes: list[str]) -> None: + def stringify_args( + self, attributes: list[str], tensor_memo: Optional[TensorIdTracker] = None + ) -> None: """ To reduce memory consumption, this method stringifies args/kwargs, stores the result, and deletes original args/kwargs. """ @@ -153,6 +177,18 @@ def stringify_args(self, attributes: list[str]) -> None: "Subclasses must implement stringify_args(), even if no-op" ) + def stringify_output( + self, + output: Any, + attributes: list[str], + tensor_memo: Optional[TensorIdTracker] = None, + ) -> None: + """Store stringified version of call output in self.output_str""" + if tree_all(lambda x: x is None, output): + return + output_str = tree_map(lambda x: _arg_to_str(x, attributes, tensor_memo), output) + self.output_str = f" -> {str(output_str)}" + def render(self, attributes: list[str]) -> str: raise NotImplementedError("Subclasses must implement string render()") @@ -179,11 +215,16 @@ def __init__( self.args_str: Optional[str] = None self.kwargs_str: Optional[str] = None - def stringify_args(self, attributes: list[str]) -> None: - self.args_str = ", ".join(_arg_to_str(arg, attributes) for arg in self.args) + def stringify_args( + self, attributes: list[str], tensor_memo: Optional[TensorIdTracker] = None + ) -> None: + self.args_str = ", ".join( + _arg_to_str(arg, attributes, tensor_memo) for arg in self.args + ) if self.kwargs: self.kwargs_str = ", " + ", ".join( - f"{k}={_arg_to_str(v, attributes)}" for k, v in self.kwargs.items() + f"{k}={_arg_to_str(v, attributes, tensor_memo)}" + for k, v in self.kwargs.items() ) else: self.kwargs_str = "" @@ -215,6 +256,8 @@ def render(self, attributes: list[str]) -> str: base_str = f"{op_name}({args_str}{kwargs_str})" + if self.output_str: + base_str += self.output_str if self.log: base_str += f" # {self.log}" return base_str @@ -247,8 +290,10 @@ def __init__( self.arg_str: Optional[str] = None - def stringify_args(self, attributes: list[str]) -> None: - self.arg_str = f"{_arg_to_str(self.arg, attributes)}" + def stringify_args( + self, attributes: list[str], tensor_memo: Optional[TensorIdTracker] = None + ) -> None: + self.arg_str = f"{_arg_to_str(self.arg, attributes, tensor_memo)}" del self.arg def render(self, attributes: list[str]) -> str: @@ -263,7 +308,11 @@ def render(self, attributes: list[str]) -> str: src_placement_str = _arg_to_str(self.src_placement, attributes) dst_placement_str = _arg_to_str(self.dst_placement, attributes) placement_str = f"{src_placement_str} -> {dst_placement_str}" - return f"{REDISTRIBUTE_FUNC}({arg_str}, {placement_str})" + + base_str = f"{REDISTRIBUTE_FUNC}({arg_str}, {placement_str})" + if self.output_str: + base_str += self.output_str + return base_str def __iter__(self): # for BC; tuple(self) returns (op, placement info, kwargs, call_depth) @@ -288,7 +337,9 @@ def __init__(self, module_name: str, call_depth: int, stack: bool = False): super().__init__(call_depth, stack=stack) self.module_name = module_name - def stringify_args(self, attributes: list[str]) -> None: + def stringify_args( + self, attributes: list[str], tensor_memo: Optional[TensorIdTracker] = None + ) -> None: pass # nothing to stringify def render(self, attributes: list[str]) -> str: @@ -341,6 +392,8 @@ def __init__( record_nn_module=False, store_original_args=False, record_stack_trace=False, + record_output=False, + record_ids=False, ): super().__init__() import torch.distributed.tensor # noqa: F401 @@ -378,8 +431,24 @@ def __init__( # e.g. via DebugMode(record_stack_trace=True), or torch.autograd.set_detect_anomaly(). self.record_stack_trace = record_stack_trace + # Records call outputs in logs (e.g. for __torch_dispatch__, __torch_function__, redistribute_input) + self.record_output: bool = record_output + + # Annotates string dumps with graph-style tensor ids, e.g. op($1, $2) -> $3. + self.record_ids: bool = record_ids + + self.reset() + + def reset(self): self.operators = [] self.call_depth = 0 + self._tensor_memo = TensorIdTracker() + self._output_info: dict[int, object] = {} + + def _track_op_output(self, op_index, result): + """Assign IDs to output tensors and store in output_info""" + # self._track_tensor_ids(result) + self._output_info[op_index] = result # Without this override, running torch.compile under DebugMode # will force torch.compile to always use the “eager” backend @@ -390,20 +459,35 @@ def ignore_compile_internals(cls): def _record_call(self, call): if not self.store_original_args: - call.stringify_args(self.record_tensor_attributes) + call.stringify_args( + self.record_tensor_attributes, + self._tensor_memo if self.record_ids else None, + ) self.operators.append(call) + def _record_call_output(self, call, output): + if not self.record_output: + return + call.stringify_output( + output, + self.record_tensor_attributes, + self._tensor_memo if self.record_ids else None, + ) + def __torch_function__(self, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} - self._record_call( - _OpCall(func, args, kwargs, self.call_depth, stack=self.record_stack_trace) + call = _OpCall( + func, args, kwargs, self.call_depth, stack=self.record_stack_trace ) + self._record_call(call) try: self.call_depth += 1 - return func(*args, **kwargs) + result = func(*args, **kwargs) + self._record_call_output(call, result) + return result finally: self.call_depth -= 1 @@ -445,13 +529,13 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): result = func(*args, **kwargs) if call: + self._record_call_output(call, result) _run_dispatch_hooks(call, func, types, args, kwargs, result) return result def __enter__(self): - self.operators = [] - self.call_depth = 0 + self.reset() if self.record_torchfunction: torch._C._push_on_torch_function_stack(self) From e8052f2f99de1fb7284e38082ff5714e17cd9562 Mon Sep 17 00:00:00 2001 From: Shangdi Yu Date: Tue, 4 Nov 2025 11:20:38 -0800 Subject: [PATCH 027/130] Add model code stack trace to torch.profile (#166677) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ```python python test/test_fx.py -k profiler ``` Insert `torch._C._profiler._RecordFunctionFast` to fx graph codegen. We post-process the profiler dump using `map_recorded_events_to_aten_ops_with_stack_trace` to add the stack trace to the dump'd trace. `map_recorded_events_to_aten_ops_with_stack_trace` queries `fx.traceback._FX_METADATA_REGISTRY` for node metadata. Each graph module has a hash'd fake file name (e.g. `fx_generated__iv4zodvbcmdkhx77jrg7h2f2opebujhfmc6tf6nx7vioq244baw.py`), which is the key to the registry. One can do `fx_g.enrich_profiler_metadata()` to add debugging info. Or `fx_g.enrich_profiler_metadata(enable=False)` to remove. `aot_eager` makes calls `fx_g.enrich_profiler_metadata()` if TORCH_ENRICH_RPOFILER_STACK_TRACE is set or _dynamo.config.enrich_profiler_metadata=True. Screenshot 2025-10-31 at 4 40 52 PM Example code gen'd. ``` def forward(self, args_list): args_iter = iter(args_list) arg0_1 = next(args_iter) arg1_1 = next(args_iter) args_list.clear() _rf = torch._C._profiler._RecordFunctionFast('## fx_generated__iv4zodvbcmdkhx77jrg7h2f2opebujhfmc6tf6nx7vioq244baw.py ##'); _rf.__enter__() repeated_subgraph0 = self.repeated_subgraph0 _rf_invoke_subgraph = torch._C._profiler._RecordFunctionFast('## 3 ##'); _rf_invoke_subgraph.__enter__() invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', arg0_1, arg1_1); repeated_subgraph0 = arg0_1 = arg1_1 = None _rf_invoke_subgraph.__exit__(None, None, None) _rf_getitem = torch._C._profiler._RecordFunctionFast('## 4 ##'); _rf_getitem.__enter__() getitem = invoke_subgraph[0]; invoke_subgraph = None _rf_getitem.__exit__(None, None, None) return (getitem,) _rf.__exit__(None, None, None) def forward(self, arg0_1, arg1_1): _rf = torch._C._profiler._RecordFunctionFast('## fx_generated__ozpadpj5cxoalxeyopej33g2vvtvhxg4xsk7bhx7ldmcibtybyn.py ##'); _rf.__enter__() _rf_mul = torch._C._profiler._RecordFunctionFast('## 2 ##'); _rf_mul.__enter__() mul = torch.ops.aten.mul.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None _rf_mul.__exit__(None, None, None) _rf_sin = torch._C._profiler._RecordFunctionFast('## 3 ##'); _rf_sin.__enter__() sin = torch.ops.aten.sin.default(mul); mul = None _rf_sin.__exit__(None, None, None) _rf_add = torch._C._profiler._RecordFunctionFast('## 4 ##'); _rf_add.__enter__() add = torch.ops.aten.add.Tensor(sin, 5); sin = None _rf_add.__exit__(None, None, None) return (add,) _rf.__exit__(None, None, None) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/166677 Approved by: https://github.com/ezyang ghstack dependencies: #166676 --- ...t-fx_backcompat_function_signatures.expect | 2 +- test/test_fx.py | 180 ++++++++++++++++++ torch/autograd/profiler_util.py | 40 ++++ torch/fx/graph.py | 23 +++ torch/fx/graph_module.py | 16 +- torch/profiler/_utils.py | 169 +++++++++++++++- 6 files changed, 425 insertions(+), 5 deletions(-) diff --git a/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect b/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect index a404e15a977ee..12f6ba2228db8 100644 --- a/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect +++ b/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect @@ -23,7 +23,7 @@ torch.fx.graph.Graph.node_copy(self, node: torch.fx.node.Node, arg_transform: Ca torch.fx.graph.Graph.output(self, result: 'Argument', type_expr: Optional[Any] = None) torch.fx.graph.Graph.placeholder(self, name: str, type_expr: Optional[Any] = None, default_value: Any) -> torch.fx.node.Node torch.fx.graph.Graph.print_tabular(self) -torch.fx.graph.Graph.python_code(self, root_module: str, verbose: bool = False, include_stride: bool = False, include_device: bool = False, colored: bool = False, expanded_def: bool = False) -> torch.fx.graph.PythonCode +torch.fx.graph.Graph.python_code(self, root_module: str, verbose: bool = False, include_stride: bool = False, include_device: bool = False, colored: bool = False, expanded_def: bool = False, record_func: bool = False) -> torch.fx.graph.PythonCode torch.fx.graph_module.GraphModule.__init__(self, root: Union[torch.nn.modules.module.Module, Dict[str, Any]], graph: torch.fx.graph.Graph, class_name: str = 'GraphModule') torch.fx.graph_module.GraphModule.add_submodule(self, target: str, m: torch.nn.modules.module.Module) -> bool torch.fx.graph_module.GraphModule.delete_all_unused_submodules(self) -> None diff --git a/test/test_fx.py b/test/test_fx.py index d6f33d426aee7..c16c42805b921 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -75,6 +75,12 @@ ) from torch.testing._internal.jit_utils import JitTestCase +import json +import tempfile +from torch.profiler import profile, ProfilerActivity +from torch.profiler._utils import map_recorded_events_to_aten_ops_with_stack_trace +from torch.autograd.profiler_util import _canonicalize_profiler_events + try: from torchvision import models as torchvision_models @@ -201,6 +207,36 @@ def side_effect_func(x: torch.Tensor): print(x) +def _enrich_profiler_traces(prof): + """ + Helper function to extract and augment profiler events with stack traces. + + Args: + prof: A torch.profiler.profile object + + Returns: + A string representing enriched events + """ + with tempfile.NamedTemporaryFile(mode='w', suffix='.json') as f: + trace_file = f.name + prof.export_chrome_trace(trace_file) + + with open(trace_file) as f: + trace_data = json.load(f) + + map_recorded_events_to_aten_ops_with_stack_trace( + trace_data + ) + + events = [] + for event in trace_data["traceEvents"]: + if "args" in event and "stack_trace" in event["args"]: + events.append(event) + + actual_traces = _canonicalize_profiler_events(events) + return actual_traces + + class TestFX(JitTestCase): def setUp(self): super().setUp() @@ -4187,6 +4223,150 @@ def fn(a, b, c, d): # recorver mutable checking flag torch.fx.proxy.TracerBase.check_mutable_operations = orig_tracer_mutable_flag + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + @torch._dynamo.config.patch("enrich_profiler_metadata", True) + def test_profiler_stack_trace_augmentation(self): + """ + Test that map_recorded_events_to_aten_ops_with_stack_trace correctly + augments profiler events with stack traces from FX metadata registry. + """ + + # Simple test model + class TestModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(10, 16) + self.relu = torch.nn.ReLU() + self.linear2 = torch.nn.Linear(16, 10) + + def forward(self, x): + x = self.linear1(x) + x = self.relu(x) + x = self.linear2(x) + return x + + model = TestModel().cuda() + + # Compile the model + compiled_model = torch.compile(model, backend="aot_eager", fullgraph=True) + + # Warmup + for _ in range(3): + _ = compiled_model(torch.randn(10, 10, device="cuda")) + + # Profile with the compiled model + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + ) as prof: + result = compiled_model(torch.randn(10, 10, device="cuda")) + + actual_traces = _enrich_profiler_traces(prof) + + self.assertExpectedInline(actual_traces, """\ +event=aten::t node=t stack_trace=x = self.linear1(x) +event=aten::transpose node=t stack_trace=x = self.linear1(x) +event=aten::as_strided node=t stack_trace=x = self.linear1(x) +event=aten::addmm node=addmm stack_trace=x = self.linear1(x) +event=cudaLaunchKernel node=addmm stack_trace=x = self.linear1(x) +event=aten::relu node=relu stack_trace=x = self.relu(x) +event=aten::clamp_min node=relu stack_trace=x = self.relu(x) +event=cudaLaunchKernel node=relu stack_trace=x = self.relu(x) +event=aten::t node=t_1 stack_trace=x = self.linear2(x) +event=aten::transpose node=t_1 stack_trace=x = self.linear2(x) +event=aten::as_strided node=t_1 stack_trace=x = self.linear2(x) +event=aten::addmm node=addmm_1 stack_trace=x = self.linear2(x) +event=cudaLaunchKernel node=addmm_1 stack_trace=x = self.linear2(x)""" + ) + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + @torch._dynamo.config.patch("enrich_profiler_metadata", True) + def test_profiler_multiple_modules(self): + """ + Test that multiple compiled modules under the same profiler session + have their events correctly augmented with stack traces. + """ + + class ModelA(torch.nn.Module): + def forward(self, x): + return x + 1 + + class ModelB(torch.nn.Module): + def forward(self, x): + return x - 1 + + model_a = ModelA().cuda() + model_b = ModelB().cuda() + + # Compile both models + compiled_a = torch.compile(model_a, backend="aot_eager", fullgraph=True) + compiled_b = torch.compile(model_b, backend="aot_eager", fullgraph=True) + + # Warmup + for _ in range(3): + _ = compiled_a(torch.randn(10, 10, device="cuda")) + _ = compiled_b(torch.randn(1, 3, 8, 8, device="cuda")) + + # Profile both models in the same session + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + ) as prof: + result_a = compiled_a(torch.randn(10, 10, device="cuda")) + result_b = compiled_b(torch.randn(1, 3, 8, 8, device="cuda")) + + actual_traces = _enrich_profiler_traces(prof) + self.assertExpectedInline(actual_traces, """\ +event=aten::add node=add stack_trace=return x + 1 +event=cudaLaunchKernel node=add stack_trace=return x + 1 +event=aten::sub node=sub stack_trace=return x - 1 +event=cudaLaunchKernel node=sub stack_trace=return x - 1""" + ) + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + @torch._dynamo.config.patch("enrich_profiler_metadata", True) + def test_profiler_nested_graph_modules(self): + """ + Test that nested graph modules (e.g., graph modules calling subgraphs) + have their events correctly augmented with stack traces. + """ + + # Model with nested structure + class Mod(torch.nn.Module): + def __init__(self): + super().__init__() + self.c = 5 + + @torch.compiler.nested_compile_region + def forward(self, x, y): + m = torch.mul(x, y) + s = m.sin() + a = s + self.c + return a + + model = Mod().cuda() + + # Compile the model (this may create nested graph modules) + compiled_model = torch.compile(model, backend="aot_eager", fullgraph=True) + + # Warmup + for _ in range(3): + _ = compiled_model(torch.randn(10, 10, device="cuda"), torch.randn(10, 10, device="cuda")) + + # Profile + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + ) as prof: + result = compiled_model(torch.randn(10, 10, device="cuda"), torch.randn(10, 10, device="cuda")) + + actual_traces = _enrich_profiler_traces(prof) + self.assertExpectedInline(actual_traces, """\ +event=aten::mul node=mul stack_trace=m = torch.mul(x, y) +event=cudaLaunchKernel node=mul stack_trace=m = torch.mul(x, y) +event=aten::sin node=sin stack_trace=s = m.sin() +event=cudaLaunchKernel node=sin stack_trace=s = m.sin() +event=aten::add node=add stack_trace=a = s + self.c +event=cudaLaunchKernel node=add stack_trace=a = s + self.c""" + ) + def run_getitem_target(): from torch.fx._symbolic_trace import _wrapped_methods_to_patch diff --git a/torch/autograd/profiler_util.py b/torch/autograd/profiler_util.py index b2d6530049e61..4b8a6d221b4e0 100644 --- a/torch/autograd/profiler_util.py +++ b/torch/autograd/profiler_util.py @@ -1224,3 +1224,43 @@ def override_time_unit(time_us, default_str, time_unit): f"time total: {override_time_unit(sum_self_device_time_total, _format_time(sum_self_device_time_total), time_unit)}" ) return "".join(result) + + +# Collect all events with stack traces and format them canonically +def _canonicalize_profiler_events(events): + """ + Extract and format all events with stack traces in a canonical way + for deterministic testing. + """ + events_with_traces = [] + + for event in events: + # Extract relevant fields + event_name = event.get("name", "") + node_name = event["args"].get("node_name", "") + stack_trace = event["args"].get("stack_trace", "") + + # Get the last non-empty line of the stack trace + lines = [s.strip() for s in stack_trace.split("\n") if s.strip()] + stack_trace = lines[-1] if lines else "" + + events_with_traces.append( + { + "event_name": event_name[:20], + "node_name": node_name, + "stack_trace": stack_trace, + "start_time": event.get("ts", 0), + } + ) + + # Sort by node_name for deterministic ordering + events_with_traces.sort(key=lambda x: x["start_time"]) + + # Format as a string + lines = [] + for evt in events_with_traces: + lines.append( + f"event={evt['event_name']} node={evt['node_name']} stack_trace={evt['stack_trace']}" + ) + + return "\n".join(lines) diff --git a/torch/fx/graph.py b/torch/fx/graph.py index 697b2f4084ca5..fd6835d2b301b 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -443,6 +443,7 @@ def _gen_python_code( colored: bool = False, # Render each argument on its own line expanded_def: bool = False, + record_func: bool = False, ) -> PythonCode: free_vars: list[str] = [] body: list[str] = [] @@ -798,6 +799,10 @@ def _tensor_annotation(t: torch.Tensor) -> str: return raise NotImplementedError(f"node: {node.op} {node.target}") + if record_func: + body.append( + "_rf = torch._C._profiler._RecordFunctionFast('## ENTER_GRAPH_PLACEHOLDER_KEY ##'); _rf.__enter__()\n" + ) for i, node in enumerate(nodes): # NOTE: emit_node does not emit a string with newline. It depends # on delete_unused_values to append one @@ -807,8 +812,22 @@ def _tensor_annotation(t: torch.Tensor) -> str: # node index, which will be deleted later # after going through _body_transformer body.append(f"# COUNTER: {i}\n") + do_record = record_func and node.op in ( + "call_function", + "call_method", + "call_module", + ) + if do_record: + # The double hash ## convention is used by post-processing to find the fx markers + body.append( + f"_rf_{node.name} = torch._C._profiler._RecordFunctionFast('## {i} ##'); _rf_{node.name}.__enter__()\n" + ) emit_node(node) delete_unused_values(node) + if do_record: + body.append(f"_rf_{node.name}.__exit__(None, None, None)\n") + if record_func: + body.append("_rf.__exit__(None, None, None)\n") if len(body) == 0: # If the Graph has no non-placeholder nodes, no lines for the body @@ -1760,6 +1779,7 @@ def python_code( include_device: bool = False, colored: bool = False, expanded_def: bool = False, + record_func: bool = False, ) -> PythonCode: """ Turn this ``Graph`` into valid Python code. @@ -1827,6 +1847,7 @@ def override_node_repr(graph: Graph): include_device=include_device, colored=colored, expanded_def=expanded_def, + record_func=record_func, ) def _python_code( @@ -1839,6 +1860,7 @@ def _python_code( include_device: bool = False, colored: bool = False, expanded_def: bool = False, + record_func: bool = False, ) -> PythonCode: return self._codegen._gen_python_code( self.nodes, @@ -1849,6 +1871,7 @@ def _python_code( include_device=include_device, colored=colored, expanded_def=expanded_def, + record_func=record_func, ) def __str__(self) -> str: diff --git a/torch/fx/graph_module.py b/torch/fx/graph_module.py index 297f76732584f..8360c96630d6c 100644 --- a/torch/fx/graph_module.py +++ b/torch/fx/graph_module.py @@ -861,14 +861,18 @@ def recompile(self) -> PythonCode: if isinstance(self._graph._codegen, _PyTreeCodeGen): self._in_spec = self._graph._codegen.pytree_info.in_spec self._out_spec = self._graph._codegen.pytree_info.out_spec - python_code = self._graph.python_code(root_module="self") + + from torch._dynamo import config as dynamo_config + + python_code = self._graph.python_code( + root_module="self", record_func=dynamo_config.enrich_profiler_metadata + ) self._code = python_code.src self._lineno_map = python_code._lineno_map self._prologue_start = python_code._prologue_start cls = type(self) co_fields = self._graph._co_fields if hasattr(self._graph, "_co_fields") else {} - from torch._dynamo import config as dynamo_config if dynamo_config.enrich_profiler_metadata: # Generate metadata and register for profiler augmentation @@ -885,7 +889,6 @@ def recompile(self) -> PythonCode: # This ensures the same code+metadata always generates the same filename hash_value = _metadata_hash(self._code, node_metadata) file_stem = f"{FX_GRAPH_MODULE_FILE_PREFIX}_{hash_value}" - filename = f"{file_stem}.py" # Only include co_filename to use it directly as the cache key @@ -905,6 +908,13 @@ def recompile(self) -> PythonCode: _register_fx_metadata(filename, metadata) + # Replace the placeholder in generated code with actual filename + # The double hash ## convention is used by post-processing to find the fx markers + self._code = self._code.replace( + "torch._C._profiler._RecordFunctionFast('## ENTER_GRAPH_PLACEHOLDER_KEY ##')", + f"torch._C._profiler._RecordFunctionFast('## {filename} ##')", + ) + cls.forward = _forward_from_src(self._code, python_code.globals, co_fields) # Determine whether this class explicitly defines a __call__ implementation diff --git a/torch/profiler/_utils.py b/torch/profiler/_utils.py index 2c6e06b2cb3c9..47df87ce1678d 100644 --- a/torch/profiler/_utils.py +++ b/torch/profiler/_utils.py @@ -4,7 +4,7 @@ import re from collections import deque from dataclasses import dataclass -from typing import TYPE_CHECKING +from typing import Any, Literal, Optional, TYPE_CHECKING from torch.autograd.profiler import profile from torch.profiler import DeviceType @@ -400,3 +400,170 @@ def _init_for_cuda_graphs() -> None: with profile(): pass + + +@dataclass +class TimelineEvent: + """Represents an event in the profiler timeline.""" + + timestamp: int + event_type: Literal["start", "end", "regular"] + marker_type: Optional[Literal["filename", "node"]] + identifier: Optional[str | int] + event: dict[str, Any] + + +@dataclass +class ContextStackEntry: + """Represents a context (filename or node) in the stack.""" + + context_type: Literal["filename", "node"] + identifier: str | int + metadata: Optional[dict] + tid: Optional[int] = None # Thread ID associated with this context + + +def map_recorded_events_to_aten_ops_with_stack_trace(traced_data): + """ + Maps recorded profiler events to their corresponding fx nodes and adds stack traces. + + Builds a timeline of all events (regular ops and FX markers for filenames/nodes), + sorts by timestamp, then processes chronologically while maintaining a context stack of active + filename/node scopes. Regular events are augmented with stack traces and node names from the + innermost active context. Runtime is O(n log n) for n events. + + Args: + traced_data: Json of profiler events from Chrome trace + + Returns: + Dict mapping recorded event names to their aten operations with added stack traces + """ + from torch.fx.traceback import _FX_METADATA_REGISTRY + + trace_events = traced_data.get("traceEvents", []) + + # Create event timeline + event_timeline: list[TimelineEvent] = [] + + def is_fx_marker_event(event): + return ( + event.get("cat") == "cpu_op" + and event.get("name", "").startswith("## ") + and event.get("name", "").endswith(" ##") + ) + + def append_fx_marker_event(event_type, identifier, event): + start_ts = event["ts"] + end_ts = start_ts + event["dur"] + event_timeline.append( + TimelineEvent(start_ts, "start", event_type, identifier, event) + ) + event_timeline.append( + TimelineEvent(end_ts, "end", event_type, identifier, event) + ) + + for event in trace_events: + if "ts" not in event or "dur" not in event: + continue + + if is_fx_marker_event(event): + content = event["name"][3:-3] + + if content.endswith(".py"): + append_fx_marker_event("filename", content, event) + else: + try: + node_index = int(content) + except ValueError: + pass + append_fx_marker_event("node", node_index, event) # type: ignore[possibly-undefined] + + else: + # Regular event that needs augmentation + start_ts = event["ts"] + event_timeline.append(TimelineEvent(start_ts, "regular", None, None, event)) + + # Sort by timestamp + event_timeline.sort(key=lambda x: x.timestamp) + + # Process events in chronological order with a stack + context_stack: list[ContextStackEntry] = [] + + # Invariant: all start event has a corresponding end event + for timeline_event in event_timeline: + match timeline_event.event_type: + case "start": + assert timeline_event.identifier is not None + + if timeline_event.marker_type == "filename": + assert isinstance(timeline_event.identifier, str) + # Push filename context - query metadata registry on-demand + metadata = _FX_METADATA_REGISTRY.get(timeline_event.identifier) + tid = timeline_event.event.get("tid") + context_stack.append( + ContextStackEntry( + "filename", timeline_event.identifier, metadata, tid + ) + ) + elif timeline_event.marker_type == "node": + # Find the current filename from stack + current_file_metadata = None + tid = timeline_event.event.get("tid") + for ctx_entry in reversed(context_stack): + if ( + ctx_entry.context_type == "filename" + and ctx_entry.tid == tid + ): + current_file_metadata = ctx_entry.metadata + break + + if current_file_metadata: + node_metadata = current_file_metadata.get("node_metadata", {}) + if timeline_event.identifier in node_metadata: + node_meta: Optional[dict] = node_metadata[ + timeline_event.identifier + ] + context_stack.append( + ContextStackEntry( + "node", timeline_event.identifier, node_meta, tid + ) + ) + + case "end": + # Pop from stack - search backwards to find matching context + for i in range(len(context_stack) - 1, -1, -1): + ctx_entry = context_stack[i] + if ( + timeline_event.marker_type == ctx_entry.context_type + and timeline_event.identifier == ctx_entry.identifier + ): + context_stack.pop(i) + break + + case "regular": + # Apply metadata from current context stack + # Find the most specific context (node takes precedence over filename) + # Only augment events with the same tid as the file/node event matched + current_stack_trace = None + current_node_name = None + event_tid = timeline_event.event.get("tid") + + for ctx_entry in reversed(context_stack): + # Only apply metadata from contexts with matching tid + if ctx_entry.tid == event_tid: + if ctx_entry.context_type == "node" and ctx_entry.metadata: + current_stack_trace = ctx_entry.metadata.get( + "stack_trace", "No model stack trace available" + ) + current_node_name = ctx_entry.metadata.get("name", "") + # Do we want to only attach the stack trace of the lowest node or stack trace of all nodes + # if nodes are nested, e.g. in nested graph modules + break + + # Augment the event + if current_stack_trace or current_node_name: + args = timeline_event.event.setdefault("args", {}) + if current_stack_trace: + args["stack_trace"] = current_stack_trace + if current_node_name: + args["node_name"] = current_node_name From e020fb3431371ea335a0d5db5094810c9f1e104d Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Tue, 4 Nov 2025 22:09:24 +0000 Subject: [PATCH 028/130] [Minor][Inductor] move some combo kernel log from warning to debug (#166993) Combo kernel warns for long reduction and large pointwise. This becomes too spammy for users such as vLLM. This PR moves these logs from warn to debug. I validated the spammy log is removed on llama-3.1-8B. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166993 Approved by: https://github.com/zou3519, https://github.com/eellison --- torch/_inductor/codegen/triton_combo_kernel.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch/_inductor/codegen/triton_combo_kernel.py b/torch/_inductor/codegen/triton_combo_kernel.py index e86753348c6b1..3e58e95ef9e9c 100644 --- a/torch/_inductor/codegen/triton_combo_kernel.py +++ b/torch/_inductor/codegen/triton_combo_kernel.py @@ -98,7 +98,7 @@ def _default_custom_combo_kernel_horizontal_partition( ] short_reduction = [n for n in reduction if n not in long_reduction] if long_reduction: - log.warning( + log.debug( "ComboKernels: %d long reduction nodes are separated", len(long_reduction), ) @@ -112,7 +112,7 @@ def _default_custom_combo_kernel_horizontal_partition( ] if large_pointwise: # TODO benchmark the performance when large pointwise nodes combining with others - log.warning( + log.debug( "ComboKernels: %d large pointwise nodes are separated", len(large_pointwise), ) From 81038fd3268074a43d0b8fc4de9cf22a6d71a896 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 4 Nov 2025 22:26:35 +0000 Subject: [PATCH 029/130] Revert "Add model code stack trace to torch.profile (#166677)" This reverts commit e8052f2f99de1fb7284e38082ff5714e17cd9562. Reverted https://github.com/pytorch/pytorch/pull/166677 on behalf of https://github.com/malfet due to Broke lint, please rebase, we've moved from mypy to pyrefly ([comment](https://github.com/pytorch/pytorch/pull/166677#issuecomment-3488219996)) --- ...t-fx_backcompat_function_signatures.expect | 2 +- test/test_fx.py | 180 ------------------ torch/autograd/profiler_util.py | 40 ---- torch/fx/graph.py | 23 --- torch/fx/graph_module.py | 16 +- torch/profiler/_utils.py | 169 +--------------- 6 files changed, 5 insertions(+), 425 deletions(-) diff --git a/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect b/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect index 12f6ba2228db8..a404e15a977ee 100644 --- a/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect +++ b/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect @@ -23,7 +23,7 @@ torch.fx.graph.Graph.node_copy(self, node: torch.fx.node.Node, arg_transform: Ca torch.fx.graph.Graph.output(self, result: 'Argument', type_expr: Optional[Any] = None) torch.fx.graph.Graph.placeholder(self, name: str, type_expr: Optional[Any] = None, default_value: Any) -> torch.fx.node.Node torch.fx.graph.Graph.print_tabular(self) -torch.fx.graph.Graph.python_code(self, root_module: str, verbose: bool = False, include_stride: bool = False, include_device: bool = False, colored: bool = False, expanded_def: bool = False, record_func: bool = False) -> torch.fx.graph.PythonCode +torch.fx.graph.Graph.python_code(self, root_module: str, verbose: bool = False, include_stride: bool = False, include_device: bool = False, colored: bool = False, expanded_def: bool = False) -> torch.fx.graph.PythonCode torch.fx.graph_module.GraphModule.__init__(self, root: Union[torch.nn.modules.module.Module, Dict[str, Any]], graph: torch.fx.graph.Graph, class_name: str = 'GraphModule') torch.fx.graph_module.GraphModule.add_submodule(self, target: str, m: torch.nn.modules.module.Module) -> bool torch.fx.graph_module.GraphModule.delete_all_unused_submodules(self) -> None diff --git a/test/test_fx.py b/test/test_fx.py index c16c42805b921..d6f33d426aee7 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -75,12 +75,6 @@ ) from torch.testing._internal.jit_utils import JitTestCase -import json -import tempfile -from torch.profiler import profile, ProfilerActivity -from torch.profiler._utils import map_recorded_events_to_aten_ops_with_stack_trace -from torch.autograd.profiler_util import _canonicalize_profiler_events - try: from torchvision import models as torchvision_models @@ -207,36 +201,6 @@ def side_effect_func(x: torch.Tensor): print(x) -def _enrich_profiler_traces(prof): - """ - Helper function to extract and augment profiler events with stack traces. - - Args: - prof: A torch.profiler.profile object - - Returns: - A string representing enriched events - """ - with tempfile.NamedTemporaryFile(mode='w', suffix='.json') as f: - trace_file = f.name - prof.export_chrome_trace(trace_file) - - with open(trace_file) as f: - trace_data = json.load(f) - - map_recorded_events_to_aten_ops_with_stack_trace( - trace_data - ) - - events = [] - for event in trace_data["traceEvents"]: - if "args" in event and "stack_trace" in event["args"]: - events.append(event) - - actual_traces = _canonicalize_profiler_events(events) - return actual_traces - - class TestFX(JitTestCase): def setUp(self): super().setUp() @@ -4223,150 +4187,6 @@ def fn(a, b, c, d): # recorver mutable checking flag torch.fx.proxy.TracerBase.check_mutable_operations = orig_tracer_mutable_flag - @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") - @torch._dynamo.config.patch("enrich_profiler_metadata", True) - def test_profiler_stack_trace_augmentation(self): - """ - Test that map_recorded_events_to_aten_ops_with_stack_trace correctly - augments profiler events with stack traces from FX metadata registry. - """ - - # Simple test model - class TestModel(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear1 = torch.nn.Linear(10, 16) - self.relu = torch.nn.ReLU() - self.linear2 = torch.nn.Linear(16, 10) - - def forward(self, x): - x = self.linear1(x) - x = self.relu(x) - x = self.linear2(x) - return x - - model = TestModel().cuda() - - # Compile the model - compiled_model = torch.compile(model, backend="aot_eager", fullgraph=True) - - # Warmup - for _ in range(3): - _ = compiled_model(torch.randn(10, 10, device="cuda")) - - # Profile with the compiled model - with profile( - activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], - ) as prof: - result = compiled_model(torch.randn(10, 10, device="cuda")) - - actual_traces = _enrich_profiler_traces(prof) - - self.assertExpectedInline(actual_traces, """\ -event=aten::t node=t stack_trace=x = self.linear1(x) -event=aten::transpose node=t stack_trace=x = self.linear1(x) -event=aten::as_strided node=t stack_trace=x = self.linear1(x) -event=aten::addmm node=addmm stack_trace=x = self.linear1(x) -event=cudaLaunchKernel node=addmm stack_trace=x = self.linear1(x) -event=aten::relu node=relu stack_trace=x = self.relu(x) -event=aten::clamp_min node=relu stack_trace=x = self.relu(x) -event=cudaLaunchKernel node=relu stack_trace=x = self.relu(x) -event=aten::t node=t_1 stack_trace=x = self.linear2(x) -event=aten::transpose node=t_1 stack_trace=x = self.linear2(x) -event=aten::as_strided node=t_1 stack_trace=x = self.linear2(x) -event=aten::addmm node=addmm_1 stack_trace=x = self.linear2(x) -event=cudaLaunchKernel node=addmm_1 stack_trace=x = self.linear2(x)""" - ) - - @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") - @torch._dynamo.config.patch("enrich_profiler_metadata", True) - def test_profiler_multiple_modules(self): - """ - Test that multiple compiled modules under the same profiler session - have their events correctly augmented with stack traces. - """ - - class ModelA(torch.nn.Module): - def forward(self, x): - return x + 1 - - class ModelB(torch.nn.Module): - def forward(self, x): - return x - 1 - - model_a = ModelA().cuda() - model_b = ModelB().cuda() - - # Compile both models - compiled_a = torch.compile(model_a, backend="aot_eager", fullgraph=True) - compiled_b = torch.compile(model_b, backend="aot_eager", fullgraph=True) - - # Warmup - for _ in range(3): - _ = compiled_a(torch.randn(10, 10, device="cuda")) - _ = compiled_b(torch.randn(1, 3, 8, 8, device="cuda")) - - # Profile both models in the same session - with profile( - activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], - ) as prof: - result_a = compiled_a(torch.randn(10, 10, device="cuda")) - result_b = compiled_b(torch.randn(1, 3, 8, 8, device="cuda")) - - actual_traces = _enrich_profiler_traces(prof) - self.assertExpectedInline(actual_traces, """\ -event=aten::add node=add stack_trace=return x + 1 -event=cudaLaunchKernel node=add stack_trace=return x + 1 -event=aten::sub node=sub stack_trace=return x - 1 -event=cudaLaunchKernel node=sub stack_trace=return x - 1""" - ) - - @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") - @torch._dynamo.config.patch("enrich_profiler_metadata", True) - def test_profiler_nested_graph_modules(self): - """ - Test that nested graph modules (e.g., graph modules calling subgraphs) - have their events correctly augmented with stack traces. - """ - - # Model with nested structure - class Mod(torch.nn.Module): - def __init__(self): - super().__init__() - self.c = 5 - - @torch.compiler.nested_compile_region - def forward(self, x, y): - m = torch.mul(x, y) - s = m.sin() - a = s + self.c - return a - - model = Mod().cuda() - - # Compile the model (this may create nested graph modules) - compiled_model = torch.compile(model, backend="aot_eager", fullgraph=True) - - # Warmup - for _ in range(3): - _ = compiled_model(torch.randn(10, 10, device="cuda"), torch.randn(10, 10, device="cuda")) - - # Profile - with profile( - activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], - ) as prof: - result = compiled_model(torch.randn(10, 10, device="cuda"), torch.randn(10, 10, device="cuda")) - - actual_traces = _enrich_profiler_traces(prof) - self.assertExpectedInline(actual_traces, """\ -event=aten::mul node=mul stack_trace=m = torch.mul(x, y) -event=cudaLaunchKernel node=mul stack_trace=m = torch.mul(x, y) -event=aten::sin node=sin stack_trace=s = m.sin() -event=cudaLaunchKernel node=sin stack_trace=s = m.sin() -event=aten::add node=add stack_trace=a = s + self.c -event=cudaLaunchKernel node=add stack_trace=a = s + self.c""" - ) - def run_getitem_target(): from torch.fx._symbolic_trace import _wrapped_methods_to_patch diff --git a/torch/autograd/profiler_util.py b/torch/autograd/profiler_util.py index 4b8a6d221b4e0..b2d6530049e61 100644 --- a/torch/autograd/profiler_util.py +++ b/torch/autograd/profiler_util.py @@ -1224,43 +1224,3 @@ def override_time_unit(time_us, default_str, time_unit): f"time total: {override_time_unit(sum_self_device_time_total, _format_time(sum_self_device_time_total), time_unit)}" ) return "".join(result) - - -# Collect all events with stack traces and format them canonically -def _canonicalize_profiler_events(events): - """ - Extract and format all events with stack traces in a canonical way - for deterministic testing. - """ - events_with_traces = [] - - for event in events: - # Extract relevant fields - event_name = event.get("name", "") - node_name = event["args"].get("node_name", "") - stack_trace = event["args"].get("stack_trace", "") - - # Get the last non-empty line of the stack trace - lines = [s.strip() for s in stack_trace.split("\n") if s.strip()] - stack_trace = lines[-1] if lines else "" - - events_with_traces.append( - { - "event_name": event_name[:20], - "node_name": node_name, - "stack_trace": stack_trace, - "start_time": event.get("ts", 0), - } - ) - - # Sort by node_name for deterministic ordering - events_with_traces.sort(key=lambda x: x["start_time"]) - - # Format as a string - lines = [] - for evt in events_with_traces: - lines.append( - f"event={evt['event_name']} node={evt['node_name']} stack_trace={evt['stack_trace']}" - ) - - return "\n".join(lines) diff --git a/torch/fx/graph.py b/torch/fx/graph.py index fd6835d2b301b..697b2f4084ca5 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -443,7 +443,6 @@ def _gen_python_code( colored: bool = False, # Render each argument on its own line expanded_def: bool = False, - record_func: bool = False, ) -> PythonCode: free_vars: list[str] = [] body: list[str] = [] @@ -799,10 +798,6 @@ def _tensor_annotation(t: torch.Tensor) -> str: return raise NotImplementedError(f"node: {node.op} {node.target}") - if record_func: - body.append( - "_rf = torch._C._profiler._RecordFunctionFast('## ENTER_GRAPH_PLACEHOLDER_KEY ##'); _rf.__enter__()\n" - ) for i, node in enumerate(nodes): # NOTE: emit_node does not emit a string with newline. It depends # on delete_unused_values to append one @@ -812,22 +807,8 @@ def _tensor_annotation(t: torch.Tensor) -> str: # node index, which will be deleted later # after going through _body_transformer body.append(f"# COUNTER: {i}\n") - do_record = record_func and node.op in ( - "call_function", - "call_method", - "call_module", - ) - if do_record: - # The double hash ## convention is used by post-processing to find the fx markers - body.append( - f"_rf_{node.name} = torch._C._profiler._RecordFunctionFast('## {i} ##'); _rf_{node.name}.__enter__()\n" - ) emit_node(node) delete_unused_values(node) - if do_record: - body.append(f"_rf_{node.name}.__exit__(None, None, None)\n") - if record_func: - body.append("_rf.__exit__(None, None, None)\n") if len(body) == 0: # If the Graph has no non-placeholder nodes, no lines for the body @@ -1779,7 +1760,6 @@ def python_code( include_device: bool = False, colored: bool = False, expanded_def: bool = False, - record_func: bool = False, ) -> PythonCode: """ Turn this ``Graph`` into valid Python code. @@ -1847,7 +1827,6 @@ def override_node_repr(graph: Graph): include_device=include_device, colored=colored, expanded_def=expanded_def, - record_func=record_func, ) def _python_code( @@ -1860,7 +1839,6 @@ def _python_code( include_device: bool = False, colored: bool = False, expanded_def: bool = False, - record_func: bool = False, ) -> PythonCode: return self._codegen._gen_python_code( self.nodes, @@ -1871,7 +1849,6 @@ def _python_code( include_device=include_device, colored=colored, expanded_def=expanded_def, - record_func=record_func, ) def __str__(self) -> str: diff --git a/torch/fx/graph_module.py b/torch/fx/graph_module.py index 8360c96630d6c..297f76732584f 100644 --- a/torch/fx/graph_module.py +++ b/torch/fx/graph_module.py @@ -861,18 +861,14 @@ def recompile(self) -> PythonCode: if isinstance(self._graph._codegen, _PyTreeCodeGen): self._in_spec = self._graph._codegen.pytree_info.in_spec self._out_spec = self._graph._codegen.pytree_info.out_spec - - from torch._dynamo import config as dynamo_config - - python_code = self._graph.python_code( - root_module="self", record_func=dynamo_config.enrich_profiler_metadata - ) + python_code = self._graph.python_code(root_module="self") self._code = python_code.src self._lineno_map = python_code._lineno_map self._prologue_start = python_code._prologue_start cls = type(self) co_fields = self._graph._co_fields if hasattr(self._graph, "_co_fields") else {} + from torch._dynamo import config as dynamo_config if dynamo_config.enrich_profiler_metadata: # Generate metadata and register for profiler augmentation @@ -889,6 +885,7 @@ def recompile(self) -> PythonCode: # This ensures the same code+metadata always generates the same filename hash_value = _metadata_hash(self._code, node_metadata) file_stem = f"{FX_GRAPH_MODULE_FILE_PREFIX}_{hash_value}" + filename = f"{file_stem}.py" # Only include co_filename to use it directly as the cache key @@ -908,13 +905,6 @@ def recompile(self) -> PythonCode: _register_fx_metadata(filename, metadata) - # Replace the placeholder in generated code with actual filename - # The double hash ## convention is used by post-processing to find the fx markers - self._code = self._code.replace( - "torch._C._profiler._RecordFunctionFast('## ENTER_GRAPH_PLACEHOLDER_KEY ##')", - f"torch._C._profiler._RecordFunctionFast('## {filename} ##')", - ) - cls.forward = _forward_from_src(self._code, python_code.globals, co_fields) # Determine whether this class explicitly defines a __call__ implementation diff --git a/torch/profiler/_utils.py b/torch/profiler/_utils.py index 47df87ce1678d..2c6e06b2cb3c9 100644 --- a/torch/profiler/_utils.py +++ b/torch/profiler/_utils.py @@ -4,7 +4,7 @@ import re from collections import deque from dataclasses import dataclass -from typing import Any, Literal, Optional, TYPE_CHECKING +from typing import TYPE_CHECKING from torch.autograd.profiler import profile from torch.profiler import DeviceType @@ -400,170 +400,3 @@ def _init_for_cuda_graphs() -> None: with profile(): pass - - -@dataclass -class TimelineEvent: - """Represents an event in the profiler timeline.""" - - timestamp: int - event_type: Literal["start", "end", "regular"] - marker_type: Optional[Literal["filename", "node"]] - identifier: Optional[str | int] - event: dict[str, Any] - - -@dataclass -class ContextStackEntry: - """Represents a context (filename or node) in the stack.""" - - context_type: Literal["filename", "node"] - identifier: str | int - metadata: Optional[dict] - tid: Optional[int] = None # Thread ID associated with this context - - -def map_recorded_events_to_aten_ops_with_stack_trace(traced_data): - """ - Maps recorded profiler events to their corresponding fx nodes and adds stack traces. - - Builds a timeline of all events (regular ops and FX markers for filenames/nodes), - sorts by timestamp, then processes chronologically while maintaining a context stack of active - filename/node scopes. Regular events are augmented with stack traces and node names from the - innermost active context. Runtime is O(n log n) for n events. - - Args: - traced_data: Json of profiler events from Chrome trace - - Returns: - Dict mapping recorded event names to their aten operations with added stack traces - """ - from torch.fx.traceback import _FX_METADATA_REGISTRY - - trace_events = traced_data.get("traceEvents", []) - - # Create event timeline - event_timeline: list[TimelineEvent] = [] - - def is_fx_marker_event(event): - return ( - event.get("cat") == "cpu_op" - and event.get("name", "").startswith("## ") - and event.get("name", "").endswith(" ##") - ) - - def append_fx_marker_event(event_type, identifier, event): - start_ts = event["ts"] - end_ts = start_ts + event["dur"] - event_timeline.append( - TimelineEvent(start_ts, "start", event_type, identifier, event) - ) - event_timeline.append( - TimelineEvent(end_ts, "end", event_type, identifier, event) - ) - - for event in trace_events: - if "ts" not in event or "dur" not in event: - continue - - if is_fx_marker_event(event): - content = event["name"][3:-3] - - if content.endswith(".py"): - append_fx_marker_event("filename", content, event) - else: - try: - node_index = int(content) - except ValueError: - pass - append_fx_marker_event("node", node_index, event) # type: ignore[possibly-undefined] - - else: - # Regular event that needs augmentation - start_ts = event["ts"] - event_timeline.append(TimelineEvent(start_ts, "regular", None, None, event)) - - # Sort by timestamp - event_timeline.sort(key=lambda x: x.timestamp) - - # Process events in chronological order with a stack - context_stack: list[ContextStackEntry] = [] - - # Invariant: all start event has a corresponding end event - for timeline_event in event_timeline: - match timeline_event.event_type: - case "start": - assert timeline_event.identifier is not None - - if timeline_event.marker_type == "filename": - assert isinstance(timeline_event.identifier, str) - # Push filename context - query metadata registry on-demand - metadata = _FX_METADATA_REGISTRY.get(timeline_event.identifier) - tid = timeline_event.event.get("tid") - context_stack.append( - ContextStackEntry( - "filename", timeline_event.identifier, metadata, tid - ) - ) - elif timeline_event.marker_type == "node": - # Find the current filename from stack - current_file_metadata = None - tid = timeline_event.event.get("tid") - for ctx_entry in reversed(context_stack): - if ( - ctx_entry.context_type == "filename" - and ctx_entry.tid == tid - ): - current_file_metadata = ctx_entry.metadata - break - - if current_file_metadata: - node_metadata = current_file_metadata.get("node_metadata", {}) - if timeline_event.identifier in node_metadata: - node_meta: Optional[dict] = node_metadata[ - timeline_event.identifier - ] - context_stack.append( - ContextStackEntry( - "node", timeline_event.identifier, node_meta, tid - ) - ) - - case "end": - # Pop from stack - search backwards to find matching context - for i in range(len(context_stack) - 1, -1, -1): - ctx_entry = context_stack[i] - if ( - timeline_event.marker_type == ctx_entry.context_type - and timeline_event.identifier == ctx_entry.identifier - ): - context_stack.pop(i) - break - - case "regular": - # Apply metadata from current context stack - # Find the most specific context (node takes precedence over filename) - # Only augment events with the same tid as the file/node event matched - current_stack_trace = None - current_node_name = None - event_tid = timeline_event.event.get("tid") - - for ctx_entry in reversed(context_stack): - # Only apply metadata from contexts with matching tid - if ctx_entry.tid == event_tid: - if ctx_entry.context_type == "node" and ctx_entry.metadata: - current_stack_trace = ctx_entry.metadata.get( - "stack_trace", "No model stack trace available" - ) - current_node_name = ctx_entry.metadata.get("name", "") - # Do we want to only attach the stack trace of the lowest node or stack trace of all nodes - # if nodes are nested, e.g. in nested graph modules - break - - # Augment the event - if current_stack_trace or current_node_name: - args = timeline_event.event.setdefault("args", {}) - if current_stack_trace: - args["stack_trace"] = current_stack_trace - if current_node_name: - args["node_name"] = current_node_name From d7e2d0ad301b5d0db049bf5d2a2fc7ff9c89c58c Mon Sep 17 00:00:00 2001 From: Laith Sakka Date: Sat, 1 Nov 2025 16:37:39 -0700 Subject: [PATCH 030/130] make narrow_tensor_symint DDE-free (#166379) https://github.com/pytorch/pytorch/issues/158081 Pull Request resolved: https://github.com/pytorch/pytorch/pull/166379 Approved by: https://github.com/Lucaskabela ghstack dependencies: #166361 --- aten/src/ATen/native/TensorShape.cpp | 4 ++-- test/functorch/test_aotdispatch.py | 2 +- test/test_dynamic_shapes.py | 13 +++++++++++++ test/test_proxy_tensor.py | 1 - 4 files changed, 16 insertions(+), 4 deletions(-) diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index 6136a6aa8c520..b3fff5a4bb42f 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -1764,8 +1764,8 @@ Tensor narrow_tensor_symint( start.dim() == 0 && isIntegralType(start.scalar_type(), /*includeBool=*/false), "start must be an 0-dim integral Tensor."); - int64_t st = start.item(); - return at::narrow_symint(self, dim, c10::SymInt(st), std::move(length)); + c10::SymInt st = start.item().toSymInt(); + return at::narrow_symint(self, dim, std::move(st), std::move(length)); } std:: diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index b0dd1ff8fa75d..6cae42d8929da 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -8126,7 +8126,7 @@ def fn(x): xfail("corrcoef"), xfail("quantile"), xfail("nanquantile"), - xfail("narrow"), + skip("narrow"), xfail("istft"), xfail("linalg.eig"), skip("as_strided_scatter"), diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index b63e0427c26c3..d3f9e415ff944 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -4452,6 +4452,19 @@ def test_narrow_unbacked_start_cpp_wrapper(self): """Test narrow with unbacked start with cpp_wrapper""" self.test_narrow_unbacked_start() + @torch._dynamo.config.patch(capture_scalar_outputs=True) + def test_narrow_with_tensor_start(self): + @torch.compile(backend="inductor", fullgraph=True) + def f(x, start, end): + return torch.narrow(x, 0, start, end) + + x = torch.tensor( + [False], device="cuda:0" if torch.cuda.is_available() else "cpu" + ) + start = torch.tensor(0) + res = f(x, start, 0) + self.assertEqual(res.shape, torch.Size([0])) + instantiate_parametrized_tests(TestUnbacked) diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index b76895a0a91f3..0487995a2d1c5 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -1987,7 +1987,6 @@ def f(t): } only_fake_tensor_failures = { - xfail('narrow'), xfail('tensor_split'), } From c1e91bd4c3bef209d43896d18abbf638bed356a9 Mon Sep 17 00:00:00 2001 From: Zhengxu Chen Date: Tue, 4 Nov 2025 22:55:26 +0000 Subject: [PATCH 031/130] [export] Codemod unittests to use new graph capture API (#166957) Summary: as title. Test Plan: pytest test/functorch/test_aot_joint_with_descriptors.py pytest test/higher_order_ops/test_local_map.py Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/166957 Approved by: https://github.com/angelayi, https://github.com/yushangdi --- .../test_aot_joint_with_descriptors.py | 32 +++++++++---------- test/higher_order_ops/test_local_map.py | 21 ++---------- 2 files changed, 17 insertions(+), 36 deletions(-) diff --git a/test/functorch/test_aot_joint_with_descriptors.py b/test/functorch/test_aot_joint_with_descriptors.py index 7949d2bb46cbf..13277fccaea11 100644 --- a/test/functorch/test_aot_joint_with_descriptors.py +++ b/test/functorch/test_aot_joint_with_descriptors.py @@ -13,7 +13,7 @@ import torch.nn as nn import torch.utils._pytree as pytree from torch._decomp import decomposition_table -from torch._dynamo.functional_export import _dynamo_graph_capture_for_export +from torch._dynamo.functional_export import dynamo_graph_capture_for_export from torch._dynamo.testing import normalize_gm from torch._functorch._aot_autograd.descriptors import ( BufferAOTInput, @@ -48,17 +48,13 @@ def graph_capture(model, inputs, with_export): gm = model - fake_mode = None + tracing_context = None if with_export: - with ( - torch._dynamo.config.patch(install_free_tensors=True), - fx_traceback.preserve_node_meta(), - ): - # TODO: switch to use the official graph_capture API once it is ready - gm = _dynamo_graph_capture_for_export(model)(*inputs) - fake_mode = gm.meta.get("fake_mode", None) - - with tracing(TracingContext(fake_mode)): + with fx_traceback.preserve_node_meta(): + gm = dynamo_graph_capture_for_export(model)(*inputs) + tracing_context = gm.meta.get("tracing_context", None) + + with tracing(tracing_context): with ExitStack() as stack: joint_with_descriptors = aot_export_joint_with_descriptors( stack, @@ -325,7 +321,7 @@ def forward(self, x, *, scale): inputs = (torch.randn(4, 3),) kwargs = {"scale": torch.tensor(2.0)} - gm = _dynamo_graph_capture_for_export(model)(*inputs, **kwargs) + gm = dynamo_graph_capture_for_export(model)(*inputs, **kwargs) with ExitStack() as stack: # Export joint with descriptors @@ -356,8 +352,8 @@ def forward( primals, tangents, ): - primals_1: "f32[2, 3]" # ParamAOTInput(target='L__self___linear_weight') - primals_2: "f32[2]" # ParamAOTInput(target='L__self___linear_bias') + primals_1: "f32[2, 3]" # ParamAOTInput(target='linear.weight') + primals_2: "f32[2]" # ParamAOTInput(target='linear.bias') primals_3: "f32[4, 3]" # PlainAOTInput(idx=0) primals_4: "f32[]" # PlainAOTInput(idx=1) tangents_1: "f32[4, 2]" # TangentAOTInput(output=PlainAOTOutput(idx=0)) @@ -379,8 +375,8 @@ def forward( transpose_3: "f32[2, 3]" = torch.ops.prims.transpose.default(transpose_2, [1, 0]); transpose_2 = None return pytree.tree_unflatten([ mul_2, # PlainAOTOutput(idx=0) - transpose_3, # GradAOTOutput(grad_of=ParamAOTInput(target='L__self___linear_weight')) - as_strided, # GradAOTOutput(grad_of=ParamAOTInput(target='L__self___linear_bias')) + transpose_3, # GradAOTOutput(grad_of=ParamAOTInput(target='linear.weight')) + as_strided, # GradAOTOutput(grad_of=ParamAOTInput(target='linear.bias')) None, # None None, # None ], self._out_spec)""", @@ -1063,9 +1059,11 @@ def forward(self, x): str(custom_metadata), """\ ('call_function', 'new_empty', {'pp_stage': 0}) +('get_attr', '_tensor_constant0', {'pp_stage': 0}) ('call_function', 'index_put', {'pp_stage': 0}) ('call_function', 'slice_2', {'pp_stage': 0}) ('call_function', 'slice_backward', {'pp_stage': 0}) +('get_attr', '_tensor_constant0_1', {'pp_stage': 0}) ('call_function', 'index', {'pp_stage': 0})""", ) @@ -1082,7 +1080,7 @@ def forward(self, x): model = SimpleLinear() inputs = (torch.randn(4, 3),) - gm = _dynamo_graph_capture_for_export(model)(*inputs) + gm = dynamo_graph_capture_for_export(model)(*inputs) fake_mode = gm.meta.get("fake_mode", None) with tracing(TracingContext(fake_mode)): diff --git a/test/higher_order_ops/test_local_map.py b/test/higher_order_ops/test_local_map.py index 5f37d8e1768d6..9d2870d3b5fdd 100644 --- a/test/higher_order_ops/test_local_map.py +++ b/test/higher_order_ops/test_local_map.py @@ -15,6 +15,7 @@ import torch.fx.traceback as fx_traceback import torch.nn.functional as F from torch import nn +from torch._dynamo.functional_export import dynamo_graph_capture_for_export from torch._dynamo.variables.higher_order_ops import LocalMapWrappedHigherOrderVariable from torch._functorch.aot_autograd import aot_export_joint_with_descriptors from torch._subclasses.fake_tensor import FakeTensorMode @@ -51,24 +52,6 @@ def enable_local_map_wrapping(): yield -def _export(model: torch.nn.Module, inputs: tuple[Any]) -> torch.nn.Module: - from torch._dynamo.functional_export import _dynamo_graph_capture_for_export - from torch.export._trace import _restore_state_dict - - """ - Thin wrapper around graph capture output that restores the - original calling convention and attribute fqn. TODO: - 1) Use bytecode for calling convention instead of pytree for more - seamless UX. - 2) Attach guards - 3) Be more careful about tensor constants names. - """ - with torch._dynamo.config.patch(install_free_tensors=True): - gm = _dynamo_graph_capture_for_export(model)(*inputs) - _restore_state_dict(model, gm) - return gm - - def ap_style_initial_capture( model: torch.nn.Module, inputs_fn: Callable ) -> torch.nn.Module: @@ -90,7 +73,7 @@ def ap_style_initial_capture( enable_local_map_wrapping(), torch._dynamo.utils._disable_saved_tensors_hooks_during_tracing(), ): - torch_ir_with_fqn = _export(model, inputs) + torch_ir_with_fqn = dynamo_graph_capture_for_export(model)(*inputs) unused = ExitStack() joint_with_descriptors = aot_export_joint_with_descriptors( unused, From a96728d1885548cbf696d1c40fc990d1cbe1699b Mon Sep 17 00:00:00 2001 From: Daniel Galvez Date: Tue, 4 Nov 2025 23:35:59 +0000 Subject: [PATCH 032/130] Clarify safety of CUDA graph memory pool sharing across graphs that are replayed in arbtirary order. (#166975) Some users at pytorch conference were asking me about whether it is safe to share a memory pool among cuda graphs that never run concurrently, but may run in arbitrary order, if they don't depend upon each other's output. Even though your capture order doesn't match replay order in this situation, this is safe. However, our documents confusingly said this wasn't allowed. This update is intended to help with that. Since vLLM essentially depends upon this behavior, I call it out specifically. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166975 Approved by: https://github.com/eellison, https://github.com/BoyuanFeng --- docs/source/notes/cuda.rst | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/docs/source/notes/cuda.rst b/docs/source/notes/cuda.rst index c7d3a93f73523..caabeb399c722 100644 --- a/docs/source/notes/cuda.rst +++ b/docs/source/notes/cuda.rst @@ -1720,6 +1720,16 @@ and can be used to share memory across graphs as shown:: g1.replay() g2.replay() +It's also safe to share a memory pool across separate graphs that do not depend +on each other's outputs, provided they never run concurrently. +Be aware that replaying one graph can clobber another graph's outputs when +they share a pool, unless :meth:`~torch.Tensor.clone` is called on the outputs +beforehand. +This pattern is frequently used in inference servers that accept variable batch +sizes at runtime. +vLLM is a notable example; see `here `__ +and `here `__. + With :func:`torch.cuda.make_graphed_callables`, if you want to graph several callables and you know they'll always run in the same order (and never concurrently) pass them as a tuple in the same order they'll run in the live workload, and From 0cd809f60c79bb808f2736fa4ac5f602f63caf8f Mon Sep 17 00:00:00 2001 From: Jason Xie Date: Tue, 4 Nov 2025 23:47:11 +0000 Subject: [PATCH 033/130] [inductor][AMD] Filter out invalid Triton Configs for MI350X _scaled_mm (#166442) Summary: Mirrors change done in D81180838 but for inductor. Without this change, running _scaled_mm on MI350X accelerator would crash. Test Plan: HIP_VISIBLE_DEVICES=7 TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 buck2 run mode/opt-amd-gpu -m rocm70 -c fbcode.rocm_arch=mi350 scripts/jchunx/gemm:scaled_mm_microbench -- --csv_file /home/jchunx/scripts/fp8_shapes.csv --backend triton,aten --fast_accum=true 2>&1 | tee ~/logs/scaled_mm.log Reviewed By: bilal Differential Revision: D85694383 Pull Request resolved: https://github.com/pytorch/pytorch/pull/166442 Approved by: https://github.com/bilal --- torch/_inductor/template_heuristics/triton.py | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/torch/_inductor/template_heuristics/triton.py b/torch/_inductor/template_heuristics/triton.py index 61616d81c2878..8cbbf5073d5ef 100644 --- a/torch/_inductor/template_heuristics/triton.py +++ b/torch/_inductor/template_heuristics/triton.py @@ -1946,6 +1946,29 @@ def _valid(self, kernel_inputs: KernelInputs) -> bool: return False return True + # pyrefly: ignore [bad-override] + def _filter_configs(self, configs: list[BaseConfig]) -> list[BaseConfig]: + """ + Filter out bad configs for specific hardware. + On AMD MI350X (GFX 9.5+), skip configs with BLOCK_K<=64 due to lack of corresponding MFMA instructions. + """ + + def should_skip_mi350x_config(config: BaseConfig) -> bool: + """Skip config if BLOCK_K<=64 on MI350X (GFX 9.5+)""" + try: + return ( + config.block_k <= 64 + and torch.version.hip is not None + and torch.cuda.get_device_capability() >= (9, 5) + ) + except RuntimeError: + # If no HIP GPUs are available, we can't check device capability + # so we don't skip any configs + return False + + filtered_configs = [c for c in configs if not should_skip_mi350x_config(c)] + return super()._filter_configs(filtered_configs) + # Scaled TMA-specific mixin for scaled MM templates with TMA class ScaledTMAConfigMixin(TMAWorkspaceMixin, BaseScaledMMConfigMixin): From 661b63966341d9569829c6ec6799be0757db1a6c Mon Sep 17 00:00:00 2001 From: "Xiangyang (Mark) Guo" Date: Tue, 4 Nov 2025 23:47:12 +0000 Subject: [PATCH 034/130] use_cpp_bmm_template supports more use cases (#165469) Summary: In certain scenarios, such as when the first stride is 0, the entire tensor may not be contiguous, but the 2D matrix within each batch can still be contiguous, allowing us to apply max autotune. This diff specifically checks for contiguity within the 2D matrix of each batch, and enables more uses for cpp bmm template. Differential Revision: D84561331 Pull Request resolved: https://github.com/pytorch/pytorch/pull/165469 Approved by: https://github.com/desertfire --- test/inductor/test_cpu_select_algorithm.py | 26 ++++++++++++++++++++++ torch/_inductor/utils.py | 18 ++++++++++++--- 2 files changed, 41 insertions(+), 3 deletions(-) diff --git a/test/inductor/test_cpu_select_algorithm.py b/test/inductor/test_cpu_select_algorithm.py index 4e1c48496ebc5..ca520ab66bcc2 100644 --- a/test/inductor/test_cpu_select_algorithm.py +++ b/test/inductor/test_cpu_select_algorithm.py @@ -2697,6 +2697,32 @@ def forward(self, x): self.common(mod, (u,), atol=atol, rtol=rtol) self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1) + @patches + @torch.no_grad + @unittest.skipIf(not TEST_MKL, "Test requires MKL") + @parametrize("bs", (5,)) + @parametrize("Mdim", (16,)) + @parametrize("Kdim", (32,)) + @parametrize("Ndim", (64,)) + @dtypes(torch.float) + def test_bmm_with_broadcasted_mat1(self, bs, Mdim, Kdim, Ndim, dtype): + class M(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x, w): + assert x.dim() == 2, f"Expected x to be 2D, got {x.dim()}D" + x_expanded = x.unsqueeze(0).expand(bs, -1, -1) + return x_expanded @ w + + counters.clear() + u = torch.randn(Mdim, Kdim).to(dtype=dtype) + v = torch.randn(bs, Kdim, Ndim).to(dtype=dtype) + mod = M().to(dtype=dtype).eval() + with verify(dtype) as (atol, rtol): + self.common(mod, (u, v), atol=atol, rtol=rtol) + self.assertEqual(counters["inductor"]["cpp_templated_kernel_counter"], 1) + @patches @torch.no_grad @unittest.skipIf(not TEST_MKL, "Test requires MKL") diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 2cf915d9e61de..3f8652882af79 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -2231,9 +2231,21 @@ def use_cpp_bmm_template( assert isinstance(mat1.layout, Layout) - return ( - use_cpp_gemm_template(layout, mat1, mat2, require_constant_mat2=False) - and mat1.layout.is_contiguous() + # In certain scenarios, such as when the first stride is 0, the entire tensor may not be contiguous. + # But the 2D matrix within each batch can still be contiguous, allowing us to apply max autotune. + # So here we specifically check for contiguity within the 2D matrix of each batch. + mat1_size = mat1.layout.size + mat1_stride = mat1.layout.stride + mat1_each_batch_is_contiguous = ( + _use_template_for_cpu(layout) + and mat1.get_dtype() == torch.float32 + and (len(mat1_size) == 3) + and (len(mat1_stride) == 3) + and (mat1_stride[1] == mat1_size[2]) + and (mat1_stride[2] == 1) + ) + return use_cpp_gemm_template(layout, mat1, mat2, require_constant_mat2=False) and ( + mat1.layout.is_contiguous() or mat1_each_batch_is_contiguous ) From 4b12c0344d0b1a6536a2659c4e498c805efdc1f1 Mon Sep 17 00:00:00 2001 From: Karhou Tam Date: Tue, 4 Nov 2025 23:53:56 +0000 Subject: [PATCH 035/130] Add default `.github/copilot-instructions.md` and item in `.gitignore` for allowing local changes (#166864) Fixes [#166850](https://github.com/pytorch/pytorch/issues/166850) - Create a default `.github/copilot-instructions.md` file (used Claude Sonnet 4.5 in Copilot). - Add `.github/copilot-instructions.md` to the `.gitignore` file. The prompt used is below, which is preset by Copilot: ``` Analyze this codebase to generate or update `.github/copilot-instructions.md` for guiding AI coding agents. Focus on discovering the essential knowledge that would help an AI agents be immediately productive in this codebase. Consider aspects like: - The "big picture" architecture that requires reading multiple files to understand - major components, service boundaries, data flows, and the "why" behind structural decisions - Critical developer workflows (builds, tests, debugging) especially commands that aren't obvious from file inspection alone - Project-specific conventions and patterns that differ from common practices - Integration points, external dependencies, and cross-component communication patterns Source existing AI conventions from `**/{.github/copilot-instructions.md,AGENT.md,AGENTS.md,CLAUDE.md,.cursorrules,.windsurfrules,.clinerules,.cursor/rules/**,.windsurf/rules/**,.clinerules/**,README.md}` (do one glob search). Guidelines (read more at https://aka.ms/vscode-instructions-docs): - If `.github/copilot-instructions.md` exists, merge intelligently - preserve valuable content while updating outdated sections - Write concise, actionable instructions (~20-50 lines) using markdown structure - Include specific examples from the codebase when describing patterns - Avoid generic advice ("write tests", "handle errors") - focus on THIS project's specific approaches - Document only discoverable patterns, not aspirational practices - Reference key files/directories that exemplify important patterns Update `.github/copilot-instructions.md` for the user, then ask for feedback on any unclear or incomplete sections to iterate. ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/166864 Approved by: https://github.com/malfet Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com> --- .github/copilot-instructions.md | 125 ++++++++++++++++++++++++++++++++ 1 file changed, 125 insertions(+) create mode 100644 .github/copilot-instructions.md diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md new file mode 100644 index 0000000000000..06c3f32abd5e1 --- /dev/null +++ b/.github/copilot-instructions.md @@ -0,0 +1,125 @@ +# PyTorch Copilot Instructions + +This is the PyTorch machine learning framework codebase. These instructions help AI agents navigate and contribute effectively. + +## Architecture Overview + +### Core Components + +- **c10/** - Core library (C++-10 compatible) for essential, binary-size-conscious functionality +- **aten/** - ATen tensor library (C++), PyTorch's foundation without autograd + - `aten/src/ATen/native/` - Modern operator implementations (CPU/CUDA/MPS/sparse) + - `aten/src/ATen/native/native_functions.yaml` - **Critical**: Declarative operator registry +- **torch/** - Python bindings and public API + - `torch/csrc/` - C++ Python bindings (hand-written and generated) + - `torch/csrc/autograd/` - Reverse-mode automatic differentiation + - `torch/csrc/jit/` - TorchScript JIT compiler +- **torchgen/** - Code generation tooling that reads `native_functions.yaml` +- **tools/** - Build scripts, autograd derivatives, code generation + +### The Code Generation Workflow + +**Most operator changes require editing `native_functions.yaml`**, not direct C++ files. This YAML file: +1. Declares operator signatures, variants (function/method), and dispatch behavior +2. Gets processed by `torchgen/` to generate C++/Python bindings +3. Produces headers in `build/aten/src/ATen/` during compilation + +Example entry structure: +```yaml +- func: my_op(Tensor self, Scalar alpha=1) -> Tensor + variants: function, method + dispatch: + CPU: my_op_cpu + CUDA: my_op_cuda +``` + +After editing `native_functions.yaml`, implement kernels in `aten/src/ATen/native/` (see `aten/src/ATen/native/README.md`). + +## Development Workflows + +### Building from Source + +**Never run `setup.py` directly** - use pip with editable install: +```bash +python -m pip install --no-build-isolation -v -e . +``` + +Speed up builds: +- `DEBUG=1` - Debug symbols with `-g -O0` +- `USE_CUDA=0` - Skip CUDA compilation +- `BUILD_TEST=0` - Skip C++ test binaries +- Install `ninja` (`pip install ninja`) for faster builds +- Use `ccache` for incremental compilation caching + +Rebuild specific targets: `(cd build && ninja )` + +### Testing + +**Critical**: DO NOT run entire test suites. Run specific tests only: +```bash +python test/test_torch.py TestTorch.test_specific_case +``` + +**Test structure**: All tests use `torch.testing._internal.common_utils`: +```python +from torch.testing._internal.common_utils import run_tests, TestCase + +class TestFeature(TestCase): + def test_something(self): + # Use self.assertEqual for tensor comparisons + pass + +if __name__ == "__main__": + run_tests() +``` + +**For bug fixes**: Create a standalone reproduction script first, verify it fails, then fix and add to appropriate test file. + +### Linting + +Run linter (not pre-commit): `lintrunner -a` (auto-applies fixes) + +## Project-Specific Conventions + +### Memory and Storage +- **Storage is never nullptr** (but `StorageImpl.data` may be nullptr for unallocated outputs) +- CUDA device info lives in storage objects + +### Python-C++ Integration (`torch/csrc/`) +- Always include `Python.h` **first** to avoid `_XOPEN_SOURCE` redefinition errors +- Use `pybind11::gil_scoped_acquire` before calling Python API or using `THPObjectPtr` +- Wrap entry points with `HANDLE_TH_ERRORS` / `END_HANDLE_TH_ERRORS` for exception conversion + +### Dispatch System +- PyTorch uses operator dispatch to route calls to backend-specific kernels +- Prefer `CompositeExplicitAutograd` dispatch when writing device-agnostic compound ops +- See `aten/src/ATen/native/README.md` for dispatch keyword guidance + +## Git Workflow (AI Agent Specific) + +When preparing PRs from this environment: +```bash +git stash -u +git reset --hard $(cat /tmp/orig_work.txt) # Reset to LOCAL branch +git stash pop +# Resolve conflicts if necessary +``` + +## Common Gotchas + +1. **Editing generated files** - If it's in `build/`, don't edit it. Edit the source template or `native_functions.yaml` +2. **NVCC template compilation** - NVCC is stricter about C++ than gcc/clang; code working on Linux may fail Windows CI +3. **Windows symbol visibility** - Use `TORCH_API` macros for exported symbols (required on Windows, optional on Linux) +4. **No internet access** - DO NOT attempt to install dependencies during development + +## Key Files Reference + +- `AGENTS.md` - Instructions specific to AI coding agents +- `CONTRIBUTING.md` - Comprehensive human contributor guide +- `GLOSSARY.md` - Terminology (ATen, kernels, operations, JIT, TorchScript) +- `aten/src/ATen/native/README.md` - Operator implementation guide +- `tools/autograd/derivatives.yaml` - Gradient definitions for autograd + +## Performance Debugging + +Use `TORCH_SHOW_CPP_STACKTRACES=1` for C++ traces in Python errors. For profiling, prefer `py-spy` over manual instrumentation. From 7eefcfb1db5995739a2614f368594cb266d33173 Mon Sep 17 00:00:00 2001 From: Lucas Kabela Date: Tue, 4 Nov 2025 23:54:15 +0000 Subject: [PATCH 036/130] [BE][Typing][Dynamo] Type torch/_dynamo/variables/ctx_manager.py (#166878) Provides type coverage to torch/_dynamo/variables/ctx_manager.py Coverage report: `mypy torch/_dynamo/variables/ctx_manager.py --linecount-report /tmp/coverage_log` Compare before to after - we go from 0 lines and 0 funcs covered to 1541 lines and 144 funcs covered Pull Request resolved: https://github.com/pytorch/pytorch/pull/166878 Approved by: https://github.com/Skylion007 --- torch/_C/_functorch.pyi | 2 + torch/_dynamo/polyfills/pytree.py | 16 +- torch/_dynamo/symbolic_convert.py | 7 +- torch/_dynamo/variables/ctx_manager.py | 571 ++++++++++++++-------- torch/_dynamo/variables/streams.py | 16 +- torch/_dynamo/variables/torch.py | 1 + torch/_dynamo/variables/torch_function.py | 3 +- 7 files changed, 387 insertions(+), 229 deletions(-) diff --git a/torch/_C/_functorch.pyi b/torch/_C/_functorch.pyi index c23240e13170a..a35befcad392d 100644 --- a/torch/_C/_functorch.pyi +++ b/torch/_C/_functorch.pyi @@ -5,6 +5,8 @@ from torch import Tensor # Defined in torch/csrc/functorch/init.cpp +def set_inplace_requires_grad_allowed(allowed: bool) -> None: ... +def get_inplace_requires_grad_allowed() -> bool: ... def _set_dynamic_layer_keys_included(included: bool) -> None: ... def get_unwrapped(tensor: Tensor) -> Tensor: ... def is_batchedtensor(tensor: Tensor) -> bool: ... diff --git a/torch/_dynamo/polyfills/pytree.py b/torch/_dynamo/polyfills/pytree.py index f9bdc0cce4a00..d86fe054b2ebc 100644 --- a/torch/_dynamo/polyfills/pytree.py +++ b/torch/_dynamo/polyfills/pytree.py @@ -64,7 +64,7 @@ def _(*args: Any, **kwargs: Any) -> bool: del __func del __name - @substitute_in_graph(optree.tree_is_leaf, can_constant_fold_through=True) + @substitute_in_graph(optree.tree_is_leaf, can_constant_fold_through=True) # type: ignore[arg-type] def tree_is_leaf( tree: PyTree, /, @@ -79,7 +79,7 @@ def tree_is_leaf( return True return False - @substitute_in_graph(optree.tree_iter, can_constant_fold_through=False) + @substitute_in_graph(optree.tree_iter, can_constant_fold_through=False) # type: ignore[arg-type] def tree_iter( tree: PyTree, /, @@ -110,7 +110,7 @@ def tree_iter( __all__ += ["tree_iter"] - @substitute_in_graph(optree.tree_leaves, can_constant_fold_through=True) + @substitute_in_graph(optree.tree_leaves, can_constant_fold_through=True) # type: ignore[arg-type] def tree_leaves( tree: PyTree, /, @@ -451,7 +451,7 @@ def treespec_dict( dict, metadata, entries, - unflatten_func, + unflatten_func, # type: ignore[arg-type] none_is_leaf=none_is_leaf, namespace=namespace, ) @@ -507,7 +507,7 @@ def helper(node: PyTree, leaves: list[Any]) -> PyTreeSpec: type(node), metadata, entries, - unflatten_func, + unflatten_func, # type: ignore[arg-type] none_is_leaf=none_is_leaf, namespace=namespace, ) # type: ignore[arg-type] @@ -557,7 +557,7 @@ def tree_unflatten(treespec: PyTreeSpec, leaves: Iterable[Any]) -> PyTree: __all__ += ["tree_unflatten"] - @substitute_in_graph(optree.tree_map, can_constant_fold_through=True) + @substitute_in_graph(optree.tree_map, can_constant_fold_through=True) # type: ignore[arg-type] def tree_map( func: Callable[..., Any], tree: PyTree, @@ -578,7 +578,7 @@ def tree_map( __all__ += ["tree_map"] - @substitute_in_graph(optree.tree_map_, can_constant_fold_through=True) + @substitute_in_graph(optree.tree_map_, can_constant_fold_through=True) # type: ignore[arg-type] def tree_map_( func: Callable[..., Any], tree: PyTree, @@ -600,7 +600,7 @@ def tree_map_( __all__ += ["tree_map_"] - _none_unflatten = optree.register_pytree_node.get(type(None)).unflatten_func # type: ignore[union-attr] + _none_unflatten = optree.register_pytree_node.get(type(None)).unflatten_func # type: ignore[union-attr, attr-defined] @substitute_in_graph( # type: ignore[arg-type] _none_unflatten, diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 9d0d87c5f8a06..53ec0ee412849 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -434,12 +434,15 @@ def resume_fn(self) -> ReenterWith: else: return ReenterWith(self.stack_index - 1) - def exit(self, tx: InstructionTranslatorBase, is_graph_break: bool) -> None: + def exit( + self, tx: InstructionTranslatorBase, is_graph_break: bool + ) -> VariableTracker | None: assert self.with_context is not None if ( is_graph_break and self.with_context.exit_on_graph_break() ) or not is_graph_break: return self.with_context.exit(tx) # type: ignore[arg-type] + return None class SpeculationLogDivergence(AssertionError): @@ -3860,7 +3863,7 @@ def enter_ctx( else: self.block_stack.append(BlockStackEntry(inst, target, len(self.stack))) - return ctx.enter(self) + return ctx.enter(self) # type: ignore[arg-type] @staticmethod def unsupported_ctx_graph_break(ctx: VariableTracker) -> NoReturn: diff --git a/torch/_dynamo/variables/ctx_manager.py b/torch/_dynamo/variables/ctx_manager.py index 0502c58a78420..4eac189b65fdd 100644 --- a/torch/_dynamo/variables/ctx_manager.py +++ b/torch/_dynamo/variables/ctx_manager.py @@ -1,5 +1,3 @@ -# mypy: ignore-errors - """ This file contains a collection of context manager classes used by Dynamo for tracking and managing various PyTorch runtime states during graph compilation. These context @@ -23,8 +21,9 @@ import inspect import sys import warnings +from collections.abc import Callable, Sequence from contextlib import ExitStack -from typing import TYPE_CHECKING, Union +from typing import Any, ContextManager, Optional, Sized, TYPE_CHECKING, Union import torch._C from torch._guards import Guard @@ -67,35 +66,43 @@ class ContextWrappingVariable(VariableTracker): *VariableTracker._nonvar_fields, } - def __init__(self, target_values, initial_values=None, **kwargs) -> None: + def __init__( + self, target_values: Any, initial_values: Optional[Any] = None, **kwargs: Any + ) -> None: super().__init__(**kwargs) self.target_values = target_values self.initial_values = initial_values - def enter(self, tx): - self._call_func(tx, self.target_values) + def enter(self, tx: "InstructionTranslator") -> VariableTracker: + if hasattr(self, "_call_func"): + self._call_func(tx, self.target_values) self.set_cleanup_hook(tx) return variables.ConstantVariable.create(None) - def set_cleanup_hook(self, tx: "InstructionTranslator", fn=None): + def set_cleanup_hook( + self, tx: "InstructionTranslator", fn: Optional[Callable[..., Any]] = None + ) -> None: if fn is None: - def fn(): - self._call_func(tx, self.initial_values) + def fn() -> None: + if hasattr(self, "_call_func"): + self._call_func(tx, self.initial_values) - self.cleanup_fn = fn + self.cleanup_fn: Optional[Callable[..., Any]] = fn tx.output.add_cleanup_hook(self.cleanup) - def exit(self, tx: "InstructionTranslator", *args): + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: self.cleanup_assert() return variables.ConstantVariable.create(None) - def reconstruct_type(self, codegen: "PyCodegen"): + def reconstruct_type(self, codegen: "PyCodegen") -> None: codegen( AttrSource(codegen.tx.import_source(self.module_name()), self.fn_name()) ) - def reconstruct(self, codegen: "PyCodegen"): + def reconstruct(self, codegen: "PyCodegen") -> None: codegen.add_push_null(lambda: self.reconstruct_type(codegen)) target_values = self.target_values if not target_values: @@ -103,18 +110,18 @@ def reconstruct(self, codegen: "PyCodegen"): codegen.extend_output([codegen.create_load_const(val) for val in target_values]) codegen.extend_output(create_call_function(len(target_values), False)) - def module_name(self): + def module_name(self) -> str: raise NotImplementedError("module_name called on base") - def fn_name(self): + def fn_name(self) -> str: raise NotImplementedError("fn_name called on base") def call_function( self, tx: "InstructionTranslator", - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ) -> "VariableTracker": + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: assert len(args) == 1 assert isinstance( args[0], @@ -128,28 +135,27 @@ def call_function( if isinstance(args[0], NestedUserFunctionVariable): return WrappedNestedUserFunctionVariable(args[0], self) - - if isinstance(args[0], SkipFunctionVariable): + elif isinstance(args[0], SkipFunctionVariable): return WrappedSkipFunctionVariable(args[0], self) - - if isinstance(args[0], UserMethodVariable): + elif isinstance(args[0], UserMethodVariable): return WrappedUserMethodVariable(args[0], self) - - if isinstance(args[0], UserFunctionVariable): + elif isinstance(args[0], UserFunctionVariable): return WrappedUserFunctionVariable(args[0], self) + else: + raise AssertionError("Unexpected arg type") - def supports_graph_breaks(self): + def supports_graph_breaks(self) -> bool: return True - def exit_on_graph_break(self): + def exit_on_graph_break(self) -> bool: return True - def cleanup(self): + def cleanup(self) -> None: if self.cleanup_fn is not None: self.cleanup_fn() self.cleanup_fn = None - def cleanup_assert(self): + def cleanup_assert(self) -> None: assert self.cleanup_fn, "multiple exits?" self.cleanup() @@ -157,7 +163,7 @@ def cleanup_assert(self): class GenericContextWrappingVariable(UserDefinedObjectVariable): # Some methods in ContextWrappingVariable assumes the arguments are # python constants. Which might not always be the case here. - def __init__(self, cm_obj, **kwargs) -> None: + def __init__(self, cm_obj: ContextManager[Any], **kwargs: Any) -> None: assert cm_obj is not None super().__init__( value=cm_obj, @@ -166,44 +172,46 @@ def __init__(self, cm_obj, **kwargs) -> None: ) self.cm_obj = cm_obj - def module_name(self): + def module_name(self) -> str: return self.cm_obj.__module__ - def fn_name(self): + def fn_name(self) -> str: return type(self.cm_obj).__name__ - def enter(self, tx): + def enter(self, tx: "InstructionTranslator") -> VariableTracker: source = None if self.source is None else AttrSource(self.source, "__enter__") return variables.UserMethodVariable( - self.cm_obj.__enter__.__func__, + self.cm_obj.__enter__.__func__, # type: ignore[attr-defined] self, source=source, ).call_function(tx, [], {}) - def exit(self, tx: "InstructionTranslator", *args): + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: source = None if self.source is None else AttrSource(self.source, "__exit__") x = variables.UserMethodVariable( - self.cm_obj.__exit__.__func__, + self.cm_obj.__exit__.__func__, # type: ignore[attr-defined] self, source=source, - ).call_function(tx, args, {}) + ).call_function(tx, list(args), {}) tx.active_generic_context_managers.pop() return x - def supports_graph_breaks(self): + def supports_graph_breaks(self) -> bool: return False - def exit_on_graph_break(self): + def exit_on_graph_break(self) -> bool: return True class RepararametrizeModuleContextVariable(GenericContextWrappingVariable): - def __init__(self, ctx_manager_vt, mod): + def __init__(self, ctx_manager_vt: ContextWrappingVariable, mod: Any) -> None: self.cm_vt = ctx_manager_vt self.mod = mod # We don't call super().__init__() because we're delegating most methods to cm_vt - def enter(self, tx: "InstructionTranslator"): + def enter(self, tx: "InstructionTranslator") -> VariableTracker: # Custom enter implementation with side effects self.old_parameters_var = self.mod.var_getattr(tx, "_parameters").realize() @@ -212,7 +220,9 @@ def enter(self, tx: "InstructionTranslator"): tx.output.side_effects.ignore_mutations_on(self.old_buffer_var) return self.cm_vt.enter(tx) - def exit(self, tx: "InstructionTranslator", *args): + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: # Custom exit implementation with side effects x = self.cm_vt.exit(tx, *args) tx.output.side_effects.stop_ignoring_mutations_on(self.old_buffer_var) @@ -220,7 +230,7 @@ def exit(self, tx: "InstructionTranslator", *args): return x # Forward all other method calls to self.cm_vt - def __getattr__(self, name): + def __getattr__(self, name: str) -> Any: # This will be called for any attribute not explicitly defined in this class return getattr(self.cm_vt, name) @@ -229,14 +239,16 @@ class GradInplaceRequiresGradCtxManagerVariable(ContextWrappingVariable): """represents torch grad requires grad""" @staticmethod - def create(tx: "InstructionTranslator", target_values, **kwargs): + def create( + tx: "InstructionTranslator", target_values: Any, **kwargs: Any + ) -> "GradInplaceRequiresGradCtxManagerVariable": return GradInplaceRequiresGradCtxManagerVariable( target_values=target_values, initial_values=None, **kwargs, ) - def enter(self, tx): + def enter(self, tx: "InstructionTranslator") -> VariableTracker: [enabled] = self.target_values self.prev_state = torch._C._functorch.get_inplace_requires_grad_allowed() torch._C._functorch.set_inplace_requires_grad_allowed(enabled) @@ -254,7 +266,9 @@ def enter(self, tx): ) return variables.ConstantVariable.create(None) - def exit(self, tx: "InstructionTranslator", *args): + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: self.cleanup() tx.output.create_node( "call_function", @@ -269,14 +283,16 @@ class TemporarilyPopInterpreterStackCtxManagerVariable(ContextWrappingVariable): """represents torch._functorch.pyfunction.temporarily_pop_interpreter_stack()""" @staticmethod - def create(tx: "InstructionTranslator", target_values, **kwargs): + def create( + tx: "InstructionTranslator", target_values: Any, **kwargs: Any + ) -> "TemporarilyPopInterpreterStackCtxManagerVariable": return TemporarilyPopInterpreterStackCtxManagerVariable( target_values=target_values, initial_values=None, **kwargs, ) - def enter(self, tx): + def enter(self, tx: "InstructionTranslator") -> VariableTracker: self.saved = torch._C._functorch.pop_dynamic_layer_stack() self.set_cleanup_hook( tx, @@ -290,7 +306,9 @@ def enter(self, tx): ) return variables.ConstantVariable.create(None) - def exit(self, tx: "InstructionTranslator", *args): + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: self.cleanup() tx.output.create_node( "call_function", @@ -309,10 +327,12 @@ class JvpIncrementNestingCtxManagerVariable(ContextWrappingVariable): # being compiled. But the FX graph may be invalid in the case of a jvp # call from eager that calls the compiled function, as the jvp levels # may be different. - _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.FUNCTORCH_STACK_MATCH) + _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.FUNCTORCH_STACK_MATCH) # type: ignore[arg-type] @staticmethod - def create(tx: "InstructionTranslator", **kwargs): + def create( + tx: "InstructionTranslator", **kwargs: Any + ) -> "JvpIncrementNestingCtxManagerVariable": var = JvpIncrementNestingCtxManagerVariable( target_values=None, initial_values=None, @@ -320,7 +340,7 @@ def create(tx: "InstructionTranslator", **kwargs): ) return var - def enter(self, tx): + def enter(self, tx: "InstructionTranslator") -> VariableTracker: install_guard(self._guards_singleton) jvp_level = torch._functorch.eager_transforms.enter_jvp_nesting() self.set_cleanup_hook( @@ -334,7 +354,9 @@ def enter(self, tx): ) return variables.ConstantVariable.create(jvp_level) - def exit(self, tx: "InstructionTranslator", *args): + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: self.cleanup() tx.output.create_node( "call_function", torch._C._functorch._jvp_decrement_nesting, (), {} @@ -346,14 +368,16 @@ class SetFwdGradEnabledContextManager(ContextWrappingVariable): """represents torch.autograd.forward_ad._set_fwd_grad_enabled() to enable/disable fwd grad""" @staticmethod - def create(tx: "InstructionTranslator", target_values, **kwargs): + def create( + tx: "InstructionTranslator", target_values: Any, **kwargs: Any + ) -> "SetFwdGradEnabledContextManager": return SetFwdGradEnabledContextManager( target_values=target_values, initial_values=None, **kwargs, ) - def enter(self, tx): + def enter(self, tx: "InstructionTranslator") -> VariableTracker: [mode] = self.target_values self.prev_state = torch._C._is_fwd_grad_enabled() torch._C._set_fwd_grad_enabled(mode) @@ -369,7 +393,9 @@ def enter(self, tx): ) return variables.ConstantVariable.create(None) - def exit(self, tx: "InstructionTranslator", *args): + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: self.cleanup() tx.output.create_node( "call_function", @@ -383,17 +409,17 @@ def exit(self, tx: "InstructionTranslator", *args): class DualLevelContextManager(ContextWrappingVariable): """Represents torch.autograd.forward_ad.dual_level ctx manager""" - _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.DUAL_LEVEL) + _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.DUAL_LEVEL) # type: ignore[arg-type] @staticmethod - def create(tx: "InstructionTranslator", **kwargs): + def create(tx: "InstructionTranslator", **kwargs: Any) -> "DualLevelContextManager": return DualLevelContextManager( target_values=None, initial_values=None, **kwargs, ) - def enter(self, tx): + def enter(self, tx: "InstructionTranslator") -> VariableTracker: install_guard(self._guards_singleton) self.new_level = torch.autograd.forward_ad.enter_dual_level() self.set_cleanup_hook( @@ -407,7 +433,9 @@ def enter(self, tx): ) return variables.ConstantVariable.create(self.new_level) - def exit(self, tx: "InstructionTranslator", *args): + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: self.cleanup() tx.output.create_node( "call_function", @@ -426,10 +454,12 @@ class GradIncrementNestingCtxManagerVariable(ContextWrappingVariable): # being compiled. But the FX graph may be invalid in the case of a grad # call from eager that calls the compiled function, as the grad levels # may be different. - _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.FUNCTORCH_STACK_MATCH) + _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.FUNCTORCH_STACK_MATCH) # type: ignore[arg-type] @staticmethod - def create(tx: "InstructionTranslator", **kwargs): + def create( + tx: "InstructionTranslator", **kwargs: Any + ) -> "GradIncrementNestingCtxManagerVariable": var = GradIncrementNestingCtxManagerVariable( target_values=None, initial_values=None, @@ -437,7 +467,7 @@ def create(tx: "InstructionTranslator", **kwargs): ) return var - def enter(self, tx): + def enter(self, tx: "InstructionTranslator") -> VariableTracker: install_guard(self._guards_singleton) grad_level = torch._C._functorch._grad_increment_nesting() self.set_cleanup_hook(tx, lambda: torch._C._functorch._grad_decrement_nesting()) @@ -449,7 +479,9 @@ def enter(self, tx): ) return variables.ConstantVariable.create(grad_level) - def exit(self, tx: "InstructionTranslator", *args): + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: self.cleanup() tx.output.create_node( "call_function", torch._C._functorch._grad_decrement_nesting, (), {} @@ -461,19 +493,29 @@ class CatchWarningsCtxManagerVariable(ContextWrappingVariable): """Delay a call to warnings.catch_warnings""" @staticmethod - def create(tx: "InstructionTranslator", catch_warnings_args): + def create( + tx: "InstructionTranslator", catch_warnings_args: dict[str, VariableTracker] + ) -> "CatchWarningsCtxManagerVariable": return CatchWarningsCtxManagerVariable( catch_warnings_args=catch_warnings_args, target_values=None, initial_values=None, ) - def __init__(self, catch_warnings_args, **kwargs) -> None: + def __init__( + self, + catch_warnings_args: dict[str, VariableTracker], + target_values: Optional[Any] = None, + initial_values: Optional[Any] = None, + **kwargs: Any, + ) -> None: assert isinstance(catch_warnings_args, dict), catch_warnings_args - super().__init__(**kwargs) + super().__init__( + target_values=target_values, initial_values=initial_values, **kwargs + ) self.catch_warnings_args = catch_warnings_args - def enter(self, tx): + def enter(self, tx: "InstructionTranslator") -> VariableTracker: kwargs = { k: v.as_python_constant() for k, v in self.catch_warnings_args.items() } @@ -481,7 +523,7 @@ def enter(self, tx): self.set_cleanup_hook(tx, lambda: ctx_val.__exit__(None, None, None)) return variables.ConstantVariable.create(ctx_val.__enter__()) - def reconstruct(self, cg): + def reconstruct(self, cg: "PyCodegen") -> None: cg.add_push_null(lambda: cg.load_import_from("warnings", "catch_warnings")) cg.foreach(self.catch_warnings_args.values()) keys = tuple(self.catch_warnings_args.keys()) @@ -496,10 +538,14 @@ class VmapIncrementNestingCtxManagerVariable(ContextWrappingVariable): # being compiled. But the FX graph may be invalid in the case of a vmap # call from eager that calls the compiled function, as the vmap levels # may be different. - _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.FUNCTORCH_STACK_MATCH) + _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.FUNCTORCH_STACK_MATCH) # type: ignore[arg-type] @staticmethod - def create(tx: "InstructionTranslator", target_values, **kwargs): + def create( + tx: "InstructionTranslator", + target_values: Sequence[VariableTracker], + **kwargs: Any, + ) -> "VmapIncrementNestingCtxManagerVariable": var = VmapIncrementNestingCtxManagerVariable( target_values=target_values, initial_values=None, @@ -507,7 +553,7 @@ def create(tx: "InstructionTranslator", target_values, **kwargs): ) return var - def enter(self, tx): + def enter(self, tx: "InstructionTranslator") -> VariableTracker: install_guard(self._guards_singleton) batch_size, randomness = self.target_values if isinstance(batch_size, variables.SymNodeVariable): @@ -527,7 +573,9 @@ def enter(self, tx): ) return variables.ConstantVariable.create(vmap_level) - def exit(self, tx: "InstructionTranslator", *args): + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: self.cleanup() tx.output.create_node( "call_function", @@ -541,10 +589,15 @@ def exit(self, tx: "InstructionTranslator", *args): class GradModeVariable(ContextWrappingVariable): """represents torch.{no_grad,enable_grad,set_grad_mode}()""" - _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.GRAD_MODE) + _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.GRAD_MODE) # type: ignore[arg-type] @staticmethod - def create(tx: "InstructionTranslator", target_value, initialized=False, **kwargs): + def create( + tx: "InstructionTranslator", + target_value: Any, + initialized: bool = False, + **kwargs: Any, + ) -> "GradModeVariable": var = GradModeVariable( target_values=[target_value], initial_values=[torch.is_grad_enabled()], @@ -555,31 +608,37 @@ def create(tx: "InstructionTranslator", target_value, initialized=False, **kwarg return var def __init__( - self, target_values, initial_values=None, initialized=True, **kwargs + self, + target_values: Any, + initial_values: Optional[Sequence[bool]] = None, + initialized: bool = True, + **kwargs: Any, ) -> None: super().__init__( target_values=target_values, initial_values=initial_values, **kwargs ) install_guard(self._guards_singleton) - def enter(self, tx): + def enter(self, tx: "InstructionTranslator") -> VariableTracker: self._call_func(tx, self.target_values) return variables.ConstantVariable.create(None) - def exit(self, tx: "InstructionTranslator", *args): + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: self._call_func(tx, self.initial_values) return variables.ConstantVariable.create(None) def call_function( self, tx: "InstructionTranslator", - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ): + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: self._call_func(tx, self.initial_values) # undo eager initialization return super().call_function(tx, args, kwargs) - def _call_func(self, tx: "InstructionTranslator", values): + def _call_func(self, tx: "InstructionTranslator", values: Any) -> None: assert len(values) == 1 value = values[0] # Coalesce grad mode mutations @@ -589,16 +648,18 @@ def _call_func(self, tx: "InstructionTranslator", values): ) torch._C._set_grad_enabled(value) - def module_name(self): + def module_name(self) -> str: return "torch" - def fn_name(self): + def fn_name(self) -> str: return "set_grad_enabled" class InferenceModeVariable(ContextWrappingVariable): @staticmethod - def create(tx: "InstructionTranslator", target_value, **kwargs): + def create( + tx: "InstructionTranslator", target_value: Any, **kwargs: Any + ) -> "InferenceModeVariable": var = InferenceModeVariable( [target_value], initial_values=torch.is_inference_mode_enabled(), **kwargs ) @@ -606,9 +667,9 @@ def create(tx: "InstructionTranslator", target_value, **kwargs): def __init__( self, - target_values, - initial_values=None, - **kwargs, + target_values: Any, + initial_values: Optional[bool] = None, + **kwargs: Any, ) -> None: if initial_values is None: # This must be called here since function defaults are evaluated at import time @@ -616,9 +677,10 @@ def __init__( super().__init__( target_values=target_values, initial_values=initial_values, **kwargs ) - self.target_values = target_values - def exit(self, tx: "InstructionTranslator", *args): + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: self.cleanup_assert() tx.output.create_node( "call_function", @@ -626,8 +688,9 @@ def exit(self, tx: "InstructionTranslator", *args): (self.proxy,), {}, ) + return variables.ConstantVariable.create(None) - def enter(self, tx): + def enter(self, tx: "InstructionTranslator") -> VariableTracker: disabled_inference_mode_forcibly = False if ( torch._dynamo.config.fake_tensor_disable_inference_mode @@ -642,7 +705,7 @@ def enter(self, tx): else: ctx = torch.autograd.grad_mode._enter_inference_mode(*self.target_values) - def cleanup_hook(): + def cleanup_hook() -> None: if disabled_inference_mode_forcibly: torch._C._set_grad_enabled(prior) else: @@ -655,11 +718,12 @@ def cleanup_hook(): (*self.target_values,), {}, ) + return variables.ConstantVariable.create(None) - def module_name(self): + def module_name(self) -> str: return "torch" - def fn_name(self): + def fn_name(self) -> str: return "inference_mode" @@ -667,7 +731,9 @@ class CUDADeviceVariable(ContextWrappingVariable): """represents torch.cuda.device""" @staticmethod - def create(tx: "InstructionTranslator", device, **kwargs): + def create( + tx: "InstructionTranslator", device: Any, **kwargs: Any + ) -> "CUDADeviceVariable": var = CUDADeviceVariable( target_values=[torch.cuda._get_device_index(device, optional=True)], initial_values=None, @@ -677,16 +743,17 @@ def create(tx: "InstructionTranslator", device, **kwargs): def __init__( self, - target_values, - initial_values=None, - **kwargs, + target_values: Any, + initial_values: Optional[Any] = None, + **kwargs: Any, ) -> None: super().__init__( target_values=target_values, initial_values=initial_values, **kwargs ) - self.target_values = target_values - def exit(self, tx: "InstructionTranslator", *args): + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: self.cleanup_assert() tx.output.create_node( "call_function", @@ -696,7 +763,7 @@ def exit(self, tx: "InstructionTranslator", *args): ) return variables.ConstantVariable.create(False) - def enter(self, tx): + def enter(self, tx: "InstructionTranslator") -> VariableTracker: prev_idx = torch.cuda._exchange_device(*self.target_values) self.set_cleanup_hook(tx, lambda: torch.cuda._maybe_exchange_device(prev_idx)) self.proxy = tx.output.create_node( @@ -705,21 +772,24 @@ def enter(self, tx): (*self.target_values,), {}, ) + return variables.ConstantVariable.create(None) - def module_name(self): + def module_name(self) -> str: return "torch.cuda" - def fn_name(self): + def fn_name(self) -> str: return "device" class TorchFunctionDisableVariable(ContextWrappingVariable): """represents whether torch function overrides are enabled or not""" - _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.TORCH_FUNCTION_STATE) + _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.TORCH_FUNCTION_STATE) # type: ignore[arg-type] @staticmethod - def create(tx: "InstructionTranslator", **kwargs): + def create( + tx: "InstructionTranslator", **kwargs: Any + ) -> "TorchFunctionDisableVariable": var = TorchFunctionDisableVariable( target_values=[], initial_values=[], @@ -728,10 +798,14 @@ def create(tx: "InstructionTranslator", **kwargs): return var def __init__( - self, target_values, initial_values=None, only_subclass=True, **kwargs + self, + target_values: Sized, + initial_values: Optional[Sized] = None, + only_subclass: bool = True, + **kwargs: Any, ) -> None: assert len(target_values) == 0 - assert len(initial_values) == 0 + assert initial_values is not None and len(initial_values) == 0 from ..symbolic_convert import InstructionTranslator tx = InstructionTranslator.current_tx() @@ -748,10 +822,14 @@ def __init__( ) install_guard(self._guards_singleton) - def set_cleanup_hook(self, tx: "InstructionTranslator", fn=None): - if fn is None: + def set_cleanup_hook( + self, + tx: "InstructionTranslator", + cleanup_fn: Optional[Callable[..., Any]] = None, + ) -> None: + if cleanup_fn is None: - def fn(): + def cleanup_fn() -> None: tx.symbolic_torch_function_state.torch_function_subclass_enabled = ( self.initial_torch_function_subclass_enabled ) @@ -760,19 +838,19 @@ def fn(): self.initial_torch_function_subclass_enabled ) - self.cleanup_fn = fn + self.cleanup_fn = cleanup_fn tx.output.add_cleanup_hook(self.cleanup) - def _call_func(self, tx: "InstructionTranslator", values): + def _call_func(self, tx: "InstructionTranslator", values: Sized) -> None: assert len(values) == 0 tx.symbolic_torch_function_state.torch_function_subclass_enabled = False if not self.only_subclass: tx.symbolic_torch_function_state.torch_function_mode_enabled = False - def module_name(self): + def module_name(self) -> str: return "torch._C" - def fn_name(self): + def fn_name(self) -> str: if self.only_subclass: return "DisableTorchFunctionSubclass" return "DisableTorchFunction" @@ -782,11 +860,14 @@ class DeterministicAlgorithmsVariable(ContextWrappingVariable): """represents torch.{are_deterministic_algorithms_enabled,use_deterministic_algorithms}()""" _guards_singleton = Guard( - GlobalStateSource(), GuardBuilder.DETERMINISTIC_ALGORITHMS + GlobalStateSource(), + GuardBuilder.DETERMINISTIC_ALGORITHMS, # type: ignore[arg-type] ) @staticmethod - def create(tx: "InstructionTranslator", target_value, **kwargs): + def create( + tx: "InstructionTranslator", target_value: bool, **kwargs: Any + ) -> "DeterministicAlgorithmsVariable": var = DeterministicAlgorithmsVariable( target_values=[target_value], initial_values=[torch.are_deterministic_algorithms_enabled()], @@ -796,16 +877,21 @@ def create(tx: "InstructionTranslator", target_value, **kwargs): var.set_cleanup_hook(tx) return var - def __init__(self, target_values, initial_values=None, **kwargs) -> None: + def __init__( + self, + target_values: Sequence[bool], + initial_values: Optional[Sequence[bool]] = None, + **kwargs: Any, + ) -> None: super().__init__( target_values=target_values, initial_values=initial_values, **kwargs ) install_guard(self._guards_singleton) - def enter(self, tx): + def enter(self, tx: "InstructionTranslator") -> VariableTracker: return variables.ConstantVariable.create(None) - def _call_func(self, tx: "InstructionTranslator", values): + def _call_func(self, tx: "InstructionTranslator", values: Sequence[bool]) -> None: assert len(values) == 1 value = values[0] tx.output.create_node( @@ -813,10 +899,10 @@ def _call_func(self, tx: "InstructionTranslator", values): ) torch._C._set_deterministic_algorithms(value) - def module_name(self): + def module_name(self) -> str: return "torch" - def fn_name(self): + def fn_name(self) -> str: return "use_deterministic_algorithms" @@ -824,7 +910,9 @@ class DisabledSavedTensorsHooksVariable(ContextWrappingVariable): """represents torch.autograd.graph.disable_saved_tensors_hook.""" @staticmethod - def create(tx: "InstructionTranslator", target_value, **kwargs): + def create( + tx: "InstructionTranslator", target_value: Optional[str], **kwargs: Any + ) -> "DisabledSavedTensorsHooksVariable": var = DisabledSavedTensorsHooksVariable( target_values=[target_value], initial_values=[ @@ -836,15 +924,22 @@ def create(tx: "InstructionTranslator", target_value, **kwargs): var.set_cleanup_hook(tx) return var - def __init__(self, target_values, initial_values=None, **kwargs) -> None: + def __init__( + self, + target_values: Sequence[Optional[str]], + initial_values: Optional[Sequence[Optional[str]]] = None, + **kwargs: Any, + ) -> None: super().__init__( target_values=target_values, initial_values=initial_values, **kwargs ) - def enter(self, tx): + def enter(self, tx: "InstructionTranslator") -> VariableTracker: return variables.ConstantVariable.create(None) - def _call_func(self, tx: "InstructionTranslator", values): + def _call_func( + self, tx: "InstructionTranslator", values: Sequence[Optional[str]] + ) -> None: assert len(values) == 1 value = values[0] if value is not None: @@ -865,16 +960,20 @@ def _call_func(self, tx: "InstructionTranslator", values): ) torch._C._autograd._saved_tensors_hooks_enable() - def module_name(self): + def module_name(self) -> str: return "torch.autograd.graph" - def fn_name(self): + def fn_name(self) -> str: return "disable_saved_tensors_hooks" class AutocastModeVariable(ContextWrappingVariable): @staticmethod - def create(func, args, kwargs): + def create( + func: torch.amp.autocast_mode.autocast, + args: Sequence[Any], + kwargs: dict[str, Any], + ) -> "AutocastModeVariable": assert func in [ torch.amp.autocast_mode.autocast, torch.cuda.amp.autocast, @@ -905,30 +1004,37 @@ def create(func, args, kwargs): var = AutocastModeVariable(target_values, initial_values=None, **kwargs) return var - def __init__(self, target_values, initial_values=None, **kwargs) -> None: + def __init__( + self, + target_values: Sequence[Any], + initial_values: Optional[Any] = None, + **kwargs: Any, + ) -> None: super().__init__( target_values=target_values, initial_values=initial_values, **kwargs ) - self.target_values = target_values - def exit(self, tx: "InstructionTranslator", *args): + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: self.cleanup_assert() tx.output.create_node( "call_function", torch.amp._exit_autocast, (self.proxy,), {} ) return variables.ConstantVariable.create(None) - def enter(self, tx): + def enter(self, tx: "InstructionTranslator") -> VariableTracker: ctx = torch.amp._enter_autocast(*self.target_values) self.set_cleanup_hook(tx, lambda: torch.amp._exit_autocast(ctx)) self.proxy = tx.output.create_node( "call_function", torch.amp._enter_autocast, (*self.target_values,), {} ) + return variables.ConstantVariable.create(None) - def module_name(self): + def module_name(self) -> str: return "torch.amp.autocast_mode" - def fn_name(self): + def fn_name(self) -> str: return "autocast" @@ -937,20 +1043,22 @@ class NullContextVariable(ContextWrappingVariable): This class represents Python contextlib.nullcontext. """ - def __init__(self, target_values=None, **kwargs) -> None: + def __init__(self, target_values: Optional[Any] = None, **kwargs: Any) -> None: super().__init__(target_values=target_values, **kwargs) - def enter(self, tx): + def enter(self, tx: "InstructionTranslator") -> VariableTracker: none = variables.ConstantVariable.create(None) return self.target_values if self.target_values else none - def exit(self, tx: "InstructionTranslator", *args): + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: return variables.ConstantVariable.create(None) - def module_name(self): + def module_name(self) -> str: return "contextlib" - def fn_name(self): + def fn_name(self) -> str: return "nullcontext" @@ -963,22 +1071,24 @@ class ProfilerContextVariable(ContextWrappingVariable): than `None`, per implementation of the torch objects. """ - def __init__(self, **kwargs) -> None: + def __init__(self, **kwargs: Any) -> None: super().__init__(target_values=None, **kwargs) - def enter(self, tx): + def enter(self, tx: "InstructionTranslator") -> VariableTracker: return self - def exit(self, tx: "InstructionTranslator", *args): + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: return variables.ConstantVariable.create(None) - def module_name(self): + def module_name(self) -> str: return "contextlib" - def fn_name(self): + def fn_name(self) -> str: return "nullcontext" - def reconstruct(self, cg): + def reconstruct(self, cg: "PyCodegen") -> None: unimplemented_v2( gb_type="torch.profiler object escaped from compiled region", context=str(self), @@ -995,27 +1105,37 @@ class PreserveVersionContextVariable(ContextWrappingVariable): """ @staticmethod - def _create_lambda_from_tensors(tx, tensors): + def _create_lambda_from_tensors( + tx: "InstructionTranslator", + tensors: VariableTracker, + ) -> "PreserveVersionContextVariable": if isinstance(tensors, variables.TensorVariable): versions = variables.TupleVariable( [x.var_getattr(tx, "_version") for x in [tensors]] ) - tensors = variables.TupleVariable([tensors]) + tensors_tuple = variables.TupleVariable([tensors]) else: + assert isinstance(tensors, variables.TupleVariable) versions = variables.TupleVariable( [x.var_getattr(tx, "_version") for x in tensors.items] ) - return PreserveVersionContextVariable(tensors, versions) + tensors_tuple = tensors + return PreserveVersionContextVariable(tensors_tuple, versions) @staticmethod - def constructor(tx): + def constructor(tx: "InstructionTranslator") -> VariableTracker: return variables.LambdaVariable( lambda tensors: PreserveVersionContextVariable._create_lambda_from_tensors( tx, tensors ) ) - def __init__(self, tensors, prev_versions, **kwargs) -> None: + def __init__( + self, + tensors: VariableTracker, + prev_versions: VariableTracker, + **kwargs: Any, + ) -> None: kwargs.setdefault("target_values", None) super().__init__(**kwargs) self.tensors = tensors @@ -1028,17 +1148,19 @@ def __init__(self, tensors, prev_versions, **kwargs) -> None: ): self.prev_versions = variables.TupleVariable([self.prev_versions]) - def enter(self, tx): - pass + def enter(self, tx: "InstructionTranslator") -> VariableTracker: + return variables.ConstantVariable.create(None) - def exit(self, tx: "InstructionTranslator", *args): + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: from ..tensor_version_op import _unsafe_set_version_counter return variables.TorchInGraphFunctionVariable( _unsafe_set_version_counter ).call_function(tx, [self.tensors, self.prev_versions], {}) - def reconstruct(self, codegen: "PyCodegen"): + def reconstruct(self, codegen: "PyCodegen") -> None: unimplemented_v2( gb_type="torch.autograd._unsafe_preserve_version_counter escaped from compiled region", context=str(self), @@ -1053,10 +1175,15 @@ def reconstruct(self, codegen: "PyCodegen"): class FSDPParamGroupUseTrainingStateVariable(ContextWrappingVariable): - _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.FSDP_TRAINING_STATE) + _guards_singleton = Guard(GlobalStateSource(), GuardBuilder.FSDP_TRAINING_STATE) # type: ignore[arg-type] @staticmethod - def create(tx: "InstructionTranslator", param_group_var, target_value, **kwargs): + def create( + tx: "InstructionTranslator", + param_group_var: Any, + target_value: Any, + **kwargs: Any, + ) -> "FSDPParamGroupUseTrainingStateVariable": var = FSDPParamGroupUseTrainingStateVariable( param_group_var=param_group_var, target_values=[target_value], @@ -1066,7 +1193,11 @@ def create(tx: "InstructionTranslator", param_group_var, target_value, **kwargs) return var def __init__( - self, param_group_var, target_values, initial_values=None, **kwargs + self, + param_group_var: Any, + target_values: Sequence[Any], + initial_values: Optional[Sequence[Any]] = None, + **kwargs: Any, ) -> None: super().__init__( target_values=target_values, initial_values=initial_values, **kwargs @@ -1074,24 +1205,27 @@ def __init__( self.param_group_var = param_group_var install_guard(self._guards_singleton) - def enter(self, tx): + def enter(self, tx: "InstructionTranslator") -> VariableTracker: self._call_func(tx, self.target_values) return variables.ConstantVariable.create(None) - def exit(self, tx: "InstructionTranslator", *args): - self._call_func(tx, self.initial_values) + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: + self._call_func(tx, self.initial_values) # type: ignore[arg-type] return variables.ConstantVariable.create(None) def call_function( self, tx: "InstructionTranslator", - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ): - self._call_func(tx, self.initial_values) # undo eager initialization + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: + # undo eager initialization + self._call_func(tx, self.initial_values) # type: ignore[arg-type] return super().call_function(tx, args, kwargs) - def _call_func(self, tx: "InstructionTranslator", values): + def _call_func(self, tx: "InstructionTranslator", values: Sequence[Any]) -> None: assert len(values) == 1 value = values[0] if self.param_group_var.value._training_state != value: @@ -1106,10 +1240,10 @@ def _call_func(self, tx: "InstructionTranslator", values): ) self.param_group_var.value._training_state = value - def module_name(self): + def module_name(self) -> str: return "torch.distributed.fsdp._fully_shard._fsdp_param_group.FSDPParamGroup" - def fn_name(self): + def fn_name(self) -> str: return "use_training_state" @@ -1117,7 +1251,12 @@ class SDPAKernelVariable(ContextWrappingVariable): """represents torch.nn.attention.sdpa_kernel""" @staticmethod - def create(tx: "InstructionTranslator", backends, set_priority=False, **kwargs): + def create( + tx: "InstructionTranslator", + backends: Any, + set_priority: bool = False, + **kwargs: Any, + ) -> "SDPAKernelVariable": if isinstance(backends, torch.nn.attention.SDPBackend): backends = [backends] var = SDPAKernelVariable( @@ -1131,9 +1270,9 @@ def create(tx: "InstructionTranslator", backends, set_priority=False, **kwargs): def __init__( self, target_values: list[torch.nn.attention.SDPBackend], - initial_values=None, + initial_values: Any = None, set_priority: bool = False, - **kwargs, + **kwargs: Any, ) -> None: super().__init__( target_values=target_values, initial_values=initial_values, **kwargs @@ -1141,7 +1280,10 @@ def __init__( self.set_priority = set_priority @staticmethod - def _backends_to_nodes(tx, backends): + def _backends_to_nodes( + tx: "InstructionTranslator", + backends: list[Any], + ) -> list[Any]: # convert to/from string in order to bake the backend into FX graph nodes = [ tx.output.create_node( @@ -1154,7 +1296,7 @@ def _backends_to_nodes(tx, backends): ] return nodes - def enter(self, tx): + def enter(self, tx: "InstructionTranslator") -> VariableTracker: self.prev_backends = torch.nn.attention._cur_sdpa_kernel_backends( with_priority=self.set_priority ) @@ -1176,7 +1318,9 @@ def enter(self, tx): ) return variables.ConstantVariable.create(None) - def exit(self, tx: "InstructionTranslator", *args): + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: self.cleanup_assert() arg = self._backends_to_nodes(tx, self.prev_backends) tx.output.create_node( @@ -1187,12 +1331,12 @@ def exit(self, tx: "InstructionTranslator", *args): ) return variables.ConstantVariable.create(None) - def module_name(self): + def module_name(self) -> str: return "torch.nn.attention" # use a private version of sdpa_kernel that accepts variadic arguments # since dynamo reconstructs the contents of target_values one-by-one - def fn_name(self): + def fn_name(self) -> str: return "_sdpa_kernel_variadic" @@ -1206,12 +1350,16 @@ class FxTracebackAnnotateVariable(ContextWrappingVariable): __exit__ method (instead of tracing). """ - def __init__(self, target_values, initial_values=None, **kwargs) -> None: + def __init__( + self, target_values: Any, initial_values: Any = None, **kwargs: Any + ) -> None: super().__init__( target_values=target_values, initial_values=initial_values, **kwargs ) - def enter(self, tx, *args): + def enter( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: # Run the annotation ctx manager in eager. Also ensure that # preserve_node_meta context manager is setup. This is important to pass # on the metadata to the create_proxy nodes. @@ -1221,13 +1369,13 @@ def enter(self, tx, *args): self.set_cleanup_hook(tx, lambda: stack.close()) return variables.ConstantVariable.create(None) - def module_name(self): + def module_name(self) -> str: return "torch.fx.traceback" - def fn_name(self): + def fn_name(self) -> str: return "annotate" - def reconstruct_type(self, codegen: "PyCodegen"): + def reconstruct_type(self, codegen: "PyCodegen") -> None: unimplemented_v2( gb_type="torch.fx.traceback.annotate escaped from compiled region", context=str(self), @@ -1243,50 +1391,52 @@ class DynamoConfigPatchVariable(ContextWrappingVariable): # NOTE: no need to guard on dynamo config because dynamo config should not affect soundness # (though it may affect tracing behavior) - def __init__(self, target_values, **kwargs) -> None: - target_values = tuple(target_values.items()) - super().__init__(target_values=(target_values,), initial_values=None, **kwargs) - self.initial_values = {} - for key, _ in target_values: - self.initial_values[key] = torch._dynamo.config.__getattr__(key) - self.initial_values = (tuple(self.initial_values.items()),) - - def _call_func(self, tx: "InstructionTranslator", values): + def __init__(self, target_values: dict[str, Any], **kwargs: Any) -> None: + target_values_tuple = tuple(target_values.items()) + super().__init__( + target_values=(target_values_tuple,), initial_values=None, **kwargs + ) + initial_values_dict = {} + for key, _ in target_values_tuple: + initial_values_dict[key] = torch._dynamo.config.__getattr__(key) # type: ignore[attr-defined] + self.initial_values = (tuple(initial_values_dict.items()),) + + def _call_func(self, tx: "InstructionTranslator", values: Any) -> None: assert len(values) == 1 value = values[0] # manually patch dynamo config for key, val in value: - torch._dynamo.config.__setattr__(key, val) + torch._dynamo.config.__setattr__(key, val) # type: ignore[attr-defined] # No need to keep track of global side effects because # dynamo will properly restore this context manager for # unsupported instructions and continuation functions. # Dynamo config also should not affect the semantics of the compiled graph. - def module_name(self): + def module_name(self) -> str: return "torch._dynamo" - def fn_name(self): + def fn_name(self) -> str: return "patch_dynamo_config" class ErrorOnGraphBreakVariable(ContextWrappingVariable): """represents torch._dynamo.error_on_graph_break""" - def __init__(self, error_on_graph_break, **kwargs) -> None: + def __init__(self, error_on_graph_break: bool, **kwargs: Any) -> None: super().__init__( target_values=(error_on_graph_break,), initial_values=(_get_error_on_graph_break(),), **kwargs, ) - def _call_func(self, tx: "InstructionTranslator", values): + def _call_func(self, tx: "InstructionTranslator", values: Sequence[bool]) -> None: assert len(values) == 1 _set_error_on_graph_break(values[0]) - def module_name(self): + def module_name(self) -> str: return "torch._dynamo" - def fn_name(self): + def fn_name(self) -> str: return "error_on_graph_break" @@ -1294,7 +1444,7 @@ class WithEnterFunctionVariable(VariableTracker): def __init__( self, ctx: Union[ContextWrappingVariable, GenericContextWrappingVariable], - **kwargs, + **kwargs: Any, ) -> None: super().__init__(**kwargs) self.ctx = ctx @@ -1302,16 +1452,17 @@ def __init__( def call_function( self, tx: "InstructionTranslator", - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ) -> "VariableTracker": + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: assert not args assert not kwargs # NOTE: we assume that the instruction immediately after the current CALL instruction # is the first instruction of the block. + # pyrefly: ignore [bad-argument-type] return tx.enter_ctx(self.ctx, tx.current_instruction) - def reconstruct(self, codegen: "PyCodegen"): + def reconstruct(self, codegen: "PyCodegen") -> None: try: type_str = f"{self.ctx.module_name()}.{self.ctx.fn_name()}" except NotImplementedError: @@ -1339,8 +1490,8 @@ class WithExitFunctionVariable(VariableTracker): def __init__( self, ctx: Union[ContextWrappingVariable, GenericContextWrappingVariable], - target, - **kwargs, + target: Any, + **kwargs: Any, ) -> None: super().__init__(**kwargs) assert isinstance( @@ -1352,27 +1503,29 @@ def __init__( def call_function( self, tx: "InstructionTranslator", - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ) -> "VariableTracker": + args: Sequence[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: assert not kwargs return self.ctx.exit(tx, *args) - def reconstruct(self, codegen: "PyCodegen"): + def reconstruct(self, codegen: "PyCodegen") -> None: # Note here we reconstruct the context manager rather than the # exit function. The handler generated by BlockStackEntry # will re-enter the context in the resume function. - self.ctx.reconstruct_type(codegen) + self.ctx.reconstruct_type(codegen) # type: ignore[attr-defined] if codegen.tx.output.partial_convert: if sys.version_info >= (3, 11): codegen.append_output(create_instruction("PUSH_NULL")) if sys.version_info < (3, 13): codegen.append_output(create_instruction("SWAP", arg=2)) + # We rely on classes subtyping `GenericContextWrappingVariable` + # to implement these fns and have these attributes codegen.extend_output( - [codegen.create_load_const(val) for val in self.ctx.target_values] + [codegen.create_load_const(val) for val in self.ctx.target_values] # type: ignore[arg-type] ) codegen.extend_output( - create_call_function(len(self.ctx.target_values), False) + create_call_function(len(self.ctx.target_values), False) # type: ignore[arg-type] ) codegen.append_output(create_setup_with(self.target)) codegen.append_output(create_instruction("POP_TOP")) diff --git a/torch/_dynamo/variables/streams.py b/torch/_dynamo/variables/streams.py index fbc0eed3a99ff..c353181eb8029 100644 --- a/torch/_dynamo/variables/streams.py +++ b/torch/_dynamo/variables/streams.py @@ -116,11 +116,7 @@ def create( **kwargs, ) - def __init__( - self, - stream: Optional["StreamVariable"], - **kwargs: dict[str, Any], - ) -> None: + def __init__(self, stream: Optional["StreamVariable"], **kwargs: Any) -> None: self.stream = stream super().__init__( target_values={"stream": self.get_stream().user_object_index}, @@ -129,14 +125,16 @@ def __init__( ) def enter( - self, tx: "InstructionTranslator", *args: tuple[Any] - ) -> "VariableTracker": + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: # to stream, from stream is the order of the arguments # we are entering the target, and leaving the initial stream tx.symbolic_stream_state.enter_stream(self.get_stream()) return super().enter(tx) - def exit(self, tx: "InstructionTranslator", *args: tuple[Any]) -> "VariableTracker": + def exit( + self, tx: "InstructionTranslator", *args: VariableTracker + ) -> VariableTracker: # to stream, from stream is the order of the arguments # we are leaving the target, and entering the initial stream tx.symbolic_stream_state.exit_stream() @@ -182,7 +180,7 @@ def call_method( name: str, args: list[VariableTracker], kwargs: dict[str, VariableTracker], - ) -> "VariableTracker": + ) -> VariableTracker: assert hasattr(self.value, name), f"no stream method found named {name}" from ..utils import cmp_name_to_op_mapping, proxy_args_kwargs diff --git a/torch/_dynamo/variables/torch.py b/torch/_dynamo/variables/torch.py index c2e3df8e4adce..be28fe9269f44 100644 --- a/torch/_dynamo/variables/torch.py +++ b/torch/_dynamo/variables/torch.py @@ -408,6 +408,7 @@ def call_function( torch.cuda.amp.autocast, torch.cpu.amp.autocast, ): + # pyrefly: ignore [bad-argument-type] return AutocastModeVariable.create(self.value, args, kwargs) elif self.value in ( # NOTE any class added here must align with the semantic diff --git a/torch/_dynamo/variables/torch_function.py b/torch/_dynamo/variables/torch_function.py index 378e9258459f5..fa8412146a427 100644 --- a/torch/_dynamo/variables/torch_function.py +++ b/torch/_dynamo/variables/torch_function.py @@ -164,7 +164,8 @@ def __init__( if value is not None: super().__init__(value, **kwargs) self.value = value - self.cm_obj = value # needed for BC with calling enter from CM code + # needed for BC with calling enter from CM code + self.cm_obj = value # type: ignore[assignment] self.source = source # type: ignore[assignment] def reconstruct(self, codegen: "PyCodegen") -> None: From 4271ffe91849335ffbcc2014c948694f8ec107fd Mon Sep 17 00:00:00 2001 From: Natalia Gimelshein Date: Wed, 5 Nov 2025 00:20:24 +0000 Subject: [PATCH 037/130] don't produce invalid grid configs (#166974) Proper fix for #164048, fixes gather too, reverts #164049 Pull Request resolved: https://github.com/pytorch/pytorch/pull/166974 Approved by: https://github.com/eqy --- aten/src/ATen/native/cuda/IndexKernel.cu | 15 +++------------ aten/src/ATen/native/cuda/IndexKernelUtils.cu | 15 +++++++++------ test/test_cuda.py | 7 ------- test/test_scatter_gather_ops.py | 8 ++++++-- 4 files changed, 18 insertions(+), 27 deletions(-) diff --git a/aten/src/ATen/native/cuda/IndexKernel.cu b/aten/src/ATen/native/cuda/IndexKernel.cu index 927af661396cd..db85f62c8d124 100644 --- a/aten/src/ATen/native/cuda/IndexKernel.cu +++ b/aten/src/ATen/native/cuda/IndexKernel.cu @@ -5,7 +5,6 @@ #include #include #include -#include #include #include #include @@ -74,7 +73,6 @@ void gpu_index_kernel(TensorIteratorBase& iter, const IntArrayRef index_size, co char* const out_ptr = static_cast(iter.data_ptr(0)); char* const in_ptr = static_cast(iter.data_ptr(1)); - if (is_gather_like && num_indices==1) { const size_t element_size = iter.element_size(0); constexpr size_t alignment = 16; @@ -84,16 +82,9 @@ void gpu_index_kernel(TensorIteratorBase& iter, const IntArrayRef index_size, co auto ind_dim_size = index_size[0]; auto inp_stride_bytes = index_stride[0]; auto out_stride_bytes = iter.strides(0)[1]; - // avoid grid overflow in the fast kernel - const int64_t vec_chunks = ceil_div(slice_size, alignment); - const int64_t blocks_per_slice_upper = ceil_div(vec_chunks, (int64_t)launch_size_nd); - const int max_grid_y = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; - // if it's an eligible grid we use the fast path, otherwise default to slower path - if (blocks_per_slice_upper <= max_grid_y) { - at::native::vectorized_gather_kernel_launch(out_ptr, in_ptr, (int64_t*)iter.data_ptr(2), num_ind, - slice_size, ind_dim_size, inp_stride_bytes, out_stride_bytes, /*allow_neg_indices*/true); - return; - } + at::native::vectorized_gather_kernel_launch(out_ptr, in_ptr, (int64_t*)iter.data_ptr(2), num_ind, + slice_size, ind_dim_size, inp_stride_bytes, out_stride_bytes, /*allow_neg_indices*/true); + return; } } diff --git a/aten/src/ATen/native/cuda/IndexKernelUtils.cu b/aten/src/ATen/native/cuda/IndexKernelUtils.cu index 8343c60418952..1e998251dd7be 100644 --- a/aten/src/ATen/native/cuda/IndexKernelUtils.cu +++ b/aten/src/ATen/native/cuda/IndexKernelUtils.cu @@ -13,11 +13,12 @@ __global__ void vectorized_gather_kernel(char * out, char * inp, index_t * idx, if (allow_neg_indices) { ind = (ind < 0) ? ind + ind_dim_size : ind; } - CUDA_KERNEL_ASSERT_VERBOSE(ind >=0 && ind < ind_dim_size && "vectorized gather kernel index out of bounds", "Expected 0 <= index < ind_dim_size(%ld), but got index = %ld", ind_dim_size, ind); - int32_t off = (blockDim.x * blockIdx.y + threadIdx.x) * Alignment; // off is guaranteed to be within int32 limits - if (off >= slice_size) return; - auto vec = at::native::memory::ld_vec(inp + ind * inp_stride + off); - at::native::memory::st_vec(out + blockIdx.x * (int32_t)out_stride + off, vec); // out offset is guaranteed to be within int32 limits + CUDA_KERNEL_ASSERT_VERBOSE(ind >=0 && ind < ind_dim_size && "vectorized gather kernel index out of bounds"); + // off is guaranteed to be within int32 limits + for (int32_t off = (blockDim.x * blockIdx.y + threadIdx.x) * Alignment; off < slice_size; off += blockDim.x * gridDim.y * Alignment) { + auto vec = at::native::memory::ld_vec(inp + ind * inp_stride + off); + at::native::memory::st_vec(out + blockIdx.x * (int32_t)out_stride + off, vec); // out offset is guaranteed to be within int32 limits + } } @@ -30,7 +31,9 @@ void vectorized_gather_kernel_launch(char * out, char * inp, index_t * idx, int auto num_threads = at::round_up( at::ceil_div(slice_size_in_bytes, Alignment), static_cast(C10_WARP_SIZE)); - dim3 grid = {static_cast(num_ind), static_cast(at::ceil_div(slice_size_in_bytes, max_num_threads * Alignment)), 1}; + uint32_t grid_y = at::cuda::getCurrentDeviceProperties()->maxGridSize[1]; + grid_y = std::min(static_cast(at::ceil_div(slice_size_in_bytes, max_num_threads * Alignment)), grid_y); + dim3 grid = {static_cast(num_ind), grid_y, 1}; auto block = std::min(max_num_threads, num_threads); vectorized_gather_kernel<<>>(out, inp, idx, num_ind, slice_size_in_bytes, ind_dim_size, inp_stride_bytes, out_stride_bytes, allow_neg_indices); diff --git a/test/test_cuda.py b/test/test_cuda.py index 00c3b00d6049c..329261fba7d3a 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -1474,13 +1474,6 @@ def test_huge_index(self): res_cpu = src.cpu()[idx.cpu()] self.assertEqual(res.cpu(), res_cpu) - def test_fast_index_overflow(self): - src = torch.randint(0, 20, (4, 87, 1056, 736), device="cuda") - indices = torch.tensor([True, False, False, True], device="cuda") - res = src[indices] - res_cpu = src.cpu()[indices.cpu()] - self.assertEqual(res.cpu(), res_cpu) - def test_randint_randomness_for_large_range(self) -> None: # For large ranges, randint generation is slightly different. This lead to a subtle bug where some Philox # offsets were not calculated correctly, resulting in reused random states. diff --git a/test/test_scatter_gather_ops.py b/test/test_scatter_gather_ops.py index ba967c142f1e7..96768f34affb0 100644 --- a/test/test_scatter_gather_ops.py +++ b/test/test_scatter_gather_ops.py @@ -6,7 +6,7 @@ from torch.testing import make_tensor from torch.testing._internal.common_utils import \ - (parametrize, run_tests, TestCase, DeterministicGuard, TEST_WITH_ROCM) + (parametrize, run_tests, TestCase, DeterministicGuard, TEST_WITH_ROCM, serialTest) from torch.testing._internal.common_device_type import \ (instantiate_device_type_tests, onlyCPU, dtypes, dtypesIfCUDA, toleranceOverride, tol,) @@ -65,10 +65,12 @@ def test_gather(self, device, dtype): actual = torch.gather(src, 2, idx) self.assertEqual(actual, expected, atol=0, rtol=0) + @serialTest() @dtypes(torch.int8, torch.bfloat16) def test_gather_large(self, device, dtype): # test larger shapes to check vectorized implementation - for (m, n, k) in ((4096, 3072, 4096), (4096, 3072, 4100)): + for (m, n, k) in ((4096, 3072, 4096), (4096, 3072, 4100), (4, 4, 16384 * 8192)): + torch.cuda.empty_cache() src = make_tensor((m, k), device=device, dtype=dtype) alloc0 = torch.empty(src.nelement() * 2, device=device, dtype=dtype) discontig = alloc0.view(m, 2 * k)[:, ::2].copy_(src) @@ -111,6 +113,8 @@ def test_gather_large(self, device, dtype): self.assertEqual(res_ind, ref, atol=0, rtol=0) res_gather = torch.gather(misaligned1, dim=dim, index=ind) self.assertEqual(res_gather, ref, atol=0, rtol=0) + del src, alloc0, alloc1, alloc2 + del discontig, misaligned, misaligned1 # test gather along 1st dim that can accidentally trigger fast path # because due to index dimension in the gather dim being 1 # an unexpected squashing in tensorIterator happens From f2fbc81c506d4497d1505a7c27d949e4fc4da8d6 Mon Sep 17 00:00:00 2001 From: Oguz Ulgen Date: Tue, 4 Nov 2025 13:58:39 -0800 Subject: [PATCH 038/130] [RFC] Add experimental Pallas TorchInductor backend (#166822) Very simple Pallas TorchInductor backend Given ``` import torch def f(x, y): return x.sin() + y torch._inductor.config.cuda_backend="pallas" x = torch.randn(4).cuda() y = torch.randn(4).cuda() compiled = torch.compile(f, backend="inductor", fullgraph=True) torch.testing.assert_close(compiled(x, y), f(x, y)) ``` it outputs ``` import torch import jax import jax.numpy as jnp from jax.experimental import pallas as pl from torch.utils import dlpack as torch_dlpack def pallas_fused_add_sin_56b646d2_kernel(in_ptr0, in_ptr1, out_ptr0): tmp0 = in_ptr0[...] tmp1 = jnp.sin(tmp0) tmp2 = in_ptr1[...] tmp3 = tmp1 + tmp2 out_ptr0[...] = tmp3 def pallas_fused_add_sin_56b646d2_main(in_ptr0, in_ptr1, out_ptr0, stream=None): # Convert Torch -> JAX for inputs in_ptr0_jax = jax.dlpack.from_dlpack(torch_dlpack.to_dlpack(in_ptr0)) in_ptr1_jax = jax.dlpack.from_dlpack(torch_dlpack.to_dlpack(in_ptr1)) # Prepare output spec from PyTorch tensor # Map PyTorch dtype to JAX dtype string _torch_dtype_to_jax = { torch.float32: jnp.float32, torch.float64: jnp.float64, torch.float16: jnp.float16, torch.int32: jnp.int32, torch.int64: jnp.int64, torch.int16: jnp.int16, torch.int8: jnp.int8, torch.uint8: jnp.uint8, torch.bool: jnp.bool_, } out_spec = jax.ShapeDtypeStruct(out_ptr0.shape, _torch_dtype_to_jax[out_ptr0.dtype]) compiled = pl.pallas_call( lambda *refs: pallas_fused_add_sin_56b646d2_kernel(*refs), out_shape=out_spec, grid=(1,), ) res = compiled(in_ptr0_jax, in_ptr1_jax) # Copy result back into the provided torch output tensor res_t = torch_dlpack.from_dlpack(jax.dlpack.to_dlpack(res)) out_ptr0.copy_(res_t) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/166822 Approved by: https://github.com/jansel ghstack dependencies: #166976, #166982 --- test/inductor/test_pallas.py | 354 ++++++++++++++++++ torch/_inductor/async_compile.py | 36 ++ torch/_inductor/codegen/common.py | 2 + torch/_inductor/codegen/pallas.py | 424 ++++++++++++++++++++++ torch/_inductor/config.py | 5 +- torch/testing/_internal/inductor_utils.py | 9 +- torch/utils/_pallas.py | 82 +++++ 7 files changed, 907 insertions(+), 5 deletions(-) create mode 100644 test/inductor/test_pallas.py create mode 100644 torch/_inductor/codegen/pallas.py create mode 100644 torch/utils/_pallas.py diff --git a/test/inductor/test_pallas.py b/test/inductor/test_pallas.py new file mode 100644 index 0000000000000..2d4e6af002ab0 --- /dev/null +++ b/test/inductor/test_pallas.py @@ -0,0 +1,354 @@ +# Owner(s): ["oncall: pt2"] +import functools +import sys +import unittest + +import torch +import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools +from torch._dynamo.testing import make_test_cls_with_patches +from torch._inductor import config +from torch._inductor.test_case import run_tests, TestCase +from torch._inductor.utils import run_and_get_code +from torch.testing._internal.common_utils import IS_CI, IS_WINDOWS +from torch.testing._internal.inductor_utils import HAS_PALLAS +from torch.utils._triton import has_triton + + +if IS_WINDOWS and IS_CI: + sys.stderr.write( + "Windows CI does not have necessary dependencies for test_torchinductor yet\n" + ) + if __name__ == "__main__": + sys.exit(0) + raise unittest.SkipTest("requires sympy/functorch/filelock") + + +try: + from . import test_torchinductor +except ImportError: + import test_torchinductor # @manual=fbcode//caffe2/test/inductor:test_inductor-library + + +test_classes = {} + + +def make_pallas(cls): + """Create a test class variant that uses Pallas backend.""" + suffix = "_pallas" + cls_prefix = "Pallas" + + test_class = make_test_cls_with_patches( + cls, + cls_prefix, + suffix, + (config, "cuda_backend", "pallas"), + xfail_prop="_expected_failure_pallas", + ) + + test_classes[test_class.__name__] = test_class + # REMOVING THIS LINE WILL STOP TESTS FROM RUNNING + globals()[test_class.__name__] = test_class + test_class.__module__ = __name__ + return test_class + + +@unittest.skipUnless(HAS_PALLAS, "requires jax and pallas") +class PallasTests(TestCase): + """Basic tests for Pallas backend functionality.""" + + def test_simple_add(self): + """Test basic element-wise addition.""" + + def fn(a, b): + return a + b + + compiled = torch.compile( + fn, backend="inductor", options={"cuda_backend": "pallas"} + ) + + a = torch.randn(1024, device="cuda") + b = torch.randn(1024, device="cuda") + result = compiled(a, b) + expected = fn(a, b) + self.assertEqual(result, expected) + + def test_simple_mul(self): + """Test basic element-wise multiplication.""" + + def fn(a, b): + return a * b + + compiled = torch.compile( + fn, backend="inductor", options={"cuda_backend": "pallas"} + ) + + a = torch.randn(1024, device="cuda") + b = torch.randn(1024, device="cuda") + result = compiled(a, b) + expected = fn(a, b) + self.assertEqual(result, expected) + + def test_sin(self): + """Test sin operation.""" + + def fn(x): + return torch.sin(x) + + compiled = torch.compile( + fn, backend="inductor", options={"cuda_backend": "pallas"} + ) + + x = torch.randn(1024, device="cuda") + result = compiled(x) + expected = fn(x) + self.assertEqual(result, expected) + + def test_fused_ops(self): + """Test fused operations (sin + add).""" + + def fn(x, y): + return x.sin() + y + + compiled = torch.compile( + fn, backend="inductor", options={"cuda_backend": "pallas"} + ) + + x = torch.randn(1024, device="cuda") + y = torch.randn(1024, device="cuda") + result = compiled(x, y) + expected = fn(x, y) + self.assertEqual(result, expected) + + def test_exp_log(self): + """Test exp and log operations.""" + + def fn(x): + return torch.log(torch.exp(x)) + + compiled = torch.compile( + fn, backend="inductor", options={"cuda_backend": "pallas"} + ) + + x = torch.randn(1024, device="cuda") + result = compiled(x) + expected = fn(x) + self.assertEqual(result, expected) + + def test_sqrt(self): + """Test sqrt operation.""" + + def fn(x): + return torch.sqrt(x) + + compiled = torch.compile( + fn, backend="inductor", options={"cuda_backend": "pallas"} + ) + + x = torch.randn(1024, device="cuda").abs() # Ensure positive for sqrt + result = compiled(x) + expected = fn(x) + self.assertEqual(result, expected) + + def test_tanh(self): + """Test tanh operation.""" + + def fn(x): + return torch.tanh(x) + + compiled = torch.compile( + fn, backend="inductor", options={"cuda_backend": "pallas"} + ) + + x = torch.randn(1024, device="cuda") + result = compiled(x) + expected = fn(x) + self.assertEqual(result, expected) + + def test_abs_neg(self): + """Test abs and neg operations.""" + + def fn(x): + return torch.abs(-x) + + compiled = torch.compile( + fn, backend="inductor", options={"cuda_backend": "pallas"} + ) + + x = torch.randn(1024, device="cuda") + result = compiled(x) + expected = fn(x) + self.assertEqual(result, expected) + + def test_maximum_minimum(self): + """Test maximum and minimum operations.""" + + def fn(a, b): + return torch.maximum(a, b) + torch.minimum(a, b) + + compiled = torch.compile( + fn, backend="inductor", options={"cuda_backend": "pallas"} + ) + + a = torch.randn(1024, device="cuda") + b = torch.randn(1024, device="cuda") + result = compiled(a, b) + expected = fn(a, b) + self.assertEqual(result, expected) + + @unittest.skipUnless(has_triton(), "requires triton") + @unittest.skip("Random ops not yet implemented in Pallas backend") + def test_random_consistency(self): + """Test that random number generation is consistent across backends.""" + seed = 1234 + shape = (3, 3) + dtype = torch.float32 + + for rand_fn in [ + functools.partial(torch.rand, shape, dtype=dtype, device="cuda"), + functools.partial(torch.randn, shape, dtype=dtype, device="cuda"), + ]: + + @torch.compile(backend="inductor", options={"cuda_backend": "pallas"}) + def get_rand_pallas(): + return rand_fn() + + @torch.compile(backend="inductor", options={"cuda_backend": "triton"}) + def get_rand_triton(): + return rand_fn() + + torch.manual_seed(seed) + pallas_output = get_rand_pallas() + torch.manual_seed(seed) + triton_output = get_rand_triton() + + self.assertEqual(pallas_output, triton_output) + + def test_compile_options(self): + """Test that Pallas backend is properly configured.""" + + @torch.compile( + backend="inductor", + options={"cuda_backend": "pallas"}, + ) + def pallas_fn(a, b): + return a.sin() + b.cos() + + _, (code,) = run_and_get_code( + pallas_fn, + torch.randn(64, device="cuda"), + torch.randn(64, device="cuda"), + ) + # Verify Pallas-specific code generation + self.assertIn("import jax", code) + self.assertIn("import jax.numpy as jnp", code) + self.assertIn("from jax.experimental import pallas as pl", code) + + def test_2d_tensor(self): + """Test with 2D tensors (though current implementation flattens).""" + + def fn(x, y): + return x + y + + compiled = torch.compile( + fn, backend="inductor", options={"cuda_backend": "pallas"} + ) + + x = torch.randn(32, 32, device="cuda") + y = torch.randn(32, 32, device="cuda") + result = compiled(x, y) + expected = fn(x, y) + self.assertEqual(result, expected) + + def test_different_shapes(self): + """Test with different tensor shapes.""" + + def fn(x): + return x * 2.0 + + compiled = torch.compile( + fn, backend="inductor", options={"cuda_backend": "pallas"} + ) + + for shape in [(64,), (128,), (256,), (1024,)]: + x = torch.randn(shape, device="cuda") + result = compiled(x) + expected = fn(x) + self.assertEqual(result, expected) + + def test_contiguous_index_validation(self): + """Test that contiguous index validation works correctly end-to-end.""" + + # Test 1: Contiguous operations should work + def contiguous_add(a, b): + return a + b + + compiled = torch.compile( + contiguous_add, backend="inductor", options={"cuda_backend": "pallas"} + ) + + a = torch.randn(1024, device="cuda") + b = torch.randn(1024, device="cuda") + result = compiled(a, b) + expected = contiguous_add(a, b) + self.assertEqual(result, expected) + + # Test 2: Operations on contiguous tensors should work + def contiguous_mul(x): + return x * 2.0 + + compiled = torch.compile( + contiguous_mul, backend="inductor", options={"cuda_backend": "pallas"} + ) + + x = torch.randn(128, 8, device="cuda") + result = compiled(x) + expected = contiguous_mul(x) + self.assertEqual(result, expected) + + # Test 3: Non-contiguous views will fail at runtime with JAX/Pallas + # This demonstrates that the Pallas backend requires contiguous memory layout + def operate_on_tensor(x): + return x.sin() + + compiled = torch.compile( + operate_on_tensor, backend="inductor", options={"cuda_backend": "pallas"} + ) + + # Create a transposed (non-contiguous) view + x = torch.randn(64, 32, device="cuda") + x_t = x.t() # Non-contiguous view + self.assertFalse(x_t.is_contiguous()) + + # This will fail because JAX/Pallas cannot handle non-contiguous layout via DLPack + # The error indicates that our contiguous-only approach is correct + with self.assertRaises((RuntimeError, Exception)) as cm: + result = compiled(x_t) + + # Verify the error is related to layout/contiguous issues + error_msg = str(cm.exception) + self.assertTrue( + "layout" in error_msg.lower() + or "contiguous" in error_msg.lower() + or "non-default" in error_msg.lower(), + f"Expected layout/contiguous error, got: {error_msg}", + ) + + # But if we make it contiguous first, it should work + x_t_contiguous = x_t.contiguous() + self.assertTrue(x_t_contiguous.is_contiguous()) + result = compiled(x_t_contiguous) + expected = operate_on_tensor(x_t_contiguous) + self.assertEqual(result, expected) + + +# Create test variants using the main test suite +# Note: Only enable GPU tests since Pallas primarily targets GPU +if test_torchinductor.HAS_GPU and HAS_PALLAS: + # Uncomment these to run full test suite with Pallas backend + # make_pallas(test_torchinductor.SweepInputsGPUTest) + # make_pallas(test_torchinductor.GPUTests) + pass + +if __name__ == "__main__": + if HAS_PALLAS: + run_tests(needs="filelock") diff --git a/torch/_inductor/async_compile.py b/torch/_inductor/async_compile.py index a2c80002eb928..5ede0cd085010 100644 --- a/torch/_inductor/async_compile.py +++ b/torch/_inductor/async_compile.py @@ -601,6 +601,42 @@ def task(): future = self.submit(task) return LambdaFuture(lambda: future.result()) + def pallas(self, kernel_name: str, source_code: str): + """ + Compile Pallas (JAX experimental) kernels. + + Args: + kernel_name: Name of the kernel to be defined + source_code: Source code of the Pallas kernel, as a string + + Note: + Pallas kernels are Python code that uses JAX and Pallas APIs. + We use the PyCodeCache to write the source code to a file and load it. + """ + from torch._inductor.codegen.pallas import MAIN_SUFFIX, PallasKernelWrapper + + kernel_code_log.info("Pallas Kernel:\n%s", source_code) + + def task(): + key, path = torch._inductor.codecache.PyCodeCache.write(source_code) + mod = torch._inductor.codecache.PyCodeCache.load_by_key_path(key, path) + + # Find our special entry point named function + main_func_name = f"{kernel_name}_{MAIN_SUFFIX}" + if not hasattr(mod, main_func_name): + available = [name for name in dir(mod) if callable(getattr(mod, name))] + raise RuntimeError( + f"Could not find Pallas main kernel function '{main_func_name}'. Available callables: {available}" + ) + + return PallasKernelWrapper(getattr(mod, main_func_name), kernel_path=path) + + if get_compile_threads() <= 1: + return task() + else: + future = self.submit(task) + return LambdaFuture(lambda: future.result()) + def wait(self, scope: dict[str, Any]) -> None: if get_compile_threads() > 1: with dynamo_timed( diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index e6a5c5e8ec176..730c03f1c813c 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -510,6 +510,7 @@ def init_backend_registration() -> None: from .cuda_combined_scheduling import CUDACombinedScheduling from .halide import HalideScheduling from .mps import MetalScheduling + from .pallas import PallasScheduling from .python_wrapper_mtia import PythonWrapperMtia from .triton import TritonScheduling from .wrapper import PythonWrapperCodegen @@ -536,6 +537,7 @@ def init_backend_registration() -> None: cuda_backends = { "triton": CUDACombinedScheduling, "halide": HalideScheduling, + "pallas": PallasScheduling, } register_backend_for_device( "cuda", diff --git a/torch/_inductor/codegen/pallas.py b/torch/_inductor/codegen/pallas.py new file mode 100644 index 0000000000000..1fc8e40724bc0 --- /dev/null +++ b/torch/_inductor/codegen/pallas.py @@ -0,0 +1,424 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +import hashlib +from typing import Any, Callable, Optional, Sequence, TYPE_CHECKING + +import sympy # noqa: TC002 + +import torch # noqa: TC001 +from torch.utils._ordered_set import OrderedSet + +from .. import config +from ..utils import get_fused_kernel_name, get_kernel_metadata +from ..virtualized import V +from .common import BackendFeature, CSEVariable, IndentedBuffer, OpOverrides +from .simd import SIMDKernel, SIMDScheduling + + +if TYPE_CHECKING: + from ..ir import IRNode + from ..scheduler import BaseSchedulerNode + + +# Main function suffix used in generated Pallas code +MAIN_SUFFIX = "main" + +# Logger for Pallas kernel code +kernel_code_log = torch._logging.getArtifactLogger(__name__, "kernel_code") + + +class PallasKernelWrapper: + """Wrapper to provide .run() interface for Pallas kernels""" + + def __init__( + self, kernel_fn: Callable[..., Any], kernel_path: Optional[str] = None + ): + self.kernel_fn = kernel_fn + self.kernel_path = kernel_path + kernel_code_log.info("Pallas kernel path: %s", kernel_path) + + def run(self, *args, stream=None, **kwargs): + """ + Execute the Pallas kernel. + + Args: + *args: Arguments to pass to the kernel function + stream: CUDA stream to pass to the kernel function + **kwargs: Additional keyword arguments for the kernel + + Returns: + Result of the kernel execution + """ + return self.kernel_fn(*args, stream=stream, **kwargs) + + +class Unsupported(RuntimeError): + """Exception raised when an operation is not supported by the Pallas backend.""" + + +class PallasKernelOverrides(OpOverrides): + """ + Map element-wise ops to JAX/Pallas operations. + + For now, we use the default Python operators which are compatible + with JAX numpy broadcasting semantics. + """ + + @staticmethod + def sin(x: str) -> str: + return f"jnp.sin({x})" + + @staticmethod + def cos(x: str) -> str: + return f"jnp.cos({x})" + + @staticmethod + def tan(x: str) -> str: + return f"jnp.tan({x})" + + @staticmethod + def sinh(x: str) -> str: + return f"jnp.sinh({x})" + + @staticmethod + def cosh(x: str) -> str: + return f"jnp.cosh({x})" + + @staticmethod + def tanh(x: str) -> str: + return f"jnp.tanh({x})" + + @staticmethod + def asin(x: str) -> str: + return f"jnp.arcsin({x})" + + @staticmethod + def acos(x: str) -> str: + return f"jnp.arccos({x})" + + @staticmethod + def atan(x: str) -> str: + return f"jnp.arctan({x})" + + @staticmethod + def exp(x: str) -> str: + return f"jnp.exp({x})" + + @staticmethod + def exp2(x: str) -> str: + return f"jnp.exp2({x})" + + @staticmethod + def expm1(x: str) -> str: + return f"jnp.expm1({x})" + + @staticmethod + def log(x: str) -> str: + return f"jnp.log({x})" + + @staticmethod + def log10(x: str) -> str: + return f"jnp.log10({x})" + + @staticmethod + def log2(x: str) -> str: + return f"jnp.log2({x})" + + @staticmethod + def log1p(x: str) -> str: + return f"jnp.log1p({x})" + + @staticmethod + def sqrt(x: str) -> str: + return f"jnp.sqrt({x})" + + @staticmethod + def rsqrt(x: str) -> str: + return f"(1.0 / jnp.sqrt({x}))" + + @staticmethod + def abs(x: str) -> str: + return f"jnp.abs({x})" + + @staticmethod + def neg(x: str) -> str: + return f"(-{x})" + + @staticmethod + def floor(x: str) -> str: + return f"jnp.floor({x})" + + @staticmethod + def ceil(x: str) -> str: + return f"jnp.ceil({x})" + + @staticmethod + def trunc(x: str) -> str: + return f"jnp.trunc({x})" + + @staticmethod + def round(x: str) -> str: + return f"jnp.round({x})" + + @staticmethod + def sigmoid(x: str) -> str: + return f"(1.0 / (1.0 + jnp.exp(-{x})))" + + @staticmethod + def relu(x: str) -> str: + return f"jnp.maximum({x}, 0)" + + @staticmethod + def pow(a: str, b: str) -> str: + return f"jnp.power({a}, {b})" + + @staticmethod + def maximum(a: str, b: str) -> str: + return f"jnp.maximum({a}, {b})" + + @staticmethod + def minimum(a: str, b: str) -> str: + return f"jnp.minimum({a}, {b})" + + @staticmethod + def where(cond: str, a: str, b: str) -> str: + return f"jnp.where({cond}, {a}, {b})" + + +class PallasKernel(SIMDKernel): + """ + Minimal Pallas kernel for simple elementwise operations. + + Strategy: + - Treat loads as full-array refs: "in_ptrX[...]" + - Compute expression with Python operators (compatible with jax.numpy broadcasting) + - Store as full-array ref assignment: "out_ptrY[...] = " + - Generate Python code that defines a Pallas kernel and a host entrypoint. + - Use async_compile.pallas path to compile and load Python code. + """ + + overrides = PallasKernelOverrides # type: ignore[assignment] + + def _get_contiguous_index_str(self, index: sympy.Expr) -> str: + """ + Validate that the index represents contiguous access and return the indexing string. + + For Pallas, we only support simple contiguous access patterns where the index + is a single symbol (e.g., xindex) representing a flattened iteration. + This ensures the load/store order is contiguous. + + Args: + index: The indexing expression to validate + + Returns: + The indexing string to use (currently always "...") + + Raises: + Unsupported: If the index is not a simple contiguous pattern + """ + # Prepare and simplify the index + prepared_index = self.prepare_indexing(index) + + # For contiguous access, we expect a single symbol (like xindex) + # or a simple integer (for scalar operations) + if isinstance(prepared_index, sympy.Symbol): + # This is the expected case: a single symbol representing contiguous iteration + return "..." + elif prepared_index.is_Integer: + # Scalar case + return "..." + else: + # If there's any complex expression (ModularIndexing, FloorDiv, etc.), + # it's not a simple contiguous pattern + raise Unsupported( + f"Pallas backend only supports contiguous access patterns. " + f"Got complex index: {prepared_index}" + ) + + def load(self, name: str, index: sympy.Expr) -> CSEVariable: # type: ignore[override] + buf = self.args.input(name) + dtype = V.graph.get_dtype(name) + # Validate contiguous access and get index string + index_str = self._get_contiguous_index_str(index) + # Pallas refs must be unpacked with [...] to load the array + return self.cse.generate( + self.compute, + f"{buf}[{index_str}]", + dtype=dtype, + ) + + def store( + self, name: str, index: sympy.Expr, value: CSEVariable, mode: Any = None + ) -> None: # type: ignore[override] + if mode is not None: + raise Unsupported("pallas store mode not supported") + out = self.args.output(name) + self.store_buffer_names.add(name) + # Validate contiguous access and get index string + index_str = self._get_contiguous_index_str(index) + # Pallas refs must use [...] assignment to store back to the ref + self.stores.writeline(f"{out}[{index_str}] = {value}") + + def codegen_kernel(self, name: Optional[str] = None) -> str: # type: ignore[override] + """ + Generate the complete Pallas kernel code as a Python string. + + This includes: + - Import statements for JAX/Pallas + - The kernel function that operates on refs + - The main wrapper function that handles PyTorch<->JAX conversions via DLPack + + Args: + name: Optional kernel name (will use placeholder if not provided) + + Returns: + str: Complete Python source code for the Pallas kernel + """ + # Ensure one (1) output for now + live_outs = list(self.args.live_output_buffers()) + if len(live_outs) != 1: + raise Unsupported( + "Pallas backend currently supports single-output elementwise kernels only" + ) + + code = IndentedBuffer() + code.splice( + """ + import torch + import jax + import jax.numpy as jnp + from jax.experimental import pallas as pl + from torch.utils import dlpack as torch_dlpack + """, + strip=True, + ) + + # Define the Pallas kernel: accepts refs, uses broadcasted expressions + arg_defs, _, _, _ = self.args.python_argdefs() + # Order: inputs (in_ptr*), then outputs (out_ptr*), then sizes/workspaces + kernel_params = [a.name for a in arg_defs] + + kernel_name = name or "" + code.writeline(f"def {kernel_name}_kernel({', '.join(kernel_params)}):") + with code.indent(): + # Emit compute (CSE) and store lines; they reference *_ptr[...] directly + for line in self.compute._lines: + code.writeline(str(line)) + for line in self.stores._lines: + code.writeline(str(line)) + + # Host entry: convert torch tensors <-> jax, call pallas_call and copy back + main_name = f"{kernel_name}_main" + code.writeline(f"def {main_name}({', '.join(kernel_params)}, stream=None):") + with code.indent(): + # Identify inputs (in_ptr*) and output (out_ptr*) + input_params = [ + p for p in kernel_params if p.startswith(("in_ptr", "in_out_ptr")) + ] + output_params = [p for p in kernel_params if p.startswith("out_ptr")] + + if len(output_params) != 1: + raise RuntimeError( + f"Expected exactly 1 output, got {len(output_params)}" + ) + + output_param = output_params[0] + + # Convert inputs to JAX arrays + code.writeline("# Convert Torch -> JAX for inputs") + for inp in input_params: + code.writeline( + f"{inp}_jax = jax.dlpack.from_dlpack(torch_dlpack.to_dlpack({inp}))" + ) + + # Get output spec from PyTorch tensor + code.writeline("# Prepare output spec from PyTorch tensor") + code.writeline("# Map PyTorch dtype to JAX dtype string") + code.writeline("_torch_dtype_to_jax = {") + code.writeline( + " torch.float32: jnp.float32, torch.float64: jnp.float64, torch.float16: jnp.float16," + ) + code.writeline( + " torch.int32: jnp.int32, torch.int64: jnp.int64, torch.int16: jnp.int16, torch.int8: jnp.int8," + ) + code.writeline(" torch.uint8: jnp.uint8, torch.bool: jnp.bool_,") + code.writeline("}") + code.writeline( + f"out_spec = jax.ShapeDtypeStruct({output_param}.shape, _torch_dtype_to_jax[{output_param}.dtype])" + ) + + # Call pallas + code.writeline("compiled = pl.pallas_call(") + code.writeline(f" lambda *refs: {kernel_name}_kernel(*refs),") + code.writeline(" out_shape=out_spec,") + code.writeline(" grid=(1,),") + code.writeline(")") + + jax_input_args = ", ".join([f"{inp}_jax" for inp in input_params]) + code.writeline(f"res = compiled({jax_input_args})") + + # Copy result back + code.writeline("# Copy result back into the provided torch output tensor") + code.writeline( + "res_t = torch_dlpack.from_dlpack(jax.dlpack.to_dlpack(res))" + ) + code.writeline(f"{output_param}.copy_(res_t)") + + return code.getvalue() + + def call_kernel(self, name: str, node: Optional[IRNode] = None) -> None: # type: ignore[override] + """Generate the Python code that calls this Pallas kernel.""" + wrapper = V.graph.wrapper_code + _, call_args, _, arg_types = self.args.python_argdefs() + + # Generate kernel call: kernel_name.run(arg1, arg2, ...) + # Note: async_compile.pallas loads {name}_main function and wraps it in PallasKernelWrapper + # which exposes a run() method + kernel_call = f"{name}.run({', '.join(map(str, call_args))})" + wrapper.writeline(kernel_call) + + +class PallasScheduling(SIMDScheduling): + kernel_type = PallasKernel # type: ignore[assignment] + + @classmethod + def get_backend_features(cls, device: torch.device) -> OrderedSet[BackendFeature]: + # Start minimal: no special features advertised + return OrderedSet() + + def define_kernel( + self, + src_code: str, + node_schedule: Sequence[BaseSchedulerNode], + kernel: PallasKernel, + ) -> str: # type: ignore[override] + wrapper = V.graph.wrapper_code + if src_code in wrapper.src_to_kernel: + return wrapper.src_to_kernel[src_code] + + fused_name = ( + get_fused_kernel_name(node_schedule, config.triton.descriptive_names) + if config.triton.descriptive_names + else "" + ) + kernel_hash = hashlib.sha256(src_code.encode("utf-8")).hexdigest()[:8] + if fused_name == "fused": + kernel_name = f"pallas_{kernel_hash}" + else: + kernel_name = f"pallas_{fused_name}_{kernel_hash}" + wrapper.src_to_kernel[src_code] = kernel_name + + # Replace placeholder if any + src_code = src_code.replace("", kernel_name) + + compile_wrapper = IndentedBuffer() + compile_wrapper.writeline(f"async_compile.pallas({kernel_name!r}, r'''") + compile_wrapper.splice(src_code, strip=True) + compile_wrapper.writeline("''')") + + origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper) + metadata_comment = f"{origins}\n{detailed_origins}" + wrapper.define_kernel(kernel_name, compile_wrapper.getvalue(), metadata_comment) + + return kernel_name diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 457f86fe7a77e..66eaf69dd59a8 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -1950,8 +1950,9 @@ class rocm: # Backend to use for CPU codegen either "cpp" or "triton" (experimental) or "halide" (experimental) cpu_backend: Literal["cpp", "triton", "halide"] = "cpp" -# Backend to use for CUDA codegen either "triton" or "halide" (experimental) -cuda_backend: Literal["triton", "halide"] = "triton" +# Backend to use for CUDA codegen either +# "triton", "halide" (experimental) or "pallas" (experimental) +cuda_backend: Literal["triton", "halide", "pallas"] = "triton" # Backend to use for XPU codegen either "triton" xpu_backend: Literal["triton"] = "triton" diff --git a/torch/testing/_internal/inductor_utils.py b/torch/testing/_internal/inductor_utils.py index bd11e01a80250..6bd34c812d641 100644 --- a/torch/testing/_internal/inductor_utils.py +++ b/torch/testing/_internal/inductor_utils.py @@ -33,6 +33,10 @@ OrderedSet, ) from torch.fx.experimental.proxy_tensor import make_fx +from torch.utils._helion import has_helion +from torch.utils._pallas import has_pallas +from torch.utils._triton import has_triton +from torch.utils._config_module import ConfigModule from torch.testing._internal.common_device_type import ( get_desired_device_type_test_bases, ) @@ -43,9 +47,6 @@ LazyVal, TestCase, ) -from torch.utils._config_module import ConfigModule -from torch.utils._helion import has_helion -from torch.utils._triton import has_triton log: logging.Logger = logging.getLogger(__name__) @@ -67,6 +68,8 @@ def test_cpu(): HAS_TRITON = has_triton() +HAS_PALLAS = has_pallas() + HAS_HELION = has_helion() if HAS_TRITON: diff --git a/torch/utils/_pallas.py b/torch/utils/_pallas.py new file mode 100644 index 0000000000000..25cc635dbb178 --- /dev/null +++ b/torch/utils/_pallas.py @@ -0,0 +1,82 @@ +import functools + +import torch + + +@functools.cache +def has_jax_package() -> bool: + """Check if JAX is installed.""" + try: + import jax # noqa: F401 # type: ignore[import-not-found] + + return True + except ImportError: + return False + + +@functools.cache +def has_pallas_package() -> bool: + """Check if Pallas (JAX experimental) is available.""" + if not has_jax_package(): + return False + try: + from jax.experimental import ( # noqa: F401 # type: ignore[import-not-found] + pallas as pl, + ) + + return True + except ImportError: + return False + + +@functools.cache +def get_jax_version(fallback: tuple[int, int, int] = (0, 0, 0)) -> tuple[int, int, int]: + """Get JAX version as (major, minor, patch) tuple.""" + try: + import jax # type: ignore[import-not-found] + + version_parts = jax.__version__.split(".") + major, minor, patch = (int(v) for v in version_parts[:3]) + return (major, minor, patch) + except (ImportError, ValueError, AttributeError): + return fallback + + +@functools.cache +def has_jax_cuda_backend() -> bool: + """Check if JAX has CUDA backend support.""" + if not has_jax_package(): + return False + try: + import jax # type: ignore[import-not-found] + + # Check if CUDA backend is available + devices = jax.devices("gpu") + return len(devices) > 0 + except Exception: + return False + + +@functools.cache +def has_pallas() -> bool: + """ + Check if Pallas backend is fully available for use. + + Requirements: + - JAX package installed + - Pallas (jax.experimental.pallas) available + - CUDA backend available (for GPU support) + """ + if not has_pallas_package(): + return False + + # Only enable Pallas if CUDA is available + # (Pallas primarily targets GPU workloads) + if not torch.cuda.is_available(): + return False + + # Check if JAX has GPU/CUDA backend + if not has_jax_cuda_backend(): + return False + + return True From 39160dba0c5120c65705a44e556c8c4af243e573 Mon Sep 17 00:00:00 2001 From: Bruce Chang Date: Wed, 5 Nov 2025 00:54:35 +0000 Subject: [PATCH 039/130] shrink_group implementation to expose ncclCommShrink API (#164518) Closes #164529 To expose the new [ncclCommShrink](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/comms.html#ncclcommshrink) API to PyTorch. This is useful when you need to exclude certain GPUs or nodes from a collective operation, for example in fault tolerance scenarios or when dynamically adjusting resource utilization. For more info: [Shrinking a communicator](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/communicators.html#shrinking-a-communicator) Pull Request resolved: https://github.com/pytorch/pytorch/pull/164518 Approved by: https://github.com/kwen2501 --- docs/source/distributed.md | 4 + test/distributed/test_c10d_nccl.py | 676 +++++++++++++++++- torch/csrc/distributed/c10d/Backend.hpp | 17 + torch/csrc/distributed/c10d/NCCLUtils.cpp | 59 ++ torch/csrc/distributed/c10d/NCCLUtils.hpp | 12 + .../distributed/c10d/ProcessGroupNCCL.cpp | 135 +++- .../distributed/c10d/ProcessGroupNCCL.hpp | 21 + torch/csrc/distributed/c10d/init.cpp | 11 + torch/distributed/distributed_c10d.py | 519 ++++++++++++++ torch/testing/_internal/common_distributed.py | 48 ++ 10 files changed, 1500 insertions(+), 2 deletions(-) diff --git a/docs/source/distributed.md b/docs/source/distributed.md index 1c9d374b8ab02..ca1fe3b5e9099 100644 --- a/docs/source/distributed.md +++ b/docs/source/distributed.md @@ -394,6 +394,10 @@ an opaque group handle that can be given as a `group` argument to all collective .. autofunction:: new_group ``` +```{eval-rst} +.. autofunction:: torch.distributed.distributed_c10d.shrink_group +``` + ```{eval-rst} .. autofunction:: get_group_rank ``` diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index c117bc810b115..cf53896187c20 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -2,6 +2,7 @@ import copy import json +import logging import os import pickle import random @@ -21,6 +22,7 @@ import torch import torch.distributed as c10d import torch.distributed._functional_collectives as _functional_collectives +from torch.distributed.distributed_c10d import SHRINK_ABORT as NCCL_SHRINK_ABORT if not c10d.is_available() or not c10d.is_nccl_available(): @@ -47,12 +49,15 @@ from torch.nn.parallel import DistributedDataParallel from torch.testing._internal.common_cuda import _get_torch_rocm_version, TEST_MULTIGPU from torch.testing._internal.common_distributed import ( + get_required_world_size, get_timeout, init_multigpu_helper, MultiProcessTestCase, requires_multicast_support, requires_nccl, + requires_nccl_shrink, requires_nccl_version, + requires_world_size, skip_if_lt_x_gpu, skip_if_rocm_multiprocess, sm_is_or_higher_than, @@ -88,6 +93,53 @@ ) +_start_time = time.time() +_logger = logging.getLogger(__name__) + + +def _ts(): + return time.time() - _start_time + + +def configure(level=logging.INFO, force=False): + try: + logging.basicConfig( + level=level, + format="%(asctime)s %(name)s %(levelname)s: %(message)s", + force=force, + ) + except TypeError: + logging.basicConfig( + level=level, format="%(asctime)s %(name)s %(levelname)s: %(message)s" + ) + + +def log_test_info(rank, message): + _logger.info("[%7.3fs][Rank %s] %s", _ts(), rank, message) + + +def log_test_success(rank, message): + _logger.info("[%7.3fs][Rank %s] ✅ %s", _ts(), rank, message) + + +def log_test_validation(rank, message): + _logger.info("[%7.3fs][Rank %s] ✓ %s", _ts(), rank, message) + + +def log_test_warning(rank, message): + _logger.warning("[%7.3fs][Rank %s] ⚠️ %s", _ts(), rank, message) + + +def log_test_error(rank, message): + _logger.error("[%7.3fs][Rank %s] ✗ %s", _ts(), rank, message) + + +_log_configure = configure + + +_log_configure(level=logging.INFO, force=True) + + class RendezvousEnvTest(TestCase): @retry_on_connect_failures @requires_nccl() @@ -317,7 +369,7 @@ def tearDown(self): @property def world_size(self): - return 2 + return get_required_world_size(self, 2) @property def rank_to_GPU(self): @@ -1255,6 +1307,628 @@ def test_set_process_group_desc(self): pg_2 = c10d.new_group([0, 1]) self.assertEqual(pg_2.group_desc, "undefined") + @requires_nccl_shrink() + @requires_world_size(2) + def test_shrink_group_basic(self): + """Test basic shrink_group functionality.""" + self._perform_shrink_test([1], "Basic shrink test") + + @requires_nccl_shrink() + @requires_world_size(2) + def test_shrink_group_validation(self): + """Test input validation in shrink_group.""" + device, pg = self._setup_shrink_test("validation") + + def _test_invalid_input(ranks, description, expected_exception): + """Helper to test invalid inputs.""" + try: + c10d.shrink_group(ranks) + self.fail(f"Expected {expected_exception.__name__} for {description}") + except expected_exception: + log_test_validation(self.rank, f"✓ {description}") + except Exception: + if expected_exception is Exception: # Accept any exception + log_test_validation(self.rank, f"✓ {description}") + else: + raise + + # Test cases + _test_invalid_input([], "Empty exclusion list", ValueError) + if self.world_size > 1: + _test_invalid_input([0, 0, 1], "Duplicate ranks", Exception) + _test_invalid_input([self.world_size + 1], "Out of bounds rank", Exception) + + log_test_success(self.rank, "All validation tests passed") + dist.destroy_process_group() + + @requires_nccl_shrink() + @requires_world_size(2) + def test_shrink_group_backend_properties(self): + """Test that backend properties are preserved after shrinking.""" + + test_name = "Backend Properties Test" + ranks_to_exclude = [0] + + # Reuse _setup_shrink_test for complete setup (device, environment, and process group) + device, pg = self._setup_shrink_test("backend_properties") + + # Follow _perform_shrink_test pattern from here + log_test_info(self.rank, f"{test_name} (world_size={self.world_size})") + + is_excluded = self.rank in ranks_to_exclude + log_test_info( + self.rank, + f"Excluding ranks: {ranks_to_exclude}, am_excluded: {is_excluded}", + ) + + # Store original backend property values (not references) before shrinking + original_timeout = None + original_high_priority = None + if not is_excluded: + original_backend = pg._get_backend(device) + original_timeout = original_backend.options._timeout + original_high_priority = original_backend.options.is_high_priority_stream + log_test_info( + self.rank, + f"Storing original backend properties: timeout={original_timeout}, high_priority={original_high_priority}", + ) + + if is_excluded: + log_test_info( + self.rank, + f"Excluded rank {self.rank} - setup complete, skipping shrink operation", + ) + dist.destroy_process_group() # hang without it + return + + # Only non-excluded ranks proceed with shrink (same as _perform_shrink_test) + log_test_info(self.rank, "Non-excluded rank calling shrink_group") + shrunk_pg = c10d.shrink_group(ranks_to_exclude) + + # Reuse _validate_shrunk_group helper (same as _perform_shrink_test) + expected_size = self.world_size - len(ranks_to_exclude) + _ = self._validate_shrunk_group(shrunk_pg, expected_size, test_name) + + # Add custom backend properties validation + new_backend = shrunk_pg._get_backend(device) + log_test_info(self.rank, "Validating backend properties are preserved") + + new_timeout = new_backend.options._timeout + new_high_priority = new_backend.options.is_high_priority_stream + + log_test_info( + self.rank, + f"Timeout comparison - original: {original_timeout}, new: {new_timeout}", + ) + self.assertEqual( + original_timeout, new_timeout, f"{test_name}: timeout not preserved" + ) + + log_test_info( + self.rank, + f"High priority stream comparison - original: {original_high_priority}, new: {new_high_priority}", + ) + self.assertEqual( + original_high_priority, + new_high_priority, + f"{test_name}: high_priority_stream not preserved", + ) + + log_test_validation( + self.rank, f"{test_name}: Backend properties preserved successfully" + ) + log_test_success( + self.rank, f"{test_name} successful (shrink + backend validation)" + ) + + # Cleanup (same as _perform_shrink_test) + dist.destroy_process_group() + + @requires_nccl_shrink() + @requires_world_size(2) + def test_shrink_group_multiple_comms(self): + """Test shrink_group with multiple communicators and subgroup invalidation.""" + + device, pg = self._setup_shrink_test("multiple_comms") + + # Create subgroup [0, 1] and test shrinking it + subgroup = c10d.new_group([0, 1]) + if self.rank <= 1: + # Shrink subgroup: exclude rank 1 + if self.rank == 0: # Only rank 0 remains + shrunk_subgroup = c10d.shrink_group([1], group=subgroup) + self.assertEqual(shrunk_subgroup.size(), 1) + # Test communication on shrunk subgroup + tensor = torch.full((1,), self.rank).cuda(device) + c10d.all_reduce(tensor, group=shrunk_subgroup) + self.assertEqual(tensor.item(), 0) # Only rank 0 + log_test_success(self.rank, "Subgroup shrinking successful") + + dist.barrier() # Sync before default group test + + # Shrink default group: exclude last rank + ranks_to_exclude = [self.world_size - 1] + if self.rank not in ranks_to_exclude: + shrunk_default = c10d.shrink_group(ranks_to_exclude) + expected_size = self.world_size - 1 + self.assertEqual(shrunk_default.size(), expected_size) + + # Test collective on shrunk default group + tensor = torch.full((1,), self.rank).cuda(device) + c10d.all_reduce(tensor, group=shrunk_default) + expected_sum = sum( + range(self.world_size - 1) + ) # 0 + 1 + ... + (world_size-2) + self.assertEqual(tensor.item(), expected_sum) + log_test_success(self.rank, "Default group shrinking successful") + + # Note: After shrinking default group, the old subgroup is invalid + # due to global rank reassignment + + dist.destroy_process_group() + + def _test_shrink_group_with_flag(self, shrink_flag, flag_name, rank_to_exclude): + """Helper method to test shrink_group with a specific flag.""" + if self.world_size < 2: + log_test_info(self.rank, f"Skipping (needs ≥2 GPUs, got {self.world_size})") + return + ranks_to_exclude = [rank_to_exclude] + log_test_info(self.rank, f"Using {flag_name} flag (value: {shrink_flag})") + if flag_name == "NCCL_SHRINK_ABORT": + log_test_info( + self.rank, + "ABORT flag will terminate ongoing operations before shrinking", + ) + + self._perform_shrink_test( + ranks_to_exclude, f"{flag_name} flag test", shrink_flags=shrink_flag + ) + + @requires_nccl_shrink() + @requires_world_size(2) + def test_shrink_group_flags(self): + """Test shrink_group with different shrink flags.""" + # Test ABORT flags + log_test_info(self.rank, "Testing NCCL_SHRINK_ABORT flag") + self._test_shrink_group_with_flag(NCCL_SHRINK_ABORT, "NCCL_SHRINK_ABORT", 1) + + @requires_nccl_shrink() + @requires_world_size(2) + def test_shrink_group_nccl_config(self): + """Verify that passing NCCL config via pg_options influences the shrunk group's backend options.""" + device, pg = self._setup_shrink_test("config") + if self.rank == self.world_size - 1: + # excluded rank should not call shrink_group + dist.destroy_process_group() + return + + # Prepare pg_options with NCCL config overrides + # Capture parent's current backend options to ensure we can prove override vs inherit + parent_backend = pg._get_backend(torch.device("cuda")) + parent_hp = parent_backend.options.is_high_priority_stream + parent_blocking = parent_backend.options.config.blocking + + # Choose overrides that differ from the parent (flip where possible) + override_hp = not parent_hp + if parent_blocking in (0, 1): + override_blocking = 1 - parent_blocking + else: + # If undefined or unexpected, set to 1 which is a concrete value + override_blocking = 1 + + opts = c10d.ProcessGroupNCCL.Options() + opts.is_high_priority_stream = override_hp + opts.config.blocking = override_blocking + + shrunk_pg = c10d.shrink_group([self.world_size - 1], pg_options=opts) + + # Validate backend options propagated + backend = shrunk_pg._get_backend(torch.device("cuda")) + # is_high_priority_stream should exactly match our override and differ from parent + self.assertEqual(backend.options.is_high_priority_stream, override_hp) + self.assertNotEqual(backend.options.is_high_priority_stream, parent_hp) + # config is a struct; check representative field and difference from parent when meaningful + self.assertEqual(backend.options.config.blocking, override_blocking) + if parent_blocking in (0, 1): + self.assertNotEqual(backend.options.config.blocking, parent_blocking) + + dist.destroy_process_group() + + @requires_nccl_shrink() + @requires_world_size(2) + def test_shrink_group_performance(self): + """Test shrink_group performance and regression detection.""" + import time + + ranks_to_exclude = self._get_default_ranks_to_exclude() + is_excluded = self.rank in ranks_to_exclude + + if not ranks_to_exclude: + log_test_info(self.rank, "Skipping performance test (world_size=1)") + return + + log_test_info(self.rank, f"Performance test with {self.world_size} processes") + device, pg = self._setup_shrink_test("performance") + + if not is_excluded: + log_test_info(self.rank, "Measuring shrink_group performance") + start_time = time.time() + shrunk_pg = c10d.shrink_group(ranks_to_exclude) + end_time = time.time() + + elapsed_time = end_time - start_time + log_test_info(self.rank, f"shrink_group: {elapsed_time:.3f}s") + + # Regression check: should complete within reasonable time + self.assertLess( + elapsed_time, + 30.0, + f"shrink_group took {elapsed_time:.3f}s, possible regression", + ) + + # Test collective performance + expected_size = self.world_size - len(ranks_to_exclude) + self._validate_shrunk_group(shrunk_pg, expected_size, "performance") + + collective_start = time.time() + _ = self._test_collective_on_shrunk_group( + shrunk_pg, device, ranks_to_exclude, "performance" + ) + collective_time = time.time() - collective_start + + log_test_info(self.rank, f"all_reduce: {collective_time:.3f}s") + log_test_success(self.rank, "Performance test passed") + else: + log_test_info(self.rank, "Excluded rank - waiting") + + dist.destroy_process_group() + + @requires_nccl_shrink() + @requires_world_size(4) + def test_shrink_group_multiple_exclusions(self): + """Test shrink_group with multiple ranks excluded at once.""" + # Scale exclusions with world size + ranks_to_exclude = list(range(2, self.world_size, 2)) # Every other rank from 2 + + self._perform_shrink_test(ranks_to_exclude, "Multiple exclusions test") + + @requires_nccl_shrink() + @requires_world_size(3) + def test_shrink_group_multiple_iterations(self): + """Test multiple shrink operations in sequence.""" + log_test_info( + self.rank, + f"Starting test_shrink_group_multiple_iterations with world_size={self.world_size}", + ) + + store = c10d.FileStore(self.file_name, self.world_size) + device = torch.device(f"cuda:{self.rank}") + _ = self._create_process_group_nccl(store, self.opts(), device_id=device) + + # Track current effective world size throughout shrinking operations + current_world_size = self.world_size + log_test_info(self.rank, f"Initial world_size: {current_world_size}") + + # First shrinking: exclude the last rank(s) + first_exclusion = [self.world_size - 1] + if self.world_size >= 6: + first_exclusion.append( + self.world_size - 2 + ) # Exclude last two ranks for larger sizes + + log_test_info(self.rank, f"First shrinking: excluding ranks {first_exclusion}") + + if self.rank not in first_exclusion: + # Only non-excluded ranks should call shrink_group + first_pg = c10d.shrink_group(first_exclusion) + self.assertIsNotNone(first_pg) + # IMPORTANT: Update world size after first shrinking + current_world_size = first_pg.size() + expected_first_size = self.world_size - len(first_exclusion) + log_test_info( + self.rank, + f"After first shrinking: world_size {self.world_size} -> {current_world_size}", + ) + self.assertEqual(first_pg.size(), expected_first_size) + + # Second shrinking: exclude another rank from the remaining group + # Choose a rank that's in the middle range + if current_world_size >= 3: + second_exclusion = [ + current_world_size - 1 + ] # Exclude the new "last" rank + log_test_info( + self.rank, + f"Second shrinking from group of size {current_world_size}: excluding ranks {second_exclusion}", + ) + + if self.rank not in second_exclusion: + # Only non-excluded ranks should call shrink_group for second iteration + second_pg = c10d.shrink_group(second_exclusion, group=first_pg) + self.assertIsNotNone(second_pg) + # IMPORTANT: Update world size after second shrinking + final_world_size = second_pg.size() + expected_final_size = current_world_size - len(second_exclusion) + log_test_info( + self.rank, + f"After second shrinking: world_size {current_world_size} -> {final_world_size}", + ) + self.assertEqual(second_pg.size(), expected_final_size) + + # Test collective on final group + tensor = torch.full((1,), self.rank).cuda(device) + log_test_info( + self.rank, + f"Performing all_reduce on final group (size {final_world_size}) with tensor: {tensor.item()}", + ) + c10d.all_reduce(tensor, group=second_pg) + log_test_info( + self.rank, + f"Final all_reduce completed, result: {tensor.item()}", + ) + + # Calculate expected sum of remaining ranks + all_excluded = set(first_exclusion + second_exclusion) + remaining_ranks = [ + r for r in range(self.world_size) if r not in all_excluded + ] + expected_sum = sum(remaining_ranks) + log_test_info( + self.rank, + f"Remaining ranks: {remaining_ranks}, expected sum: {expected_sum}, actual: {tensor.item()}", + ) + self.assertEqual(tensor.item(), expected_sum) + log_test_info(self.rank, "Final verification passed") + else: + log_test_info( + self.rank, + "This rank excluded in second shrinking, not calling shrink_group", + ) + else: + log_test_info( + self.rank, "Skipping second shrinking (remaining group too small)" + ) + else: + log_test_info( + self.rank, + "This rank excluded in first shrinking, not calling shrink_group", + ) + + log_test_info(self.rank, "Destroying process group") + dist.destroy_process_group() + log_test_info(self.rank, "test_shrink_group_multiple_iterations completed") + + # Helper methods for optimized shrink group tests + def _setup_shrink_test(self, test_suffix, world_size=None, warmup=True): + """Common setup for shrink group tests.""" + os.environ["TORCH_NCCL_USE_COMM_NONBLOCKING"] = "1" + world_size = world_size or self.world_size + store = c10d.FileStore(self.file_name + f"_{test_suffix}", world_size) + device = torch.device(f"cuda:{self.rank}") + c10d.init_process_group( + "nccl", + world_size=world_size, + rank=self.rank, + store=store, + pg_options=self.opts(), + device_id=device, + ) + pg = c10d.distributed_c10d._get_default_group() + + if warmup: + c10d.all_reduce(torch.ones(1).cuda(device), group=pg) + + return device, pg + + def _validate_shrunk_group(self, shrunk_pg, expected_size, test_name=""): + """Validate properties of a shrunk process group.""" + self.assertIsNotNone(shrunk_pg, f"{test_name}: shrunk_pg should not be None") + actual_size = shrunk_pg.size() + self.assertEqual( + actual_size, expected_size, f"{test_name}: group size mismatch" + ) + + new_rank = shrunk_pg.rank() + self.assertTrue( + 0 <= new_rank < expected_size, f"{test_name}: invalid new rank {new_rank}" + ) + + log_test_info( + self.rank, + f"{test_name}: world_size {self.world_size} -> {actual_size}, rank {self.rank} -> {new_rank}", + ) + return new_rank + + def _test_collective_on_shrunk_group( + self, shrunk_pg, device, ranks_to_exclude, test_name="" + ): + """Test collective communication on shrunk group and verify correctness.""" + test_tensor = torch.full((1,), self.rank, device=device, dtype=torch.float32) + c10d.all_reduce(test_tensor, group=shrunk_pg) + + result = test_tensor.item() + expected_sum = sum( + r for r in range(self.world_size) if r not in ranks_to_exclude + ) + + self.assertEqual( + result, expected_sum, f"{test_name}: collective result mismatch" + ) + log_test_info( + self.rank, f"{test_name}: collective passed ({result} == {expected_sum})" + ) + return result + + def _perform_shrink_test( + self, ranks_to_exclude, test_name, shrink_flags=0, with_collective=True + ): + """Complete shrink test flow: setup, shrink, validate, test collective, cleanup. + + Consistent API: All ranks perform setup to initialize distributed environment. + ONLY non-excluded ranks call shrink_group() for both default and non-default groups. + Excluded ranks perform setup, then exit without calling shrink_group() or waiting. + """ + log_test_info(self.rank, f"{test_name} (world_size={self.world_size})") + + is_excluded = self.rank in ranks_to_exclude + log_test_info( + self.rank, + f"Excluding ranks: {ranks_to_exclude}, am_excluded: {is_excluded}", + ) + + # All ranks (including excluded ones) perform setup to initialize distributed environment + device, pg = self._setup_shrink_test(test_name.lower().replace(" ", "_")) + is_default_group = pg == c10d.distributed_c10d._get_default_group() + + if is_excluded: + log_test_info( + self.rank, + f"Excluded rank {self.rank} - setup complete, skipping shrink operation", + ) + if shrink_flags & NCCL_SHRINK_ABORT: + log_test_info(self.rank, f"Using abort for excluded rank {self.rank}") + pg._get_backend(torch.device(device)).abort() + log_test_info( + self.rank, f"cleanup resources for excluded rank {self.rank}" + ) + dist.destroy_process_group() + log_test_info(self.rank, f"Excluded rank {self.rank} - exit") + else: + log_test_info( + self.rank, f"Using regular destroy for excluded rank {self.rank}" + ) + dist.destroy_process_group() + return None + + # Only non-excluded ranks proceed with shrink + log_test_info( + self.rank, + f"Non-excluded rank calling shrink_group (default_group={is_default_group})", + ) + shrunk_pg = c10d.shrink_group(ranks_to_exclude, shrink_flags=shrink_flags) + log_test_info( + self.rank, + f"Non-excluded rank calling shrink_group (default_group={is_default_group}) done", + ) + + # Non-excluded ranks: validate and test the new group + expected_size = self.world_size - len(ranks_to_exclude) + _ = self._validate_shrunk_group(shrunk_pg, expected_size, test_name) + + if with_collective: + _ = self._test_collective_on_shrunk_group( + shrunk_pg, device, ranks_to_exclude, test_name + ) + log_test_success(self.rank, f"{test_name} successful (shrink + collective)") + else: + log_test_success(self.rank, f"{test_name} successful (shrink only)") + + dist.destroy_process_group() + return shrunk_pg + + def _get_default_ranks_to_exclude(self): + """Get default ranks to exclude based on world size.""" + if self.world_size <= 1: + return [] + return [self.world_size - 1] # Exclude last rank by default + + @requires_nccl_shrink() + @requires_world_size(3) + def test_shrink_group_vs_abort_reinit_performance(self): + """Compare performance of shrink_group vs traditional abort+reinit (simplified for reliability).""" + log_test_info(self.rank, "=== TEST 1: abort+reinit ===") + + device, pg1 = self._setup_shrink_test("_perf_reinit") + torch.cuda.synchronize(device) + + # Test 1: Traditional abort + reinit + start_time = time.perf_counter() + dist.destroy_process_group() + + device, new_pg = self._setup_shrink_test("perf_shrink_test1") + reinit_time = time.perf_counter() - start_time + + # Test collective with original rank values for fair comparison (non-blocking mode) + test_tensor = torch.full((1,), self.rank, device=device, dtype=torch.float32) + work = c10d.all_reduce(test_tensor, group=new_pg, async_op=True) + work.wait() + + torch.cuda.synchronize(device) + + # Verify correctness + expected_sum = sum(r for r in range(self.world_size)) + self.assertEqual(test_tensor.item(), expected_sum, "Reinit collective failed") + + log_test_info(self.rank, f"abort+reinit: {reinit_time:.4f}s") + dist.destroy_process_group(new_pg) + + # Test 2: shrink_group with NCCL_SHRINK_ABORT + log_test_info(self.rank, "=== TEST 2: shrink_group ===") + + ranks_to_exclude = [self.world_size - 1] + is_excluded = self.rank in ranks_to_exclude + log_test_info( + self.rank, + f"Excluding ranks: {ranks_to_exclude}, am_excluded: {is_excluded}", + ) + + device, pg1 = self._setup_shrink_test("perf_shrink_test2") # Unique suffix + + shrink_time = 0 + if not is_excluded: + torch.cuda.synchronize(device) # Ensure accurate timing + start_time = time.perf_counter() + shrunk_pg = c10d.shrink_group( + ranks_to_exclude, shrink_flags=NCCL_SHRINK_ABORT + ) + c10d.all_reduce(torch.ones(1).cuda(device), group=shrunk_pg) + shrink_time = time.perf_counter() - start_time + + # Test collective communication on shrunk group (non-blocking mode) + test_tensor = torch.full( + (1,), self.rank, device=device, dtype=torch.float32 + ) + work = c10d.all_reduce(test_tensor, group=shrunk_pg, async_op=True) + work.wait() + + # Verify correctness + expected_sum = sum( + r for r in range(self.world_size) if r not in ranks_to_exclude + ) + self.assertEqual( + test_tensor.item(), + expected_sum, + "shrink_test: collective result mismatch", + ) + + torch.cuda.synchronize(device) # Ensure operations complete + log_test_info(self.rank, f"shrink_group: {shrink_time:.4f}s") + dist.destroy_process_group() + else: + log_test_info(self.rank, "Excluded from shrink test - exiting immediately") + dist.destroy_process_group() + return + + # Performance analysis (only for participating ranks) + if shrink_time > 0 and reinit_time > 0: + speedup = reinit_time / shrink_time + time_saved = reinit_time - shrink_time + + log_test_info(self.rank, "=== PERFORMANCE RESULTS ===") + log_test_info(self.rank, f"shrink_group: {shrink_time:.4f}s") + log_test_info(self.rank, f"abort+reinit: {reinit_time:.4f}s") + log_test_info(self.rank, f"time_saved: {time_saved:+.4f}s") + log_test_info(self.rank, f"speedup: {speedup:.2f}x") + + if speedup > 1.1: + log_test_success(self.rank, "shrink_group significantly faster") + elif speedup > 0.9: + log_test_info(self.rank, "≈ comparable performance") + else: + log_test_warning(self.rank, "abort+reinit faster") + + log_test_info(self.rank, "Performance test completed") + @requires_nccl() @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") def test_deterministic_mode_no_break(self): diff --git a/torch/csrc/distributed/c10d/Backend.hpp b/torch/csrc/distributed/c10d/Backend.hpp index 6ffa1529a4de0..72e35e3fc9dd3 100644 --- a/torch/csrc/distributed/c10d/Backend.hpp +++ b/torch/csrc/distributed/c10d/Backend.hpp @@ -79,6 +79,23 @@ class TORCH_API Backend : public torch::CustomClassHolder { return false; } + virtual bool supportsShrinking() const { + return false; + } + + // Shrink the backend by excluding specified ranks. Backends that support + // communicator shrinking should override this and return a new backend + // instance representing the shrunken group. Backends may use opts_override + // to supply backend-specific options for the new group. + virtual c10::intrusive_ptr shrink( + const std::vector& /*ranks_to_exclude*/, + int /*shrink_flags*/ = 0, + const c10::intrusive_ptr& /*opts_override*/ = nullptr) { + TORCH_CHECK( + false, + c10::str("Backend ", getBackendName(), " does not support shrink")); + } + virtual void setTimeout(std::chrono::milliseconds timeout) { TORCH_CHECK( false, diff --git a/torch/csrc/distributed/c10d/NCCLUtils.cpp b/torch/csrc/distributed/c10d/NCCLUtils.cpp index 8074cc98a04f1..a41f654b9ae20 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.cpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.cpp @@ -259,6 +259,65 @@ std::shared_ptr NCCLComm::split( } #endif +#ifdef NCCL_HAS_COMM_SHRINK +std::shared_ptr NCCLComm::shrink( + NCCLComm* source, + std::vector& ranks_to_exclude, + ncclConfig_t* config, + int shrinkFlags) { + // Preconditions are validated in ProcessGroupNCCL::shrink + + LOG(INFO) << "Rank " << source->rank_ << ": shrinking comm " << source->repr() + << " excluding " << ranks_to_exclude.size() << " ranks"; + + at::cuda::OptionalCUDAGuard gpuGuard(source->deviceIndex_); + auto comm = std::make_shared(); + + // This call will block until the source communicator is initialized + auto sourceComm = source->getNcclComm(); + + C10D_NCCL_CHECK_NONBLOCKING( + ncclCommShrink( + sourceComm, + ranks_to_exclude.data(), + ranks_to_exclude.size(), + reinterpret_cast(&(comm->ncclComm_)), + config, + shrinkFlags), + source->getNcclCommFailureReason()); + + // Wait for the child communicator to be ready + source->waitReady(true); + comm->initialized_ = true; + + // NCCL automatically assigns rank during shrink - query it efficiently + int assigned_rank; + try { + C10D_NCCL_CHECK( + ncclCommUserRank(comm->ncclComm_, &assigned_rank), std::nullopt); + comm->rank_ = assigned_rank; + } catch (const std::exception& e) { + // Fallback: if ncclCommUserRank fails, we can't determine the rank + LOG(ERROR) << "Failed to query NCCL-assigned rank: " << e.what(); + throw; + } + + // Child comm should be on the same device as parent comm + comm->deviceIndex_ = source->deviceIndex_; + if (config != nullptr) { + comm->nonBlocking_ = config->blocking == 0; + } else { + // Inherit parent behavior if no config provided + comm->nonBlocking_ = source->nonBlocking_; + } + + LOG(INFO) << "Rank " << source->rank_ << ": created shrunken comm " + << comm->repr() << " with NCCL-assigned rank " << assigned_rank; + + return comm; +} +#endif + void NCCLComm::finalize() { LockType lock(mutex_); if (aborted_) { diff --git a/torch/csrc/distributed/c10d/NCCLUtils.hpp b/torch/csrc/distributed/c10d/NCCLUtils.hpp index fdd50f69ef3d7..142633b823744 100644 --- a/torch/csrc/distributed/c10d/NCCLUtils.hpp +++ b/torch/csrc/distributed/c10d/NCCLUtils.hpp @@ -90,6 +90,10 @@ static_assert( #define NCCL_HAS_NVLS_CTAS #endif +#if NCCL_VERSION_CODE >= NCCL_VERSION(2, 27, 0) +#define NCCL_HAS_COMM_SHRINK +#endif + // Macro to throw on a non-successful NCCL return value. #define C10D_NCCL_CHECK(cmd, failureReason) \ do { \ @@ -294,6 +298,14 @@ class NCCLComm { ncclConfig_t& config); #endif // NCCL_HAS_COMM_SPLIT +#ifdef NCCL_HAS_COMM_SHRINK + static std::shared_ptr shrink( + NCCLComm* source, + std::vector& ranks_to_exclude, + ncclConfig_t* config, + int shrinkFlags = 0); +#endif // NCCL_HAS_COMM_SHRINK + #if (defined(IS_NCCLX) || defined(USE_ROCM)) && defined(NCCL_COMM_DUMP) std::unordered_map ncclCommDump(); #endif diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index fd7f0b4246517..d051803aa7376 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -165,7 +165,7 @@ ncclRedOpRAII getNcclReduceOp( } // Get a key string from device -inline std::string getKeyFromDevice(at::Device& device) { +inline std::string getKeyFromDevice(const at::Device& device) { return std::to_string(device.index()); } @@ -5842,6 +5842,139 @@ at::Tensor ProcessGroupNCCL::allocateTensor( return tensor; } +#ifdef NCCL_HAS_COMM_SHRINK +c10::intrusive_ptr ProcessGroupNCCL::shrink( + const std::vector& ranks_to_exclude, + int shrink_flags, + const c10::intrusive_ptr& opts_override) { + // Runtime version check with better error message + auto runtime_version = torch::cuda::nccl::version(); + TORCH_CHECK( + runtime_version >= NCCL_VERSION(2, 27, 0), + "ProcessGroupNCCL::shrink requires NCCL version 2.27.0 or later. " + "Found version: ", + runtime_version); + + // Early validation with detailed error messages + TORCH_CHECK_VALUE( + !ranks_to_exclude.empty(), "ranks_to_exclude cannot be empty"); + TORCH_CHECK_VALUE( + static_cast(ranks_to_exclude.size()) < size_, + "Cannot exclude all ranks (", + ranks_to_exclude.size(), + " >= ", + size_, + ")"); + + // Validate ranks and convert to int efficiently + std::vector int_ranks_to_exclude; + int_ranks_to_exclude.reserve(ranks_to_exclude.size()); + for (int64_t rank : ranks_to_exclude) { + TORCH_CHECK_VALUE( + rank >= 0 && rank < size_, + "Invalid rank ", + rank, + " for group size ", + size_); + int_ranks_to_exclude.push_back(static_cast(rank)); + } + + // Get primary communicator with better error context + auto primary_device_index = guessDeviceId(); + auto primary_device = at::Device(at::kCUDA, primary_device_index); + const auto primary_key = getKeyFromDevice(primary_device); + + std::shared_ptr primary_comm = getNCCLComm(primary_key); + TORCH_CHECK( + primary_comm, + "Primary NCCL communicator for device ", + primary_device, + " (key: ", + primary_key, + ") is not initialized"); + + // Cache device index before shrink operation + at::DeviceIndex parent_device_index = primary_comm->getDeviceIndex(); + + ncclConfig_t* config = nullptr; + // Default to inheriting from parent options + bool high_priority_stream = options_->is_high_priority_stream; + if (opts_override) { + auto nccl_opts = + c10::static_intrusive_pointer_cast( + opts_override); + config = &nccl_opts->config; + // If user provided override options, honor is_high_priority_stream as well + high_priority_stream = nccl_opts->is_high_priority_stream; + } + + std::shared_ptr shrunk_comm = NCCLComm::shrink( + primary_comm.get(), + int_ranks_to_exclude, + (config != nullptr ? config : &options_->config), + shrink_flags); + + // Calculate new size and get NCCL-assigned rank + int new_size = size_ - static_cast(ranks_to_exclude.size()); + int new_rank = shrunk_comm->rank_; + + // Create new ProcessGroupNCCL with optimized options cloning + auto new_store = store_->clone(); + auto new_opts = ProcessGroupNCCL::Options::create(high_priority_stream); + new_opts->timeout = options_->timeout; + if (config != nullptr) { + new_opts->config = *config; + } else { + new_opts->config = options_->config; + } + + auto new_pg = c10::make_intrusive( + new_store, new_rank, new_size, new_opts); + + // Set up the new process group with optimized device setup + new_pg->initializeDeviceStateForComm( + at::Device(at::kCUDA, parent_device_index), shrunk_comm); + + return c10::static_intrusive_pointer_cast(new_pg); +} + +#else // !NCCL_HAS_COMM_SHRINK +// Backend interface override: raise consistent error when shrink is +// unsupported. +c10::intrusive_ptr ProcessGroupNCCL::shrink( + const std::vector& /*ranks_to_exclude*/, + int /*shrink_flags*/, + const c10::intrusive_ptr& /*opts_override*/) { + TORCH_CHECK( + false, + "ProcessGroupNCCL::shrink requires NCCL version 2.27.0 or later, " + "but PyTorch was built with an older version or without NCCL shrink support."); +} + +#endif // NCCL_HAS_COMM_SHRINK + +void ProcessGroupNCCL::initializeDeviceStateForComm( + const at::Device& device, + std::shared_ptr comm) { + const auto key = getKeyFromDevice(device); + std::unique_lock lock(mutex_); + at::cuda::OptionalCUDAGuard gpuGuard(device); + + bool force_high = getCvarBool(TORCH_NCCL_HIGH_PRIORITY, false); + auto stream = at::cuda::getStreamFromPool( + options_->is_high_priority_stream || force_high); + + devNCCLCommMap_[key] = comm; + ncclStreams_.emplace(key, stream); + ncclEvents_.emplace(key, at::cuda::CUDAEvent(cudaEventDisableTiming)); + usedDeviceIdxs_.insert(device.index()); + + if (shouldAllCommunicatorsRegisterAllTensors()) { + std::lock_guard map_lock(ncclCommMemPoolMapMutex); + ncclCommMemPoolMap.emplace(std::move(comm), MemPoolSet{}); + } +} + } // namespace c10d #endif // USE_C10D_NCCL diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp index 286eab14d1a86..2ead1a107394d 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp @@ -997,6 +997,21 @@ class TORCH_API ProcessGroupNCCL : public Backend { ErrorType getError() override; + bool supportsShrinking() const override { +#ifdef NCCL_HAS_COMM_SHRINK + return true; +#else + return false; +#endif + } + + // Backend-style shrink override that returns a Backend instance. + c10::intrusive_ptr shrink( + const std::vector& ranks_to_exclude, + int shrink_flags = 0, + const c10::intrusive_ptr& opts_override = + nullptr) override; + std::shared_ptr getMemAllocator() override; // Allocate tensor from communication-optimized memory pool @@ -1065,6 +1080,12 @@ class TORCH_API ProcessGroupNCCL : public Backend { int p2pRank = 0, bool isSendRecvSelf = false); + // Initialize device-specific state (comm, stream, event, bookkeeping) for a + // given communicator on this process group instance. + void initializeDeviceStateForComm( + const at::Device& device, + std::shared_ptr comm); + // Wrapper method which can be overridden for tests. virtual std::exception_ptr checkForNCCLErrors( std::shared_ptr& ncclComm); diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index a6c6c6f8c4744..4c6bdbe2ce70f 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -2734,12 +2734,23 @@ The hook must have the following signature: "supports_time_estimate", &::c10d::Backend::supportsTimeEstimation, "(test whether the backend supports collective time estimation)") + .def_property_readonly( + "supports_shrinking", + &::c10d::Backend::supportsShrinking, + "(test whether the backend supports communicator shrinking)") .def( "set_timeout", &::c10d::Backend::setTimeout, py::arg("timeout"), py::call_guard(), R"(Sets the default timeout for all future operations.)") + .def( + "shrink", + &::c10d::Backend::shrink, + py::arg("ranks_to_exclude"), + py::arg("shrink_flags") = 0, + py::arg("opts_override") = nullptr, + py::call_guard()) .def( "broadcast", &::c10d::Backend::broadcast, diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index bc79408a32ff9..9e4ec1483e960 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -130,6 +130,7 @@ "reduce_scatter_tensor", "get_node_local_rank", "split_group", + "shrink_group", ] _MPI_AVAILABLE = True @@ -5753,3 +5754,521 @@ def _get_process_group_name(pg: ProcessGroup) -> str: def _get_process_group_store(pg: ProcessGroup) -> Store: return _world.pg_map[pg][1] + + +# Shrink flags for process group backends +SHRINK_DEFAULT = 0x00 +SHRINK_ABORT = 0x01 + + +@_time_logger +def shrink_group( + ranks_to_exclude: list[int], + group: Optional[ProcessGroup] = None, + shrink_flags: int = SHRINK_DEFAULT, + pg_options: Optional[Any] = None, +) -> ProcessGroup: + """ + Shrinks a process group by excluding specified ranks. + + Creates and returns a new, smaller process group comprising only the ranks + from the original group that were not in the ``ranks_to_exclude`` list. + + Args: + ranks_to_exclude (List[int]): A list of ranks from the original + ``group`` to exclude from the new group. + group (ProcessGroup, optional): The process group to shrink. If ``None``, + the default process group is used. Defaults to ``None``. + shrink_flags (int, optional): Flags to control the shrinking behavior. + Can be ``SHRINK_DEFAULT`` (default) or ``SHRINK_ABORT``. + ``SHRINK_ABORT`` will attempt to terminate ongoing operations + in the parent communicator before shrinking. + Defaults to ``SHRINK_DEFAULT``. + pg_options (ProcessGroupOptions, optional): Backend-specific options to apply + to the shrunken process group. If provided, the backend will use + these options when creating the new group. If omitted, the new group + inherits defaults from the parent. + + Returns: + ProcessGroup: a new group comprised of the remaining ranks. If the + default group was shrunk, the returned group becomes the new default group. + + Raises: + TypeError: if the group’s backend does not support shrinking. + ValueError: if ``ranks_to_exclude`` is invalid (empty, out of bounds, + duplicates, or excludes all ranks). + RuntimeError: if an excluded rank calls this function or the backend + fails the operation. + + Notes: + - Only non-excluded ranks should call this function; excluded ranks + must not participate in the shrink operation. + - Shrinking the default group destroys all other process groups since + rank reassignment makes them inconsistent. + """ + # Step 1: Validate input parameters with comprehensive error checking + _validate_shrink_inputs(ranks_to_exclude, shrink_flags) + + # Step 2: Get target group and essential properties + target_group_info = _prepare_shrink_target_group(group) + + # Step 3: Validate backend requirements and availability + backend_impl = _validate_shrink_backend_requirements(target_group_info) + + # Step 4: Validate ranks against group and check for duplicates + excluded_ranks_set = _validate_and_process_excluded_ranks( + ranks_to_exclude, target_group_info + ) + + # Step 5: Execute the actual shrink operation (backend-specific) + new_backend = backend_impl.shrink( + sorted(excluded_ranks_set), + shrink_flags, + pg_options if pg_options is not None else None, + ) + + # Step 6: Handle cleanup and creation of new process group + target_group_info["pg_options_override"] = pg_options + return _finalize_shrunk_group(target_group_info, excluded_ranks_set, new_backend) + + +def _validate_shrink_inputs(ranks_to_exclude: list[int], shrink_flags: int) -> None: + """Validate input parameters for shrink_group.""" + if not isinstance(ranks_to_exclude, list): + raise TypeError( + f"ranks_to_exclude must be a list, but got {type(ranks_to_exclude).__name__}. " + f"Example: [1, 3, 5] to exclude ranks 1, 3, and 5." + ) + + if not ranks_to_exclude: + raise ValueError( + "ranks_to_exclude cannot be empty. To shrink a group, you must specify at least " + "one rank to exclude. Example: [failed_rank_id]" + ) + + # Validate shrink_flags with clear explanation of valid values + valid_flags = [SHRINK_DEFAULT, SHRINK_ABORT] + if not isinstance(shrink_flags, int) or shrink_flags not in valid_flags: + raise ValueError( + f"Invalid shrink_flags value: {shrink_flags}. Must be one of: " + f"SHRINK_DEFAULT ({SHRINK_DEFAULT}) or SHRINK_ABORT ({SHRINK_ABORT}). " + f"Use SHRINK_ABORT to abort ongoing operations before shrinking." + ) + + +def _prepare_shrink_target_group(group: Optional[ProcessGroup]) -> dict: + """Prepare and validate the target group for shrinking.""" + target_pg = group if group is not None else _get_default_group() + + # Cache frequently accessed properties to avoid repeated calls + group_size = int(target_pg.size()) + group_info = { + "process_group": target_pg, + "is_default_group": (target_pg == _get_default_group()), + "group_size": group_size, + "current_rank": target_pg.rank(), + "group_name": _get_process_group_name(target_pg), + } + + # Validate that we have a valid process group + if group_size <= 1: + raise ValueError( + f"Cannot shrink a process group with size {group_size}. " + f"Group must have at least 2 ranks to support shrinking." + ) + + return group_info + + +def _validate_shrink_backend_requirements(group_info: dict) -> Any: + """Return the backend implementation for the target group or raise if unsupported.""" + target_pg = group_info["process_group"] + group_name = group_info["group_name"] + + # Get the group's backend directly via ProcessGroup API. Prefer a bound device if present, + # otherwise try CUDA then fall back to CPU. + try: + preferred_device = getattr(target_pg, "bound_device_id", None) + if preferred_device is not None: + backend_impl = target_pg._get_backend(preferred_device) + else: + # Try CUDA first if available, else CPU + try: + backend_impl = target_pg._get_backend(torch.device("cuda")) + except Exception: + backend_impl = target_pg._get_backend(torch.device("cpu")) + except RuntimeError as e: + raise RuntimeError( + f"Cannot access device backend for process group '{group_name}'. " + f"Ensure the process group was initialized with a compatible device backend and devices are available." + ) from e + + try: + supports = bool(backend_impl.supports_shrinking) + except Exception: + supports = False + if not supports: + raise TypeError( + f"Process group backend for '{group_name}' does not support shrinking operations." + ) + + return backend_impl + + +def _validate_and_process_excluded_ranks( + ranks_to_exclude: list[int], group_info: dict +) -> set: + """Validate excluded ranks and convert to set for efficient operations.""" + group_size = group_info["group_size"] + current_rank = group_info["current_rank"] + + # Use set for O(1) duplicate detection and membership testing + excluded_ranks_set = set() + + # Validate each rank with detailed error messages + for i, rank in enumerate(ranks_to_exclude): + if not isinstance(rank, int): + raise TypeError( + f"All elements in ranks_to_exclude must be integers. " + f"Element at index {i} is {type(rank).__name__}: {rank}" + ) + + if not (0 <= rank < group_size): + raise ValueError( + f"Rank {rank} at index {i} is out of bounds for group size {group_size}. " + f"Valid ranks are in range [0, {group_size - 1}]." + ) + + if rank in excluded_ranks_set: + raise ValueError( + f"Duplicate rank {rank} found in ranks_to_exclude at index {i}. " + f"Each rank can only be excluded once." + ) + + excluded_ranks_set.add(rank) + + # Ensure we don't exclude all ranks + if len(excluded_ranks_set) >= group_size: + raise ValueError( + f"Cannot exclude all {group_size} ranks from process group. " + f"At least one rank must remain. Excluding {len(excluded_ranks_set)} ranks." + ) + + # Critical check: current rank should not be in excluded list + if current_rank in excluded_ranks_set: + raise RuntimeError( + f"Current rank {current_rank} is in the exclusion list and should not call shrink_group(). " + f"Only non-excluded ranks should participate in the shrinking operation. " + f"Excluded ranks should terminate their processes instead." + ) + + return excluded_ranks_set + + +def _finalize_shrunk_group( + group_info: dict, excluded_ranks_set: set, new_backend +) -> ProcessGroup: + """Clean up old group and create new shrunk process group.""" + target_pg = group_info["process_group"] + is_default_group = group_info["is_default_group"] + + # Handle default group dependencies - destroy other groups first + if is_default_group: + _destroy_all_other_groups(exclude_group=target_pg) + + # Gather original group metadata before cleanup + original_group_metadata = _extract_group_metadata(target_pg) + + # Calculate remaining ranks efficiently + original_ranks = get_process_group_ranks(target_pg) + remaining_ranks = [ + rank for rank in original_ranks if rank not in excluded_ranks_set + ] + + # Clean up the original group + _cleanup_original_group(target_pg, is_default_group) + + # Create and configure the new process group + new_pg = _create_shrunk_process_group( + new_backend, remaining_ranks, original_group_metadata, is_default_group + ) + + # Register the new group in global state + if is_default_group: + _update_default_pg(new_pg) + + # Update global state with new group information + rank_mapping = { + global_rank: group_rank + for group_rank, global_rank in enumerate(remaining_ranks) + } + _update_process_group_global_state( + pg=new_pg, + backend_name=original_group_metadata["backend_name"], + store=original_group_metadata["store"], + group_name=original_group_metadata["new_group_name"], + backend_config=original_group_metadata["backend_config"], + rank_mapping=rank_mapping, + ) + + return new_pg + + +def _extract_group_metadata(target_pg: ProcessGroup) -> dict: + """Extract metadata from the original group before cleanup.""" + original_backend_name, original_store = _world.pg_map[target_pg] + original_backend_config = _world.pg_backend_config.get(target_pg, "") + original_group_name = _get_process_group_name(target_pg) + + # Extract device binding information before cleanup to avoid accessing destroyed group + bound_device_id = None + if hasattr(target_pg, "bound_device_id"): + bound_device_id = target_pg.bound_device_id + + # Generate new group name for the shrunk group; hash for uniqueness across backends + remaining_ranks = list(get_process_group_ranks(target_pg)) + new_group_name = _process_group_name(remaining_ranks, use_hashed_name=True) + + return { + "backend_name": original_backend_name, + "store": original_store, + "backend_config": original_backend_config, + "original_group_name": original_group_name, + "new_group_name": new_group_name, + "bound_device_id": bound_device_id, # Safe to access after cleanup + } + + +def _cleanup_original_group(target_pg: ProcessGroup, is_default_group: bool) -> None: + """Clean up the original process group safely.""" + try: + destroy_process_group(target_pg) + except Exception: + group_type = "default" if is_default_group else "non-default" + logger.warning( + "Failed to destroy %s group during shrinking", group_type, exc_info=True + ) + + # Ensure global state cleanup even if destroy_process_group fails + _cleanup_process_group_global_state(target_pg) + + +def _create_shrunk_process_group( + new_backend, remaining_ranks: list[int], metadata: dict, is_default_group: bool +) -> ProcessGroup: + """Create and configure the new shrunk process group.""" + # Create new group properties + new_group_rank = new_backend.rank() + new_group_size = new_backend.size() + group_name = metadata["new_group_name"] + + # Generate descriptive group description + if is_default_group: + group_desc = "default:shrunken" + else: + group_desc = f"{metadata['original_group_name']}:shrunk" + + # Create process group with new communicator (clone the parent store like split does) + prefix_store = PrefixStore(f"{group_name}/", metadata["store"].clone()) + new_pg = ProcessGroup(prefix_store, new_group_rank, new_group_size) + + # Configure backend using the device type of the new backend's bound device if available, + # otherwise derive from the original group's bound device or fall back to CPU. + backend_device = metadata.get("bound_device_id") + if backend_device is None: + # Default to CPU if no bound device is present + backend_device = torch.device("cpu") + + # Choose backend enum based on device type + if backend_device.type == "cuda": + backend_type = ProcessGroup.BackendType.NCCL + else: + backend_type = ProcessGroup.BackendType.GLOO + + new_pg._register_backend(backend_device, backend_type, new_backend) + new_pg._set_default_backend(backend_type) + + # Inherit device binding from original group if it was bound + bound_device_id = metadata.get("bound_device_id") + if bound_device_id is not None: + new_pg.bound_device_id = bound_device_id + + # Set group metadata + new_pg._set_group_name(group_name) + new_pg._set_group_desc(group_desc) + + # Persist backend configuration overrides (if provided via shrink_group) + backend_config_override = metadata.get("backend_config") + if backend_config_override is not None: + # Store for introspection/debugging and potential backend hooks + _world.pg_backend_config[new_pg] = backend_config_override + + return new_pg + + +def _destroy_all_other_groups(exclude_group: Optional[ProcessGroup] = None) -> None: + """ + Destroy all process groups except the excluded group and clean up all global state. + + This is necessary when shrinking the default group because global ranks + are reassigned by NCCL, making all existing process groups inconsistent. + + Note: Uses abort for non-collective cleanup since excluded ranks may not + participate in collective operations. Backend cleanup is handled independently per group. + + Args: + exclude_group (ProcessGroup, optional): Process group to exclude from destruction. + If None, destroys all process groups. + """ + # Get list of groups to destroy (avoid modifying dict while iterating) + groups_to_destroy = [] + for pg in list(_world.pg_group_ranks.keys()): + if exclude_group is not None and pg == exclude_group: + continue + groups_to_destroy.append(pg) + + # Warn user about automatic destruction + if groups_to_destroy: + group_names = [_get_process_group_name(pg) for pg in groups_to_destroy] + logger.warning( + "Shrinking default group will destroy %d other process groups: %s. " + "This is necessary because shrinking the default group reassigns global ranks, " + "making existing groups inconsistent.", + len(groups_to_destroy), + ", ".join(group_names), + ) + + # Destroy each group and clean up global state + for pg in groups_to_destroy: + try: + # First call abort_process_group which handles the C++ cleanup non-collectively + _abort_process_group(pg) + except Exception: + # Log but don't fail - some groups might already be destroyed + logger.warning( + "Failed to abort process group %s", + _get_process_group_name(pg), + exc_info=True, + ) + + # Ensure all global state is cleaned up even if _abort_process_group fails + # or doesn't clean up everything + _cleanup_process_group_global_state(pg) + + +def _cleanup_process_group_global_state(pg: ProcessGroup) -> None: + """ + Clean up all global state associated with a process group. + + This function ensures complete cleanup of process group state from all + global dictionaries and registries, even if destroy_process_group fails + or doesn't clean up everything. This is critical when destroying multiple + groups to prevent inconsistent state. + + The cleanup removes the process group from: + - _world.pg_map (backend and store mapping) + - _world.pg_names (group name mapping) + - _world.pg_group_ranks (rank mappings) + - _world.pg_backend_config (backend configuration) + - _world.tags_to_pg and _world.pg_to_tag (tag mappings) + - _world.pg_coalesce_state (coalescing state) + - C++ internal registries via _unregister_process_group + + Args: + pg (ProcessGroup): The process group to clean up. + """ + try: + # Clean up main process group mappings + _world.pg_map.pop(pg, None) + _world.pg_group_ranks.pop(pg, None) + _world.pg_backend_config.pop(pg, None) + + # Clean up process group name mapping + group_name = _world.pg_names.pop(pg, None) + + # Clean up tag mappings + pg_tag = _world.pg_to_tag.pop(pg, None) + if pg_tag is not None and pg_tag in _world.tags_to_pg: + try: + _world.tags_to_pg[pg_tag].remove(pg) + # Remove the tag entry if list is empty + if not _world.tags_to_pg[pg_tag]: + _world.tags_to_pg.pop(pg_tag, None) + except (ValueError, KeyError): + # Process group was already removed from the list + pass + + # Clean up any registered process group names using C++ unregister function + if group_name is not None: + try: + _unregister_process_group(group_name) + except Exception: + # Process group name might not be registered or already unregistered + pass + + # Clean up coalesce state if present + _world.pg_coalesce_state.pop(pg, None) + + except Exception: + # Log cleanup failures but don't propagate - we want to continue with other cleanups + logger.warning( + "Failed to fully clean up global state for process group", exc_info=True + ) + + +def _update_process_group_global_state( + pg: ProcessGroup, + backend_name: str, + store: Store, + group_name: str, + backend_config: str, + rank_mapping: Optional[dict[int, int]] = None, + pg_tag: Optional[str] = None, + user_tag: Optional[str] = None, +) -> None: + """ + Update all global state dictionaries for a process group. + + This helper function consolidates the common pattern of updating multiple + global state dictionaries when creating or modifying process groups. + + Args: + pg (ProcessGroup): The process group to update state for. + backend_name (str): Backend name for pg_map. + store (Store): Store instance for pg_map. + group_name (str): Group name for pg_names and registration. + backend_config (str): Backend configuration string. + rank_mapping (Dict[int, int], optional): Global rank to group rank mapping. + If None, skips updating pg_group_ranks. + pg_tag (str, optional): Process group tag. If None, defaults to f"ptd:{group_name}". + user_tag (str, optional): User-provided tag for special tag handling. + If provided, creates "user:{user_tag}" tag and also adds to default "". + """ + # Update main process group mappings + _world.pg_map[pg] = (backend_name, store) + _world.pg_names[pg] = group_name + _world.pg_backend_config[pg] = backend_config + + # Register the process group name + _register_process_group(group_name, pg) + + # Update rank mapping if provided + if rank_mapping is not None: + _world.pg_group_ranks[pg] = rank_mapping + + # Handle tag management + if pg_tag is None: + pg_tag = f"ptd:{group_name}" + + if user_tag is not None: + # Special handling for user-provided tags + # Add to default "" tag first + _world.tags_to_pg.setdefault("", []).append(pg) + # Then create user-specific tag + user_pg_tag = f"user:{user_tag}" + _world.tags_to_pg.setdefault(user_pg_tag, []).append(pg) + _world.pg_to_tag[pg] = user_pg_tag + else: + # Standard process group tag + _world.tags_to_pg.setdefault(pg_tag, []).append(pg) + _world.pg_to_tag[pg] = pg_tag diff --git a/torch/testing/_internal/common_distributed.py b/torch/testing/_internal/common_distributed.py index 18384b311b936..91f09adf9e816 100644 --- a/torch/testing/_internal/common_distributed.py +++ b/torch/testing/_internal/common_distributed.py @@ -238,6 +238,47 @@ def wrapper(*args, **kwargs): return decorator +def requires_world_size(n: int): + """ + Decorator to request a specific world size for a test. The test harness can + read this attribute to set the number of ranks to spawn. If there are fewer + than `n` CUDA devices available, the test should be skipped by the harness. + + Usage: + @require_world_size(3) + def test_something(self): + ... + """ + + def decorator(func): + func._required_world_size = n + available = torch.cuda.device_count() + return unittest.skipUnless( + available >= n, f"requires {n} GPUs, found {available}" + )(func) + + return decorator + + +def get_required_world_size(obj: Any, default: int) -> int: + """ + Returns the requested world size for the currently running unittest method on `obj` + if annotated via `@require_world_size(n)`, else returns `default`. + """ + try: + # Try MultiProcessTestCase helper first, then unittest fallback + test_name = ( + obj._current_test_name() # type: ignore[attr-defined] + if hasattr(obj, "_current_test_name") and callable(obj._current_test_name) + else obj._testMethodName + ) + fn = getattr(obj, test_name) + value = fn._required_world_size + return int(value) + except Exception: + return default + + # This decorator helps avoiding initializing cuda while testing other backends def nccl_skip_if_lt_x_gpu(backend, x): def decorator(func): @@ -367,6 +408,13 @@ def requires_nccl_version(version, msg): ) +def requires_nccl_shrink(): + """ + Require NCCL shrink support (NCCL available and version >= 2.27). + """ + return requires_nccl_version((2, 27), "Need NCCL 2.27+ for shrink_group") + + def requires_nccl(): return skip_but_pass_in_sandcastle_if( not c10d.is_nccl_available(), From 45da6e1fe17dc7fb4d96526e907ed9c9bf002f70 Mon Sep 17 00:00:00 2001 From: "Wang, Chuanqi" Date: Wed, 5 Nov 2025 01:02:53 +0000 Subject: [PATCH 040/130] [CD] Upload XPU inductor benchmark test reports to s3 (#166954) As the title Pull Request resolved: https://github.com/pytorch/pytorch/pull/166954 Approved by: https://github.com/atalman --- .github/workflows/_xpu-test.yml | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/.github/workflows/_xpu-test.yml b/.github/workflows/_xpu-test.yml index e68bc6ead3a26..d27325b8a63dc 100644 --- a/.github/workflows/_xpu-test.yml +++ b/.github/workflows/_xpu-test.yml @@ -344,5 +344,21 @@ jobs: if-no-files-found: ignore path: ./**/core.[1-9]* + - name: Authenticate with AWS + uses: aws-actions/configure-aws-credentials@ececac1a45f3b08a01d2dd070d28d111c5fe6722 # v4.1.0 + with: + role-to-assume: arn:aws:iam::308535385114:role/gha_workflow_upload-benchmark-results + # The max duration enforced by the server side + role-duration-seconds: 18000 + aws-region: us-east-1 + + - name: Upload the benchmark results + uses: pytorch/test-infra/.github/actions/upload-benchmark-results@main + with: + benchmark-results-dir: test/test-reports + dry-run: false + schema-version: v3 + github-token: ${{ secrets.GITHUB_TOKEN }} + - name: Teardown XPU uses: ./.github/actions/teardown-xpu From 64ae31c5d36255d16c832204acb7709e0762f1b3 Mon Sep 17 00:00:00 2001 From: Xiao Fu Date: Tue, 4 Nov 2025 10:58:04 -0800 Subject: [PATCH 041/130] [HOP][print] Add HOP subclass for printing (#166660) Pull Request resolved: https://github.com/pytorch/pytorch/pull/166660 Approved by: https://github.com/angelayi, https://github.com/anijain2305 Co-authored-by: Angela Yi --- test/higher_order_ops/test_print.py | 44 +++++++++++++++++++++++++++++ torch/_higher_order_ops/__init__.py | 2 ++ torch/_higher_order_ops/print.py | 44 +++++++++++++++++++++++++++++ torch/testing/_internal/hop_db.py | 25 +++++++++++----- 4 files changed, 108 insertions(+), 7 deletions(-) create mode 100644 test/higher_order_ops/test_print.py create mode 100644 torch/_higher_order_ops/print.py diff --git a/test/higher_order_ops/test_print.py b/test/higher_order_ops/test_print.py new file mode 100644 index 0000000000000..aef538854864f --- /dev/null +++ b/test/higher_order_ops/test_print.py @@ -0,0 +1,44 @@ +# Owner(s): ["module: higher order operators"] +import io +from unittest.mock import patch + +import torch +from torch._dynamo.utils import counters +from torch.testing._internal.common_utils import run_tests, TestCase + + +class TestHopPrint(TestCase): + def test_base_print(self): + def f(x): + x = x + x + torch._higher_order_ops.print("moo") + x = x * x + torch._higher_order_ops.print("moo") + return x + + counters.clear() + x = torch.randn(3, 3) + with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout: + f(x) + printed_output = mock_stdout.getvalue().strip() + + self.assertEqual(printed_output, "moo\nmoo") + + def test_para_print(self): + def f(x): + x = x + x + torch._higher_order_ops.print("moo {x} {y}", x=1, y=2) + x = x * x + return x + + counters.clear() + x = torch.randn(3, 3) + with patch("sys.stdout", new_callable=io.StringIO) as mock_stdout: + f(x) + printed_output = mock_stdout.getvalue().strip() + + self.assertEqual(printed_output, "moo 1 2") + + +if __name__ == "__main__": + run_tests() diff --git a/torch/_higher_order_ops/__init__.py b/torch/_higher_order_ops/__init__.py index 516d58bdf314e..452a080570ebe 100644 --- a/torch/_higher_order_ops/__init__.py +++ b/torch/_higher_order_ops/__init__.py @@ -24,6 +24,7 @@ from torch._higher_order_ops.local_map import local_map_hop from torch._higher_order_ops.map import map from torch._higher_order_ops.out_dtype import out_dtype +from torch._higher_order_ops.print import print from torch._higher_order_ops.run_const_graph import run_const_graph from torch._higher_order_ops.scan import scan from torch._higher_order_ops.strict_mode import strict_mode @@ -75,4 +76,5 @@ "map", "while_loop_stack_output", "local_map_hop", + "print", ] diff --git a/torch/_higher_order_ops/print.py b/torch/_higher_order_ops/print.py new file mode 100644 index 0000000000000..5a14ef23aa24e --- /dev/null +++ b/torch/_higher_order_ops/print.py @@ -0,0 +1,44 @@ +import builtins + +import torch +import torch.utils._pytree as pytree +from torch._ops import HigherOrderOperator + + +class Print(HigherOrderOperator): + """ + print(format_str, **kwargs) -> None + + This Higher Order Operator (HOP) provides a functional version of print for use in PyTorch graphs. + It enables format printing with named arguments, e.g., torch._higher_order_ops.print("moo {x} {y}", x=1, y=2). + + This HOP enables printing without causing graph break. + """ + + def __init__(self) -> None: + super().__init__("print") + + def __call__(self, format_str: str, **kwargs: object) -> object: + assert isinstance(format_str, str) + return super().__call__(format_str, **kwargs) + + +print = Print() + + +@print.py_impl(torch._C.DispatchKey.CompositeExplicitAutograd) +# pyre-ignore +def print_cpu(format_str: str, **kwargs: object) -> None: + # Ensure all immutable_dict/list in kwargs are converted to regular dict/list + map_types: dict[type, type] = { + torch.fx.immutable_collections.immutable_dict: dict, + torch.fx.immutable_collections.immutable_list: list, + } + new_kwargs = pytree.tree_map_only( + tuple(map_types.keys()), + lambda a: map_types[type(a)](a), + kwargs, + lambda a: isinstance(a, tuple(map_types.keys())), + ) + # Use built-in print to avoid recursion with the HOP print + builtins.print(format_str.format(**new_kwargs)) diff --git a/torch/testing/_internal/hop_db.py b/torch/testing/_internal/hop_db.py index fc6cfa8cf7f4e..3b38661c69b8c 100644 --- a/torch/testing/_internal/hop_db.py +++ b/torch/testing/_internal/hop_db.py @@ -103,6 +103,7 @@ def f2(x, y0, y1): "dynamo_bypassing_wrapper", # TODO(soulitzer) "foreach_map", "aoti_call_delegate", + "print", ] torch.library.define( @@ -153,6 +154,7 @@ def sample_inputs_invoke_subgraph(opinfo, device, dtype, requires_grad, **kwargs def fn_for_invoke_subgraph(x): return torch.sin(x) + def simple_invoke_subgraph(x): return fn_for_invoke_subgraph(x) @@ -202,6 +204,7 @@ def body_fn(iter_t, x): return torch._higher_order_ops.while_loop(cond_fn, body_fn, (iter_t, x)) + def simple_while_loop_stack_output(iter_t, x): def cond_fn(iter_t, x): return iter_t > 0 @@ -209,7 +212,9 @@ def cond_fn(iter_t, x): def body_fn(iter_t, x): return iter_t - 1, x.cos() - return torch._higher_order_ops.while_loop_stack_output(cond_fn, body_fn, (iter_t, x), tuple()) + return torch._higher_order_ops.while_loop_stack_output( + cond_fn, body_fn, (iter_t, x), tuple() + ) def sample_inputs_local_map_hop(opinfo, device, dtype, requires_grad, **kwargs): @@ -226,18 +231,21 @@ def sample_inputs_local_map_hop(opinfo, device, dtype, requires_grad, **kwargs): def simple_local_map_hop(inp1, inp2): def body_gm(inp1, inp2): return inp1.cos() + inp2.sin() + gm = torch.fx.symbolic_trace(body_gm) assert torch.distributed.is_available() from torch.distributed.tensor.placement_types import Replicate + gm.meta["local_map_kwargs"] = { "in_placements": (Replicate(), Replicate(), Replicate()), - "out_placements": ((Replicate(), Replicate(), Replicate()),) + "out_placements": ((Replicate(), Replicate(), Replicate()),), } # TODO: Dynamo would rewrite this op differently return torch._higher_order_ops.local_map_hop(gm, inp1, inp2) + def sample_inputs_scan(opinfo, device, dtype, requires_grad, **kwargs): make_arg = functools.partial( make_tensor, device=device, dtype=dtype, requires_grad=requires_grad @@ -249,7 +257,6 @@ def sample_inputs_scan(opinfo, device, dtype, requires_grad, **kwargs): def simple_scan(init, xs): - def combine_fn(carry, x): result = carry @ x + x return result, carry.clone() @@ -264,15 +271,14 @@ def simple_invoke_quant(x): def fn(x, y): return (torch.sin(x) * y,) - return quant_tracer(fn, x, x)[0] * 2. + return quant_tracer(fn, x, x)[0] * 2.0 def simple_invoke_quant_packed(x): def fn(x): return (torch.sin(x),) - return invoke_quant_packed(fn, x)[0] * 2. - + return invoke_quant_packed(fn, x)[0] * 2.0 hop_db = [ @@ -496,6 +502,11 @@ def fn(x): DecorateInfo(unittest.expectedFailure, "TestHOP", "test_serialize_export"), DecorateInfo(unittest.expectedFailure, "TestHOP", "test_retrace_export"), ), - decorators=[onlyCUDA, unittest.skipIf(not torch.distributed.is_available(), "requires distributed build")], + decorators=[ + onlyCUDA, + unittest.skipIf( + not torch.distributed.is_available(), "requires distributed build" + ), + ], ), ] From bcd159bcddf477fe38fd020af403f7d1004c6c2b Mon Sep 17 00:00:00 2001 From: Haifeng Jin Date: Wed, 5 Nov 2025 01:16:54 +0000 Subject: [PATCH 042/130] Fix the vmap op fallback bug (#166032) ## The bug In some environments, if run: ```py def inner_func(x): return x.to(torch.float32, memory_format=torch.channels_last) x = torch.randn(2, 2, 3, 4, device="cpu", dtype=torch.float64) torch.vmap(inner_func)(x) ``` we get: ``` E RuntimeError: Batching rule not implemented for aten::to.dtype_layout; the fallback path doesn't work on out= or view ops. ``` Otherwise, it would always fallback and result in an error for ops like `to.dtype` and `to.dtype_layout` even the kernels are registered. ## The cause The alias key of `FuncTorchBatchedDecomposition` is not properly translated to runtime dispatch keys when updating the dispatch table of `OperatorEntry::dispatchTable_`. [[link](https://github.com/pytorch/pytorch/blob/984b096d10398a615a791fd11296d6d51fdd55a4/aten/src/ATen/core/dispatch/OperatorEntry.cpp#L500-L501)] The [`getRuntimeDispatchKeySet`](https://github.com/pytorch/pytorch/blob/f3fa560dec727380b3e9c074efe05f0ce715a5ca/c10/core/DispatchKeySet.cpp#L62) use if-else to translate all other alias keys but `FuncTorchBatchedDecomposition`. This would result in not finding the kernel in many cases. ## The fix This PR adds one more `if` statement to `getRuntimeDispatchKeySet` to map `FuncTorchBatchedDecomposition` to the corresponding runtime dispatch key, `FuncTorchBatched`. So, that the dispatch table can be properly updated. This fix allows people to use ops inside vmaps in more environments and across more compilers. ## Why does it work without the PR As long as the `FuncTorchBatchedDecomposition` [[link](https://github.com/pytorch/pytorch/blob/51319ca090bc7458168a8451c04ca7e021a72693/aten/src/ATen/functorch/BatchRulesDecompositions.cpp#L35)] is registered before the fallback method of `FuncTorchBatched` [[link](https://github.com/pytorch/pytorch/blob/d311a3d1dca4bfc01ae44fc5d1f8d7ff22bc551f/aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp#L759)], everything runs fine. In this case, it relies on the registration of the fallback method to update the dispatch table, which flushes all the kernels in `OperatorEntry::kernels_` into `dispatchTable_`, among which there are kernels registered with `FuncTorchBatchedDecomposition`. ## When does it fail However, the order of the op registration and the fallback registration is not garanteed at all. It relies on the C++ static initialization order, which varies from environment to environment. On our compiler, it the fallback registration goes first and the alias key kernels under `FuncTorchBatchedDecomposition` comes later and not get flushed into the dispatch table by the fallback registration. Therefore, it cannot find the kernel for it. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166032 Approved by: https://github.com/albanD --- c10/core/DispatchKeySet.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/c10/core/DispatchKeySet.cpp b/c10/core/DispatchKeySet.cpp index 72e72f49a5e40..107530e9e28a2 100644 --- a/c10/core/DispatchKeySet.cpp +++ b/c10/core/DispatchKeySet.cpp @@ -59,6 +59,9 @@ constexpr DispatchKeySet nested_dispatch_keyset = {DispatchKey::AutogradNestedTensor, DispatchKey::NestedTensor}) | DispatchKeySet(DispatchKeySet::RAW, full_backend_mask); +constexpr DispatchKeySet functorch_batched_dispatch_keyset = + DispatchKeySet(DispatchKey::FuncTorchBatched); + DispatchKeySet getRuntimeDispatchKeySet(DispatchKey t) { TORCH_INTERNAL_ASSERT(t != DispatchKey::Undefined); switch (t) { @@ -77,6 +80,8 @@ DispatchKeySet getRuntimeDispatchKeySet(DispatchKey t) { return backend_dispatch_keyset; case DispatchKey::CompositeExplicitAutogradNonFunctional: return non_functional_backend_dispatch_keyset; + case DispatchKey::FuncTorchBatchedDecomposition: + return functorch_batched_dispatch_keyset; default: return DispatchKeySet(t); } From 01e6e35c7faf913c3a85c7a64d2939cfa768358a Mon Sep 17 00:00:00 2001 From: Dzmitry Huba Date: Tue, 4 Nov 2025 13:45:20 -0800 Subject: [PATCH 043/130] Send / recv support in local tensor (#166595) This change introduces LocalRunnerMode that allows you to run multiple SPMD functions concurrently. SMPD functions are executing one at a time, yielding execution capability while waiting for send or receive operations to complete. Send and receive peer operations only supported while running under LocalRunnerMode. The example test in this change demonstrates how ranks are sending data to the next peer and receiving data from the previous peer (ring). Pull Request resolved: https://github.com/pytorch/pytorch/pull/166595 Approved by: https://github.com/wconstab, https://github.com/ezyang --- build_variables.bzl | 1 + test/distributed/test_local_tensor.py | 77 ++++++++++ torch/_C/_distributed_c10d.pyi | 7 +- torch/csrc/distributed/c10d/init.cpp | 28 ++++ .../distributed/c10d/python_callback_work.cpp | 64 +++++++++ .../distributed/c10d/python_callback_work.hpp | 28 ++++ torch/distributed/_local_tensor/__init__.py | 134 ++++++++++++++++++ torch/distributed/_local_tensor/_c10d.py | 104 ++++++++++++-- .../distributed/_tensor/common_dtensor.py | 4 + 9 files changed, 434 insertions(+), 13 deletions(-) create mode 100644 torch/csrc/distributed/c10d/python_callback_work.cpp create mode 100644 torch/csrc/distributed/c10d/python_callback_work.hpp diff --git a/build_variables.bzl b/build_variables.bzl index 70121e19d8099..258e739300c1e 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -1025,6 +1025,7 @@ libtorch_python_core_sources = [ libtorch_python_distributed_core_sources = [ "torch/csrc/distributed/c10d/init.cpp", "torch/csrc/distributed/c10d/python_comm_hook.cpp", + "torch/csrc/distributed/c10d/python_callback_work.cpp", ] libtorch_python_distributed_sources = libtorch_python_distributed_core_sources + [ diff --git a/test/distributed/test_local_tensor.py b/test/distributed/test_local_tensor.py index 114780627e334..c58ddf0f82ba7 100644 --- a/test/distributed/test_local_tensor.py +++ b/test/distributed/test_local_tensor.py @@ -7,6 +7,8 @@ import torch.distributed as dist from torch.distributed._local_tensor import ( local_tensor_mode, + LocalIntNode, + LocalRunnerMode, LocalTensor, LocalTensorMode, ) @@ -17,8 +19,10 @@ Partial, Replicate, Shard, + zeros, ) from torch.testing._internal.common_utils import run_tests, TestCase +from torch.testing._internal.distributed._tensor.common_dtensor import reduce_local_int class LocalTensorTestBase(TestCase): @@ -411,5 +415,78 @@ def test_dtensor_addmm(self): self.assertEqual(full_tensor, local_res) +from torch.distributed._local_tensor._c10d import local_p2p_op, wait_all + + +class TestLocalRunner(LocalTensorTestBase): + world_size = 6 + + @staticmethod + def _get_pp_peer(pp_index, mesh, dim, dir): + pp_meshes = mesh._get_all_submeshes(dim) + pp_ret = {} + for pp_mesh in pp_meshes: + global_rank = pp_mesh.mesh[pp_index].item() + global_peer = pp_mesh.mesh[(pp_index + dir) % pp_mesh.size()].item() + pp_ret[global_rank] = global_peer + + return torch.SymInt(LocalIntNode(pp_ret)) + + def _run_dp_pp( + self, + mesh: DeviceMesh, + pp_index: int, + actual: list[torch.Tensor | None], + expected: list[torch.Tensor | None], + ) -> None: + ltm = LocalTensorMode(mesh.size()) + with ltm: + dp_mesh = mesh["dp"] + pp_mesh = mesh["pp"] + + x = torch.rand(2, 4) + xd = distribute_tensor(x, dp_mesh, [Shard(0)]) + xd = xd * 2 + x = x * 2 + + yd = zeros(*xd.shape, device_mesh=dp_mesh, placements=[Shard(0)]) + + if pp_index != pp_mesh.size(0) - 1: + # Send to next pp rank + pp_next_rank = TestLocalRunner._get_pp_peer(pp_index, mesh, "pp", +1) + local_p2p_op(pp_next_rank, xd, dist.isend) + expected[pp_index + 1] = ltm.tensor_map( + x, + lambda r, t: t + if reduce_local_int(pp_next_rank, lambda vals: r in vals.values()) + else torch.zeros_like(t), + ) + + if pp_index != 0: + # Receive from prev pp rank + pp_prev_rank = TestLocalRunner._get_pp_peer(pp_index, mesh, "pp", -1) + rw = local_p2p_op(pp_prev_rank, yd, dist.irecv) + wait_all(rw) + + y = yd.full_tensor() + actual[pp_index] = y + + def test_dp_pp(self): + pp_size = 3 + mesh = init_device_mesh( + "cpu", (self.world_size // pp_size, pp_size), mesh_dim_names=("dp", "pp") + ) + actual: list[torch.Tensor | None] = [None] * pp_size + expected: list[torch.Tensor | None] = [None] * pp_size + with LocalRunnerMode( + self.world_size, + pp_size, + lambda pp_index: self._run_dp_pp(mesh, pp_index, actual, expected), + ): + pass + + self.assertEqual(actual, expected) + + if __name__ == "__main__": run_tests() diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi index 737362be62b48..f3d96860f5584 100644 --- a/torch/_C/_distributed_c10d.pyi +++ b/torch/_C/_distributed_c10d.pyi @@ -2,7 +2,7 @@ # mypy: disable-error-code="type-arg" from datetime import timedelta from enum import Enum -from typing import Any, Optional, overload, Union +from typing import Any, Callable, Optional, overload, Union import torch from torch import Tensor @@ -616,6 +616,11 @@ class FakeWork(Work): def wait(self, timeout: timedelta = ...) -> bool: ... def getFuture(self) -> Future: ... +class PythonCallbackWork(Work): + def __init__(self, callback: Callable[[timedelta], bool]) -> None: ... + def wait(self, timeout: timedelta = ...) -> bool: ... + def get_future(self) -> Future: ... + class ProcessGroupGloo(Backend): class Device: ... diff --git a/torch/csrc/distributed/c10d/init.cpp b/torch/csrc/distributed/c10d/init.cpp index 4c6bdbe2ce70f..91bb3469e3e85 100644 --- a/torch/csrc/distributed/c10d/init.cpp +++ b/torch/csrc/distributed/c10d/init.cpp @@ -19,6 +19,7 @@ #include #include #include +#include #ifdef USE_C10D_GLOO #include @@ -3887,6 +3888,33 @@ such as `dist.all_reduce(tensor, async_op=True)`. .def("wait", &::c10d::FakeWork::wait, py::arg("timeout") = kNoTimeout) .def("getFuture", &::c10d::FakeWork::getFuture); + auto pythonCallbackWork = + intrusive_ptr_no_gil_destructor_class_<::c10d::PythonCallbackWork>( + module, "PythonCallbackWork", work) + .def(py::init(), py::arg("callback")) + .def( + "wait", + &::c10d::PythonCallbackWork::wait, + py::arg("timeout") = kNoTimeout, + R"( + Waits until the callback completes. Blocking operation. + The callback is invoked with the timeout parameter and should return a boolean. + Throws if the callback completes with an exception. + Returns the boolean value returned by the callback. + )") + .def( + "get_future", + [](::c10d::PythonCallbackWork& work) + -> std::shared_ptr { + return std::make_shared( + work.getFuture()); + }, + R"( + Returns: + A ``torch.futures.Future`` object which is associated with the completion of + the ``PythonCallbackWork``. + )"); + py::class_(module, "DDPLoggingData") .def(py::init<>()) .def_readwrite("strs_map", &c10::DDPLoggingData::strs_map) diff --git a/torch/csrc/distributed/c10d/python_callback_work.cpp b/torch/csrc/distributed/c10d/python_callback_work.cpp new file mode 100644 index 0000000000000..47bef1831a480 --- /dev/null +++ b/torch/csrc/distributed/c10d/python_callback_work.cpp @@ -0,0 +1,64 @@ +#include +#include + +namespace c10d { + +PythonCallbackWork::PythonCallbackWork(py::function callback) + : callback_(std::move(callback)) { + // Create a future that will be marked as complete when wait() is called + future_ = c10::make_intrusive( + c10::ListType::create(c10::TensorType::get())); +} + +// NOLINTNEXTLINE(bugprone-exception-escape) +PythonCallbackWork::~PythonCallbackWork() { + py::gil_scoped_acquire ag; + callback_.dec_ref(); + // Explicitly set callback_ to nullptr to prevent py::object's dtor + // to decref on the PyObject again. + // See Note [Destructing py::object] in python_ivalue.h + callback_.ptr() = nullptr; +} + +bool PythonCallbackWork::wait(std::chrono::milliseconds timeout) { + py::gil_scoped_acquire ag; + + try { + // Call the Python callback with timeout + py::object result = callback_(timeout); + + // Extract the boolean result + bool success = result.cast(); + + // Mark the work as completed if successful + if (success) { + finish(); + // Mark the future as complete with an empty list + if (!future_->completed()) { + future_->markCompleted(c10::IValue(c10::List())); + } + } + + return success; + } catch (py::error_already_set& e) { + // Capture the Python exception and store it + finish(std::current_exception()); + if (!future_->completed()) { + future_->setErrorIfNeeded(std::current_exception()); + } + throw; + } catch (const std::exception& e) { + // Capture any C++ exception and store it + finish(std::current_exception()); + if (!future_->completed()) { + future_->setErrorIfNeeded(std::current_exception()); + } + throw; + } +} + +c10::intrusive_ptr PythonCallbackWork::getFuture() { + return future_; +} + +} // namespace c10d diff --git a/torch/csrc/distributed/c10d/python_callback_work.hpp b/torch/csrc/distributed/c10d/python_callback_work.hpp new file mode 100644 index 0000000000000..48966e785ad60 --- /dev/null +++ b/torch/csrc/distributed/c10d/python_callback_work.hpp @@ -0,0 +1,28 @@ +#pragma once + +#include +#include +#include + +namespace c10d { + +// PythonCallbackWork is a subclass of Work that wraps a Python callback +// function that implements wait(). This allows asynchronous work to +// be integrated with Python code, enabling custom completion logic or +// post-processing in Python. +class PythonCallbackWork : public Work { + public: + explicit PythonCallbackWork(py::function callback); + + ~PythonCallbackWork() override; + + bool wait(std::chrono::milliseconds timeout) override; + + c10::intrusive_ptr getFuture() override; + + private: + py::function callback_; + c10::intrusive_ptr future_; +}; + +} // namespace c10d diff --git a/torch/distributed/_local_tensor/__init__.py b/torch/distributed/_local_tensor/__init__.py index ea9707b2e1e85..c186694df94e7 100644 --- a/torch/distributed/_local_tensor/__init__.py +++ b/torch/distributed/_local_tensor/__init__.py @@ -64,6 +64,7 @@ np = None # type: ignore[assignment] import torch +import torch.distributed as dist from torch import Size, SymBool, SymInt, Tensor from torch._C import DispatchKey, DispatchKeySet, ScriptObject from torch._export.wrappers import mark_subclass_constructor_exportable_experimental @@ -921,6 +922,22 @@ def rank_map(self, cb: Callable[[int], Tensor]) -> LocalTensor: # pyrefly: ignore [bad-argument-type, bad-argument-count] return LocalTensor({r: cb(r) for r in self.ranks}) + def tensor_map( + self, tensor: LocalTensor, cb: Callable[[int, Tensor], Tensor | None] + ) -> LocalTensor: + """ + Creates a LocalTensor instance by mapping rank id to ids local shard. + """ + + with self.disable(): + results = {} + for r in self.ranks: + if r in tensor._local_tensors: + m = cb(r, tensor._local_tensors[r]) + if m is not None: + results[r] = m + return LocalTensor(results) + def _patch_device_mesh(self) -> None: assert self._old_get_coordinate is None self._old_get_coordinate = DeviceMesh.get_coordinate # type: ignore[assignment] @@ -1049,3 +1066,120 @@ def maybe_disable_local_tensor_mode() -> contextlib.AbstractContextManager: """ lm = local_tensor_mode() return lm.disable() if lm is not None else contextlib.nullcontext() + + +import threading +from queue import Queue + + +_LOCAL_RUNNER_MODE: "LocalRunnerMode | None" = None + + +class LocalRunnerMode: + """ + A class for running multiple SPMD functions concurrently, however at any point + in time only one function can be running. The main use case for the local runner + mode is to enable SPMD functions to be able to use send and recv to communicate + with each other. Without local runner mode send and recv are not supported. + """ + + runner_context = threading.local() + + def __init__( + self, ranks: frozenset[int] | int, concurrency: int, fn: Callable[[int], None] + ): + if isinstance(ranks, int): + ranks = frozenset(range(ranks)) + self._ranks = ranks + self._fn = fn + self._run_lock = threading.Lock() + self._run_id = -1 + self._run_cond = threading.Condition(self._run_lock) + + self._recv_objects: dict[int, dict[int, Queue]] = { + dst: {src: Queue() for src in ranks} for dst in ranks + } + self._runners = [ + threading.Thread(target=self._run, args=(i,), name="LocalRunnerMode") + for i in range(concurrency) + ] + + def __enter__(self) -> "LocalRunnerMode": + global _LOCAL_RUNNER_MODE + assert _LOCAL_RUNNER_MODE is None, "LocalRunnerMode is already running" + _LOCAL_RUNNER_MODE = self + + for r in self._runners: + r.start() + return self + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + for r in self._runners: + r.join() + global _LOCAL_RUNNER_MODE + _LOCAL_RUNNER_MODE = None + + def _run(self, id: int) -> None: + LocalRunnerMode.runner_context.id = id + # Only one thread can run at a time, hence must acquire the lock + try: + self._acquire_run_lock() + self._fn(id) + finally: + self._release_run_lock() + + def _acquire_run_lock(self) -> None: + self._run_lock.acquire() + self._run_id = LocalRunnerMode.runner_context.id + + def _release_run_lock(self) -> None: + self._run_id = -1 + self._run_lock.release() + + def _assert_holds_run_lock(self) -> None: + assert self._run_id == LocalRunnerMode.runner_context.id, ( + "Calling thread does not hold the run lock" + ) + + def _get_recv_object(self, src: int, dst: int) -> object | None: + peers = [src] if src != -1 else list(self._ranks) + recv_objects = self._recv_objects[dst] + + for p in peers: + if not recv_objects[p].empty(): + return recv_objects[p].get() + + return None + + def _signal_send(self, src: int, dst: int, obj: object) -> None: + assert obj is not None, "Cannot signal None" + self._assert_holds_run_lock() + # Only a single thread a time executes so it is safe to mutate + # read objects queue (executing thread is already holding the lock) + self._recv_objects[dst][src].put(obj) + # Signal directly condition variable since the calling thread is already + # holding the lock + self._run_cond.notify_all() + + def _wait_recv(self, src: int, dst: int, post: Callable[[object], None]) -> None: + self._assert_holds_run_lock() + # Wait for the object to be available + while True: + obj = self._get_recv_object(src, dst) + if obj is not None: + post(obj) + # Note that we are not releasing the lock here, since the thread + # will continue to run and therefore must hold the lock + return + self._run_cond.wait() + + @staticmethod + def current() -> "LocalRunnerMode": + global _LOCAL_RUNNER_MODE + assert _LOCAL_RUNNER_MODE is not None, "LocalRunnerMode is not enabled" + return _LOCAL_RUNNER_MODE diff --git a/torch/distributed/_local_tensor/_c10d.py b/torch/distributed/_local_tensor/_c10d.py index 30b99931f2514..c9256543e8977 100644 --- a/torch/distributed/_local_tensor/_c10d.py +++ b/torch/distributed/_local_tensor/_c10d.py @@ -2,12 +2,15 @@ import math import operator from collections.abc import Sequence +from datetime import timedelta +from typing import Callable import torch from torch._C import ScriptObject -from torch._C._distributed_c10d import FakeWork +from torch._C._distributed_c10d import FakeWork, PythonCallbackWork from torch.distributed._mesh_layout import _MeshLayout from torch.distributed.distributed_c10d import ( + _check_op, _get_default_group, _resolve_process_group, ProcessGroup, @@ -765,10 +768,19 @@ def _local_send( # "send(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, " # "int dst, int tag) -> __torch__.torch.classes.c10d.Work"; - raise NotImplementedError( - "LocalTensor does not support MPMD operations like send. " - "Use SPMD collective operations instead." - ) + from . import LocalRunnerMode, LocalTensor + + assert len(tensors) == 1 + tensor = tensors[0] + + assert isinstance(tensor, LocalTensor), "Input tensor must be a Tensor" + src = int(tensor.__src_rank__) + + LocalRunnerMode.current()._signal_send(src, dst, tensor._local_tensors[src]) + + work = FakeWork() + work_so = Work.boxed(work) + return work_so def _local_recv_( @@ -779,11 +791,26 @@ def _local_recv_( ) -> ScriptObject: # "recv_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, " # "int src, int tag) -> __torch__.torch.classes.c10d.Work"; + from . import LocalRunnerMode, LocalTensor - raise NotImplementedError( - "LocalTensor does not support MPMD operations like recv. " - "Use SPMD collective operations instead." - ) + assert len(tensors) == 1 + tensor = tensors[0] + + assert isinstance(tensor, LocalTensor), "Input tensor must be a Tensor" + dst = int(tensor.__src_rank__) + + def _recv_and_store(timeout: timedelta) -> bool: + def _wait_and_store(obj: object) -> None: + assert isinstance(obj, torch.Tensor), "Expected to receive a Tensor" + assert isinstance(tensor, LocalTensor), "Input tensor must be a Tensor" + tensor._local_tensors[dst] = obj + + LocalRunnerMode.current()._wait_recv(src, dst, _wait_and_store) + return True + + work = PythonCallbackWork(_recv_and_store) + work_so = Work.boxed(work) + return work_so def _local_recv_any_source_( @@ -792,7 +819,60 @@ def _local_recv_any_source_( # "recv_any_source_(Tensor[] tensors, __torch__.torch.classes.c10d.ProcessGroup process_group, " # "int tag) -> __torch__.torch.classes.c10d.Work"; - raise NotImplementedError( - "LocalTensor does not support MPMD operations like recv_any_source. " - "Use SPMD collective operations instead." + return _local_recv_(tensors, process_group_so, -1, tag) + + +def _attach_rank(tensor: torch.Tensor, rank: int) -> torch.Tensor: + """ + Attaches rank as an attribute to given tensor so that the send or recv implementation + knows which rank initiates the operation (note under local tensor mode ). + """ + from torch.distributed.tensor import DTensor + + if isinstance(tensor, DTensor): + tensor = tensor._local_tensor + + tensor.__src_rank__ = rank # type: ignore[attr-defined] + return tensor + + +def local_p2p_op( + dst: torch.SymInt, + tensor: torch.Tensor, + op: Callable[[torch.Tensor, int], Work | None], +) -> Work | None | list[Work | None]: + """ + Runs a point-to-point (P2P) operation for all combinations of source and destination ranks. + """ + _check_op(op) + + from . import LocalIntNode + + assert isinstance(dst.node, LocalIntNode), ( + "Expected 'dst' to be a LocalIntNode where the value is the destination rank and key is the source rank" ) + + w = [] + for s, d in dst.node._local_ints.items(): + tensor = _attach_rank(tensor, s) + w.append(op(tensor, d)) + return w + + +def wait_all(work: Work | None | list[Work | None]) -> None: + """ + Waits for all work objects in the input to complete. + + A single Work object, None, or a list of Work objects (possibly containing None). + If None, does nothing. If a single Work, waits for it to complete. If a list, waits + for each non-None Work in the list to complete. + """ + + if work is None: + return + if isinstance(work, Work): + work = [work] + for w in work: + if w is None: + continue + w.wait() diff --git a/torch/testing/_internal/distributed/_tensor/common_dtensor.py b/torch/testing/_internal/distributed/_tensor/common_dtensor.py index 17140f40684dd..f4afca4bd1803 100644 --- a/torch/testing/_internal/distributed/_tensor/common_dtensor.py +++ b/torch/testing/_internal/distributed/_tensor/common_dtensor.py @@ -814,3 +814,7 @@ def map_local_tensor_for_rank(tensor, rank, func): @maybe_run_for_local_tensor def map_local_for_rank(rank, func): return func(rank) + + +def reduce_local_int(val, func): + return func(val.node._local_ints) From cd5d810c3aefa6bcc6a71e849e5eaa9db90f8d35 Mon Sep 17 00:00:00 2001 From: Shangdi Yu Date: Wed, 5 Nov 2025 02:22:29 +0000 Subject: [PATCH 044/130] Annotation should be deepcopied (#167017) The annotation should be deepcopied. Otherwise all nodes with the same `seq_nr` share the same underlying dict Pull Request resolved: https://github.com/pytorch/pytorch/pull/167017 Approved by: https://github.com/yiming0416 --- torch/_functorch/_aot_autograd/utils.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torch/_functorch/_aot_autograd/utils.py b/torch/_functorch/_aot_autograd/utils.py index 844f34bb576da..9fbb5e5fe9841 100644 --- a/torch/_functorch/_aot_autograd/utils.py +++ b/torch/_functorch/_aot_autograd/utils.py @@ -3,6 +3,7 @@ Contains various utils for AOTAutograd, including those for handling collections. """ +import copy import dataclasses import logging import operator @@ -459,7 +460,9 @@ def _copy_metadata_to_bw_nodes_in_subgraph( node.meta["fwd_nn_module_stack"] = fwd_node.meta.get("nn_module_stack") node.meta["fwd_source_fn_stack"] = fwd_node.meta.get("source_fn_stack") # TODO: better to change to a specific field of custom? - node.meta["custom"] = fwd_node.meta.get("custom") + custom = fwd_node.meta.get("custom") + if custom is not None: + node.meta["custom"] = copy.deepcopy(custom) def copy_fwd_metadata_to_bw_nodes(fx_g: torch.fx.GraphModule) -> None: From 53b03f1a2b4c8fc0e20bdff4cfbb43aad01bb978 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 5 Nov 2025 02:36:46 +0000 Subject: [PATCH 045/130] Revert "make narrow_tensor_symint DDE-free (#166379)" This reverts commit d7e2d0ad301b5d0db049bf5d2a2fc7ff9c89c58c. Reverted https://github.com/pytorch/pytorch/pull/166379 on behalf of https://github.com/malfet due to Need to revert previous PR in the stack ([comment](https://github.com/pytorch/pytorch/pull/166379#issuecomment-3488910172)) --- aten/src/ATen/native/TensorShape.cpp | 4 ++-- test/functorch/test_aotdispatch.py | 2 +- test/test_dynamic_shapes.py | 13 ------------- test/test_proxy_tensor.py | 1 + 4 files changed, 4 insertions(+), 16 deletions(-) diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index b3fff5a4bb42f..6136a6aa8c520 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -1764,8 +1764,8 @@ Tensor narrow_tensor_symint( start.dim() == 0 && isIntegralType(start.scalar_type(), /*includeBool=*/false), "start must be an 0-dim integral Tensor."); - c10::SymInt st = start.item().toSymInt(); - return at::narrow_symint(self, dim, std::move(st), std::move(length)); + int64_t st = start.item(); + return at::narrow_symint(self, dim, c10::SymInt(st), std::move(length)); } std:: diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index 6cae42d8929da..b0dd1ff8fa75d 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -8126,7 +8126,7 @@ def fn(x): xfail("corrcoef"), xfail("quantile"), xfail("nanquantile"), - skip("narrow"), + xfail("narrow"), xfail("istft"), xfail("linalg.eig"), skip("as_strided_scatter"), diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index d3f9e415ff944..b63e0427c26c3 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -4452,19 +4452,6 @@ def test_narrow_unbacked_start_cpp_wrapper(self): """Test narrow with unbacked start with cpp_wrapper""" self.test_narrow_unbacked_start() - @torch._dynamo.config.patch(capture_scalar_outputs=True) - def test_narrow_with_tensor_start(self): - @torch.compile(backend="inductor", fullgraph=True) - def f(x, start, end): - return torch.narrow(x, 0, start, end) - - x = torch.tensor( - [False], device="cuda:0" if torch.cuda.is_available() else "cpu" - ) - start = torch.tensor(0) - res = f(x, start, 0) - self.assertEqual(res.shape, torch.Size([0])) - instantiate_parametrized_tests(TestUnbacked) diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index 0487995a2d1c5..b76895a0a91f3 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -1987,6 +1987,7 @@ def f(t): } only_fake_tensor_failures = { + xfail('narrow'), xfail('tensor_split'), } From a743f9eeb57255f800d4c91ba29da6e0d9c4a229 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 5 Nov 2025 02:39:55 +0000 Subject: [PATCH 046/130] Revert "Avoid DDE in narrow with unbacked start (#166361)" This reverts commit ed45c5f38df6aa419c67d139d932c2c94404223a. Reverted https://github.com/pytorch/pytorch/pull/166361 on behalf of https://github.com/malfet due to Looks like it broke test_torchfuzz subtests, see https://hud.pytorch.org/hud/pytorch/pytorch/01e6e35c7faf913c3a85c7a64d2939cfa768358a/1?per_page=50&name_filter=trunk&mergeEphemeralLF=true ([comment](https://github.com/pytorch/pytorch/pull/166361#issuecomment-3488916766)) --- aten/src/ATen/native/TensorShape.cpp | 38 +++--------------- c10/core/SymBool.cpp | 14 ------- c10/core/SymBool.h | 6 --- test/export/test_export.py | 31 +++++--------- test/test_dynamic_shapes.py | 51 ------------------------ test/test_torchfuzz_repros.py | 5 +-- torch/_inductor/codegen/wrapper.py | 3 +- torch/fx/experimental/symbolic_shapes.py | 19 +-------- torch/utils/_sympy/printers.py | 36 ----------------- 9 files changed, 19 insertions(+), 184 deletions(-) diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index 6136a6aa8c520..6df7761d822db 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -1,6 +1,5 @@ #include #include -#include #define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include #include @@ -1711,14 +1710,11 @@ Tensor narrow_symint( "], but got ", start, ")") - // Bounds check without converting start: - // - If start < 0: need (start + cur_size) + length <= cur_size, i.e., start + - // length <= 0 - // - If start >= 0: need start + length <= cur_size - auto end = start + length; + if (start < 0) { + start = start + cur_size; + } TORCH_SYM_CHECK( - (start.sym_lt(0).sym_and((end).sym_le(0))) - .sym_or(start.sym_ge(0).sym_and((end).sym_le(cur_size))), + start.sym_le(cur_size - length), "start (", start, ") + length (", @@ -1726,31 +1722,7 @@ Tensor narrow_symint( ") exceeds dimension size (", cur_size, ")."); - - if (TORCH_GUARD_OR_FALSE(start.sym_ge(0).sym_or(end.sym_ne(0)))) { - return at::slice_symint(self, dim, start, end, 1); - } else if (TORCH_GUARD_OR_FALSE(start.sym_lt(0))) { - // Avoid the complex symbolic expressions path for non-unbacked. - return at::slice_symint(self, dim, start + cur_size, end + cur_size, 1); - } else { - // Cannot statically determine the condition due to unbacked. - // This is an interesting situation; when start is negative and - // start + length == 0, slice and narrow do different things. - // i.e., x.narrow(0, -2, 2) != x[-2:0]; in that case, we want to - // pass curr_size instead of 0. Otherwise, they would do the same thing. - // This says at runtime: if start < 0 and end == 0, then pass curr_size - // instead of 0. - - auto use_different = start.sym_lt(0).sym_and(end.sym_eq(0)).toSymInt(); - auto result = - at::slice_symint(self, dim, start, end + use_different * cur_size, 1); - - // Ensure slice allocated unbacked size is specialized to length. - SymInt new_size = result.sym_size(dim); - TORCH_SYM_CHECK(new_size.sym_eq(length), "") - - return result; - } + return at::slice_symint(self, dim, start, start + length, 1); } // This overload exists purely for XLA, because they wanted to pass in diff --git a/c10/core/SymBool.cpp b/c10/core/SymBool.cpp index 48c407b8b069c..d804eb9d27409 100644 --- a/c10/core/SymBool.cpp +++ b/c10/core/SymBool.cpp @@ -1,5 +1,4 @@ #include -#include #include namespace c10 { @@ -112,17 +111,4 @@ bool SymBool::has_hint() const { return toSymNodeImpl()->has_hint(); } -SymInt SymBool::toSymInt() const { - // If concrete bool, return concrete SymInt - if (auto ma = maybe_as_bool()) { - return SymInt(*ma ? 1 : 0); - } - - // Symbolic case: use sym_ite to convert bool to int (0 or 1) - auto node = toSymNodeImpl(); - auto one_node = node->wrap_int(1); - auto zero_node = node->wrap_int(0); - return SymInt(node->sym_ite(one_node, zero_node)); -} - } // namespace c10 diff --git a/c10/core/SymBool.h b/c10/core/SymBool.h index a27a28a5bf8a3..d5d509e239b1d 100644 --- a/c10/core/SymBool.h +++ b/c10/core/SymBool.h @@ -12,8 +12,6 @@ namespace c10 { -class SymInt; - class C10_API SymBool { public: /*implicit*/ SymBool(bool b) : data_(b) {} @@ -82,10 +80,6 @@ class C10_API SymBool { return toSymNodeImplUnowned()->constant_bool(); } - // Convert SymBool to SymInt (0 or 1) - // This is the C++ equivalent of Python's cast_symbool_to_symint_guardless - SymInt toSymInt() const; - bool is_heap_allocated() const { return ptr_; } diff --git a/test/export/test_export.py b/test/export/test_export.py index cdc18b1d4c564..3908f03b11e55 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -6093,19 +6093,26 @@ def forward(self, x, y, fixes): retry_export( cf_implicitsize(), (torch.tensor(2), torch.randn(10)), - fixes=[], + fixes=[ + # Could not guard on data-dependent expression u0 < 0 + "torch._check(i >= 0)", + ], ) class cf_stacklist(torch.nn.Module): def forward(self, xs, y, fixes): i = y.item() eval(fixes) + # instead of xs[i] return torch.stack(xs, 0).narrow(0, i, 1).squeeze() retry_export( cf_stacklist(), ([torch.ones(5) * i for i in range(10)], torch.tensor(2)), - fixes=[], + fixes=[ + # Could not guard on data-dependent expression u0 < 0 + "torch._check(i >= 0)", + ], ) class cf_tensorsplit(torch.nn.Module): @@ -6159,12 +6166,7 @@ def test_no_suggested_fixes_for_data_dependent_errors(self): class cf_stacklist(torch.nn.Module): def forward(self, xs, y): # y.item() is not a local, so we can't suggest a fix - if y.item() < 0: - return ( - torch.stack(xs, 0).narrow(0, y.item() + xs.size(), 1).squeeze() - ) - else: - return torch.stack(xs, 0).narrow(0, y.item(), 1).squeeze() + return torch.stack(xs, 0).narrow(0, y.item(), 1).squeeze() with self.assertRaisesRegex( error_type, @@ -6194,18 +6196,7 @@ class cf_stacklist_udd(torch.nn.Module): def forward(self, xs, y): box = Box(y.item()) # box.content is not a local, so we can't suggest a fix - if box.content < 0: - return ( - torch.stack(xs, 0) - .narrow(0, box.content + xs.size(), 1) - .squeeze() - ) - else: - return ( - torch.stack(xs, 0) - .narrow(0, box.content + xs.size(), 1) - .squeeze() - ) + return torch.stack(xs, 0).narrow(0, box.content, 1).squeeze() with self.assertRaisesRegex( error_type, diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index b63e0427c26c3..fb1d22805d50a 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -4401,57 +4401,6 @@ def func(x, y): self.assertEqual(compiled(a, b), func(a, b)) - @fresh_cache() - @torch._dynamo.config.patch("capture_scalar_outputs", True) - def test_narrow_unbacked_start(self): - def func(x, start, length): - # unbacked start - u0 = start.item() - return torch.narrow(x, 0, u0, length) - - compiled_func = torch.compile(func, fullgraph=True, backend="inductor") - - x = torch.tensor([1, 2, 3, 4, 5, 6]) - - # Test cases: (start, length) - test_cases = [ - # Negative starts - (-2, 2), # Start from second-to-last element - (-1, 1), # Start from last element - (-3, 3), # Start from third-to-last element - (-6, 2), # Start from beginning (negative) - (-4, 1), # Start from fourth-to-last element - # Positive starts - (0, 2), # Start from beginning - (1, 3), # Start from second element - (2, 2), # Start from third element - (4, 2), # Start near end - # Edge cases - (0, 6), # Full tensor - (0, 1), # Single element from start - (5, 1), # Single element from end - ] - - for start_val, length in test_cases: - with self.subTest(start=start_val, length=length): - start = torch.tensor([start_val]) - - # Test with compiled function - result_compiled = compiled_func(x, start, length) - - # Test with eager function (expected behavior) - result_eager = func(x, start, length) - - # Compare results - self.assertEqual(result_compiled, result_eager) - - @fresh_cache() - @torch._dynamo.config.patch("capture_scalar_outputs", True) - @torch._inductor.config.patch("cpp_wrapper", True) - def test_narrow_unbacked_start_cpp_wrapper(self): - """Test narrow with unbacked start with cpp_wrapper""" - self.test_narrow_unbacked_start() - instantiate_parametrized_tests(TestUnbacked) diff --git a/test/test_torchfuzz_repros.py b/test/test_torchfuzz_repros.py index 84a00430420cf..3b864aae4f477 100644 --- a/test/test_torchfuzz_repros.py +++ b/test/test_torchfuzz_repros.py @@ -16,10 +16,6 @@ from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON -# Skip all tests in this file if CUDA is not available -pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") - - class TestFuzzerCompileIssues(TestCase): """Test cases for fuzzer-discovered eager/compile divergence issues.""" @@ -261,6 +257,7 @@ def foo(arg0, arg1): out_compiled.sum().backward() print("Compile Success! ✅") + @pytest.mark.xfail(reason="Issue #163971") def test_fuzzer_issue_163971(self): torch.manual_seed(0) diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index 947166cf216cd..e629d9c7bdebd 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -2063,8 +2063,7 @@ def clamp_index(x): neg = self.codegen_sizevar( sympy.Max(0, sympy.Min(x + node.size, node.size)) ) - x_cond = self.codegen_sizevar(x) - return f"{pos} if {x_cond} >= 0 else {neg}" + return f"{pos} if {x} >= 0 else {neg}" def codegen_with_step(start_var, end_var, step): if step == 1: diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index 693d25aea6130..aeccdfbe000db 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -547,7 +547,6 @@ def rebind_unbacked( assert shape_env is not None for raw_u0, path in bindings.items(): u1 = pytree.key_get(result, path) - # Sometimes, things were previously unbacked bindings become constants. # There are two situations this can happen. # @@ -603,23 +602,7 @@ def rebind_unbacked( if u1.node.hint is not None: continue - # unbacked symbols bindings might be replaced to other backed or - # unbacked replacements. - # - # Example: - # u = x.item() - # torch._check(u == 5) - # - # The safest approach is to retrieve raw_u1 from u1.node._expr - # and perform the rebinding on the original unbacked symbol, - # even if it’s no longer directly referenced. - # - # In other words, we should always rebind the original symbol - # before any replacements are applied. - # u0 -> u0 == s1 - raw_u1 = u1.node._expr - - # TODO Do we still need this logic below? + raw_u1 = u1.node.expr # Simplify SymBool binding if ( isinstance(raw_u1, sympy.Piecewise) diff --git a/torch/utils/_sympy/printers.py b/torch/utils/_sympy/printers.py index 915d0e5461f1e..526443577b3f8 100644 --- a/torch/utils/_sympy/printers.py +++ b/torch/utils/_sympy/printers.py @@ -306,24 +306,6 @@ def _print_RoundDecimal(self, expr: sympy.Expr) -> str: raise TypeError("ndigits must be an instance of sympy.Integer") return f"round({self._print(number)}, {ndigits})" - def _print_Piecewise(self, expr: sympy.Expr) -> str: - # Convert Piecewise(expr_cond_pairs) to nested ternary expressions - # Piecewise((e1, c1), (e2, c2), ..., (eN, cN)) - # becomes: e1 if c1 else (e2 if c2 else (... else eN)) - result: Optional[str] = None - for expr_i, cond_i in reversed(expr.args): - expr_str = self._print(expr_i) - if cond_i == True: # noqa: E712 - # This is the default case - result = expr_str - else: - cond_str = self._print(cond_i) - if result is None: - result = expr_str - else: - result = f"({expr_str} if {cond_str} else {result})" - return result if result else "0" - class CppPrinter(ExprPrinter): def _print_Integer(self, expr: sympy.Expr) -> str: @@ -345,24 +327,6 @@ def _print_Where(self, expr: sympy.Expr) -> str: ) return f"{c} ? {p} : {q}" - def _print_Piecewise(self, expr: sympy.Expr) -> str: - # Convert Piecewise(expr_cond_pairs) to nested ternary operators - # Piecewise((e1, c1), (e2, c2), ..., (eN, cN)) - # becomes: c1 ? e1 : (c2 ? e2 : (... : eN)) - result: Optional[str] = None - for expr_i, cond_i in reversed(expr.args): - expr_str = self.parenthesize(expr_i, PRECEDENCE["Atom"] - 0.5) - if cond_i == True: # noqa: E712 - # This is the default case - result = expr_str - else: - cond_str = self.parenthesize(cond_i, PRECEDENCE["Atom"] - 0.5) - if result is None: - result = expr_str - else: - result = f"{cond_str} ? {expr_str} : {result}" - return f"({result})" if result else "0" - def _print_ModularIndexing(self, expr: sympy.Expr) -> str: x, div, mod = expr.args x = self.doprint(x) From 5863ba1b2e4de9ea0ae16a663465ec5d3d6f9f52 Mon Sep 17 00:00:00 2001 From: Yuanyuan Chen Date: Wed, 5 Nov 2025 03:03:41 +0000 Subject: [PATCH 047/130] [12/N] Apply ruff UP035 rule (#166929) This PR continues to apply ruff UP035 rule to test code and some remaining torch files. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166929 Approved by: https://github.com/Lucaskabela --- test/distributed/tensor/test_attention.py | 3 ++- test/higher_order_ops/test_local_map.py | 3 ++- test/inductor/test_caching.py | 3 ++- test/inductor/test_fx_fusion.py | 3 ++- test/inductor/test_native_matmul.py | 2 +- test/quantization/fx/test_quantize_fx.py | 3 ++- test/test_matmul_cuda.py | 2 +- torch/_dynamo/eval_frame.py | 3 ++- torch/_dynamo/graph_bytecode_inputs.py | 3 ++- torch/_dynamo/variables/distributed.py | 3 ++- torch/_dynamo/variables/iter.py | 4 ++-- torch/_dynamo/variables/optimizer.py | 3 ++- torch/_dynamo/variables/script_object.py | 4 ++-- torch/_dynamo/variables/sdpa.py | 3 ++- torch/_dynamo/variables/streams.py | 3 ++- torch/_dynamo/variables/torch_function.py | 4 ++-- torch/_functorch/_aot_autograd/aot_autograd_result.py | 3 ++- torch/_inductor/compile_worker/timer.py | 3 ++- torch/_inductor/fx_passes/bucketing.py | 3 ++- torch/_inductor/fx_passes/ddp_fusion.py | 4 ++-- torch/_inductor/fx_passes/fsdp.py | 2 +- torch/_inductor/fx_passes/memory_estimator.py | 2 +- torch/_inductor/fx_passes/mkldnn_fusion.py | 6 +++++- torch/_inductor/fx_passes/overlap_scheduling.py | 4 ++-- torch/_inductor/fx_passes/pad_mm.py | 4 ++-- torch/_inductor/fx_passes/post_grad.py | 3 ++- torch/_inductor/fx_passes/reinplace.py | 4 ++-- torch/_inductor/fx_passes/split_cat.py | 5 ++--- torch/_inductor/kernel/custom_op.py | 3 ++- torch/_inductor/kernel/flex/flex_flash_attention.py | 3 ++- torch/_inductor/runtime/benchmarking.py | 4 ++-- torch/_inductor/runtime/caching/interfaces.py | 6 ++++-- torch/_inductor/runtime/caching/locks.py | 5 +++-- torch/distributed/elastic/multiprocessing/tail_log.py | 3 ++- torch/utils/_cxx_pytree.py | 4 ++-- torch/utils/_debug_mode.py | 3 ++- torch/utils/_pytree.py | 3 ++- 37 files changed, 76 insertions(+), 50 deletions(-) diff --git a/test/distributed/tensor/test_attention.py b/test/distributed/tensor/test_attention.py index eaf3a4042060d..6c3485f9d7025 100644 --- a/test/distributed/tensor/test_attention.py +++ b/test/distributed/tensor/test_attention.py @@ -3,7 +3,8 @@ import itertools import random import unittest -from typing import Any, Callable, ClassVar, Optional +from collections.abc import Callable +from typing import Any, ClassVar, Optional import torch import torch.distributed as dist diff --git a/test/higher_order_ops/test_local_map.py b/test/higher_order_ops/test_local_map.py index 9d2870d3b5fdd..fbb21633260e7 100644 --- a/test/higher_order_ops/test_local_map.py +++ b/test/higher_order_ops/test_local_map.py @@ -4,8 +4,9 @@ import functools import unittest +from collections.abc import Callable from contextlib import contextmanager, ExitStack -from typing import Any, Callable, Optional +from typing import Any, Optional import torch import torch._dynamo diff --git a/test/inductor/test_caching.py b/test/inductor/test_caching.py index bcb66beea700c..aa4c3a1f229f1 100644 --- a/test/inductor/test_caching.py +++ b/test/inductor/test_caching.py @@ -13,7 +13,7 @@ from shutil import rmtree from threading import Lock from time import sleep, time -from typing import Any, Generator, Sequence, TYPE_CHECKING, Union +from typing import Any, TYPE_CHECKING, Union from typing_extensions import TypeVar from unittest.mock import patch @@ -37,6 +37,7 @@ if TYPE_CHECKING: + from collections.abc import Generator, Sequence from pathlib import Path diff --git a/test/inductor/test_fx_fusion.py b/test/inductor/test_fx_fusion.py index ebe98373e622a..63342502d3cd9 100644 --- a/test/inductor/test_fx_fusion.py +++ b/test/inductor/test_fx_fusion.py @@ -1,5 +1,6 @@ # Owner(s): ["module: inductor"] -from typing import Any, Callable +from collections.abc import Callable +from typing import Any import torch from torch._inductor.fx_passes.pre_grad import ( diff --git a/test/inductor/test_native_matmul.py b/test/inductor/test_native_matmul.py index 1870a0e373be0..c37f844e41eae 100644 --- a/test/inductor/test_native_matmul.py +++ b/test/inductor/test_native_matmul.py @@ -1,7 +1,7 @@ # Owner(s): ["module: inductor"] -from typing import Callable +from collections.abc import Callable import torch from torch._dynamo.testing import rand_strided diff --git a/test/quantization/fx/test_quantize_fx.py b/test/quantization/fx/test_quantize_fx.py index cd922d94c60c3..faba2f5edc6a7 100644 --- a/test/quantization/fx/test_quantize_fx.py +++ b/test/quantization/fx/test_quantize_fx.py @@ -204,7 +204,8 @@ import operator import unittest import io -from typing import Callable, Optional +from typing import Optional +from collections.abc import Callable class BinaryOp(torch.nn.Module): def __init__(self, binary_op, ibinary_op, is_inplace, is_scalar): diff --git a/test/test_matmul_cuda.py b/test/test_matmul_cuda.py index 5e54a851812e0..1ba947befd9e7 100644 --- a/test/test_matmul_cuda.py +++ b/test/test_matmul_cuda.py @@ -5,7 +5,7 @@ import unittest from itertools import product from functools import partial -from typing import Callable +from collections.abc import Callable import torch diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index e23e049e3bbb1..222647eeae9ab 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -39,10 +39,11 @@ import unittest import warnings import weakref +from collections.abc import Sized from dataclasses import dataclass from enum import Enum from os.path import dirname, join -from typing import Any, NamedTuple, Optional, Sized, TYPE_CHECKING, Union +from typing import Any, NamedTuple, Optional, TYPE_CHECKING, Union from unittest.mock import patch import sympy diff --git a/torch/_dynamo/graph_bytecode_inputs.py b/torch/_dynamo/graph_bytecode_inputs.py index 979950cf3bd1b..16583b89201ec 100644 --- a/torch/_dynamo/graph_bytecode_inputs.py +++ b/torch/_dynamo/graph_bytecode_inputs.py @@ -1,5 +1,6 @@ import weakref -from typing import Any, Callable +from collections.abc import Callable +from typing import Any from torch._dynamo.source import Source diff --git a/torch/_dynamo/variables/distributed.py b/torch/_dynamo/variables/distributed.py index eb39dd8fa3e07..187055c26cd00 100644 --- a/torch/_dynamo/variables/distributed.py +++ b/torch/_dynamo/variables/distributed.py @@ -20,7 +20,8 @@ import functools import inspect -from typing import Any, Sequence, TYPE_CHECKING +from collections.abc import Sequence +from typing import Any, TYPE_CHECKING import torch from torch.fx.experimental._backward_state import BackwardState diff --git a/torch/_dynamo/variables/iter.py b/torch/_dynamo/variables/iter.py index 5970ba0e1dda7..be765cbbc8bf9 100644 --- a/torch/_dynamo/variables/iter.py +++ b/torch/_dynamo/variables/iter.py @@ -14,8 +14,8 @@ """ import itertools -from collections.abc import Callable -from typing import Any, Sequence, TYPE_CHECKING, Union +from collections.abc import Callable, Sequence +from typing import Any, TYPE_CHECKING, Union from .. import graph_break_hints, polyfills, variables from ..bytecode_transformation import ( diff --git a/torch/_dynamo/variables/optimizer.py b/torch/_dynamo/variables/optimizer.py index 289cebbe8129b..c09cc2163a5f4 100644 --- a/torch/_dynamo/variables/optimizer.py +++ b/torch/_dynamo/variables/optimizer.py @@ -22,7 +22,8 @@ import logging import weakref -from typing import Any, Iterable, Optional, TYPE_CHECKING +from collections.abc import Iterable +from typing import Any, Optional, TYPE_CHECKING import torch from torch._dynamo.variables.tensor import TensorVariable diff --git a/torch/_dynamo/variables/script_object.py b/torch/_dynamo/variables/script_object.py index 85977104977fb..644c269a23a34 100644 --- a/torch/_dynamo/variables/script_object.py +++ b/torch/_dynamo/variables/script_object.py @@ -19,8 +19,8 @@ """ import functools -from collections.abc import Callable -from typing import Any, Iterable, TYPE_CHECKING, TypeVar +from collections.abc import Callable, Iterable +from typing import Any, TYPE_CHECKING, TypeVar from typing_extensions import ParamSpec import torch diff --git a/torch/_dynamo/variables/sdpa.py b/torch/_dynamo/variables/sdpa.py index 75928842cf297..629bf094dc951 100644 --- a/torch/_dynamo/variables/sdpa.py +++ b/torch/_dynamo/variables/sdpa.py @@ -1,5 +1,6 @@ +from collections.abc import Sequence from inspect import getattr_static -from typing import Any, Sequence, TYPE_CHECKING, TypeGuard +from typing import Any, TYPE_CHECKING, TypeGuard from torch._guards import Source from torch.backends.cuda import SDPAParams diff --git a/torch/_dynamo/variables/streams.py b/torch/_dynamo/variables/streams.py index c353181eb8029..fb5dd775bd636 100644 --- a/torch/_dynamo/variables/streams.py +++ b/torch/_dynamo/variables/streams.py @@ -1,5 +1,6 @@ import collections -from typing import Any, Callable, Optional +from collections.abc import Callable +from typing import Any, Optional import torch from torch._dynamo.variables.dicts import ConstDictVariable diff --git a/torch/_dynamo/variables/torch_function.py b/torch/_dynamo/variables/torch_function.py index fa8412146a427..4d0f0b4fae8ab 100644 --- a/torch/_dynamo/variables/torch_function.py +++ b/torch/_dynamo/variables/torch_function.py @@ -29,9 +29,9 @@ import functools import inspect import operator -from collections.abc import Sequence +from collections.abc import Generator, Iterable, Sequence from types import TracebackType -from typing import Any, Generator, Iterable, Optional, TYPE_CHECKING +from typing import Any, Optional, TYPE_CHECKING import torch._C import torch.utils._pytree as pytree diff --git a/torch/_functorch/_aot_autograd/aot_autograd_result.py b/torch/_functorch/_aot_autograd/aot_autograd_result.py index ce01e37f03243..7e608933b34c3 100644 --- a/torch/_functorch/_aot_autograd/aot_autograd_result.py +++ b/torch/_functorch/_aot_autograd/aot_autograd_result.py @@ -22,9 +22,10 @@ import json import logging from abc import ABC, abstractmethod +from collections.abc import Callable from copy import copy from dataclasses import dataclass -from typing import Any, Callable, Generic, Optional, TYPE_CHECKING, TypeVar +from typing import Any, Generic, Optional, TYPE_CHECKING, TypeVar import torch from torch._dynamo.precompile_context import BackendCacheArtifact diff --git a/torch/_inductor/compile_worker/timer.py b/torch/_inductor/compile_worker/timer.py index 7cfeb4217e26b..7c495403b3a55 100644 --- a/torch/_inductor/compile_worker/timer.py +++ b/torch/_inductor/compile_worker/timer.py @@ -1,6 +1,7 @@ +from collections.abc import Callable from threading import Lock, Thread from time import monotonic, sleep -from typing import Callable, Optional, Union +from typing import Optional, Union class Timer: diff --git a/torch/_inductor/fx_passes/bucketing.py b/torch/_inductor/fx_passes/bucketing.py index ab831c96c94ba..29f070564349c 100644 --- a/torch/_inductor/fx_passes/bucketing.py +++ b/torch/_inductor/fx_passes/bucketing.py @@ -2,7 +2,8 @@ import logging import operator from collections import defaultdict -from typing import Any, Callable, Literal, TypeAlias +from collections.abc import Callable +from typing import Any, Literal, TypeAlias import torch import torch.distributed as dist diff --git a/torch/_inductor/fx_passes/ddp_fusion.py b/torch/_inductor/fx_passes/ddp_fusion.py index 8a4de1a604869..44314b912786f 100644 --- a/torch/_inductor/fx_passes/ddp_fusion.py +++ b/torch/_inductor/fx_passes/ddp_fusion.py @@ -4,10 +4,10 @@ import logging import math import operator -from collections.abc import Generator +from collections.abc import Callable, Generator from dataclasses import dataclass from functools import partial -from typing import Any, Callable, cast +from typing import Any, cast import torch import torch.fx as fx diff --git a/torch/_inductor/fx_passes/fsdp.py b/torch/_inductor/fx_passes/fsdp.py index 6b0c2ad2c94a7..1e71c350ed7b6 100644 --- a/torch/_inductor/fx_passes/fsdp.py +++ b/torch/_inductor/fx_passes/fsdp.py @@ -1,5 +1,5 @@ import logging -from typing import Callable +from collections.abc import Callable import torch from torch._inductor.fx_passes.bucketing import ( diff --git a/torch/_inductor/fx_passes/memory_estimator.py b/torch/_inductor/fx_passes/memory_estimator.py index c6b7c51b948e5..e887d4bf62c8e 100644 --- a/torch/_inductor/fx_passes/memory_estimator.py +++ b/torch/_inductor/fx_passes/memory_estimator.py @@ -1,8 +1,8 @@ import itertools import logging from collections import defaultdict +from collections.abc import Callable from dataclasses import dataclass -from typing import Callable import torch import torch.fx as fx diff --git a/torch/_inductor/fx_passes/mkldnn_fusion.py b/torch/_inductor/fx_passes/mkldnn_fusion.py index 70b3a3c355dde..214d3bf02f7f4 100644 --- a/torch/_inductor/fx_passes/mkldnn_fusion.py +++ b/torch/_inductor/fx_passes/mkldnn_fusion.py @@ -2,7 +2,7 @@ import functools import operator from functools import reduce -from typing import Any, Callable +from typing import Any, TYPE_CHECKING import torch from torch._dynamo.utils import counters @@ -35,6 +35,10 @@ ) +if TYPE_CHECKING: + from collections.abc import Callable + + if torch._C._has_mkldnn: aten = torch.ops.aten mkldnn = torch.ops.mkldnn diff --git a/torch/_inductor/fx_passes/overlap_scheduling.py b/torch/_inductor/fx_passes/overlap_scheduling.py index a47aa960e58c5..f383ab63dc261 100644 --- a/torch/_inductor/fx_passes/overlap_scheduling.py +++ b/torch/_inductor/fx_passes/overlap_scheduling.py @@ -4,9 +4,9 @@ import logging import sys from collections import Counter, defaultdict -from collections.abc import Iterable +from collections.abc import Callable, Iterable from dataclasses import dataclass -from typing import Any, Callable +from typing import Any import torch import torch.fx as fx diff --git a/torch/_inductor/fx_passes/pad_mm.py b/torch/_inductor/fx_passes/pad_mm.py index 30768fda9bb72..b511403d4874c 100644 --- a/torch/_inductor/fx_passes/pad_mm.py +++ b/torch/_inductor/fx_passes/pad_mm.py @@ -2,8 +2,8 @@ import itertools import operator import typing -from collections.abc import Sequence -from typing import Any, Callable +from collections.abc import Callable, Sequence +from typing import Any import torch import torch._inductor.runtime.runtime_utils diff --git a/torch/_inductor/fx_passes/post_grad.py b/torch/_inductor/fx_passes/post_grad.py index 7d995adec04ef..91b4e10bf7238 100644 --- a/torch/_inductor/fx_passes/post_grad.py +++ b/torch/_inductor/fx_passes/post_grad.py @@ -5,7 +5,8 @@ import logging import operator from collections import Counter, defaultdict -from typing import Any, Callable, TypeVar +from collections.abc import Callable +from typing import Any, TypeVar from typing_extensions import ParamSpec import torch diff --git a/torch/_inductor/fx_passes/reinplace.py b/torch/_inductor/fx_passes/reinplace.py index 52222f3da8344..e42e8a1139770 100644 --- a/torch/_inductor/fx_passes/reinplace.py +++ b/torch/_inductor/fx_passes/reinplace.py @@ -3,10 +3,10 @@ import logging import operator from collections import defaultdict -from collections.abc import Sequence +from collections.abc import Callable, Sequence from contextlib import nullcontext from dataclasses import dataclass -from typing import Any, Callable, cast +from typing import Any, cast import torch import torch.fx.node diff --git a/torch/_inductor/fx_passes/split_cat.py b/torch/_inductor/fx_passes/split_cat.py index 92e1e6f375f44..0bad4fa7cc635 100644 --- a/torch/_inductor/fx_passes/split_cat.py +++ b/torch/_inductor/fx_passes/split_cat.py @@ -4,9 +4,8 @@ import operator import os from collections import defaultdict -from collections.abc import Sequence -from typing import Any, Callable -from typing_extensions import TypeAlias +from collections.abc import Callable, Sequence +from typing import Any, TypeAlias import torch from torch._dynamo.utils import counters diff --git a/torch/_inductor/kernel/custom_op.py b/torch/_inductor/kernel/custom_op.py index 303110a561b5e..d35309c01d07c 100644 --- a/torch/_inductor/kernel/custom_op.py +++ b/torch/_inductor/kernel/custom_op.py @@ -2,7 +2,8 @@ import functools import logging -from typing import Any, Callable, Optional, Union +from collections.abc import Callable +from typing import Any, Optional, Union import torch from torch._inductor.codegen.subgraph import SubgraphTemplate diff --git a/torch/_inductor/kernel/flex/flex_flash_attention.py b/torch/_inductor/kernel/flex/flex_flash_attention.py index c100df84d5a73..0d3721aa730a4 100644 --- a/torch/_inductor/kernel/flex/flex_flash_attention.py +++ b/torch/_inductor/kernel/flex/flex_flash_attention.py @@ -3,8 +3,9 @@ import functools import importlib +from collections.abc import Callable, Sequence from contextlib import contextmanager -from typing import Any, Callable, Optional, Sequence +from typing import Any, Optional import sympy from sympy import Expr, Integer diff --git a/torch/_inductor/runtime/benchmarking.py b/torch/_inductor/runtime/benchmarking.py index d592a8c8c00f9..d9d92e363879d 100644 --- a/torch/_inductor/runtime/benchmarking.py +++ b/torch/_inductor/runtime/benchmarking.py @@ -5,8 +5,8 @@ from functools import cached_property, wraps from itertools import chain from statistics import median -from typing import Any, Optional, Union -from typing_extensions import Concatenate, ParamSpec, Self, TypeVar +from typing import Any, Concatenate, Optional, Union +from typing_extensions import ParamSpec, Self, TypeVar import torch import torch.utils._pytree as pytree diff --git a/torch/_inductor/runtime/caching/interfaces.py b/torch/_inductor/runtime/caching/interfaces.py index 0758e11134018..03d2957493679 100644 --- a/torch/_inductor/runtime/caching/interfaces.py +++ b/torch/_inductor/runtime/caching/interfaces.py @@ -12,8 +12,8 @@ from pathlib import Path from threading import Lock from time import time -from typing import Any, Callable, TYPE_CHECKING -from typing_extensions import override, TypeAlias +from typing import Any, TYPE_CHECKING, TypeAlias +from typing_extensions import override from filelock import FileLock @@ -21,6 +21,8 @@ if TYPE_CHECKING: + from collections.abc import Callable + from .utils import P, R diff --git a/torch/_inductor/runtime/caching/locks.py b/torch/_inductor/runtime/caching/locks.py index e7e1f1adc3622..8e8cd011e2d44 100644 --- a/torch/_inductor/runtime/caching/locks.py +++ b/torch/_inductor/runtime/caching/locks.py @@ -12,8 +12,8 @@ from __future__ import annotations from contextlib import _GeneratorContextManager, contextmanager, ExitStack -from typing import Generator, TYPE_CHECKING -from typing_extensions import Protocol, TypeAlias +from typing import TYPE_CHECKING, TypeAlias +from typing_extensions import Protocol from filelock import FileLock, Timeout @@ -21,6 +21,7 @@ if TYPE_CHECKING: + from collections.abc import Generator from threading import Lock diff --git a/torch/distributed/elastic/multiprocessing/tail_log.py b/torch/distributed/elastic/multiprocessing/tail_log.py index 7ad35115cd34a..034740810dcdd 100644 --- a/torch/distributed/elastic/multiprocessing/tail_log.py +++ b/torch/distributed/elastic/multiprocessing/tail_log.py @@ -10,9 +10,10 @@ import logging import os import time +from collections.abc import Callable from concurrent.futures.thread import ThreadPoolExecutor from threading import Event -from typing import Callable, Optional, TextIO, TYPE_CHECKING, Union +from typing import Optional, TextIO, TYPE_CHECKING, Union if TYPE_CHECKING: diff --git a/torch/utils/_cxx_pytree.py b/torch/utils/_cxx_pytree.py index 603625ed97c12..897279bd39b1e 100644 --- a/torch/utils/_cxx_pytree.py +++ b/torch/utils/_cxx_pytree.py @@ -15,8 +15,8 @@ import functools import types from collections.abc import Callable, Iterable, Mapping -from typing import Any, Optional, overload, TypeVar, Union -from typing_extensions import deprecated, Self, TypeAlias, TypeIs +from typing import Any, Optional, overload, TypeAlias, TypeVar, Union +from typing_extensions import deprecated, Self, TypeIs import torch.utils._pytree as python_pytree from torch.torch_version import TorchVersion as _TorchVersion diff --git a/torch/utils/_debug_mode.py b/torch/utils/_debug_mode.py index 5e24ce086e1aa..5a6ee246abf7e 100644 --- a/torch/utils/_debug_mode.py +++ b/torch/utils/_debug_mode.py @@ -3,7 +3,8 @@ import functools import traceback import weakref -from typing import Any, Callable, Optional, TYPE_CHECKING +from collections.abc import Callable +from typing import Any, Optional, TYPE_CHECKING import torch from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode diff --git a/torch/utils/_pytree.py b/torch/utils/_pytree.py index 56704bb3f8024..147340f58d66e 100644 --- a/torch/utils/_pytree.py +++ b/torch/utils/_pytree.py @@ -36,10 +36,11 @@ Optional, overload, Protocol, + TypeAlias, TypeVar, Union, ) -from typing_extensions import deprecated, NamedTuple, Self, TypeAlias +from typing_extensions import deprecated, NamedTuple, Self from torch.torch_version import TorchVersion as _TorchVersion From 56fc99915b7e5c653c30460052644204caaddbcf Mon Sep 17 00:00:00 2001 From: Michael Klamkin Date: Wed, 5 Nov 2025 03:05:04 +0000 Subject: [PATCH 048/130] Fix typos in complex numbers docs (#166671) This PR fixes two small typos in the complex numbers docs: 1. "numbercial" -> "numerical" 2. "easily to switch" -> "easily switch to" Pull Request resolved: https://github.com/pytorch/pytorch/pull/166671 Approved by: https://github.com/jcaip, https://github.com/Arpitha781, https://github.com/mlazos, https://github.com/cyyever --- docs/source/complex_numbers.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/complex_numbers.md b/docs/source/complex_numbers.md index 610f9a06615a1..095401879f09b 100644 --- a/docs/source/complex_numbers.md +++ b/docs/source/complex_numbers.md @@ -45,7 +45,7 @@ supported for complex tensors. ## Transition from the old representation Users who currently worked around the lack of complex tensors with real tensors of shape {math}`(..., 2)` -can easily to switch using the complex tensors in their code using {func}`torch.view_as_complex` +can easily switch to using the complex tensors in their code using {func}`torch.view_as_complex` and {func}`torch.view_as_real`. Note that these functions don’t perform any copy and return a view of the input tensor. @@ -140,7 +140,7 @@ through the same optimizer on the {func}`torch.view_as_real` equivalent of the c `real_optim` and `complex_optim` will compute the same updates on the parameters, though there may be slight numerical discrepancies between the two optimizers, similar to numerical discrepancies between foreach vs forloop optimizers -and capturable vs default optimizers. For more details, see [numbercial accuracy](https://pytorch.org/docs/stable/notes/numerical_accuracy.html). +and capturable vs default optimizers. For more details, see [numerical accuracy](https://pytorch.org/docs/stable/notes/numerical_accuracy.html). Specifically, while you can think of our optimizer's handling of complex tensors as the same as optimizing over their `p.real` and `p.imag` pieces separately, the implementation details are not precisely that. Note that the From 08ef852a4b0f8cab0d35c30be33dfde812bfc6d8 Mon Sep 17 00:00:00 2001 From: Michael Lee Date: Wed, 5 Nov 2025 03:09:52 +0000 Subject: [PATCH 049/130] [unified v2][apple] Clean up `APPLETVOS` from caffe2 (#166953) Summary: This is not used, so delete it Test Plan: ``` $ buck targets xplat/... > /dev/null ``` Reviewed By: dtolnay Differential Revision: D86125712 Pull Request resolved: https://github.com/pytorch/pytorch/pull/166953 Approved by: https://github.com/seemethere --- .../quantized/cpu/qnnpack/buckbuild.bzl | 20 +-- buckbuild.bzl | 4 +- third_party/xnnpack.buck.bzl | 114 +++++++++--------- 3 files changed, 69 insertions(+), 69 deletions(-) diff --git a/aten/src/ATen/native/quantized/cpu/qnnpack/buckbuild.bzl b/aten/src/ATen/native/quantized/cpu/qnnpack/buckbuild.bzl index 180442b4b09a4..fecce634ec08c 100644 --- a/aten/src/ATen/native/quantized/cpu/qnnpack/buckbuild.bzl +++ b/aten/src/ATen/native/quantized/cpu/qnnpack/buckbuild.bzl @@ -1,7 +1,7 @@ load("//tools/build_defs:fb_xplat_cxx_library.bzl", "fb_xplat_cxx_library") load("//tools/build_defs:fb_xplat_cxx_test.bzl", "fb_xplat_cxx_test") load("//tools/build_defs:glob_defs.bzl", "subdir_glob") -load("//tools/build_defs:platform_defs.bzl", "ANDROID", "APPLE", "APPLETVOS", "CXX", "IOS", "MACOSX") +load("//tools/build_defs:platform_defs.bzl", "ANDROID", "APPLE", "CXX", "IOS", "MACOSX") # Shared by internal and OSS BUCK def define_qnnpack(third_party, labels = []): @@ -21,7 +21,7 @@ def define_qnnpack(third_party, labels = []): ("src", "requantization/*.h"), ]), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", "-DPYTORCH_QNNPACK_RUNTIME_QUANTIZATION", @@ -82,7 +82,7 @@ def define_qnnpack(third_party, labels = []): ("src", "requantization/*.h"), ]), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O3", "-ffast-math", @@ -129,7 +129,7 @@ def define_qnnpack(third_party, labels = []): ("src", "requantization/*.h"), ]), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O3", "-ffast-math", @@ -184,7 +184,7 @@ def define_qnnpack(third_party, labels = []): ("src", "requantization/*.h"), ]), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O3", "-ffast-math", @@ -236,7 +236,7 @@ def define_qnnpack(third_party, labels = []): ], ), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-DPYTORCH_QNNPACK_RUNTIME_QUANTIZATION", ], @@ -291,7 +291,7 @@ def define_qnnpack(third_party, labels = []): ("src", "qnnpack/*.h"), ("include", "*.h"), ]), - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", "-DPYTORCH_QNNPACK_RUNTIME_QUANTIZATION", @@ -398,7 +398,7 @@ def define_qnnpack(third_party, labels = []): ("src", "requantization/*.h"), ]), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O3", "-ffast-math", @@ -465,7 +465,7 @@ def define_qnnpack(third_party, labels = []): ("src", "requantization/*.h"), ]), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-DPYTORCH_QNNPACK_RUNTIME_QUANTIZATION", "-Wno-unused-command-line-argument", @@ -525,7 +525,7 @@ def define_qnnpack(third_party, labels = []): ("src", "qnnpack/*.h"), ]), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O3", "-ffast-math", diff --git a/buckbuild.bzl b/buckbuild.bzl index 4c1affd10e1bc..9f18ad4849dde 100644 --- a/buckbuild.bzl +++ b/buckbuild.bzl @@ -8,7 +8,7 @@ load("//tools/build_defs:fb_xplat_genrule.bzl", "fb_xplat_genrule") load("//tools/build_defs/windows:windows_flag_map.bzl", "windows_convert_gcc_clang_flags") load("//tools/build_defs:fbsource_utils.bzl", "is_arvr_mode") load("//tools/build_defs:glob_defs.bzl", "subdir_glob") -load("//tools/build_defs:platform_defs.bzl", "APPLETVOS", "IOS", "MACOSX") +load("//tools/build_defs:platform_defs.bzl", "IOS", "MACOSX") load("//tools/build_defs:type_defs.bzl", "is_list", "is_string") load("//tools/build_defs/android:build_mode_defs.bzl", is_production_build_android = "is_production_build") load("//tools/build_defs/apple:build_mode_defs.bzl", is_production_build_ios = "is_production_build", is_profile_build_ios = "is_profile_build") @@ -1090,7 +1090,7 @@ def define_buck_targets( srcs = [ "caffe2/core/common.cc", ], - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = get_pt_compiler_flags(), labels = labels, # @lint-ignore BUCKLINT link_whole diff --git a/third_party/xnnpack.buck.bzl b/third_party/xnnpack.buck.bzl index b353d5d0d5982..217cc8db68864 100644 --- a/third_party/xnnpack.buck.bzl +++ b/third_party/xnnpack.buck.bzl @@ -1,7 +1,7 @@ load("//tools/build_defs:fb_xplat_cxx_library.bzl", "fb_xplat_cxx_library") load("//tools/build_defs:fbsource_utils.bzl", "is_arvr_mode") load("//tools/build_defs:glob_defs.bzl", "subdir_glob") -load("//tools/build_defs:platform_defs.bzl", "ANDROID", "APPLE", "APPLETVOS", "CXX", "IOS", "MACOSX", "WINDOWS") +load("//tools/build_defs:platform_defs.bzl", "ANDROID", "APPLE", "CXX", "IOS", "MACOSX", "WINDOWS") load( "@fbsource//xplat/caffe2/third_party:xnnpack_buck_shim.bzl", "LOGGING_SRCS", @@ -55,7 +55,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F exported_headers = { "xnnpack.h": "XNNPACK/include/xnnpack.h", }, - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), labels = labels, preprocessor_flags = XNN_COMMON_PREPROCESSOR_FLAGS, visibility = ["PUBLIC"], @@ -70,7 +70,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F srcs = SUBGRAPH_SRCS + ["XNNPACK/src/datatype.c"], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ], @@ -97,7 +97,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F srcs = TABLE_SRCS, headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ], @@ -121,7 +121,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F srcs = prod_srcs_for_arch_wrapper("scalar"), headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", "-fno-fast-math", @@ -147,7 +147,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ], @@ -179,7 +179,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_sse_ovr_win32", headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ], @@ -211,7 +211,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ], @@ -243,7 +243,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_sse2_ovr_win32", headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ], @@ -275,7 +275,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ], @@ -307,7 +307,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_ssse3_ovr_win32", headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ], @@ -339,7 +339,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ], @@ -371,7 +371,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_sse41_ovr_win32", headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ], @@ -403,7 +403,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ] + select({ @@ -443,7 +443,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_avx_ovr_win32", headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", "-mavx", @@ -476,7 +476,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ] + select({ @@ -531,7 +531,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_avx512vnnigfni_ovr_win32", headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ], @@ -568,7 +568,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ] + select({ @@ -625,7 +625,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_avx512vnni_ovr_win32", headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ], @@ -660,7 +660,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F srcs = prod_srcs_for_arch_wrapper("avxvnni") if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", "-mavxvnni", @@ -697,7 +697,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_avxvnni_ovr_win32", headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ], @@ -729,7 +729,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ] + select({ @@ -770,7 +770,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_f16c_ovr_win32", headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", "-mf16c", @@ -804,7 +804,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ] + select({ @@ -853,7 +853,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_fma3_ovr_win32", headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", "-mfma", @@ -894,7 +894,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ] + select({ @@ -948,7 +948,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_avx2_ovr_win32", headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", "-mavx2", @@ -994,7 +994,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ] + select({ @@ -1039,7 +1039,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ] + select({ @@ -1108,7 +1108,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_avx512_ovr_win32", headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", "-mavx512f", @@ -1141,7 +1141,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ] + select({ @@ -1206,7 +1206,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F name = "ukernels_avx512skx_ovr_win32", headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", "-mavx512f", @@ -1259,7 +1259,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", "-fno-fast-math", @@ -1301,7 +1301,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ] + select({ @@ -1350,7 +1350,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ], @@ -1378,7 +1378,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ] + select({ @@ -1430,7 +1430,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ], @@ -1460,7 +1460,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", "-Wno-error=missing-braces", # required since the SGX toolchain does not have this by default @@ -1532,7 +1532,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ] + select({ @@ -1582,7 +1582,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ] + select({ @@ -1645,7 +1645,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ] + select({ @@ -1690,7 +1690,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ] + select({ @@ -1729,7 +1729,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ] + select({ @@ -1774,7 +1774,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ] + select({ @@ -1815,7 +1815,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ] + select({ @@ -1860,7 +1860,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ] + select({ @@ -1900,7 +1900,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ] + select({ @@ -1959,7 +1959,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F }) if is_arvr_mode() else [], headers = get_xnnpack_headers(), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ] + select({ @@ -2004,7 +2004,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ("XNNPACK/src", "**/*.S"), ]), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ] + select({ @@ -2053,7 +2053,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F ("XNNPACK/src", "**/*.S"), ]), header_namespace = "", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), compiler_flags = [ "-O2", ] + select({ @@ -2088,7 +2088,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F fb_xplat_cxx_library( name = "arm64_lib", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), labels = labels, fbandroid_link_whole = True, preferred_linkage = "static", @@ -2114,7 +2114,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F fb_xplat_cxx_library( name = "x86_and_x86_64_lib", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), labels = labels, preferred_linkage = "static", visibility = ["PUBLIC"], @@ -2138,7 +2138,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F fb_xplat_cxx_library( name = "x86_and_x86_64_lib_ovr_win32", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), labels = labels, preferred_linkage = "static", visibility = ["PUBLIC"], @@ -2165,7 +2165,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F fb_xplat_cxx_library( name = "arm_lib", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), labels = labels, preferred_linkage = "static", visibility = ["PUBLIC"], @@ -2193,7 +2193,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F fb_xplat_cxx_library( name = "armv7_lib", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), labels = labels, fbandroid_link_whole = True, preferred_linkage = "static", @@ -2209,7 +2209,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F fb_xplat_cxx_library( name = "prod_ukernels", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), labels = labels, fbandroid_link_whole = True, preferred_linkage = "static", @@ -2234,7 +2234,7 @@ def define_xnnpack(third_party, labels = [], XNNPACK_WINDOWS_AVX512F_ENABLED = F fb_xplat_cxx_library( name = "XNNPACK", - apple_sdks = (IOS, MACOSX, APPLETVOS), + apple_sdks = (IOS, MACOSX), labels = labels, deps = [ ":tables", From 066c5c57a97ca1876e58040caa7a23b4d3d00065 Mon Sep 17 00:00:00 2001 From: "Sv. Lockal" Date: Wed, 5 Nov 2025 04:13:57 +0000 Subject: [PATCH 050/130] Fix typo in gloo_hip library name (#166502) The typo was never noticed; conditions to enable it require system gloo: `-DUSE_SYSTEM_GLOO=ON -DUSE_GLOO=ON -DUSE_DISTRIBUTED=ON -DUSE_ROCM=ON`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166502 Approved by: https://github.com/jerryzh168, https://github.com/cyyever --- cmake/Modules/FindGloo.cmake | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmake/Modules/FindGloo.cmake b/cmake/Modules/FindGloo.cmake index 944cd4d8d2573..0bdfe275d9c06 100644 --- a/cmake/Modules/FindGloo.cmake +++ b/cmake/Modules/FindGloo.cmake @@ -26,7 +26,7 @@ find_library(Gloo_CUDA_LIBRARY # if Gloo + HIP is desired, Gloo_HIP_LIBRARY # needs to be linked to desired target find_library(Gloo_HIP_LIBRARY - NAMES gloo_hiop + NAMES gloo_hip DOC "Gloo's HIP support/code" ) From 14956eaef4a14901a95a6d0779d99db11fd7406b Mon Sep 17 00:00:00 2001 From: Jeff Daily Date: Wed, 5 Nov 2025 04:18:04 +0000 Subject: [PATCH 051/130] [ROCm][CI] revert ROCm magma commit hash to last known good (#167044) PR https://github.com/pytorch/pytorch/pull/166693 updated the magma commit hash but this has been linked to ROCm 7.1 CI failures. Go back to last known working magma version. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167044 Approved by: https://github.com/jeffdaily Co-authored-by: Jeff Daily --- .ci/magma-rocm/build_magma.sh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.ci/magma-rocm/build_magma.sh b/.ci/magma-rocm/build_magma.sh index 7d95fed873dc0..c7c7780227ea5 100755 --- a/.ci/magma-rocm/build_magma.sh +++ b/.ci/magma-rocm/build_magma.sh @@ -6,8 +6,8 @@ set -eou pipefail # The script expects DESIRED_CUDA and PACKAGE_NAME to be set ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)" -# post merge of https://github.com/icl-utk-edu/magma/pull/65 -MAGMA_VERSION=c0792ae825fb36872784892ea643dd6f3456bc5f +# https://github.com/icl-utk-edu/magma/pull/65 +MAGMA_VERSION=d6e4117bc88e73f06d26c6c2e14f064e8fc3d1ec # Folders for the build PACKAGE_FILES=${ROOT_DIR}/magma-rocm/package_files # metadata @@ -20,7 +20,7 @@ mkdir -p ${PACKAGE_DIR} ${PACKAGE_OUTPUT}/linux-64 ${PACKAGE_BUILD} ${PACKAGE_RE # Fetch magma sources and verify checksum pushd ${PACKAGE_DIR} -git clone https://github.com/icl-utk-edu/magma +git clone https://github.com/jeffdaily/magma pushd magma git checkout ${MAGMA_VERSION} popd From 9ffc480c5a928eaccb4ac0e1755a1c596674d884 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Tue, 4 Nov 2025 06:46:06 -0800 Subject: [PATCH 052/130] Add min/max support for barebones uint types (#166813) Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/166813 Approved by: https://github.com/Skylion007 --- aten/src/ATen/cuda/NumericLimits.cuh | 31 +++++++++++++++++++ .../ATen/native/cpu/ReduceAllOpsKernel.cpp | 13 ++++---- aten/src/ATen/native/cpu/ReduceOpsKernel.cpp | 28 +++++++++-------- .../ATen/native/cpu/TensorCompareKernel.cpp | 13 ++++---- .../ATen/native/cuda/ReduceAMinMaxKernel.cu | 13 ++++---- .../ATen/native/cuda/ReduceMaxValuesKernel.cu | 17 +++++----- .../ATen/native/cuda/ReduceMinValuesKernel.cu | 13 ++++---- .../_internal/common_methods_invocations.py | 14 ++++----- 8 files changed, 90 insertions(+), 52 deletions(-) diff --git a/aten/src/ATen/cuda/NumericLimits.cuh b/aten/src/ATen/cuda/NumericLimits.cuh index 7081e94837caa..ebbc004382380 100644 --- a/aten/src/ATen/cuda/NumericLimits.cuh +++ b/aten/src/ATen/cuda/NumericLimits.cuh @@ -55,6 +55,14 @@ struct numeric_limits { static inline __host__ __device__ int8_t upper_bound() { return INT8_MAX; } }; +template <> +struct numeric_limits { + static inline __host__ __device__ uint16_t lowest() { return 0; } + static inline __host__ __device__ uint16_t max() { return UINT16_MAX; } + static inline __host__ __device__ uint16_t lower_bound() { return 0; } + static inline __host__ __device__ uint16_t upper_bound() { return UINT16_MAX; } +}; + template <> struct numeric_limits { static inline __host__ __device__ int16_t lowest() { return INT16_MIN; } @@ -63,6 +71,14 @@ struct numeric_limits { static inline __host__ __device__ int16_t upper_bound() { return INT16_MAX; } }; +template <> +struct numeric_limits { + static inline __host__ __device__ uint32_t lowest() { return 0; } + static inline __host__ __device__ uint32_t max() { return UINT32_MAX; } + static inline __host__ __device__ uint32_t lower_bound() { return 0; } + static inline __host__ __device__ uint32_t upper_bound() { return UINT32_MAX; } +}; + template <> struct numeric_limits { static inline __host__ __device__ int32_t lowest() { return INT32_MIN; } @@ -71,6 +87,21 @@ struct numeric_limits { static inline __host__ __device__ int32_t upper_bound() { return INT32_MAX; } }; +template <> +struct numeric_limits { +#ifdef _MSC_VER + static inline __host__ __device__ uint64_t lowest() { return 0; } + static inline __host__ __device__ uint64_t max() { return _UI64_MAX; } + static inline __host__ __device__ uint64_t lower_bound() { return 0; } + static inline __host__ __device__ uint64_t upper_bound() { return _UI64_MAX; } +#else + static inline __host__ __device__ uint64_t lowest() { return 0; } + static inline __host__ __device__ uint64_t max() { return UINT64_MAX; } + static inline __host__ __device__ uint64_t lower_bound() { return 0; } + static inline __host__ __device__ uint64_t upper_bound() { return UINT64_MAX; } +#endif +}; + template <> struct numeric_limits { #ifdef _MSC_VER diff --git a/aten/src/ATen/native/cpu/ReduceAllOpsKernel.cpp b/aten/src/ATen/native/cpu/ReduceAllOpsKernel.cpp index c7eaa802af125..c5dbf05039eb1 100644 --- a/aten/src/ATen/native/cpu/ReduceAllOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/ReduceAllOpsKernel.cpp @@ -5,6 +5,7 @@ #include #include +#include #include #include #include @@ -78,12 +79,12 @@ void min_all_kernel_impl(Tensor& result, const Tensor& input) { reduce_all_impl(result, input, upper_bound(), [=](int64_t a, int64_t b) -> int64_t { return min_impl(a, b); }); } else { - AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "min_all", [&] { + AT_DISPATCH_V2(input.scalar_type(), "min_all", AT_WRAP([&] { using Vec = Vectorized>; reduce_all_impl_vec(result, input, upper_bound(), [=] (scalar_t a , scalar_t b) -> scalar_t { return min_impl(a, b); }, [=](Vec a, Vec b) -> Vec { return minimum(a, b); }); - }); + }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kHalf, kBFloat16); } } @@ -103,12 +104,12 @@ void max_all_kernel_impl(Tensor& result, const Tensor& input) { reduce_all_impl(result, input, lower_bound(), [=](int64_t a, int64_t b) -> int64_t { return max_impl(a, b); }); } else { - AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, input.scalar_type(), "max_all", [&] { + AT_DISPATCH_V2(input.scalar_type(), "max_all", AT_WRAP([&] { using Vec = Vectorized>; reduce_all_impl_vec(result, input, lower_bound(), [=] (scalar_t a , scalar_t b) -> scalar_t { return max_impl(a, b); }, [=](Vec a, Vec b) -> Vec { return maximum(a, b); }); - }); + }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kHalf, kBFloat16); } } @@ -199,7 +200,7 @@ void aminmax_allreduce_kernel( } ); } else { - AT_DISPATCH_ALL_TYPES_AND2(kBFloat16, kHalf, input.scalar_type(), "aminmax_cpu", [&] { + AT_DISPATCH_V2(input.scalar_type(), "aminmax_cpu", AT_WRAP([&] { using Vec = Vectorized>; using scalar_t_pair = std::pair; reduce_all_impl_vec_two_outputs( @@ -214,7 +215,7 @@ void aminmax_allreduce_kernel( [=](Vec a, Vec b) -> Vec { return minimum(a, b); }, [=](Vec a, Vec b) -> Vec { return maximum(a, b); } ); - }); + }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf); } } diff --git a/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp b/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp index 2e62936501948..3bad49a32d98c 100644 --- a/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp @@ -3,6 +3,7 @@ #include #include +#include #include #include #include @@ -347,34 +348,35 @@ struct MinValuesOps: public at::native::MinOps { }; void min_values_kernel_impl(TensorIterator& iter) { - if (iter.dtype() == kLong) { - // This case is special because of Vectorized does not - // handle upper_bound(). - // See: https://github.com/pytorch/pytorch/issues/43254 - using scalar_t = int64_t; - binary_kernel_reduce( - iter, - MinValuesOps{}, - std::pair(upper_bound(), -1)); + // This case is special because of Vectorized does not + // handle upper_bound(). + // See: https://github.com/pytorch/pytorch/issues/43254 + if (iter.dtype() == kLong || iter.dtype() == kUInt64) { + AT_DISPATCH_V2(iter.dtype(), "min_values_cpu", AT_WRAP([&iter] { + binary_kernel_reduce( + iter, + MinValuesOps{}, + std::pair(upper_bound(), -1)); + }), kLong, kUInt64); return; } - AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.dtype(), "min_values_cpu", [&iter] { + AT_DISPATCH_V2(iter.dtype(), "min_values_cpu", AT_WRAP([&iter] { binary_kernel_reduce_vec( iter, [](scalar_t a, scalar_t b) -> scalar_t { return min_impl(a, b); }, [](Vectorized a, Vectorized b) { return minimum(a, b); }, static_cast(upper_bound())); - }); + }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool); } void max_values_kernel_impl(TensorIterator& iter) { - AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.dtype(), "max_values_cpu", [&iter] { + AT_DISPATCH_V2(iter.dtype(), "max_values_cpu", AT_WRAP([&iter] { binary_kernel_reduce_vec( iter, [](scalar_t a, scalar_t b) -> scalar_t { return max_impl(a, b); }, [](Vectorized a, Vectorized b) { return maximum(a, b); }, lower_bound()); - }); + }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool); } void argmax_kernel_impl(TensorIterator &iter) { diff --git a/aten/src/ATen/native/cpu/TensorCompareKernel.cpp b/aten/src/ATen/native/cpu/TensorCompareKernel.cpp index c479e1610cbeb..22c85735ad6ab 100644 --- a/aten/src/ATen/native/cpu/TensorCompareKernel.cpp +++ b/aten/src/ATen/native/cpu/TensorCompareKernel.cpp @@ -11,6 +11,7 @@ #include #include +#include #include #include #include @@ -106,7 +107,7 @@ void min_kernel_impl( bool keepdim) { int64_t self_dim_size = ensure_nonempty_size(self, dim); - AT_DISPATCH_ALL_TYPES_AND3(ScalarType::Half, ScalarType::BFloat16, ScalarType::Bool, self.scalar_type(), "min_cpu", [&] { + AT_DISPATCH_V2(self.scalar_type(), "min_cpu", AT_WRAP([&] { compare_base_kernel(result, indice, self, dim, keepdim, [&] ( scalar_t* result_data, int64_t* indice_data, const scalar_t* self_data, auto self_dim_stride) { @@ -128,7 +129,7 @@ void min_kernel_impl( *indice_data = index; } ); - }); + }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), ScalarType::Half, ScalarType::BFloat16, ScalarType::Bool); } void max_kernel_impl( @@ -139,7 +140,7 @@ void max_kernel_impl( bool keepdim) { int64_t self_dim_size = ensure_nonempty_size(self, dim); - AT_DISPATCH_ALL_TYPES_AND3(ScalarType::Half, ScalarType::BFloat16, ScalarType::Bool, self.scalar_type(), "max_cpu", [&] { + AT_DISPATCH_V2(self.scalar_type(), "max_cpu", AT_WRAP([&] { compare_base_kernel(result, indice, self, dim, keepdim, [&] ( scalar_t* result_data, int64_t* indice_data, const scalar_t* self_data, auto self_dim_stride) { @@ -161,7 +162,7 @@ void max_kernel_impl( *indice_data = index; } ); - }); + }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), ScalarType::Half, ScalarType::BFloat16, ScalarType::Bool); } void aminmax_kernel( @@ -186,7 +187,7 @@ void aminmax_kernel( return; } - AT_DISPATCH_ALL_TYPES_AND3(ScalarType::Bool, ScalarType::BFloat16, ScalarType::Half, self.scalar_type(), "aminmax_cpu", [&] { + AT_DISPATCH_V2(self.scalar_type(), "aminmax_cpu", AT_WRAP([&] { compare_base_kernel(min_result, max_result, self, wrap_dim, keepdim, [&] ( scalar_t* min_result_data, scalar_t* max_result_data, const scalar_t* self_data, auto self_dim_stride) { @@ -209,7 +210,7 @@ void aminmax_kernel( *max_result_data = max_number; } ); - }); + }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), ScalarType::Bool, ScalarType::BFloat16, ScalarType::Half); } void where_kernel_impl(TensorIterator &iter) { diff --git a/aten/src/ATen/native/cuda/ReduceAMinMaxKernel.cu b/aten/src/ATen/native/cuda/ReduceAMinMaxKernel.cu index cdd5daab2d983..0b7823863047a 100644 --- a/aten/src/ATen/native/cuda/ReduceAMinMaxKernel.cu +++ b/aten/src/ATen/native/cuda/ReduceAMinMaxKernel.cu @@ -1,5 +1,6 @@ #define TORCH_ASSERT_NO_OPERATORS #include +#include #include #include #include @@ -28,22 +29,22 @@ void _min_max_values_kernel_cuda_impl(TensorIterator& iter) { } void aminmax_allreduce_launch_kernel(TensorIterator& iter) { - AT_DISPATCH_ALL_TYPES_AND3( - kBFloat16, kHalf, kBool, iter.input_dtype(), "aminmax_all_cuda", [&] { + AT_DISPATCH_V2( + iter.input_dtype(), "aminmax_all_cuda", AT_WRAP([&] { _min_max_values_kernel_cuda_impl(iter); - }); + }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool); } void aminmax_launch_kernel(TensorIterator& iter) { - AT_DISPATCH_ALL_TYPES_AND3( - kBFloat16, kHalf, kBool, iter.input_dtype(), "aminmax_cuda", [&]() { + AT_DISPATCH_V2( + iter.input_dtype(), "aminmax_cuda", AT_WRAP([&]() { gpu_reduce_kernel( iter, MinMaxOps{}, thrust::pair( at::numeric_limits::upper_bound(), at::numeric_limits::lower_bound())); - }); + }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool); } } // namespace at::native diff --git a/aten/src/ATen/native/cuda/ReduceMaxValuesKernel.cu b/aten/src/ATen/native/cuda/ReduceMaxValuesKernel.cu index e8d1e88ebb3ec..bcbc4c0359943 100644 --- a/aten/src/ATen/native/cuda/ReduceMaxValuesKernel.cu +++ b/aten/src/ATen/native/cuda/ReduceMaxValuesKernel.cu @@ -1,5 +1,6 @@ #define TORCH_ASSERT_NO_OPERATORS #include +#include #include #include #include @@ -33,27 +34,27 @@ void max_values_kernel_cuda_impl(TensorIterator& iter) { } void max_values_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_ALL_TYPES_AND3( - kBFloat16, kHalf, kBool, iter.dtype(), "max_values_cuda", [&]() { + AT_DISPATCH_V2( + iter.dtype(), "max_values_cuda", AT_WRAP([&]() { max_values_kernel_cuda_impl(iter); - }); + }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool); } void max_launch_kernel(TensorIterator& iter) { - AT_DISPATCH_ALL_TYPES_AND3( - kBFloat16, kHalf, kBool, iter.input_dtype(), "max_cuda", [&]() { + AT_DISPATCH_V2( + iter.input_dtype(), "max_cuda", AT_WRAP([&]() { gpu_reduce_kernel( iter, MaxOps{}, thrust::pair( at::numeric_limits::lower_bound(), 0)); - }); + }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool); } void max_all_launch_kernel(TensorIterator &iter) { - AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.input_dtype(), "max_all_cuda", [&] { + AT_DISPATCH_V2(iter.input_dtype(), "max_all_cuda", AT_WRAP([&] { max_values_kernel_cuda_impl(iter); - }); + }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool); } REGISTER_DISPATCH(max_values_stub, &max_values_kernel_cuda) diff --git a/aten/src/ATen/native/cuda/ReduceMinValuesKernel.cu b/aten/src/ATen/native/cuda/ReduceMinValuesKernel.cu index e01ca6c88ebc8..0006a24dbc466 100644 --- a/aten/src/ATen/native/cuda/ReduceMinValuesKernel.cu +++ b/aten/src/ATen/native/cuda/ReduceMinValuesKernel.cu @@ -12,6 +12,7 @@ #include #include +#include #include #include @@ -33,24 +34,24 @@ void min_values_kernel_cuda_impl(TensorIterator& iter) { } void min_values_kernel_cuda(TensorIterator& iter) { - AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.dtype(), "min_values_cuda", [&]() { + AT_DISPATCH_V2(iter.dtype(), "min_values_cuda", AT_WRAP([&]() { min_values_kernel_cuda_impl(iter); - }); + }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool); } void min_launch_kernel(TensorIterator &iter) { - AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.input_dtype(), "min_cuda", [&]() { + AT_DISPATCH_V2(iter.input_dtype(), "min_cuda", AT_WRAP([&]() { gpu_reduce_kernel( iter, MinOps{}, thrust::pair(at::numeric_limits::upper_bound(), 0)); - }); + }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool); } void min_all_launch_kernel(TensorIterator &iter) { - AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.input_dtype(), "min_all_cuda", [&] { + AT_DISPATCH_V2(iter.input_dtype(), "min_all_cuda", AT_WRAP([&] { min_values_kernel_cuda_impl(iter); - }); + }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool); } REGISTER_DISPATCH(min_values_stub, &min_values_kernel_cuda) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 92f212a3c650e..0413c9bf6b6e0 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -14311,7 +14311,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): )), OpInfo('max', variant_test_name='reduction_with_dim', - dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool, torch.uint16, torch.uint32, torch.uint64), dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), sample_inputs_func=sample_inputs_max_min_reduction_with_dim, supports_fwgrad_bwgrad=True, @@ -14320,7 +14320,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): supports_forward_ad=True), OpInfo('max', variant_test_name='reduction_no_dim', - dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool, torch.uint16, torch.uint32, torch.uint64), dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), supports_out=True, supports_forward_ad=True, @@ -14465,7 +14465,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): check_batched_forward_grad=False,), OpInfo('min', variant_test_name='reduction_with_dim', - dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool, torch.uint16, torch.uint32, torch.uint64), dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32), sample_inputs_func=sample_inputs_max_min_reduction_with_dim, supports_fwgrad_bwgrad=True, @@ -14474,7 +14474,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): )), OpInfo('min', variant_test_name='reduction_no_dim', - dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool, torch.uint16, torch.uint32, torch.uint64), supports_out=True, supports_forward_ad=True, supports_fwgrad_bwgrad=True, @@ -14784,7 +14784,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): supports_fwgrad_bwgrad=True), OpInfo('aminmax', ref=lambda x, dim=None, keepdim=False: (np.amin(x, axis=dim, keepdims=keepdim), np.amax(x, axis=dim, keepdims=keepdim)), - dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16), + dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16, torch.uint16, torch.uint32, torch.uint64), dtypesIfHpu=custom_types(torch.float32, torch.bfloat16, torch.int32, torch.int8), decorators=(onlyNativeDeviceTypes,), supports_autograd=False, @@ -21126,7 +21126,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): supports_forward_ad=True, check_batched_forward_grad=False, supports_fwgrad_bwgrad=True, - dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool, torch.uint16, torch.uint32, torch.uint64), ref=reference_reduction_numpy(np.amax), skips=( # FIXME: reduces all dimensions when dim=[] @@ -21141,7 +21141,7 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs): supports_forward_ad=True, check_batched_forward_grad=False, supports_fwgrad_bwgrad=True, - dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool), + dtypes=all_types_and(torch.float16, torch.bfloat16, torch.bool, torch.uint16, torch.uint32, torch.uint64), ref=reference_reduction_numpy(np.amin), skips=( # FIXME: reduces all dimensions when dim=[] From c00696144dae1f02e04ce345480b55e46c7d32a8 Mon Sep 17 00:00:00 2001 From: Shangdi Yu Date: Tue, 4 Nov 2025 16:09:28 -0800 Subject: [PATCH 053/130] Add model code stack trace to torch.profile (#166677) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ```python python test/test_fx.py -k profiler ``` Insert `torch._C._profiler._RecordFunctionFast` to fx graph codegen. We post-process the profiler dump using `map_recorded_events_to_aten_ops_with_stack_trace` to add the stack trace to the dump'd trace. `map_recorded_events_to_aten_ops_with_stack_trace` queries `fx.traceback._FX_METADATA_REGISTRY` for node metadata. Each graph module has a hash'd fake file name (e.g. `fx_generated__iv4zodvbcmdkhx77jrg7h2f2opebujhfmc6tf6nx7vioq244baw.py`), which is the key to the registry. One can do `fx_g.enrich_profiler_metadata()` to add debugging info. Or `fx_g.enrich_profiler_metadata(enable=False)` to remove. `aot_eager` makes calls `fx_g.enrich_profiler_metadata()` if TORCH_ENRICH_RPOFILER_STACK_TRACE is set or _dynamo.config.enrich_profiler_metadata=True. Screenshot 2025-10-31 at 4 40 52 PM Example code gen'd. ``` def forward(self, args_list): args_iter = iter(args_list) arg0_1 = next(args_iter) arg1_1 = next(args_iter) args_list.clear() _rf = torch._C._profiler._RecordFunctionFast('## fx_generated__iv4zodvbcmdkhx77jrg7h2f2opebujhfmc6tf6nx7vioq244baw.py ##'); _rf.__enter__() repeated_subgraph0 = self.repeated_subgraph0 _rf_invoke_subgraph = torch._C._profiler._RecordFunctionFast('## 3 ##'); _rf_invoke_subgraph.__enter__() invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', arg0_1, arg1_1); repeated_subgraph0 = arg0_1 = arg1_1 = None _rf_invoke_subgraph.__exit__(None, None, None) _rf_getitem = torch._C._profiler._RecordFunctionFast('## 4 ##'); _rf_getitem.__enter__() getitem = invoke_subgraph[0]; invoke_subgraph = None _rf_getitem.__exit__(None, None, None) return (getitem,) _rf.__exit__(None, None, None) def forward(self, arg0_1, arg1_1): _rf = torch._C._profiler._RecordFunctionFast('## fx_generated__ozpadpj5cxoalxeyopej33g2vvtvhxg4xsk7bhx7ldmcibtybyn.py ##'); _rf.__enter__() _rf_mul = torch._C._profiler._RecordFunctionFast('## 2 ##'); _rf_mul.__enter__() mul = torch.ops.aten.mul.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None _rf_mul.__exit__(None, None, None) _rf_sin = torch._C._profiler._RecordFunctionFast('## 3 ##'); _rf_sin.__enter__() sin = torch.ops.aten.sin.default(mul); mul = None _rf_sin.__exit__(None, None, None) _rf_add = torch._C._profiler._RecordFunctionFast('## 4 ##'); _rf_add.__enter__() add = torch.ops.aten.add.Tensor(sin, 5); sin = None _rf_add.__exit__(None, None, None) return (add,) _rf.__exit__(None, None, None) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/166677 Approved by: https://github.com/ezyang --- ...t-fx_backcompat_function_signatures.expect | 2 +- test/test_fx.py | 180 ++++++++++++++++++ torch/autograd/profiler_util.py | 40 ++++ torch/fx/graph.py | 23 +++ torch/fx/graph_module.py | 16 +- torch/profiler/_utils.py | 169 +++++++++++++++- 6 files changed, 425 insertions(+), 5 deletions(-) diff --git a/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect b/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect index a404e15a977ee..12f6ba2228db8 100644 --- a/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect +++ b/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect @@ -23,7 +23,7 @@ torch.fx.graph.Graph.node_copy(self, node: torch.fx.node.Node, arg_transform: Ca torch.fx.graph.Graph.output(self, result: 'Argument', type_expr: Optional[Any] = None) torch.fx.graph.Graph.placeholder(self, name: str, type_expr: Optional[Any] = None, default_value: Any) -> torch.fx.node.Node torch.fx.graph.Graph.print_tabular(self) -torch.fx.graph.Graph.python_code(self, root_module: str, verbose: bool = False, include_stride: bool = False, include_device: bool = False, colored: bool = False, expanded_def: bool = False) -> torch.fx.graph.PythonCode +torch.fx.graph.Graph.python_code(self, root_module: str, verbose: bool = False, include_stride: bool = False, include_device: bool = False, colored: bool = False, expanded_def: bool = False, record_func: bool = False) -> torch.fx.graph.PythonCode torch.fx.graph_module.GraphModule.__init__(self, root: Union[torch.nn.modules.module.Module, Dict[str, Any]], graph: torch.fx.graph.Graph, class_name: str = 'GraphModule') torch.fx.graph_module.GraphModule.add_submodule(self, target: str, m: torch.nn.modules.module.Module) -> bool torch.fx.graph_module.GraphModule.delete_all_unused_submodules(self) -> None diff --git a/test/test_fx.py b/test/test_fx.py index d6f33d426aee7..c16c42805b921 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -75,6 +75,12 @@ ) from torch.testing._internal.jit_utils import JitTestCase +import json +import tempfile +from torch.profiler import profile, ProfilerActivity +from torch.profiler._utils import map_recorded_events_to_aten_ops_with_stack_trace +from torch.autograd.profiler_util import _canonicalize_profiler_events + try: from torchvision import models as torchvision_models @@ -201,6 +207,36 @@ def side_effect_func(x: torch.Tensor): print(x) +def _enrich_profiler_traces(prof): + """ + Helper function to extract and augment profiler events with stack traces. + + Args: + prof: A torch.profiler.profile object + + Returns: + A string representing enriched events + """ + with tempfile.NamedTemporaryFile(mode='w', suffix='.json') as f: + trace_file = f.name + prof.export_chrome_trace(trace_file) + + with open(trace_file) as f: + trace_data = json.load(f) + + map_recorded_events_to_aten_ops_with_stack_trace( + trace_data + ) + + events = [] + for event in trace_data["traceEvents"]: + if "args" in event and "stack_trace" in event["args"]: + events.append(event) + + actual_traces = _canonicalize_profiler_events(events) + return actual_traces + + class TestFX(JitTestCase): def setUp(self): super().setUp() @@ -4187,6 +4223,150 @@ def fn(a, b, c, d): # recorver mutable checking flag torch.fx.proxy.TracerBase.check_mutable_operations = orig_tracer_mutable_flag + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + @torch._dynamo.config.patch("enrich_profiler_metadata", True) + def test_profiler_stack_trace_augmentation(self): + """ + Test that map_recorded_events_to_aten_ops_with_stack_trace correctly + augments profiler events with stack traces from FX metadata registry. + """ + + # Simple test model + class TestModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(10, 16) + self.relu = torch.nn.ReLU() + self.linear2 = torch.nn.Linear(16, 10) + + def forward(self, x): + x = self.linear1(x) + x = self.relu(x) + x = self.linear2(x) + return x + + model = TestModel().cuda() + + # Compile the model + compiled_model = torch.compile(model, backend="aot_eager", fullgraph=True) + + # Warmup + for _ in range(3): + _ = compiled_model(torch.randn(10, 10, device="cuda")) + + # Profile with the compiled model + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + ) as prof: + result = compiled_model(torch.randn(10, 10, device="cuda")) + + actual_traces = _enrich_profiler_traces(prof) + + self.assertExpectedInline(actual_traces, """\ +event=aten::t node=t stack_trace=x = self.linear1(x) +event=aten::transpose node=t stack_trace=x = self.linear1(x) +event=aten::as_strided node=t stack_trace=x = self.linear1(x) +event=aten::addmm node=addmm stack_trace=x = self.linear1(x) +event=cudaLaunchKernel node=addmm stack_trace=x = self.linear1(x) +event=aten::relu node=relu stack_trace=x = self.relu(x) +event=aten::clamp_min node=relu stack_trace=x = self.relu(x) +event=cudaLaunchKernel node=relu stack_trace=x = self.relu(x) +event=aten::t node=t_1 stack_trace=x = self.linear2(x) +event=aten::transpose node=t_1 stack_trace=x = self.linear2(x) +event=aten::as_strided node=t_1 stack_trace=x = self.linear2(x) +event=aten::addmm node=addmm_1 stack_trace=x = self.linear2(x) +event=cudaLaunchKernel node=addmm_1 stack_trace=x = self.linear2(x)""" + ) + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + @torch._dynamo.config.patch("enrich_profiler_metadata", True) + def test_profiler_multiple_modules(self): + """ + Test that multiple compiled modules under the same profiler session + have their events correctly augmented with stack traces. + """ + + class ModelA(torch.nn.Module): + def forward(self, x): + return x + 1 + + class ModelB(torch.nn.Module): + def forward(self, x): + return x - 1 + + model_a = ModelA().cuda() + model_b = ModelB().cuda() + + # Compile both models + compiled_a = torch.compile(model_a, backend="aot_eager", fullgraph=True) + compiled_b = torch.compile(model_b, backend="aot_eager", fullgraph=True) + + # Warmup + for _ in range(3): + _ = compiled_a(torch.randn(10, 10, device="cuda")) + _ = compiled_b(torch.randn(1, 3, 8, 8, device="cuda")) + + # Profile both models in the same session + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + ) as prof: + result_a = compiled_a(torch.randn(10, 10, device="cuda")) + result_b = compiled_b(torch.randn(1, 3, 8, 8, device="cuda")) + + actual_traces = _enrich_profiler_traces(prof) + self.assertExpectedInline(actual_traces, """\ +event=aten::add node=add stack_trace=return x + 1 +event=cudaLaunchKernel node=add stack_trace=return x + 1 +event=aten::sub node=sub stack_trace=return x - 1 +event=cudaLaunchKernel node=sub stack_trace=return x - 1""" + ) + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + @torch._dynamo.config.patch("enrich_profiler_metadata", True) + def test_profiler_nested_graph_modules(self): + """ + Test that nested graph modules (e.g., graph modules calling subgraphs) + have their events correctly augmented with stack traces. + """ + + # Model with nested structure + class Mod(torch.nn.Module): + def __init__(self): + super().__init__() + self.c = 5 + + @torch.compiler.nested_compile_region + def forward(self, x, y): + m = torch.mul(x, y) + s = m.sin() + a = s + self.c + return a + + model = Mod().cuda() + + # Compile the model (this may create nested graph modules) + compiled_model = torch.compile(model, backend="aot_eager", fullgraph=True) + + # Warmup + for _ in range(3): + _ = compiled_model(torch.randn(10, 10, device="cuda"), torch.randn(10, 10, device="cuda")) + + # Profile + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + ) as prof: + result = compiled_model(torch.randn(10, 10, device="cuda"), torch.randn(10, 10, device="cuda")) + + actual_traces = _enrich_profiler_traces(prof) + self.assertExpectedInline(actual_traces, """\ +event=aten::mul node=mul stack_trace=m = torch.mul(x, y) +event=cudaLaunchKernel node=mul stack_trace=m = torch.mul(x, y) +event=aten::sin node=sin stack_trace=s = m.sin() +event=cudaLaunchKernel node=sin stack_trace=s = m.sin() +event=aten::add node=add stack_trace=a = s + self.c +event=cudaLaunchKernel node=add stack_trace=a = s + self.c""" + ) + def run_getitem_target(): from torch.fx._symbolic_trace import _wrapped_methods_to_patch diff --git a/torch/autograd/profiler_util.py b/torch/autograd/profiler_util.py index b2d6530049e61..a61aee321fcff 100644 --- a/torch/autograd/profiler_util.py +++ b/torch/autograd/profiler_util.py @@ -1224,3 +1224,43 @@ def override_time_unit(time_us, default_str, time_unit): f"time total: {override_time_unit(sum_self_device_time_total, _format_time(sum_self_device_time_total), time_unit)}" ) return "".join(result) + + +# Collect all events with stack traces and format them canonically +def _canonicalize_profiler_events(events): + """ + Extract and format all events with stack traces in a canonical way + for deterministic testing. + """ + events_with_traces = [] + + for event in events: + # Extract relevant fields + event_name = event.get("name", "") + node_name = event["args"].get("node_name", "") + stack_trace = event["args"].get("stack_trace", "") + + # Get the last non-empty line of the stack trace + lines = [s.strip() for s in stack_trace.split("\n") if s.strip()] + stack_trace = lines[-1] if lines else "" + + events_with_traces.append( + { + "event_name": event_name[:20], + "node_name": node_name, + "stack_trace": stack_trace, + "start_time": event.get("ts", 0), + } + ) + + # Sort by node_name for deterministic ordering + events_with_traces.sort(key=lambda x: x["start_time"]) + + # Format as a string + lines: list[str] = [] + for evt in events_with_traces: + lines.append( + f"event={evt['event_name']} node={evt['node_name']} stack_trace={evt['stack_trace']}" + ) + + return "\n".join(lines) diff --git a/torch/fx/graph.py b/torch/fx/graph.py index 697b2f4084ca5..fd6835d2b301b 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -443,6 +443,7 @@ def _gen_python_code( colored: bool = False, # Render each argument on its own line expanded_def: bool = False, + record_func: bool = False, ) -> PythonCode: free_vars: list[str] = [] body: list[str] = [] @@ -798,6 +799,10 @@ def _tensor_annotation(t: torch.Tensor) -> str: return raise NotImplementedError(f"node: {node.op} {node.target}") + if record_func: + body.append( + "_rf = torch._C._profiler._RecordFunctionFast('## ENTER_GRAPH_PLACEHOLDER_KEY ##'); _rf.__enter__()\n" + ) for i, node in enumerate(nodes): # NOTE: emit_node does not emit a string with newline. It depends # on delete_unused_values to append one @@ -807,8 +812,22 @@ def _tensor_annotation(t: torch.Tensor) -> str: # node index, which will be deleted later # after going through _body_transformer body.append(f"# COUNTER: {i}\n") + do_record = record_func and node.op in ( + "call_function", + "call_method", + "call_module", + ) + if do_record: + # The double hash ## convention is used by post-processing to find the fx markers + body.append( + f"_rf_{node.name} = torch._C._profiler._RecordFunctionFast('## {i} ##'); _rf_{node.name}.__enter__()\n" + ) emit_node(node) delete_unused_values(node) + if do_record: + body.append(f"_rf_{node.name}.__exit__(None, None, None)\n") + if record_func: + body.append("_rf.__exit__(None, None, None)\n") if len(body) == 0: # If the Graph has no non-placeholder nodes, no lines for the body @@ -1760,6 +1779,7 @@ def python_code( include_device: bool = False, colored: bool = False, expanded_def: bool = False, + record_func: bool = False, ) -> PythonCode: """ Turn this ``Graph`` into valid Python code. @@ -1827,6 +1847,7 @@ def override_node_repr(graph: Graph): include_device=include_device, colored=colored, expanded_def=expanded_def, + record_func=record_func, ) def _python_code( @@ -1839,6 +1860,7 @@ def _python_code( include_device: bool = False, colored: bool = False, expanded_def: bool = False, + record_func: bool = False, ) -> PythonCode: return self._codegen._gen_python_code( self.nodes, @@ -1849,6 +1871,7 @@ def _python_code( include_device=include_device, colored=colored, expanded_def=expanded_def, + record_func=record_func, ) def __str__(self) -> str: diff --git a/torch/fx/graph_module.py b/torch/fx/graph_module.py index 297f76732584f..8360c96630d6c 100644 --- a/torch/fx/graph_module.py +++ b/torch/fx/graph_module.py @@ -861,14 +861,18 @@ def recompile(self) -> PythonCode: if isinstance(self._graph._codegen, _PyTreeCodeGen): self._in_spec = self._graph._codegen.pytree_info.in_spec self._out_spec = self._graph._codegen.pytree_info.out_spec - python_code = self._graph.python_code(root_module="self") + + from torch._dynamo import config as dynamo_config + + python_code = self._graph.python_code( + root_module="self", record_func=dynamo_config.enrich_profiler_metadata + ) self._code = python_code.src self._lineno_map = python_code._lineno_map self._prologue_start = python_code._prologue_start cls = type(self) co_fields = self._graph._co_fields if hasattr(self._graph, "_co_fields") else {} - from torch._dynamo import config as dynamo_config if dynamo_config.enrich_profiler_metadata: # Generate metadata and register for profiler augmentation @@ -885,7 +889,6 @@ def recompile(self) -> PythonCode: # This ensures the same code+metadata always generates the same filename hash_value = _metadata_hash(self._code, node_metadata) file_stem = f"{FX_GRAPH_MODULE_FILE_PREFIX}_{hash_value}" - filename = f"{file_stem}.py" # Only include co_filename to use it directly as the cache key @@ -905,6 +908,13 @@ def recompile(self) -> PythonCode: _register_fx_metadata(filename, metadata) + # Replace the placeholder in generated code with actual filename + # The double hash ## convention is used by post-processing to find the fx markers + self._code = self._code.replace( + "torch._C._profiler._RecordFunctionFast('## ENTER_GRAPH_PLACEHOLDER_KEY ##')", + f"torch._C._profiler._RecordFunctionFast('## {filename} ##')", + ) + cls.forward = _forward_from_src(self._code, python_code.globals, co_fields) # Determine whether this class explicitly defines a __call__ implementation diff --git a/torch/profiler/_utils.py b/torch/profiler/_utils.py index 2c6e06b2cb3c9..47df87ce1678d 100644 --- a/torch/profiler/_utils.py +++ b/torch/profiler/_utils.py @@ -4,7 +4,7 @@ import re from collections import deque from dataclasses import dataclass -from typing import TYPE_CHECKING +from typing import Any, Literal, Optional, TYPE_CHECKING from torch.autograd.profiler import profile from torch.profiler import DeviceType @@ -400,3 +400,170 @@ def _init_for_cuda_graphs() -> None: with profile(): pass + + +@dataclass +class TimelineEvent: + """Represents an event in the profiler timeline.""" + + timestamp: int + event_type: Literal["start", "end", "regular"] + marker_type: Optional[Literal["filename", "node"]] + identifier: Optional[str | int] + event: dict[str, Any] + + +@dataclass +class ContextStackEntry: + """Represents a context (filename or node) in the stack.""" + + context_type: Literal["filename", "node"] + identifier: str | int + metadata: Optional[dict] + tid: Optional[int] = None # Thread ID associated with this context + + +def map_recorded_events_to_aten_ops_with_stack_trace(traced_data): + """ + Maps recorded profiler events to their corresponding fx nodes and adds stack traces. + + Builds a timeline of all events (regular ops and FX markers for filenames/nodes), + sorts by timestamp, then processes chronologically while maintaining a context stack of active + filename/node scopes. Regular events are augmented with stack traces and node names from the + innermost active context. Runtime is O(n log n) for n events. + + Args: + traced_data: Json of profiler events from Chrome trace + + Returns: + Dict mapping recorded event names to their aten operations with added stack traces + """ + from torch.fx.traceback import _FX_METADATA_REGISTRY + + trace_events = traced_data.get("traceEvents", []) + + # Create event timeline + event_timeline: list[TimelineEvent] = [] + + def is_fx_marker_event(event): + return ( + event.get("cat") == "cpu_op" + and event.get("name", "").startswith("## ") + and event.get("name", "").endswith(" ##") + ) + + def append_fx_marker_event(event_type, identifier, event): + start_ts = event["ts"] + end_ts = start_ts + event["dur"] + event_timeline.append( + TimelineEvent(start_ts, "start", event_type, identifier, event) + ) + event_timeline.append( + TimelineEvent(end_ts, "end", event_type, identifier, event) + ) + + for event in trace_events: + if "ts" not in event or "dur" not in event: + continue + + if is_fx_marker_event(event): + content = event["name"][3:-3] + + if content.endswith(".py"): + append_fx_marker_event("filename", content, event) + else: + try: + node_index = int(content) + except ValueError: + pass + append_fx_marker_event("node", node_index, event) # type: ignore[possibly-undefined] + + else: + # Regular event that needs augmentation + start_ts = event["ts"] + event_timeline.append(TimelineEvent(start_ts, "regular", None, None, event)) + + # Sort by timestamp + event_timeline.sort(key=lambda x: x.timestamp) + + # Process events in chronological order with a stack + context_stack: list[ContextStackEntry] = [] + + # Invariant: all start event has a corresponding end event + for timeline_event in event_timeline: + match timeline_event.event_type: + case "start": + assert timeline_event.identifier is not None + + if timeline_event.marker_type == "filename": + assert isinstance(timeline_event.identifier, str) + # Push filename context - query metadata registry on-demand + metadata = _FX_METADATA_REGISTRY.get(timeline_event.identifier) + tid = timeline_event.event.get("tid") + context_stack.append( + ContextStackEntry( + "filename", timeline_event.identifier, metadata, tid + ) + ) + elif timeline_event.marker_type == "node": + # Find the current filename from stack + current_file_metadata = None + tid = timeline_event.event.get("tid") + for ctx_entry in reversed(context_stack): + if ( + ctx_entry.context_type == "filename" + and ctx_entry.tid == tid + ): + current_file_metadata = ctx_entry.metadata + break + + if current_file_metadata: + node_metadata = current_file_metadata.get("node_metadata", {}) + if timeline_event.identifier in node_metadata: + node_meta: Optional[dict] = node_metadata[ + timeline_event.identifier + ] + context_stack.append( + ContextStackEntry( + "node", timeline_event.identifier, node_meta, tid + ) + ) + + case "end": + # Pop from stack - search backwards to find matching context + for i in range(len(context_stack) - 1, -1, -1): + ctx_entry = context_stack[i] + if ( + timeline_event.marker_type == ctx_entry.context_type + and timeline_event.identifier == ctx_entry.identifier + ): + context_stack.pop(i) + break + + case "regular": + # Apply metadata from current context stack + # Find the most specific context (node takes precedence over filename) + # Only augment events with the same tid as the file/node event matched + current_stack_trace = None + current_node_name = None + event_tid = timeline_event.event.get("tid") + + for ctx_entry in reversed(context_stack): + # Only apply metadata from contexts with matching tid + if ctx_entry.tid == event_tid: + if ctx_entry.context_type == "node" and ctx_entry.metadata: + current_stack_trace = ctx_entry.metadata.get( + "stack_trace", "No model stack trace available" + ) + current_node_name = ctx_entry.metadata.get("name", "") + # Do we want to only attach the stack trace of the lowest node or stack trace of all nodes + # if nodes are nested, e.g. in nested graph modules + break + + # Augment the event + if current_stack_trace or current_node_name: + args = timeline_event.event.setdefault("args", {}) + if current_stack_trace: + args["stack_trace"] = current_stack_trace + if current_node_name: + args["node_name"] = current_node_name From 431dfe8692f3f927c19c739884054d7f1d42a33d Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Tue, 4 Nov 2025 12:18:14 +0800 Subject: [PATCH 054/130] [dynamo] extend `collections.defaultdict` support with `*args`, `**kwargs` and custom `default_factory` (#166793) Fixes #166238 Extend `collections.defaultdict` to accept `*args` and `**kwargs` in the constructor. And also support custom `default_factory`, such as `dd.default_factory` (a `GetAttrVariable`). Pull Request resolved: https://github.com/pytorch/pytorch/pull/166793 Approved by: https://github.com/guilhermeleobas --- test/dynamo/test_dicts.py | 69 +++++++++++++++++-- ...tDefaultDict.test_keyerror_without_factory | 0 ...13-test_dict-DictTest.test_dict_copy_order | 0 ...redDictSubclassTests.test_sorted_iterators | 0 ...thonOrderedDictTests.test_sorted_iterators | 0 ...313-test_set-TestGraphs.test_cuboctahedron | 0 torch/_dynamo/polyfills/__init__.py | 62 ++++++++++------- torch/_dynamo/variables/builtin.py | 8 ++- torch/_dynamo/variables/dicts.py | 17 ++++- torch/_dynamo/variables/user_defined.py | 21 +++--- 10 files changed, 135 insertions(+), 42 deletions(-) delete mode 100644 test/dynamo_expected_failures/CPython313-test_defaultdict-TestDefaultDict.test_keyerror_without_factory delete mode 100644 test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_dict_copy_order delete mode 100644 test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_sorted_iterators delete mode 100644 test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_sorted_iterators delete mode 100644 test/dynamo_expected_failures/CPython313-test_set-TestGraphs.test_cuboctahedron diff --git a/test/dynamo/test_dicts.py b/test/dynamo/test_dicts.py index 966acd1d81394..4a4d2ff87718f 100644 --- a/test/dynamo/test_dicts.py +++ b/test/dynamo/test_dicts.py @@ -36,6 +36,15 @@ class DummyUserDict(UserDict): pass +class FakeMapping: + def __init__(self, value: Any) -> None: + self._value = value + self.keys = lambda: ["a", "b", "c"] # not required to be a method + + def __getitem__(self, key: str) -> Any: + return self._value + + class DictTests(torch._dynamo.test_case.TestCase): def test_dict_subclass_instantiation(self): def fn(x): @@ -666,6 +675,18 @@ def fn(): for k1, m2 in zip(modules, module_dict.children()): self.assertTrue(modules[k1] is m2) + # FIXME: see comment in torch/_dynamo/polyfills/__init__.py:mutable_mapping_update + @unittest.expectedFailure + def test_dict_construct_from_mapping_like(self): + def fn(x): + fm = FakeMapping(x) + d = dict(fm, x=x) + return d + + x = torch.randn(4) + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + self.assertEqual(fn(x), opt_fn(x)) + def test_dict_subclass_initialization_in_graph(self): for super_class in ( OrderedDict, @@ -1087,12 +1108,52 @@ def f(x): self.assertEqual(ref, res) - @unittest.expectedFailure + def test_newly_constructed_default_dict_no_default_factory(self): + def f1(x): + d = defaultdict() + try: + d[1] += 42 + except KeyError: + d[1] = 1 + return x + 1, d + + x = torch.ones(2) + ref = f1(x) + res = torch.compile(f1, backend="eager", fullgraph=True)(x) + + self.assertEqual(ref, res) + + def f2(x): + d = defaultdict(None) + try: + d[1] += 42 + except KeyError: + d[1] = 1 + return x + 1, d + + ref = f2(x) + res = torch.compile(f2, backend="eager", fullgraph=True)(x) + self.assertEqual(ref, res) + + def f3(x): + d = defaultdict(None, {1: 10}) + d[1] += 42 + try: + d[2] += 24 + except KeyError: + d[2] = 1 + return x + 1, d + + ref = f3(x) + res = torch.compile(f3, backend="eager", fullgraph=True)(x) + self.assertEqual(ref, res) + def test_newly_constructed_default_dict_with_dict(self): def f(x): - d = defaultdict(dict, {2: {"a": 1}}) - d[0] = {"b": 2} - return x + 1, d + d = dict([("a", 1), ("b", 2)], c=3) # noqa: C406 + dd = defaultdict(list, d, d=4, e=5) + dd["x"].append(42) + return x + 1, d, dd x = torch.ones(2) ref = f(x) diff --git a/test/dynamo_expected_failures/CPython313-test_defaultdict-TestDefaultDict.test_keyerror_without_factory b/test/dynamo_expected_failures/CPython313-test_defaultdict-TestDefaultDict.test_keyerror_without_factory deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_dict_copy_order b/test/dynamo_expected_failures/CPython313-test_dict-DictTest.test_dict_copy_order deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_sorted_iterators b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictSubclassTests.test_sorted_iterators deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_sorted_iterators b/test/dynamo_expected_failures/CPython313-test_ordered_dict-CPythonOrderedDictTests.test_sorted_iterators deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/test/dynamo_expected_failures/CPython313-test_set-TestGraphs.test_cuboctahedron b/test/dynamo_expected_failures/CPython313-test_set-TestGraphs.test_cuboctahedron deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/torch/_dynamo/polyfills/__init__.py b/torch/_dynamo/polyfills/__init__.py index a8dcf3e00c166..59f6f76317e6d 100644 --- a/torch/_dynamo/polyfills/__init__.py +++ b/torch/_dynamo/polyfills/__init__.py @@ -10,7 +10,7 @@ import types from collections import OrderedDict -from collections.abc import Callable, Hashable, Iterable, MutableMapping, Sequence +from collections.abc import Callable, Hashable, Iterable, Mapping, Sequence from itertools import repeat as _repeat from operator import eq, ne from typing import Any, TYPE_CHECKING @@ -276,7 +276,7 @@ def getattr_and_trace(*args, **kwargs): return fn(*args[2:], **kwargs) -def mapping_get(obj, key, value=None): +def mapping_get(obj, key, value=None, /): try: return obj.__getitem__(key) except KeyError: @@ -293,31 +293,45 @@ def instantiate_user_defined_class_object(cls, /, *args, **kwargs): return obj -# Used with something like dict(obj) -def construct_dict(cls, /, *args, **kwargs): - dst = cls.__new__(cls) - - if args: - src = args[0] - - if not isinstance(src, Iterable): - raise TypeError(f"{type(src)} object is not iterable") - - # Ensure that the overridden __iter__ method is invoked - if isinstance(src, (dict, MutableMapping, types.MappingProxyType)): - for key in src: - # This will inline the __getitem__ of the src object - dst[key] = src[key] - else: - # likely a sequence like tuple of pairs - for key, value in src: - dst[key] = value +def mutable_mapping_update(self, data=(), /, **kwargs): + if isinstance(data, Mapping): + # Merge standard mapping with PyMapping_Items + for key, value in data.items(): + self[key] = value + # FIXME: Enabling the `elif`-branch below needs too many `VariableClass.call_obj_hasattr` changes. + # >>> class Foo: + # ... def __init__(self): + # ... self.keys = lambda: ['a', 'b', 'c'] # not required to be a method + # ... + # ... def __getitem__(self, key): + # ... return 0 + # ... + # >>> dict(Foo()) + # {'a': 0, 'b': 0, 'c': 0} + # + # > This is a rare case, so we comment it out for now. + # + # elif hasattr(data, "keys"): + # # Merge mapping-like object with PyMapping_Keys + PyObject_GetItem + # for key in data.keys(): + # self[key] = data[key] + else: + if not isinstance(data, Iterable): + raise TypeError(f"{type(data).__name__!r} object is not iterable") + # Likely a sequence of pairs + for key, value in data: + self[key] = value if kwargs: - for key in kwargs: - dst[key] = kwargs[key] + for key, value in kwargs.items(): + self[key] = value - return dst + +# Used with something like dict(obj) +def construct_dict(cls, data=(), /, **kwargs): + self = cls.__new__(cls) + mutable_mapping_update(self, data, **kwargs) + return self def foreach_map_fn(*args): diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index 1817a5f3c7ed1..0f198377605ec 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -2061,7 +2061,11 @@ def call_dir( return None def call_dict( - self, tx: "InstructionTranslator", *args: Any, **kwargs: Any + self, + tx: "InstructionTranslator", + /, + *args: VariableTracker, + **kwargs: VariableTracker, ) -> VariableTracker: return BuiltinVariable.call_custom_dict(tx, dict, *args, **kwargs) @@ -2069,6 +2073,7 @@ def call_dict( def call_custom_dict( tx: "InstructionTranslator", user_cls: type, + /, *args: VariableTracker, **kwargs: VariableTracker, ) -> VariableTracker: @@ -2093,6 +2098,7 @@ def call_custom_dict( def call_custom_dict_fromkeys( tx: "InstructionTranslator", user_cls: type, + /, *args: VariableTracker, **kwargs: VariableTracker, ) -> VariableTracker: diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py index 4f1f84a55b0b0..f70ba99c0c93d 100644 --- a/torch/_dynamo/variables/dicts.py +++ b/torch/_dynamo/variables/dicts.py @@ -911,6 +911,8 @@ class DefaultDictVariable(ConstDictVariable): def __init__(self, items, user_cls, default_factory=None, **kwargs) -> None: super().__init__(items, user_cls, **kwargs) assert user_cls is collections.defaultdict + if default_factory is None: + default_factory = ConstantVariable.create(None) self.default_factory = default_factory def is_python_constant(self): @@ -930,7 +932,13 @@ def is_supported_arg(arg): if isinstance(arg, variables.BuiltinVariable): return arg.fn in (list, tuple, dict, set) else: - return isinstance(arg, variables.functions.BaseUserFunctionVariable) + return isinstance( + arg, + ( + variables.functions.BaseUserFunctionVariable, + variables.functions.PolyfilledFunctionVariable, + ), + ) def call_method( self, @@ -946,8 +954,11 @@ def call_method( if args[0] in self: return self.getitem_const(tx, args[0]) else: - if self.default_factory is None: - raise KeyError(f"{args[0]}") + if ( + istype(self.default_factory, ConstantVariable) + and self.default_factory.value is None + ): + raise_observed_exception(KeyError, tx, args=[args[0]]) else: default_var = self.default_factory.call_function(tx, [], {}) super().call_method( diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index 085b5e0c648c5..9dd154dacbb9e 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -419,9 +419,7 @@ def call_method( self.value in {collections.OrderedDict, collections.defaultdict} and name == "fromkeys" ): - from .builtin import BuiltinVariable - - return BuiltinVariable.call_custom_dict_fromkeys( + return variables.BuiltinVariable.call_custom_dict_fromkeys( tx, self.value, *args, **kwargs ) elif self.value is collections.OrderedDict and name == "move_to_end": @@ -501,15 +499,18 @@ def call_function( [self, *args], kwargs, ) - elif ( - self.value is collections.defaultdict - and len(args) <= 1 - and DefaultDictVariable.is_supported_arg(args[0]) - ): + elif self.value is collections.defaultdict: + if len(args) == 0: + default_factory = variables.ConstantVariable.create(None) + else: + default_factory, *args = args + dict_vt = variables.BuiltinVariable.call_custom_dict( + tx, dict, *args, **kwargs + ) return DefaultDictVariable( - {}, + dict_vt.items, collections.defaultdict, - args[0], + default_factory, mutation_type=ValueMutationNew(), ) elif is_typeddict(self.value): From 59a6c83dfe9d88d44d0e5440aa61d2e883a88122 Mon Sep 17 00:00:00 2001 From: Scott Lee Date: Wed, 5 Nov 2025 06:39:26 +0000 Subject: [PATCH 055/130] [fx] Add strict argument validation to Interpreter.boxed_run (#166784) # Summary This PR fixes an issue where `torch.fx.Interpreter.boxed_run` would silently ignore extra input arguments instead of validating the argument count. Previously, `boxed_run` would only consume as many inputs as there were placeholder nodes and then clear the entire `args_list`, hiding potential bugs. This change introduces a strict check to ensure `len(args_list)` matches the number of placeholder nodes, raising a `RuntimeError` on a mismatch. Fixes #166583. # Changes * Validate `len(args_list)` against the number of placeholder nodes at the beginning of `boxed_run`. * Raise a `RuntimeError` with a clear message ("extra arguments" or "missing arguments") if the counts do not match. * Move `args_list.clear()` to only execute after successful validation and environment setup. If an error is raised, `args_list` is preserved for debugging. # Testing * Added `test_interpreter_boxed_run_argument_validation` to `test/test_fx.py`. * This test covers three scenarios: 1. Correct number of arguments (succeeds, `args_list` is cleared). 2. Extra arguments (raises `RuntimeError`, `args_list` is preserved). 3. Missing arguments (raises `RuntimeError`, `args_list` is preserved). # User-facing impact / BC notes This is a bug fix. Code that was incorrectly passing the wrong number of arguments to `boxed_run` will now fail fast with a `RuntimeError` instead of executing silently with unintended inputs. Correctly written code is unaffected. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166784 Approved by: https://github.com/ezyang, https://github.com/xmfan --- test/test_fx.py | 25 +++++++++++++++++++++++++ torch/fx/interpreter.py | 22 +++++++++++++++++----- 2 files changed, 42 insertions(+), 5 deletions(-) diff --git a/test/test_fx.py b/test/test_fx.py index c16c42805b921..e12189dfea461 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -2070,6 +2070,31 @@ def forward(self, x): self.assertEqual(interpreter.run(input), gm(input)) self.assertEqual(interpreter.run(input), m(input)) + def test_interpreter_boxed_run_argument_validation(self): + class AddModule(torch.nn.Module): + def forward(self, lhs, rhs): + return lhs + rhs + + gm = torch.fx.symbolic_trace(AddModule()) + interpreter = Interpreter(gm) + + lhs = torch.tensor(1.0) + rhs = torch.tensor(2.0) + good_args = [lhs.clone(), rhs.clone()] + result = interpreter.boxed_run(good_args) + torch.testing.assert_close(result, lhs + rhs) + self.assertEqual(good_args, []) + + extra_args = [lhs.clone(), rhs.clone(), torch.tensor(3.0)] + with self.assertRaisesRegex(RuntimeError, "extra arguments"): + interpreter.boxed_run(extra_args) + self.assertEqual(len(extra_args), 3) + + missing_args = [lhs.clone()] + with self.assertRaisesRegex(RuntimeError, "missing arguments"): + interpreter.boxed_run(missing_args) + self.assertEqual(len(missing_args), 1) + def test_interpreter_other_graph(self): class MyModule(torch.nn.Module): def __init__(self) -> None: diff --git a/torch/fx/interpreter.py b/torch/fx/interpreter.py index a3114a14a657e..5ad1424c4e489 100644 --- a/torch/fx/interpreter.py +++ b/torch/fx/interpreter.py @@ -220,11 +220,23 @@ def boxed_run(self, args_list): calling convention, where you pass a list of arguments, which will be cleared by the interpreter. This ensures that input tensors are promptly deallocated. """ - args_iter = iter(args_list) - env = {} - for n in self.graph.nodes: - if n.op == "placeholder": - env[n] = next(args_iter) + # Collect placeholder nodes first + placeholder_nodes = [n for n in self.graph.nodes if n.op == "placeholder"] + + # Check argument count + if len(args_list) != len(placeholder_nodes): + detail = ( + "extra arguments" + if len(args_list) > len(placeholder_nodes) + else "missing arguments" + ) + raise RuntimeError( + f"Interpreter.boxed_run expected {len(placeholder_nodes)} arguments for placeholders " + f"but received {len(args_list)} ({detail})" + ) + + # Assign arguments to placeholders + env = dict(zip(placeholder_nodes, args_list)) args_list.clear() return self.run(initial_env=env) From 658c5f879c37142b1df51c7eb6c5a5bb06318597 Mon Sep 17 00:00:00 2001 From: Nikhil Patel Date: Wed, 5 Nov 2025 06:51:30 +0000 Subject: [PATCH 056/130] [Inductor][Grouped Gemm] Add Blackwell CuTeDSL Kernel (#167003) Summary: This is a reland of https://github.com/pytorch/pytorch/pull/165036?fbclid=IwY2xjawN3RL1leHRuA2FlbQIxMQBicmlkETExOEcxcnVhNVA1TzRSVmhiAR63GOEpJbZA-JhQ0CSj9ji8H_RHBUhDwYNDtxjOYfDol56OGqmC4r7jPP96Fw_aem_bWvtMfVifLQrnpv1YB_fJA, which previously contained a minor bug in the logic that determined whether the kernel should be enabled. As a result, it was incorrectly activated on non-Blackwell GPUs. Test Plan: Inductor test (fbcode): `INDUCTOR_TEST_DISABLE_FRESH_CACHE=1 TORCHINDUCTOR_CACHE_DIR=~/cutetest buck2 run mode/opt //caffe2/test/inductor:cutedsl_grouped_mm -c fbcode.nvcc_arch=b200a -c fbcode.enable_gpu_sections=true -c fbcode.platform010_cuda_version=12.8 -m "ovr_config//third-party/pypi/nvidia-cutlass-dsl/constraints:4.2.1"` Tritonbench (fbcode): `clear; CUDA_VISIBLE_DEVICES=7 TRITON_PRINT_AUTOTUNING=1 TRITON_ALWAYS_COMPILE=1 TORCH_LOGS=+inductor TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 TORCHINDUCTOR_MAX_AUTOTUNE_GEMM=1 buck2 run mode/opt //pytorch/tritonbench:run -c fbcode.nvcc_arch=b200a -c fbcode.enable_gpu_sections=true -c fbcode.platform010_cuda_version=12.8 -m "ovr_config//third-party/pypi/nvidia-cutlass-dsl/constraints:4.2.1" -- --op grouped_gemm --only aten_grouped_mm,preprocessed_pt2_cute_grouped_mm --precision bf16 --num-inputs 1 --metrics tflops,accuracy` Tritonbench(oss): `clear; CUDA_VISIBLE_DEVICES=2 TRITON_PRINT_AUTOTUNING=1 TRITON_ALWAYS_COMPILE=1 TORCH_LOGS=+inductor TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 TORCHINDUCTOR_MAX_AUTOTUNE_GEMM=1 python run.py --op grouped_gemm --only aten_grouped_mm,preprocessed_pt2_triton_grouped_mm --precision bf16 --num-inputs 1 --metrics tflops,accuracy` Unit Tests(oss): `clear; python test/inductor/test_cutedsl_grouped_mm.py` Differential Revision: D86231180 Pull Request resolved: https://github.com/pytorch/pytorch/pull/167003 Approved by: https://github.com/jananisriram --- .ci/pytorch/test.sh | 2 +- .gitignore | 1 + setup.py | 34 ++ test/inductor/test_cutedsl_grouped_mm.py | 154 ++++++++ torch/_inductor/config.py | 4 + torch/_inductor/kernel/mm_common.py | 7 + torch/_inductor/kernel/mm_grouped.py | 93 +++-- .../templates/cutedsl_mm_grouped.py.jinja | 333 ++++++++++++++++++ .../_inductor/template_heuristics/cutedsl.py | 141 ++++++++ torch/_inductor/utils.py | 78 ++++ 10 files changed, 814 insertions(+), 33 deletions(-) create mode 100644 test/inductor/test_cutedsl_grouped_mm.py create mode 100644 torch/_inductor/kernel/templates/cutedsl_mm_grouped.py.jinja create mode 100644 torch/_inductor/template_heuristics/cutedsl.py diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index 26996b5a32d56..9ae2578758939 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -337,7 +337,7 @@ test_python() { test_python_smoke() { # Smoke tests for H100/B200 - time python test/run_test.py --include test_matmul_cuda test_scaled_matmul_cuda inductor/test_fp8 inductor/test_max_autotune $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running + time python test/run_test.py --include test_matmul_cuda test_scaled_matmul_cuda inductor/test_fp8 inductor/test_max_autotune inductor/test_cutedsl_grouped_mm $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running assert_git_not_dirty } diff --git a/.gitignore b/.gitignore index d1b3b17445dac..3b4323051073a 100644 --- a/.gitignore +++ b/.gitignore @@ -127,6 +127,7 @@ torch/test/ torch/utils/benchmark/utils/valgrind_wrapper/callgrind.h torch/utils/benchmark/utils/valgrind_wrapper/valgrind.h torch/version.py +torch/_inductor/kernel/vendored_templates/* minifier_launcher.py aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_fwd_d* aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_bwd_d* diff --git a/setup.py b/setup.py index 31e78d0245d93..dd8a52cbeb7c7 100644 --- a/setup.py +++ b/setup.py @@ -630,6 +630,37 @@ def mirror_files_into_torchgen() -> None: raise RuntimeError("Check the file paths in `mirror_files_into_torchgen()`") +def mirror_inductor_external_kernels() -> None: + """ + Copy external kernels into Inductor so they are importable. + """ + paths = [ + ( + CWD / "torch/_inductor/kernel/vendored_templates/cutedsl_grouped_gemm.py", + CWD + / "third_party/cutlass/examples/python/CuTeDSL/blackwell/grouped_gemm.py", + ), + ] + for new_path, orig_path in paths: + # Create the dirs involved in new_path if they don't exist + if not new_path.exists(): + new_path.parent.mkdir(parents=True, exist_ok=True) + + # Copy the files from the orig location to the new location + if orig_path.is_file(): + shutil.copyfile(orig_path, new_path) + continue + if orig_path.is_dir(): + if new_path.exists(): + # copytree fails if the tree exists already, so remove it. + shutil.rmtree(new_path) + shutil.copytree(orig_path, new_path) + continue + raise RuntimeError( + "Check the file paths in `mirror_inductor_external_kernels()`" + ) + + # ATTENTION: THIS IS AI SLOP def extract_variant_from_version(version: str) -> str: """Extract variant from version string, defaulting to 'cpu'.""" @@ -1616,6 +1647,8 @@ def main() -> None: if RUN_BUILD_DEPS: build_deps() + mirror_inductor_external_kernels() + ( ext_modules, cmdclass, @@ -1649,6 +1682,7 @@ def main() -> None: "_inductor/codegen/aoti_runtime/*.cpp", "_inductor/script.ld", "_inductor/kernel/flex/templates/*.jinja", + "_inductor/kernel/templates/*.jinja", "_export/serde/*.yaml", "_export/serde/*.thrift", "share/cmake/ATen/*.cmake", diff --git a/test/inductor/test_cutedsl_grouped_mm.py b/test/inductor/test_cutedsl_grouped_mm.py new file mode 100644 index 0000000000000..c26def3a54099 --- /dev/null +++ b/test/inductor/test_cutedsl_grouped_mm.py @@ -0,0 +1,154 @@ +# Owner(s): ["module: inductor"] + + +import unittest + +import torch +from torch import Tensor +from torch._inductor import config +from torch._inductor.codegen.cuda.cuda_env import is_datacenter_blackwell_arch +from torch._inductor.test_case import run_tests, TestCase as InductorTestCase +from torch._inductor.utils import ensure_cute_available +from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + parametrize, +) + + +@unittest.skipIf( + not (ensure_cute_available() and is_datacenter_blackwell_arch()), + "CuTeDSL library or Blackwell device not available", +) +@instantiate_parametrized_tests +class TestCuTeDSLGroupedGemm(InductorTestCase): + def _get_inputs( + self, + group_size: int, + M_hint: int, + K: int, + N: int, + device: str, + dtype: torch.dtype, + alignment: int = 16, + ) -> tuple[Tensor, Tensor, Tensor]: + # --- Random, tile-aligned M sizes --- + M_sizes = ( + torch.randint(1, (M_hint // alignment) + 1, (group_size,), dtype=torch.int) + * alignment + ) + + M_total = torch.sum(M_sizes).item() + + # --- Construct input tensors --- + A = torch.randn(int(M_total), K, dtype=dtype, device=device) * 0.1 + B = torch.randn((group_size, K, N), dtype=dtype, device=device) * 0.01 + + # --- Build offsets (no leading zero, strictly increasing) --- + offsets = torch.cumsum(M_sizes, dim=0).to(dtype=torch.int32, device=device) + + return (A, B, offsets) + + @parametrize("group_size", (2, 8)) + @parametrize("M_hint", (256, 1024)) + @parametrize("K", (64, 128)) + @parametrize("N", (128, 256)) + def test_grouped_gemm_basic(self, group_size: int, M_hint: int, K: int, N: int): + device = "cuda" + dtype = torch.bfloat16 + + A, B, offsets = self._get_inputs(group_size, M_hint, K, N, device, dtype) + + def grouped_gemm_fn(A_packed, B_batched, offs): + return torch._grouped_mm(A_packed, B_batched, offs=offs) + + # Eager execution + c_eager = grouped_gemm_fn(A, B, offsets) + + # Test with Cute backend + with config.patch( + { + "max_autotune": True, + "max_autotune_gemm_backends": "CUTEDSL", + "test_configs.autotune_choice_name_regex": "cutedsl", + "autotune_fallback_to_aten": False, + } + ): + grouped_gemm_compiled = torch.compile( + grouped_gemm_fn, backend="inductor", dynamic=False + ) + c_compiled = grouped_gemm_compiled(A, B, offsets) + + self.assertEqual(c_eager.dtype, dtype) + self.assertEqual(c_compiled.dtype, dtype) + torch.testing.assert_close(c_eager, c_compiled) + + @parametrize("layout_A", ("contiguous", "offset", "padded", "view")) + @parametrize("layout_B", ("contiguous", "broadcasted")) + def test_grouped_gemm_assorted_layouts( + self, + layout_A: str, + layout_B: str, + ): + device = "cuda" + dtype = torch.bfloat16 + + G, K, N = 8, 64, 128 + M_sizes = [128] * G + sum_M = sum(M_sizes) + offsets = torch.tensor( + [sum(M_sizes[: i + 1]) for i in range(G)], dtype=torch.int32, device=device + ) + + A_base = torch.randn(sum_M, K, device=device, dtype=dtype) + A = A_base + + if layout_A == "offset": + # allocate bigger buffer than needed, use nonzero storage offset + storage = torch.randn(sum_M * K + 512, device=device, dtype=dtype) + offset = 128 # skip first 128 elements + A = torch.as_strided(storage[offset:], (sum_M, K), (K, 1)) + elif layout_A == "padded": + # simulate row pitch > K (row_stride = K + pad) + row_pitch = K + 8 + storage = torch.randn(sum_M * row_pitch, device=device, dtype=dtype) + A = torch.as_strided(storage, (sum_M, K), (row_pitch, 1)) + elif layout_A == "view": + A_storage = torch.randn(sum_M * K, device=device, dtype=dtype) + A = A_storage.view(sum_M, K) + assert A._base is not None + assert A.shape == (sum_M, K) + + B = torch.randn((G, K, N), dtype=dtype, device=device) * 0.01 + + if layout_B == "broadcasted": + # Broadcast B across groups (zero stride along G) + B = B[0].expand(G, K, N) + assert B.stride(0) == 0 + + def grouped_gemm_fn(A_packed, B_batched, offs): + return torch._grouped_mm(A_packed, B_batched, offs=offs) + + # --- eager --- + c_eager = grouped_gemm_fn(A, B, offsets) + + # --- compiled (CUTE backend) --- + with config.patch( + { + "max_autotune": True, + "max_autotune_gemm_backends": "CUTEDSL", + "test_configs.autotune_choice_name_regex": "cutedsl", + "autotune_fallback_to_aten": False, + } + ): + grouped_gemm_compiled = torch.compile( + grouped_gemm_fn, backend="inductor", dynamic=False + ) + c_compiled = grouped_gemm_compiled(A, B, offsets) + + self.assertEqual(c_eager.dtype, dtype) + self.assertEqual(c_compiled.dtype, dtype) + torch.testing.assert_close(c_eager, c_compiled) + + +if __name__ == "__main__": + run_tests() diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 66eaf69dd59a8..bd1fa7710b06c 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -546,6 +546,10 @@ def prologue_fusion_enabled() -> bool: "TORCHINDUCTOR_MAX_AUTOTUNE_FLEX_SEARCH_SPACE", "DEFAULT" ).upper() # type: ignore[assignment] +cutedsl_enable_autotuning: bool = ( + os.environ.get("CUTEDSL_ENABLE_AUTOTUNING", "0") == "1" +) + # DEPRECATED. This setting is ignored. autotune_fallback_to_aten = False diff --git a/torch/_inductor/kernel/mm_common.py b/torch/_inductor/kernel/mm_common.py index b95073e769f31..eb22b95af2afc 100644 --- a/torch/_inductor/kernel/mm_common.py +++ b/torch/_inductor/kernel/mm_common.py @@ -1,6 +1,8 @@ # mypy: allow-untyped-defs import logging from collections.abc import Sequence +from functools import partial +from pathlib import Path from typing import Any import torch @@ -12,6 +14,7 @@ from .. import config from ..codegen.wrapper import PythonWrapperCodegen from ..ir import _IntLike, Layout, TensorBox +from ..utils import load_template log = logging.getLogger(__name__) @@ -254,3 +257,7 @@ def is_batch_stride_largest_or_zero(mat1, mat2, layout) -> bool: return False return True + + +_KERNEL_TEMPLATE_DIR = Path(__file__).parent / "templates" +load_kernel_template = partial(load_template, template_dir=_KERNEL_TEMPLATE_DIR) diff --git a/torch/_inductor/kernel/mm_grouped.py b/torch/_inductor/kernel/mm_grouped.py index 881c14fd43d0d..0a44b728a5a93 100644 --- a/torch/_inductor/kernel/mm_grouped.py +++ b/torch/_inductor/kernel/mm_grouped.py @@ -1,10 +1,11 @@ # mypy: allow-untyped-defs import logging -from dataclasses import dataclass +from dataclasses import asdict, dataclass from typing import Any, Optional import torch from torch._dynamo.utils import counters +from torch._inductor.codegen.cutedsl.cutedsl_template import CuteDSLTemplate from torch._inductor.runtime.triton_compat import tl from torch._inductor.virtualized import V from torch.utils._triton import has_triton @@ -18,19 +19,25 @@ TritonTemplate, ) from ..utils import ( + ensure_cute_available, get_gpu_shared_memory, get_num_sms, has_free_symbols, use_aten_gemm_kernels, + use_blackwell_cutedsl_grouped_mm, use_triton_template, ) from .mm_common import ( _is_static_problem, check_supported_striding, + load_kernel_template, persistent_grouped_mm_grid, ) +if ensure_cute_available(): + from torch._inductor.template_heuristics.cutedsl import get_groupgemm_configs + log = logging.getLogger(__name__) aten = torch.ops.aten @@ -513,6 +520,11 @@ def do_mma(a, b, accumulator): source=triton_grouped_mm_source, ) +cutedsl_grouped_mm_template = CuteDSLTemplate( + name="grouped_gemm_cutedsl", + source=load_kernel_template("cutedsl_mm_grouped"), +) + def grouped_mm_args( mat1: TensorBox, @@ -714,43 +726,44 @@ def _tuned_grouped_mm_common( # Checking only for the equality of corresponding dims of # multiplicands here, relying on meta function checks for # everything else. + if len(m1_size) == 2: + if len(m2_size) == 2: + m, k1 = m1_size + k2, _ = m2_size + # pyrefly: ignore [missing-attribute] + g = offs.get_size()[0] + V.graph.sizevars.check_equals(k1, k2) + a_is_2d, b_is_2d = True, True + else: + # pyrefly: ignore [missing-attribute] + g1 = offs.layout.size[0] + m, k1 = m1_size + g2, k2, _ = m2_size + g = V.graph.sizevars.check_equals_and_simplify(g1, g2) + V.graph.sizevars.check_equals(k1, k2) + a_is_2d, b_is_2d = True, False + else: + if len(m2_size) == 2: + # pyrefly: ignore [missing-attribute] + g1 = offs.layout.size[0] + g2, m, k1 = m1_size + k2, _ = m2_size + g = V.graph.sizevars.check_equals_and_simplify(g1, g2) + V.graph.sizevars.check_equals(k1, k2) + a_is_2d, b_is_2d = False, True + else: + g1, m, k1 = m1_size + g2, k2, _ = m2_size + g = V.graph.sizevars.check_equals_and_simplify(g1, g2) + V.graph.sizevars.check_equals(k1, k2) + a_is_2d, b_is_2d = False, False + if ( is_nonzero and use_triton_template(layout) and can_use_triton_kernel(mat_a, mat_b, offs, bias, scale_result) ): scaled = scale_a is not None - if len(m1_size) == 2: - if len(m2_size) == 2: - m, k1 = m1_size - k2, _ = m2_size - # pyrefly: ignore [missing-attribute] - g = offs.get_size()[0] - V.graph.sizevars.check_equals(k1, k2) - a_is_2d, b_is_2d = True, True - else: - # pyrefly: ignore [missing-attribute] - g1 = offs.layout.size[0] - m, k1 = m1_size - g2, k2, _ = m2_size - g = V.graph.sizevars.check_equals_and_simplify(g1, g2) - V.graph.sizevars.check_equals(k1, k2) - a_is_2d, b_is_2d = True, False - else: - if len(m2_size) == 2: - # pyrefly: ignore [missing-attribute] - g1 = offs.layout.size[0] - g2, m, k1 = m1_size - k2, _ = m2_size - g = V.graph.sizevars.check_equals_and_simplify(g1, g2) - V.graph.sizevars.check_equals(k1, k2) - a_is_2d, b_is_2d = False, True - else: - g1, m, k1 = m1_size - g2, k2, _ = m2_size - g = V.graph.sizevars.check_equals_and_simplify(g1, g2) - V.graph.sizevars.check_equals(k1, k2) - a_is_2d, b_is_2d = False, False a_is_k_major = mat_a.get_stride()[-1] == 1 b_is_k_major = mat_b.get_stride()[-2] == 1 @@ -788,6 +801,22 @@ def _tuned_grouped_mm_common( **config.kwargs, ) + if use_blackwell_cutedsl_grouped_mm( + mat_a, mat_b, layout, a_is_2d, b_is_2d, offs, bias, scale_result + ): + for config in get_groupgemm_configs(): + kwargs = dict( + ACC_DTYPE="cutlass.Float32", + ) + + cutedsl_grouped_mm_template.maybe_append_choice( + choices, + input_nodes=input_nodes, + layout=layout, + **kwargs, + **asdict(config), + ) + input_gen_fns = { 4: lambda x: create_offsets( x, m1_size, m2_size, offs.get_size() if offs is not None else None diff --git a/torch/_inductor/kernel/templates/cutedsl_mm_grouped.py.jinja b/torch/_inductor/kernel/templates/cutedsl_mm_grouped.py.jinja new file mode 100644 index 0000000000000..989f297c5f80f --- /dev/null +++ b/torch/_inductor/kernel/templates/cutedsl_mm_grouped.py.jinja @@ -0,0 +1,333 @@ +import functools +from torch._inductor.runtime.runtime_utils import ceildiv +from cutlass.utils import TensorMapUpdateMode +{{gen_defines()}} +# ---- Import GroupedGemm implementation, copied on PyTorch build from Cutlass repository: cutlass/examples/python/CuTeDSL/blackwell/grouped_gemm.py ---- +from torch._inductor.kernel.vendored_templates.cutedsl_grouped_gemm import ( + GroupedGemmKernel, +) + + +# Note about caching: +# Each instantiated CuTeDSL grouped GEMM kernel file generated by Inductor +# maintains its own local caching system. At this stage, all compile-time +# constexprs (e.g., TILE_M, TILE_N, CLUSTER_M/N, USE_2_CTA) and the kernel +# name itself ({{kernel_name}}) are permanently baked into the file, so they +# do not need to be included in any cache key. +# +# The caching mechanism is split into two levels: +# +# 1. prep_cache +# Caches the compiled executor for build_group_ptrs_from_bases(). This +# kernel depends only on the tensor shapes, strides, and dtypes of A/B/C, +# and can therefore be safely reused across runs with different group +# partitioning (`offs`). +# +# 2. gemm_cache +# Caches the compiled Grouped GEMM executor. Its key extends the prep +# cache key with hardware- and grid-specific parameters: +# (prep_cache_key, max_active_clusters, total_num_clusters). +# This is necessary because different `offs` tensors can change the +# per-group problem sizes and thus alter `total_num_clusters`, which in +# turn changes the grid shape and persistent scheduler configuration. +# Kernels compiled for one grid cannot be safely reused for another. +# +# +# Additionally, note the @lru_cache decorator on get_hardware_info(). Empirically, +# hw.get_max_active_clusters() triggers significant MLIR recompilation overhead, +# despite depending only on the GPU type. We cache this function to mitigate +# redundant recompiles even when shape/stride/dtype cache misses force kernel +# regeneration. A follow-up study will investigate the root cause. + +prep_cache = {} +gemm_cache = {} + + +@functools.lru_cache +def get_hardware_info(): + hw = cutlass.utils.HardwareInfo() + sm_count = hw.get_max_active_clusters(1) + max_active_clusters = hw.get_max_active_clusters(CLUSTER_M * CLUSTER_N) + + return (sm_count, max_active_clusters) + + +def get_prep_cache_key(input_a, input_b, output): + """ + Returns a tuple key for caching the preprocessing kernel executor based on kernel name, + shapes, strides, and dtypes of input/output tensors. + """ + return ( + tuple(input_a.shape), + tuple(input_a.stride()), + input_a.dtype, + tuple(input_b.shape), + tuple(input_b.stride()), + input_b.dtype, + tuple(output.shape), + tuple(output.stride()), + output.dtype, + ) + + +def get_gemm_cache_key(prep_cache_key, max_active_clusters, total_num_clusters): + """ + Returns a tuple key for caching the gemm kernel executor by extending the + prep cache key with hardware- and grid-specific parameters. + """ + return ( + prep_cache_key, + max_active_clusters, + total_num_clusters, + ) + + +@cute.kernel +def build_group_ptrs_from_bases_kernel( + base_A_u64: cutlass.Int64, # device addr of input_a (bytes) + base_B_u64: cutlass.Int64, # device addr of input_b (bytes) + base_C_u64: cutlass.Int64, # device addr of Output (bytes) + offs: cute.Tensor, # [G], cutlass.Int32/64 cumulative + K: cutlass.Constexpr, + N: cutlass.Constexpr, + sizeof_element: cutlass.Int32, # bytes + # -------- STRIDES (in ELEMENTS) -------- + stride_A_m_elems: cutlass.Constexpr, # A.stride(0) + stride_A_k_elems: cutlass.Constexpr, # A.stride(1) + stride_B0_elems: cutlass.Constexpr, # B.stride(0) + stride_Bk_elems: cutlass.Constexpr, # B.stride(1) + stride_Bn_elems: cutlass.Constexpr, # B.stride(2) + stride_C_m_elems: cutlass.Constexpr, # C.stride(0) + stride_C_n_elems: cutlass.Constexpr, # C.stride(1) + # -------- OUTPUTS -------- + out_ptrs: cute.Tensor, # [G,3] cutlass.Int64: (A_ptr, B_ptr, C_ptr) + out_problem: cute.Tensor, # [G,4] cutlass.Int32: (m_g, n, k, 1) + out_strides_abc: cute.Tensor, # [G,3,2] cutlass.Int32 [[A_m,A_k],[B_n,B_k],[C_m,C_n]] +): + tidx, _, _ = cute.arch.thread_idx() + g = tidx + + m_beg_i32 = 0 + if g > 0: + m_beg_i32 = offs[g - 1] + m_end_i32 = offs[g] + m_g_i32 = m_end_i32 - m_beg_i32 + + a_byte_off = ( + cutlass.Int64(m_beg_i32) * stride_A_m_elems * cutlass.Int64(sizeof_element) + ) + c_byte_off = ( + cutlass.Int64(m_beg_i32) * stride_C_m_elems * cutlass.Int64(sizeof_element) + ) + b_byte_off = cutlass.Int64(g) * stride_B0_elems * cutlass.Int64(sizeof_element) + + # ---- pointers ---- + out_ptrs[g, 0] = base_A_u64 + a_byte_off + out_ptrs[g, 1] = base_B_u64 + b_byte_off + out_ptrs[g, 2] = base_C_u64 + c_byte_off + + # ---- (m, n, k, 1) ---- + out_problem[g, 0] = m_g_i32 + out_problem[g, 1] = N + out_problem[g, 2] = K + out_problem[g, 3] = cutlass.Int32(1) + + # ---- strides ---- + out_strides_abc[g, 0, 0] = cutlass.Int32(stride_A_m_elems) + out_strides_abc[g, 0, 1] = cutlass.Int32(stride_A_k_elems) + out_strides_abc[g, 1, 0] = cutlass.Int32(stride_Bn_elems) + out_strides_abc[g, 1, 1] = cutlass.Int32(stride_Bk_elems) + out_strides_abc[g, 2, 0] = cutlass.Int32(stride_C_m_elems) + out_strides_abc[g, 2, 1] = cutlass.Int32(stride_C_n_elems) + + +@cute.jit +def launch_build_group_ptrs_from_bases( + base_A_u64: cutlass.Int64, + base_B_u64: cutlass.Int64, + base_C_u64: cutlass.Int64, + offs: cute.Tensor, + G: cutlass.Constexpr, + K: cutlass.Constexpr, + N: cutlass.Constexpr, + sizeof_element: cutlass.Constexpr, + stride_A_m_elems: cutlass.Constexpr, + stride_A_k_elems: cutlass.Constexpr, + stride_B0_elems: cutlass.Constexpr, + stride_Bk_elems: cutlass.Constexpr, + stride_Bn_elems: cutlass.Constexpr, + stride_C_m_elems: cutlass.Constexpr, + stride_C_n_elems: cutlass.Constexpr, + out_ptrs: cute.Tensor, # [G,3] cutlass.Int64 + out_problem: cute.Tensor, # [G,4] cutlass.Int32 + out_strides_abc: cute.Tensor, # [3,2] cutlass.Int32 + stream: cuda.CUstream, +): + build_group_ptrs_from_bases_kernel( + base_A_u64, + base_B_u64, + base_C_u64, + offs, + K, + N, + sizeof_element, + stride_A_m_elems, + stride_A_k_elems, + stride_B0_elems, + stride_Bk_elems, + stride_Bn_elems, + stride_C_m_elems, + stride_C_n_elems, + out_ptrs, + out_problem, + out_strides_abc, + ).launch(grid=(1, 1, 1), block=(G, 1, 1), stream=stream) + + +{{def_kernel("input_a", "input_b", "input_a_offs")}} + stream = cuda.CUstream(stream) + + input_b = input_b.transpose(1, 2) + + sumM, K = input_a.shape + G, N, Kb = input_b.shape + + dev = input_a.device + + base_A_u64 = int(input_a.data_ptr()) + base_B_u64 = int(input_b.data_ptr()) + base_C_u64 = int({{get_output()}}.data_ptr()) + + ptrs_t = torch.empty((G, 3), device=dev, dtype=torch.int64) + probs_t = torch.empty((G, 4), device=dev, dtype=torch.int32) + strides_t = torch.empty((G, 3, 2), device=dev, dtype=torch.int32) + ptrs = from_dlpack(ptrs_t) + probs = from_dlpack(probs_t) + strides = from_dlpack(strides_t) + + prep_cache_key = get_prep_cache_key(input_a, input_b, {{get_output()}}) + prep_executor = prep_cache.get(prep_cache_key) + + if prep_executor is None: + sizeof_element = int(input_a.element_size()) + sA_m, sA_k = map(int, input_a.stride()) + sB_0, sB_n, sB_k = map(int, input_b.stride()) + sC_m, sC_n = map(int, {{get_output()}}.stride()) + + prep_executor = cute.compile( + launch_build_group_ptrs_from_bases, + base_A_u64=base_A_u64, + base_B_u64=base_B_u64, + base_C_u64=base_C_u64, + offs=from_dlpack(input_a_offs), + G=int(G), + K=int(K), + N=int(N), + sizeof_element=sizeof_element, + stride_A_m_elems=sA_m, + stride_A_k_elems=sA_k, + stride_B0_elems=sB_0, + stride_Bk_elems=sB_k, + stride_Bn_elems=sB_n, + stride_C_m_elems=sC_m, + stride_C_n_elems=sC_n, + out_ptrs=ptrs, + out_problem=probs, + out_strides_abc=strides, + stream=stream, + ) + + prep_cache[prep_cache_key] = prep_executor + + prep_executor( + base_A_u64=base_A_u64, + base_B_u64=base_B_u64, + base_C_u64=base_C_u64, + offs=from_dlpack(input_a_offs), + out_ptrs=ptrs, + out_problem=probs, + out_strides_abc=strides, + stream=stream, + ) + + # --- Tensormap workspace per SM --- + num_tensormap_buffers, max_active_clusters = get_hardware_info() + tensormap_shape = ( + num_tensormap_buffers, + GroupedGemmKernel.num_tensormaps, + GroupedGemmKernel.bytes_per_tensormap // 8, + ) + tensormap_workspace_t = torch.empty(tensormap_shape, device=dev, dtype=torch.int64) + tensormap_workspace = from_dlpack(tensormap_workspace_t) + + # --- Total clusters --- + def compute_total_num_clusters( + problem_sizes_mnkl, + cluster_tile_shape_mn, + ): + total_num_clusters = 0 + for m, n, _, _ in problem_sizes_mnkl: + num_clusters_mn = tuple( + ceildiv(x, y) for x, y in zip((m, n), cluster_tile_shape_mn) + ) + total_num_clusters += functools.reduce(lambda x, y: x * y, num_clusters_mn) + return total_num_clusters + + # Compute cluster tile shape + def compute_cluster_tile_shape( + mma_tiler_mn, + cluster_shape_mn, + use_2cta_instrs, + ): + cta_tile_shape_mn = list(mma_tiler_mn) + if use_2cta_instrs: + cta_tile_shape_mn[0] = cta_tile_shape_mn[0] // 2 + return tuple(x * y for x, y in zip(cta_tile_shape_mn, cluster_shape_mn)) + + cluster_tile_shape_mn = compute_cluster_tile_shape( + (TILE_M, TILE_N), (CLUSTER_M, CLUSTER_N), bool(USE_2_CTA) + ) + + total_num_clusters = int(compute_total_num_clusters(probs_t, cluster_tile_shape_mn)) + + gemm_cache_key = get_gemm_cache_key( + prep_cache_key, max_active_clusters, total_num_clusters + ) + gemm_executor = gemm_cache.get(gemm_cache_key) + + if gemm_executor is None: + grouped_gemm = GroupedGemmKernel( + acc_dtype=ACC_DTYPE, + use_2cta_instrs=USE_2_CTA, + mma_tiler_mn=(TILE_M, TILE_N), + cluster_shape_mn=(CLUSTER_M, CLUSTER_N), + tensormap_update_mode=TENSORMAP_UPDATE_MODE, + ) + + gemm_executor = cute.compile( + grouped_gemm, + from_dlpack(input_a.unsqueeze(-1), assumed_align=16), + from_dlpack(input_b[0].unsqueeze(-1), assumed_align=16), + from_dlpack({{get_output()}}.unsqueeze(-1), assumed_align=16), + G, + probs, + strides, + ptrs, + total_num_clusters, + tensormap_workspace, + max_active_clusters, + stream, + ) + + gemm_cache[gemm_cache_key] = gemm_executor + + gemm_executor( + from_dlpack(input_a.unsqueeze(-1), assumed_align=16), + from_dlpack(input_b[0].unsqueeze(-1), assumed_align=16), + from_dlpack({{get_output()}}.unsqueeze(-1), assumed_align=16), + probs, + strides, + ptrs, + tensormap_workspace, + stream, + ) diff --git a/torch/_inductor/template_heuristics/cutedsl.py b/torch/_inductor/template_heuristics/cutedsl.py new file mode 100644 index 0000000000000..db337b9d8a271 --- /dev/null +++ b/torch/_inductor/template_heuristics/cutedsl.py @@ -0,0 +1,141 @@ +from dataclasses import dataclass +from enum import auto, Enum +from itertools import product + +import torch._inductor.config as config + + +class TensorMapUpdateMode(Enum): + """Enum mirroring cutlass.utils.TensorMapUpdateMode to decouple this file from a cutlass dependency.""" + + SMEM = auto() + GMEM = auto() + + +@dataclass(frozen=True) +class CuTeGemmConfig: + TILE_M: int = 128 + TILE_N: int = 192 + CLUSTER_M: int = 2 + CLUSTER_N: int = 1 + USE_2_CTA: bool = False + TENSORMAP_UPDATE_MODE: TensorMapUpdateMode = TensorMapUpdateMode.SMEM + + +def get_exhaustive_groupgemm_configs() -> list[CuTeGemmConfig]: + """ + Returns the exhaustive configuration set for the Blackwell CuTeDSL Grouped GEMM kernel. + For information regarding valid config sets, see: + https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/blackwell/grouped_gemm.py + """ + + # Tile_n is always the same regardless of 2cta + tile_n_vals = [32, 64, 96, 128, 160, 192, 224, 256] + + # Valid clusters + clusters_no_2cta = [ + (1, 1), + (1, 2), + (1, 4), + (1, 8), + (1, 16), + (2, 1), + (2, 2), + (2, 4), + (2, 8), + (4, 1), + (4, 2), + (4, 4), + (8, 1), + (8, 2), + (16, 1), + ] + clusters_2cta = [ + (2, 1), + (2, 2), + (2, 4), + (2, 8), + (4, 1), + (4, 2), + (4, 4), + (8, 1), + (8, 2), + (16, 1), + ] + + configs: list[CuTeGemmConfig] = [] + + for use_2cta, cluster_set, tile_m_range in [ + (False, clusters_no_2cta, [64, 128]), + (True, clusters_2cta, [128, 256]), + ]: + for tensormap_update_mode, tile_m, tile_n, (cluster_m, cluster_n) in product( + [TensorMapUpdateMode.SMEM, TensorMapUpdateMode.GMEM], + tile_m_range, + tile_n_vals, + cluster_set, + ): + configs.append( + CuTeGemmConfig( + tile_m, + tile_n, + cluster_m, + cluster_n, + USE_2_CTA=use_2cta, + TENSORMAP_UPDATE_MODE=tensormap_update_mode, + ) + ) + + return configs + + +def get_default_groupgemm_configs() -> list[CuTeGemmConfig]: + """ + Returns the default configuration set for the Blackwell CuTeDSL Grouped GEMM kernel. + """ + + config_tuples = [ + (128, 256, 2, 1, False, TensorMapUpdateMode.SMEM), + (256, 160, 2, 1, True, TensorMapUpdateMode.GMEM), + (256, 256, 2, 1, True, TensorMapUpdateMode.GMEM), + (64, 32, 1, 1, False, TensorMapUpdateMode.GMEM), + (64, 256, 1, 2, False, TensorMapUpdateMode.SMEM), + (128, 256, 1, 2, False, TensorMapUpdateMode.SMEM), + (256, 256, 2, 2, True, TensorMapUpdateMode.GMEM), + (128, 256, 1, 2, False, TensorMapUpdateMode.GMEM), + (64, 32, 1, 1, False, TensorMapUpdateMode.SMEM), + (256, 256, 2, 1, True, TensorMapUpdateMode.SMEM), + (128, 256, 1, 1, False, TensorMapUpdateMode.GMEM), + (256, 256, 8, 1, True, TensorMapUpdateMode.GMEM), + (64, 32, 1, 2, False, TensorMapUpdateMode.SMEM), + (256, 192, 2, 1, True, TensorMapUpdateMode.GMEM), + (256, 256, 2, 2, True, TensorMapUpdateMode.SMEM), + (128, 96, 1, 2, False, TensorMapUpdateMode.SMEM), + (64, 192, 1, 1, False, TensorMapUpdateMode.SMEM), + (64, 64, 1, 1, False, TensorMapUpdateMode.GMEM), + (64, 192, 1, 1, False, TensorMapUpdateMode.GMEM), + (128, 64, 1, 1, False, TensorMapUpdateMode.GMEM), + (64, 160, 1, 1, False, TensorMapUpdateMode.GMEM), + (64, 256, 1, 1, False, TensorMapUpdateMode.GMEM), + ] + + return [CuTeGemmConfig(*args) for args in config_tuples] + + +def get_groupgemm_configs() -> list[CuTeGemmConfig]: + """ + Returns the configuration set for the Blackwell CuTeDSL Grouped GEMM kernel. + + Note: CuTeDSL autotuning is still experimental — enabling it may trigger kernel launch failures + or unstable results. By default, autotuning is disabled and we return only + a single baseline config. + """ + if ( + config.cutedsl_enable_autotuning + and config.max_autotune_gemm_search_space == "EXHAUSTIVE" + ): + return get_exhaustive_groupgemm_configs() + elif config.cutedsl_enable_autotuning: + return get_default_groupgemm_configs() + else: + return [get_default_groupgemm_configs()[0]] diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 3f8652882af79..efdb4a9a58912 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -1975,6 +1975,84 @@ def use_triton_blackwell_tma_template( return has_triton_tensor_descriptor_host_tma() and is_datacenter_blackwell_arch() +@functools.lru_cache(maxsize=1) +def ensure_cute_available() -> bool: + """Check if CuTeDSL is importable; cache the result for reuse. + + Call ensure_cute_available.cache_clear() after installing CuTeDSL + in the same interpreter to retry the import. + """ + try: + return importlib.util.find_spec("cutlass.cute") is not None + except ImportError: + return False + + +def use_blackwell_cutedsl_grouped_mm( + mat_a: Any, + mat_b: Any, + layout: Layout, + a_is_2d: bool, + b_is_2d: bool, + offs: Optional[Any], + bias: Optional[Any], + scale_result: Optional[Any], +) -> bool: + """ + Returns True if we can use the blackwell kernel for grouped mm. + Required conditions: + 1. CuTeDSL backend is enabled + 2. CuTeDSL is available + 3. We are on a blackwell arch + 4. The dtype is bf16 + 5. Max autotune or max autotune gemm is enabled + 6. A, B, and the output are 16B aligned + 7. We are not using dynamic shapes + 8. A is 2d + 9. B is 3d + 10. Offsets are provided + 11. Bias and Scale are not provided + """ + if not ensure_cute_available(): + return False + + if not _use_autotune_backend("CUTEDSL"): + return False + + from .codegen.cuda.cuda_env import is_datacenter_blackwell_arch + + if not is_gpu(layout.device.type): + return False + + if not is_datacenter_blackwell_arch(): + return False + + layout_dtypes = [torch.bfloat16] + if not _use_template_for_gpu(layout, layout_dtypes): + return False + + if not (config.max_autotune or config.max_autotune_gemm): + return False + + # Checks for 16B ptr and stride alignment + if not can_use_tma(mat_a, mat_b, output_layout=layout): + return False + + if any(is_dynamic(x) for x in [mat_a, mat_b]): + return False + + if not a_is_2d or b_is_2d: + return False + + if offs is None: + return False + + if bias is not None or scale_result is not None: + return False + + return True + + def use_cutlass_template(layout: Layout, m: int, n: int, k: int) -> bool: from .virtualized import V From edd8d356b6d9a00cfa34fa323578e5cf1c7e0463 Mon Sep 17 00:00:00 2001 From: arkadip-maitra Date: Wed, 5 Nov 2025 08:07:42 +0000 Subject: [PATCH 057/130] fixes keyerror when loading parameter with unsaved optimizer state (#165228) Fixes #164257 Pull Request resolved: https://github.com/pytorch/pytorch/pull/165228 Approved by: https://github.com/fegin --- torch/distributed/checkpoint/state_dict.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/torch/distributed/checkpoint/state_dict.py b/torch/distributed/checkpoint/state_dict.py index 16d988a79103e..9202851537fba 100644 --- a/torch/distributed/checkpoint/state_dict.py +++ b/torch/distributed/checkpoint/state_dict.py @@ -1007,7 +1007,14 @@ def _split_optim_state_dict( raise AssertionError(f"Expected list, got {type(params)}") params.append(fqn) if param.requires_grad: - state[fqn] = cast(DictValueType, optim_state_dict[_STATE])[fqn] + if fqn in cast(DictValueType, optim_state_dict[_STATE]): + state[fqn] = cast(DictValueType, optim_state_dict[_STATE])[fqn] + elif info.strict: + raise RuntimeError( + f"Missing optimizer state for parameter '{fqn}' in checkpoint. " + "The parameter requires gradients but has no saved optimizer state. " + "To load anyway, use StateDictOptions(strict=False)." + ) for loaded_param_group in cast( ListDictValueType, optim_state_dict[_PG] ): From 0b4dd08e047bda63e1e8dc78f52bcda51562caa5 Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Tue, 4 Nov 2025 19:59:19 -0800 Subject: [PATCH 058/130] [dynamo] Introduce _set_lru_cache (#167038) Addresses the short-term plan for https://github.com/pytorch/pytorch/issues/166926. This PR can't be defaulted on, that would be terrible for cache look up times. There's a proper fix in the works by @williamwen42. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167038 Approved by: https://github.com/williamwen42 --- test/dynamo/test_repros.py | 88 +++++++++++++++++++++++++++++++ torch/csrc/dynamo/extra_state.cpp | 27 ++++++++-- torch/csrc/dynamo/extra_state.h | 1 + torch/csrc/dynamo/init.cpp | 1 + 4 files changed, 114 insertions(+), 3 deletions(-) diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index c6138f7574fd4..f3766fe0c973e 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -48,6 +48,7 @@ CompileCounter, CompileCounterWithBackend, EagerAndRecordGraphs, + expectedFailureDynamic, rand_strided, same, skipIfNotPy312, @@ -7455,6 +7456,93 @@ def forward(self, x): msg, ) + @expectedFailureDynamic + def test_dynamo_default_lru_cache_behavior(self): + @torch.compile(backend="eager") + def fn(x): + return x + 10 + + torch._dynamo.reset() + assert not torch._C._dynamo.eval_frame._debug_get_cache_entry_list( + fn._torchdynamo_orig_callable.__code__ + ) + + # Step 1: Compile a static shapes graph + x = torch.randn(10, 10) + fn(x) + a = torch._C._dynamo.eval_frame._debug_get_cache_entry_list( + fn._torchdynamo_orig_callable.__code__ + ) + self.assertEqual(len(a), 1) + static_shapes_cache_entry = a[0] + + # Step 2: Compile a dynamic shapes graph + y = torch.randn(20, 20) + fn(y) + b = torch._C._dynamo.eval_frame._debug_get_cache_entry_list( + fn._torchdynamo_orig_callable.__code__ + ) + self.assertEqual(len(b), 2) + self.assertEqual(b[1], static_shapes_cache_entry) + dynamic_shapes_cache_entry = b[0] + + # Step 3: Run with Step 1's inputs + # LRU cache will match against dynamic shape graph first + fn(x) + c = torch._C._dynamo.eval_frame._debug_get_cache_entry_list( + fn._torchdynamo_orig_callable.__code__ + ) + self.assertEqual(len(c), 2) + self.assertEqual(c[0], dynamic_shapes_cache_entry) + self.assertEqual(c[1], static_shapes_cache_entry) + + @expectedFailureDynamic + def test_dynamo_disable_lru_cache_behavior(self): + @torch.compile(backend="eager") + def fn(x): + return x + 10 + + def run(): + torch._dynamo.reset() + assert not torch._C._dynamo.eval_frame._debug_get_cache_entry_list( + fn._torchdynamo_orig_callable.__code__ + ) + + # Step 1: Compile a static shapes graph + x = torch.randn(10, 10) + fn(x) + a = torch._C._dynamo.eval_frame._debug_get_cache_entry_list( + fn._torchdynamo_orig_callable.__code__ + ) + self.assertEqual(len(a), 1) + static_shapes_cache_entry = a[0] + + # Step 2: Compile a dynamic shapes graph + y = torch.randn(20, 20) + fn(y) + b = torch._C._dynamo.eval_frame._debug_get_cache_entry_list( + fn._torchdynamo_orig_callable.__code__ + ) + self.assertEqual(len(b), 2) + self.assertEqual(b[0], static_shapes_cache_entry) + dynamic_shapes_cache_entry = b[1] + + # Step 3: Run with Step 1's inputs + # LRU cache is disabled, we should still have static entry first + fn(x) + c = torch._C._dynamo.eval_frame._debug_get_cache_entry_list( + fn._torchdynamo_orig_callable.__code__ + ) + self.assertEqual(len(c), 2) + self.assertEqual(c[0], static_shapes_cache_entry) + self.assertEqual(c[1], dynamic_shapes_cache_entry) + + try: + torch._C._dynamo.eval_frame._set_lru_cache(False) + run() + finally: + torch._C._dynamo.eval_frame._set_lru_cache(True) + class ReproTestsDevice(torch._dynamo.test_case.TestCase): def test_sub_alpha_scalar_repro(self, device): diff --git a/torch/csrc/dynamo/extra_state.cpp b/torch/csrc/dynamo/extra_state.cpp index b9dccb456fd65..8dc316b98e63c 100644 --- a/torch/csrc/dynamo/extra_state.cpp +++ b/torch/csrc/dynamo/extra_state.cpp @@ -13,6 +13,11 @@ #define _PyCode_SetExtra PyUnstable_Code_SetExtra #endif +namespace { +// Short-term fix for: https://github.com/pytorch/pytorch/issues/166926 +bool use_lru = true; +} // namespace + Py_ssize_t extra_index = -1; CacheEntry* ExtraState::get_first_entry() { @@ -190,7 +195,9 @@ void lookup( ++index; } if (found) { - extra_state->move_to_front(found); + if (use_lru) { + extra_state->move_to_front(found); + } *maybe_cached_code = found->code.ptr(); *trace_annotation = found->trace_annotation.c_str(); return; @@ -202,8 +209,14 @@ CacheEntry* create_cache_entry( ExtraState* extra_state, PyObject* guarded_code, PyObject* backend) { - extra_state->cache_entry_list.emplace_front(guarded_code, backend); - auto new_iter = extra_state->cache_entry_list.begin(); + std::list::iterator new_iter; + if (use_lru) { + extra_state->cache_entry_list.emplace_front(guarded_code, backend); + new_iter = extra_state->cache_entry_list.begin(); + } else { + extra_state->cache_entry_list.emplace_back(guarded_code, backend); + new_iter = std::prev(extra_state->cache_entry_list.end()); + } new_iter->_owner = extra_state; new_iter->_owner_loc = new_iter; // Set guard_manager references to extra_state and CacheEntry @@ -269,6 +282,14 @@ void _load_precompile_entry( extra->precompile_entries.push_back(std::move(entry)); } +void _set_lru_cache(py::object boolean) { + if (py::cast(boolean)) { + use_lru = true; + } else { + use_lru = false; + } +} + py::list _debug_get_precompile_entries(const py::handle& code_obj) { if (!py::isinstance(code_obj, py::module::import("types").attr("CodeType"))) { throw py::type_error("expected a code object!"); diff --git a/torch/csrc/dynamo/extra_state.h b/torch/csrc/dynamo/extra_state.h index 1630ac90b21dd..bc62e93bf3f1d 100644 --- a/torch/csrc/dynamo/extra_state.h +++ b/torch/csrc/dynamo/extra_state.h @@ -203,5 +203,6 @@ void _load_precompile_entry( py::object guard_manager, py::object dynamo_code); py::list _debug_get_precompile_entries(const py::handle& code_obj); +void _set_lru_cache(py::object boolean); #endif diff --git a/torch/csrc/dynamo/init.cpp b/torch/csrc/dynamo/init.cpp index f1590e19d49cf..790ff9acff3a1 100644 --- a/torch/csrc/dynamo/init.cpp +++ b/torch/csrc/dynamo/init.cpp @@ -254,6 +254,7 @@ void initDynamoBindings(PyObject* torch) { m.def("_reset_precompile_entries", &_reset_precompile_entries); m.def("_load_precompile_entry", &_load_precompile_entry); m.def("_debug_get_precompile_entries", &_debug_get_precompile_entries); + m.def("_set_lru_cache", &_set_lru_cache); py::bind_vector>(m, "VectorUInt8"); init_THPCaches(); if (THP_PyOpcode_Caches != nullptr) { From 5c639466f7b1f9453c2a9c0e25b41c3774a12af8 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 5 Nov 2025 14:30:15 +0000 Subject: [PATCH 059/130] Revert "[Inductor][Grouped Gemm] Add Blackwell CuTeDSL Kernel (#167003)" This reverts commit 658c5f879c37142b1df51c7eb6c5a5bb06318597. Reverted https://github.com/pytorch/pytorch/pull/167003 on behalf of https://github.com/atalman due to regressed vllm signal: [GH job link](https://github.com/pytorch/pytorch/actions/runs/19093785744/job/54553796743) [HUD commit link](https://hud.pytorch.org/pytorch/pytorch/commit/658c5f879c37142b1df51c7eb6c5a5bb06318597) ([comment](https://github.com/pytorch/pytorch/pull/167003#issuecomment-3491527704)) --- .ci/pytorch/test.sh | 2 +- .gitignore | 1 - setup.py | 34 -- test/inductor/test_cutedsl_grouped_mm.py | 154 -------- torch/_inductor/config.py | 4 - torch/_inductor/kernel/mm_common.py | 7 - torch/_inductor/kernel/mm_grouped.py | 93 ++--- .../templates/cutedsl_mm_grouped.py.jinja | 333 ------------------ .../_inductor/template_heuristics/cutedsl.py | 141 -------- torch/_inductor/utils.py | 78 ---- 10 files changed, 33 insertions(+), 814 deletions(-) delete mode 100644 test/inductor/test_cutedsl_grouped_mm.py delete mode 100644 torch/_inductor/kernel/templates/cutedsl_mm_grouped.py.jinja delete mode 100644 torch/_inductor/template_heuristics/cutedsl.py diff --git a/.ci/pytorch/test.sh b/.ci/pytorch/test.sh index 9ae2578758939..26996b5a32d56 100755 --- a/.ci/pytorch/test.sh +++ b/.ci/pytorch/test.sh @@ -337,7 +337,7 @@ test_python() { test_python_smoke() { # Smoke tests for H100/B200 - time python test/run_test.py --include test_matmul_cuda test_scaled_matmul_cuda inductor/test_fp8 inductor/test_max_autotune inductor/test_cutedsl_grouped_mm $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running + time python test/run_test.py --include test_matmul_cuda test_scaled_matmul_cuda inductor/test_fp8 inductor/test_max_autotune $PYTHON_TEST_EXTRA_OPTION --upload-artifacts-while-running assert_git_not_dirty } diff --git a/.gitignore b/.gitignore index 3b4323051073a..d1b3b17445dac 100644 --- a/.gitignore +++ b/.gitignore @@ -127,7 +127,6 @@ torch/test/ torch/utils/benchmark/utils/valgrind_wrapper/callgrind.h torch/utils/benchmark/utils/valgrind_wrapper/valgrind.h torch/version.py -torch/_inductor/kernel/vendored_templates/* minifier_launcher.py aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_fwd_d* aten/src/ATen/native/transformers/hip/flash_attn/ck/fmha_bwd_d* diff --git a/setup.py b/setup.py index dd8a52cbeb7c7..31e78d0245d93 100644 --- a/setup.py +++ b/setup.py @@ -630,37 +630,6 @@ def mirror_files_into_torchgen() -> None: raise RuntimeError("Check the file paths in `mirror_files_into_torchgen()`") -def mirror_inductor_external_kernels() -> None: - """ - Copy external kernels into Inductor so they are importable. - """ - paths = [ - ( - CWD / "torch/_inductor/kernel/vendored_templates/cutedsl_grouped_gemm.py", - CWD - / "third_party/cutlass/examples/python/CuTeDSL/blackwell/grouped_gemm.py", - ), - ] - for new_path, orig_path in paths: - # Create the dirs involved in new_path if they don't exist - if not new_path.exists(): - new_path.parent.mkdir(parents=True, exist_ok=True) - - # Copy the files from the orig location to the new location - if orig_path.is_file(): - shutil.copyfile(orig_path, new_path) - continue - if orig_path.is_dir(): - if new_path.exists(): - # copytree fails if the tree exists already, so remove it. - shutil.rmtree(new_path) - shutil.copytree(orig_path, new_path) - continue - raise RuntimeError( - "Check the file paths in `mirror_inductor_external_kernels()`" - ) - - # ATTENTION: THIS IS AI SLOP def extract_variant_from_version(version: str) -> str: """Extract variant from version string, defaulting to 'cpu'.""" @@ -1647,8 +1616,6 @@ def main() -> None: if RUN_BUILD_DEPS: build_deps() - mirror_inductor_external_kernels() - ( ext_modules, cmdclass, @@ -1682,7 +1649,6 @@ def main() -> None: "_inductor/codegen/aoti_runtime/*.cpp", "_inductor/script.ld", "_inductor/kernel/flex/templates/*.jinja", - "_inductor/kernel/templates/*.jinja", "_export/serde/*.yaml", "_export/serde/*.thrift", "share/cmake/ATen/*.cmake", diff --git a/test/inductor/test_cutedsl_grouped_mm.py b/test/inductor/test_cutedsl_grouped_mm.py deleted file mode 100644 index c26def3a54099..0000000000000 --- a/test/inductor/test_cutedsl_grouped_mm.py +++ /dev/null @@ -1,154 +0,0 @@ -# Owner(s): ["module: inductor"] - - -import unittest - -import torch -from torch import Tensor -from torch._inductor import config -from torch._inductor.codegen.cuda.cuda_env import is_datacenter_blackwell_arch -from torch._inductor.test_case import run_tests, TestCase as InductorTestCase -from torch._inductor.utils import ensure_cute_available -from torch.testing._internal.common_utils import ( - instantiate_parametrized_tests, - parametrize, -) - - -@unittest.skipIf( - not (ensure_cute_available() and is_datacenter_blackwell_arch()), - "CuTeDSL library or Blackwell device not available", -) -@instantiate_parametrized_tests -class TestCuTeDSLGroupedGemm(InductorTestCase): - def _get_inputs( - self, - group_size: int, - M_hint: int, - K: int, - N: int, - device: str, - dtype: torch.dtype, - alignment: int = 16, - ) -> tuple[Tensor, Tensor, Tensor]: - # --- Random, tile-aligned M sizes --- - M_sizes = ( - torch.randint(1, (M_hint // alignment) + 1, (group_size,), dtype=torch.int) - * alignment - ) - - M_total = torch.sum(M_sizes).item() - - # --- Construct input tensors --- - A = torch.randn(int(M_total), K, dtype=dtype, device=device) * 0.1 - B = torch.randn((group_size, K, N), dtype=dtype, device=device) * 0.01 - - # --- Build offsets (no leading zero, strictly increasing) --- - offsets = torch.cumsum(M_sizes, dim=0).to(dtype=torch.int32, device=device) - - return (A, B, offsets) - - @parametrize("group_size", (2, 8)) - @parametrize("M_hint", (256, 1024)) - @parametrize("K", (64, 128)) - @parametrize("N", (128, 256)) - def test_grouped_gemm_basic(self, group_size: int, M_hint: int, K: int, N: int): - device = "cuda" - dtype = torch.bfloat16 - - A, B, offsets = self._get_inputs(group_size, M_hint, K, N, device, dtype) - - def grouped_gemm_fn(A_packed, B_batched, offs): - return torch._grouped_mm(A_packed, B_batched, offs=offs) - - # Eager execution - c_eager = grouped_gemm_fn(A, B, offsets) - - # Test with Cute backend - with config.patch( - { - "max_autotune": True, - "max_autotune_gemm_backends": "CUTEDSL", - "test_configs.autotune_choice_name_regex": "cutedsl", - "autotune_fallback_to_aten": False, - } - ): - grouped_gemm_compiled = torch.compile( - grouped_gemm_fn, backend="inductor", dynamic=False - ) - c_compiled = grouped_gemm_compiled(A, B, offsets) - - self.assertEqual(c_eager.dtype, dtype) - self.assertEqual(c_compiled.dtype, dtype) - torch.testing.assert_close(c_eager, c_compiled) - - @parametrize("layout_A", ("contiguous", "offset", "padded", "view")) - @parametrize("layout_B", ("contiguous", "broadcasted")) - def test_grouped_gemm_assorted_layouts( - self, - layout_A: str, - layout_B: str, - ): - device = "cuda" - dtype = torch.bfloat16 - - G, K, N = 8, 64, 128 - M_sizes = [128] * G - sum_M = sum(M_sizes) - offsets = torch.tensor( - [sum(M_sizes[: i + 1]) for i in range(G)], dtype=torch.int32, device=device - ) - - A_base = torch.randn(sum_M, K, device=device, dtype=dtype) - A = A_base - - if layout_A == "offset": - # allocate bigger buffer than needed, use nonzero storage offset - storage = torch.randn(sum_M * K + 512, device=device, dtype=dtype) - offset = 128 # skip first 128 elements - A = torch.as_strided(storage[offset:], (sum_M, K), (K, 1)) - elif layout_A == "padded": - # simulate row pitch > K (row_stride = K + pad) - row_pitch = K + 8 - storage = torch.randn(sum_M * row_pitch, device=device, dtype=dtype) - A = torch.as_strided(storage, (sum_M, K), (row_pitch, 1)) - elif layout_A == "view": - A_storage = torch.randn(sum_M * K, device=device, dtype=dtype) - A = A_storage.view(sum_M, K) - assert A._base is not None - assert A.shape == (sum_M, K) - - B = torch.randn((G, K, N), dtype=dtype, device=device) * 0.01 - - if layout_B == "broadcasted": - # Broadcast B across groups (zero stride along G) - B = B[0].expand(G, K, N) - assert B.stride(0) == 0 - - def grouped_gemm_fn(A_packed, B_batched, offs): - return torch._grouped_mm(A_packed, B_batched, offs=offs) - - # --- eager --- - c_eager = grouped_gemm_fn(A, B, offsets) - - # --- compiled (CUTE backend) --- - with config.patch( - { - "max_autotune": True, - "max_autotune_gemm_backends": "CUTEDSL", - "test_configs.autotune_choice_name_regex": "cutedsl", - "autotune_fallback_to_aten": False, - } - ): - grouped_gemm_compiled = torch.compile( - grouped_gemm_fn, backend="inductor", dynamic=False - ) - c_compiled = grouped_gemm_compiled(A, B, offsets) - - self.assertEqual(c_eager.dtype, dtype) - self.assertEqual(c_compiled.dtype, dtype) - torch.testing.assert_close(c_eager, c_compiled) - - -if __name__ == "__main__": - run_tests() diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index bd1fa7710b06c..66eaf69dd59a8 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -546,10 +546,6 @@ def prologue_fusion_enabled() -> bool: "TORCHINDUCTOR_MAX_AUTOTUNE_FLEX_SEARCH_SPACE", "DEFAULT" ).upper() # type: ignore[assignment] -cutedsl_enable_autotuning: bool = ( - os.environ.get("CUTEDSL_ENABLE_AUTOTUNING", "0") == "1" -) - # DEPRECATED. This setting is ignored. autotune_fallback_to_aten = False diff --git a/torch/_inductor/kernel/mm_common.py b/torch/_inductor/kernel/mm_common.py index eb22b95af2afc..b95073e769f31 100644 --- a/torch/_inductor/kernel/mm_common.py +++ b/torch/_inductor/kernel/mm_common.py @@ -1,8 +1,6 @@ # mypy: allow-untyped-defs import logging from collections.abc import Sequence -from functools import partial -from pathlib import Path from typing import Any import torch @@ -14,7 +12,6 @@ from .. import config from ..codegen.wrapper import PythonWrapperCodegen from ..ir import _IntLike, Layout, TensorBox -from ..utils import load_template log = logging.getLogger(__name__) @@ -257,7 +254,3 @@ def is_batch_stride_largest_or_zero(mat1, mat2, layout) -> bool: return False return True - - -_KERNEL_TEMPLATE_DIR = Path(__file__).parent / "templates" -load_kernel_template = partial(load_template, template_dir=_KERNEL_TEMPLATE_DIR) diff --git a/torch/_inductor/kernel/mm_grouped.py b/torch/_inductor/kernel/mm_grouped.py index 0a44b728a5a93..881c14fd43d0d 100644 --- a/torch/_inductor/kernel/mm_grouped.py +++ b/torch/_inductor/kernel/mm_grouped.py @@ -1,11 +1,10 @@ # mypy: allow-untyped-defs import logging -from dataclasses import asdict, dataclass +from dataclasses import dataclass from typing import Any, Optional import torch from torch._dynamo.utils import counters -from torch._inductor.codegen.cutedsl.cutedsl_template import CuteDSLTemplate from torch._inductor.runtime.triton_compat import tl from torch._inductor.virtualized import V from torch.utils._triton import has_triton @@ -19,25 +18,19 @@ TritonTemplate, ) from ..utils import ( - ensure_cute_available, get_gpu_shared_memory, get_num_sms, has_free_symbols, use_aten_gemm_kernels, - use_blackwell_cutedsl_grouped_mm, use_triton_template, ) from .mm_common import ( _is_static_problem, check_supported_striding, - load_kernel_template, persistent_grouped_mm_grid, ) -if ensure_cute_available(): - from torch._inductor.template_heuristics.cutedsl import get_groupgemm_configs - log = logging.getLogger(__name__) aten = torch.ops.aten @@ -520,11 +513,6 @@ def do_mma(a, b, accumulator): source=triton_grouped_mm_source, ) -cutedsl_grouped_mm_template = CuteDSLTemplate( - name="grouped_gemm_cutedsl", - source=load_kernel_template("cutedsl_mm_grouped"), -) - def grouped_mm_args( mat1: TensorBox, @@ -726,44 +714,43 @@ def _tuned_grouped_mm_common( # Checking only for the equality of corresponding dims of # multiplicands here, relying on meta function checks for # everything else. - if len(m1_size) == 2: - if len(m2_size) == 2: - m, k1 = m1_size - k2, _ = m2_size - # pyrefly: ignore [missing-attribute] - g = offs.get_size()[0] - V.graph.sizevars.check_equals(k1, k2) - a_is_2d, b_is_2d = True, True - else: - # pyrefly: ignore [missing-attribute] - g1 = offs.layout.size[0] - m, k1 = m1_size - g2, k2, _ = m2_size - g = V.graph.sizevars.check_equals_and_simplify(g1, g2) - V.graph.sizevars.check_equals(k1, k2) - a_is_2d, b_is_2d = True, False - else: - if len(m2_size) == 2: - # pyrefly: ignore [missing-attribute] - g1 = offs.layout.size[0] - g2, m, k1 = m1_size - k2, _ = m2_size - g = V.graph.sizevars.check_equals_and_simplify(g1, g2) - V.graph.sizevars.check_equals(k1, k2) - a_is_2d, b_is_2d = False, True - else: - g1, m, k1 = m1_size - g2, k2, _ = m2_size - g = V.graph.sizevars.check_equals_and_simplify(g1, g2) - V.graph.sizevars.check_equals(k1, k2) - a_is_2d, b_is_2d = False, False - if ( is_nonzero and use_triton_template(layout) and can_use_triton_kernel(mat_a, mat_b, offs, bias, scale_result) ): scaled = scale_a is not None + if len(m1_size) == 2: + if len(m2_size) == 2: + m, k1 = m1_size + k2, _ = m2_size + # pyrefly: ignore [missing-attribute] + g = offs.get_size()[0] + V.graph.sizevars.check_equals(k1, k2) + a_is_2d, b_is_2d = True, True + else: + # pyrefly: ignore [missing-attribute] + g1 = offs.layout.size[0] + m, k1 = m1_size + g2, k2, _ = m2_size + g = V.graph.sizevars.check_equals_and_simplify(g1, g2) + V.graph.sizevars.check_equals(k1, k2) + a_is_2d, b_is_2d = True, False + else: + if len(m2_size) == 2: + # pyrefly: ignore [missing-attribute] + g1 = offs.layout.size[0] + g2, m, k1 = m1_size + k2, _ = m2_size + g = V.graph.sizevars.check_equals_and_simplify(g1, g2) + V.graph.sizevars.check_equals(k1, k2) + a_is_2d, b_is_2d = False, True + else: + g1, m, k1 = m1_size + g2, k2, _ = m2_size + g = V.graph.sizevars.check_equals_and_simplify(g1, g2) + V.graph.sizevars.check_equals(k1, k2) + a_is_2d, b_is_2d = False, False a_is_k_major = mat_a.get_stride()[-1] == 1 b_is_k_major = mat_b.get_stride()[-2] == 1 @@ -801,22 +788,6 @@ def _tuned_grouped_mm_common( **config.kwargs, ) - if use_blackwell_cutedsl_grouped_mm( - mat_a, mat_b, layout, a_is_2d, b_is_2d, offs, bias, scale_result - ): - for config in get_groupgemm_configs(): - kwargs = dict( - ACC_DTYPE="cutlass.Float32", - ) - - cutedsl_grouped_mm_template.maybe_append_choice( - choices, - input_nodes=input_nodes, - layout=layout, - **kwargs, - **asdict(config), - ) - input_gen_fns = { 4: lambda x: create_offsets( x, m1_size, m2_size, offs.get_size() if offs is not None else None diff --git a/torch/_inductor/kernel/templates/cutedsl_mm_grouped.py.jinja b/torch/_inductor/kernel/templates/cutedsl_mm_grouped.py.jinja deleted file mode 100644 index 989f297c5f80f..0000000000000 --- a/torch/_inductor/kernel/templates/cutedsl_mm_grouped.py.jinja +++ /dev/null @@ -1,333 +0,0 @@ -import functools -from torch._inductor.runtime.runtime_utils import ceildiv -from cutlass.utils import TensorMapUpdateMode -{{gen_defines()}} -# ---- Import GroupedGemm implementation, copied on PyTorch build from Cutlass repository: cutlass/examples/python/CuTeDSL/blackwell/grouped_gemm.py ---- -from torch._inductor.kernel.vendored_templates.cutedsl_grouped_gemm import ( - GroupedGemmKernel, -) - - -# Note about caching: -# Each instantiated CuTeDSL grouped GEMM kernel file generated by Inductor -# maintains its own local caching system. At this stage, all compile-time -# constexprs (e.g., TILE_M, TILE_N, CLUSTER_M/N, USE_2_CTA) and the kernel -# name itself ({{kernel_name}}) are permanently baked into the file, so they -# do not need to be included in any cache key. -# -# The caching mechanism is split into two levels: -# -# 1. prep_cache -# Caches the compiled executor for build_group_ptrs_from_bases(). This -# kernel depends only on the tensor shapes, strides, and dtypes of A/B/C, -# and can therefore be safely reused across runs with different group -# partitioning (`offs`). -# -# 2. gemm_cache -# Caches the compiled Grouped GEMM executor. Its key extends the prep -# cache key with hardware- and grid-specific parameters: -# (prep_cache_key, max_active_clusters, total_num_clusters). -# This is necessary because different `offs` tensors can change the -# per-group problem sizes and thus alter `total_num_clusters`, which in -# turn changes the grid shape and persistent scheduler configuration. -# Kernels compiled for one grid cannot be safely reused for another. -# -# -# Additionally, note the @lru_cache decorator on get_hardware_info(). Empirically, -# hw.get_max_active_clusters() triggers significant MLIR recompilation overhead, -# despite depending only on the GPU type. We cache this function to mitigate -# redundant recompiles even when shape/stride/dtype cache misses force kernel -# regeneration. A follow-up study will investigate the root cause. - -prep_cache = {} -gemm_cache = {} - - -@functools.lru_cache -def get_hardware_info(): - hw = cutlass.utils.HardwareInfo() - sm_count = hw.get_max_active_clusters(1) - max_active_clusters = hw.get_max_active_clusters(CLUSTER_M * CLUSTER_N) - - return (sm_count, max_active_clusters) - - -def get_prep_cache_key(input_a, input_b, output): - """ - Returns a tuple key for caching the preprocessing kernel executor based on kernel name, - shapes, strides, and dtypes of input/output tensors. - """ - return ( - tuple(input_a.shape), - tuple(input_a.stride()), - input_a.dtype, - tuple(input_b.shape), - tuple(input_b.stride()), - input_b.dtype, - tuple(output.shape), - tuple(output.stride()), - output.dtype, - ) - - -def get_gemm_cache_key(prep_cache_key, max_active_clusters, total_num_clusters): - """ - Returns a tuple key for caching the gemm kernel executor by extending the - prep cache key with hardware- and grid-specific parameters. - """ - return ( - prep_cache_key, - max_active_clusters, - total_num_clusters, - ) - - -@cute.kernel -def build_group_ptrs_from_bases_kernel( - base_A_u64: cutlass.Int64, # device addr of input_a (bytes) - base_B_u64: cutlass.Int64, # device addr of input_b (bytes) - base_C_u64: cutlass.Int64, # device addr of Output (bytes) - offs: cute.Tensor, # [G], cutlass.Int32/64 cumulative - K: cutlass.Constexpr, - N: cutlass.Constexpr, - sizeof_element: cutlass.Int32, # bytes - # -------- STRIDES (in ELEMENTS) -------- - stride_A_m_elems: cutlass.Constexpr, # A.stride(0) - stride_A_k_elems: cutlass.Constexpr, # A.stride(1) - stride_B0_elems: cutlass.Constexpr, # B.stride(0) - stride_Bk_elems: cutlass.Constexpr, # B.stride(1) - stride_Bn_elems: cutlass.Constexpr, # B.stride(2) - stride_C_m_elems: cutlass.Constexpr, # C.stride(0) - stride_C_n_elems: cutlass.Constexpr, # C.stride(1) - # -------- OUTPUTS -------- - out_ptrs: cute.Tensor, # [G,3] cutlass.Int64: (A_ptr, B_ptr, C_ptr) - out_problem: cute.Tensor, # [G,4] cutlass.Int32: (m_g, n, k, 1) - out_strides_abc: cute.Tensor, # [G,3,2] cutlass.Int32 [[A_m,A_k],[B_n,B_k],[C_m,C_n]] -): - tidx, _, _ = cute.arch.thread_idx() - g = tidx - - m_beg_i32 = 0 - if g > 0: - m_beg_i32 = offs[g - 1] - m_end_i32 = offs[g] - m_g_i32 = m_end_i32 - m_beg_i32 - - a_byte_off = ( - cutlass.Int64(m_beg_i32) * stride_A_m_elems * cutlass.Int64(sizeof_element) - ) - c_byte_off = ( - cutlass.Int64(m_beg_i32) * stride_C_m_elems * cutlass.Int64(sizeof_element) - ) - b_byte_off = cutlass.Int64(g) * stride_B0_elems * cutlass.Int64(sizeof_element) - - # ---- pointers ---- - out_ptrs[g, 0] = base_A_u64 + a_byte_off - out_ptrs[g, 1] = base_B_u64 + b_byte_off - out_ptrs[g, 2] = base_C_u64 + c_byte_off - - # ---- (m, n, k, 1) ---- - out_problem[g, 0] = m_g_i32 - out_problem[g, 1] = N - out_problem[g, 2] = K - out_problem[g, 3] = cutlass.Int32(1) - - # ---- strides ---- - out_strides_abc[g, 0, 0] = cutlass.Int32(stride_A_m_elems) - out_strides_abc[g, 0, 1] = cutlass.Int32(stride_A_k_elems) - out_strides_abc[g, 1, 0] = cutlass.Int32(stride_Bn_elems) - out_strides_abc[g, 1, 1] = cutlass.Int32(stride_Bk_elems) - out_strides_abc[g, 2, 0] = cutlass.Int32(stride_C_m_elems) - out_strides_abc[g, 2, 1] = cutlass.Int32(stride_C_n_elems) - - -@cute.jit -def launch_build_group_ptrs_from_bases( - base_A_u64: cutlass.Int64, - base_B_u64: cutlass.Int64, - base_C_u64: cutlass.Int64, - offs: cute.Tensor, - G: cutlass.Constexpr, - K: cutlass.Constexpr, - N: cutlass.Constexpr, - sizeof_element: cutlass.Constexpr, - stride_A_m_elems: cutlass.Constexpr, - stride_A_k_elems: cutlass.Constexpr, - stride_B0_elems: cutlass.Constexpr, - stride_Bk_elems: cutlass.Constexpr, - stride_Bn_elems: cutlass.Constexpr, - stride_C_m_elems: cutlass.Constexpr, - stride_C_n_elems: cutlass.Constexpr, - out_ptrs: cute.Tensor, # [G,3] cutlass.Int64 - out_problem: cute.Tensor, # [G,4] cutlass.Int32 - out_strides_abc: cute.Tensor, # [3,2] cutlass.Int32 - stream: cuda.CUstream, -): - build_group_ptrs_from_bases_kernel( - base_A_u64, - base_B_u64, - base_C_u64, - offs, - K, - N, - sizeof_element, - stride_A_m_elems, - stride_A_k_elems, - stride_B0_elems, - stride_Bk_elems, - stride_Bn_elems, - stride_C_m_elems, - stride_C_n_elems, - out_ptrs, - out_problem, - out_strides_abc, - ).launch(grid=(1, 1, 1), block=(G, 1, 1), stream=stream) - - -{{def_kernel("input_a", "input_b", "input_a_offs")}} - stream = cuda.CUstream(stream) - - input_b = input_b.transpose(1, 2) - - sumM, K = input_a.shape - G, N, Kb = input_b.shape - - dev = input_a.device - - base_A_u64 = int(input_a.data_ptr()) - base_B_u64 = int(input_b.data_ptr()) - base_C_u64 = int({{get_output()}}.data_ptr()) - - ptrs_t = torch.empty((G, 3), device=dev, dtype=torch.int64) - probs_t = torch.empty((G, 4), device=dev, dtype=torch.int32) - strides_t = torch.empty((G, 3, 2), device=dev, dtype=torch.int32) - ptrs = from_dlpack(ptrs_t) - probs = from_dlpack(probs_t) - strides = from_dlpack(strides_t) - - prep_cache_key = get_prep_cache_key(input_a, input_b, {{get_output()}}) - prep_executor = prep_cache.get(prep_cache_key) - - if prep_executor is None: - sizeof_element = int(input_a.element_size()) - sA_m, sA_k = map(int, input_a.stride()) - sB_0, sB_n, sB_k = map(int, input_b.stride()) - sC_m, sC_n = map(int, {{get_output()}}.stride()) - - prep_executor = cute.compile( - launch_build_group_ptrs_from_bases, - base_A_u64=base_A_u64, - base_B_u64=base_B_u64, - base_C_u64=base_C_u64, - offs=from_dlpack(input_a_offs), - G=int(G), - K=int(K), - N=int(N), - sizeof_element=sizeof_element, - stride_A_m_elems=sA_m, - stride_A_k_elems=sA_k, - stride_B0_elems=sB_0, - stride_Bk_elems=sB_k, - stride_Bn_elems=sB_n, - stride_C_m_elems=sC_m, - stride_C_n_elems=sC_n, - out_ptrs=ptrs, - out_problem=probs, - out_strides_abc=strides, - stream=stream, - ) - - prep_cache[prep_cache_key] = prep_executor - - prep_executor( - base_A_u64=base_A_u64, - base_B_u64=base_B_u64, - base_C_u64=base_C_u64, - offs=from_dlpack(input_a_offs), - out_ptrs=ptrs, - out_problem=probs, - out_strides_abc=strides, - stream=stream, - ) - - # --- Tensormap workspace per SM --- - num_tensormap_buffers, max_active_clusters = get_hardware_info() - tensormap_shape = ( - num_tensormap_buffers, - GroupedGemmKernel.num_tensormaps, - GroupedGemmKernel.bytes_per_tensormap // 8, - ) - tensormap_workspace_t = torch.empty(tensormap_shape, device=dev, dtype=torch.int64) - tensormap_workspace = from_dlpack(tensormap_workspace_t) - - # --- Total clusters --- - def compute_total_num_clusters( - problem_sizes_mnkl, - cluster_tile_shape_mn, - ): - total_num_clusters = 0 - for m, n, _, _ in problem_sizes_mnkl: - num_clusters_mn = tuple( - ceildiv(x, y) for x, y in zip((m, n), cluster_tile_shape_mn) - ) - total_num_clusters += functools.reduce(lambda x, y: x * y, num_clusters_mn) - return total_num_clusters - - # Compute cluster tile shape - def compute_cluster_tile_shape( - mma_tiler_mn, - cluster_shape_mn, - use_2cta_instrs, - ): - cta_tile_shape_mn = list(mma_tiler_mn) - if use_2cta_instrs: - cta_tile_shape_mn[0] = cta_tile_shape_mn[0] // 2 - return tuple(x * y for x, y in zip(cta_tile_shape_mn, cluster_shape_mn)) - - cluster_tile_shape_mn = compute_cluster_tile_shape( - (TILE_M, TILE_N), (CLUSTER_M, CLUSTER_N), bool(USE_2_CTA) - ) - - total_num_clusters = int(compute_total_num_clusters(probs_t, cluster_tile_shape_mn)) - - gemm_cache_key = get_gemm_cache_key( - prep_cache_key, max_active_clusters, total_num_clusters - ) - gemm_executor = gemm_cache.get(gemm_cache_key) - - if gemm_executor is None: - grouped_gemm = GroupedGemmKernel( - acc_dtype=ACC_DTYPE, - use_2cta_instrs=USE_2_CTA, - mma_tiler_mn=(TILE_M, TILE_N), - cluster_shape_mn=(CLUSTER_M, CLUSTER_N), - tensormap_update_mode=TENSORMAP_UPDATE_MODE, - ) - - gemm_executor = cute.compile( - grouped_gemm, - from_dlpack(input_a.unsqueeze(-1), assumed_align=16), - from_dlpack(input_b[0].unsqueeze(-1), assumed_align=16), - from_dlpack({{get_output()}}.unsqueeze(-1), assumed_align=16), - G, - probs, - strides, - ptrs, - total_num_clusters, - tensormap_workspace, - max_active_clusters, - stream, - ) - - gemm_cache[gemm_cache_key] = gemm_executor - - gemm_executor( - from_dlpack(input_a.unsqueeze(-1), assumed_align=16), - from_dlpack(input_b[0].unsqueeze(-1), assumed_align=16), - from_dlpack({{get_output()}}.unsqueeze(-1), assumed_align=16), - probs, - strides, - ptrs, - tensormap_workspace, - stream, - ) diff --git a/torch/_inductor/template_heuristics/cutedsl.py b/torch/_inductor/template_heuristics/cutedsl.py deleted file mode 100644 index db337b9d8a271..0000000000000 --- a/torch/_inductor/template_heuristics/cutedsl.py +++ /dev/null @@ -1,141 +0,0 @@ -from dataclasses import dataclass -from enum import auto, Enum -from itertools import product - -import torch._inductor.config as config - - -class TensorMapUpdateMode(Enum): - """Enum mirroring cutlass.utils.TensorMapUpdateMode to decouple this file from a cutlass dependency.""" - - SMEM = auto() - GMEM = auto() - - -@dataclass(frozen=True) -class CuTeGemmConfig: - TILE_M: int = 128 - TILE_N: int = 192 - CLUSTER_M: int = 2 - CLUSTER_N: int = 1 - USE_2_CTA: bool = False - TENSORMAP_UPDATE_MODE: TensorMapUpdateMode = TensorMapUpdateMode.SMEM - - -def get_exhaustive_groupgemm_configs() -> list[CuTeGemmConfig]: - """ - Returns the exhaustive configuration set for the Blackwell CuTeDSL Grouped GEMM kernel. - For information regarding valid config sets, see: - https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/blackwell/grouped_gemm.py - """ - - # Tile_n is always the same regardless of 2cta - tile_n_vals = [32, 64, 96, 128, 160, 192, 224, 256] - - # Valid clusters - clusters_no_2cta = [ - (1, 1), - (1, 2), - (1, 4), - (1, 8), - (1, 16), - (2, 1), - (2, 2), - (2, 4), - (2, 8), - (4, 1), - (4, 2), - (4, 4), - (8, 1), - (8, 2), - (16, 1), - ] - clusters_2cta = [ - (2, 1), - (2, 2), - (2, 4), - (2, 8), - (4, 1), - (4, 2), - (4, 4), - (8, 1), - (8, 2), - (16, 1), - ] - - configs: list[CuTeGemmConfig] = [] - - for use_2cta, cluster_set, tile_m_range in [ - (False, clusters_no_2cta, [64, 128]), - (True, clusters_2cta, [128, 256]), - ]: - for tensormap_update_mode, tile_m, tile_n, (cluster_m, cluster_n) in product( - [TensorMapUpdateMode.SMEM, TensorMapUpdateMode.GMEM], - tile_m_range, - tile_n_vals, - cluster_set, - ): - configs.append( - CuTeGemmConfig( - tile_m, - tile_n, - cluster_m, - cluster_n, - USE_2_CTA=use_2cta, - TENSORMAP_UPDATE_MODE=tensormap_update_mode, - ) - ) - - return configs - - -def get_default_groupgemm_configs() -> list[CuTeGemmConfig]: - """ - Returns the default configuration set for the Blackwell CuTeDSL Grouped GEMM kernel. - """ - - config_tuples = [ - (128, 256, 2, 1, False, TensorMapUpdateMode.SMEM), - (256, 160, 2, 1, True, TensorMapUpdateMode.GMEM), - (256, 256, 2, 1, True, TensorMapUpdateMode.GMEM), - (64, 32, 1, 1, False, TensorMapUpdateMode.GMEM), - (64, 256, 1, 2, False, TensorMapUpdateMode.SMEM), - (128, 256, 1, 2, False, TensorMapUpdateMode.SMEM), - (256, 256, 2, 2, True, TensorMapUpdateMode.GMEM), - (128, 256, 1, 2, False, TensorMapUpdateMode.GMEM), - (64, 32, 1, 1, False, TensorMapUpdateMode.SMEM), - (256, 256, 2, 1, True, TensorMapUpdateMode.SMEM), - (128, 256, 1, 1, False, TensorMapUpdateMode.GMEM), - (256, 256, 8, 1, True, TensorMapUpdateMode.GMEM), - (64, 32, 1, 2, False, TensorMapUpdateMode.SMEM), - (256, 192, 2, 1, True, TensorMapUpdateMode.GMEM), - (256, 256, 2, 2, True, TensorMapUpdateMode.SMEM), - (128, 96, 1, 2, False, TensorMapUpdateMode.SMEM), - (64, 192, 1, 1, False, TensorMapUpdateMode.SMEM), - (64, 64, 1, 1, False, TensorMapUpdateMode.GMEM), - (64, 192, 1, 1, False, TensorMapUpdateMode.GMEM), - (128, 64, 1, 1, False, TensorMapUpdateMode.GMEM), - (64, 160, 1, 1, False, TensorMapUpdateMode.GMEM), - (64, 256, 1, 1, False, TensorMapUpdateMode.GMEM), - ] - - return [CuTeGemmConfig(*args) for args in config_tuples] - - -def get_groupgemm_configs() -> list[CuTeGemmConfig]: - """ - Returns the configuration set for the Blackwell CuTeDSL Grouped GEMM kernel. - - Note: CuTeDSL autotuning is still experimental — enabling it may trigger kernel launch failures - or unstable results. By default, autotuning is disabled and we return only - a single baseline config. - """ - if ( - config.cutedsl_enable_autotuning - and config.max_autotune_gemm_search_space == "EXHAUSTIVE" - ): - return get_exhaustive_groupgemm_configs() - elif config.cutedsl_enable_autotuning: - return get_default_groupgemm_configs() - else: - return [get_default_groupgemm_configs()[0]] diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index efdb4a9a58912..3f8652882af79 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -1975,84 +1975,6 @@ def use_triton_blackwell_tma_template( return has_triton_tensor_descriptor_host_tma() and is_datacenter_blackwell_arch() -@functools.lru_cache(maxsize=1) -def ensure_cute_available() -> bool: - """Check if CuTeDSL is importable; cache the result for reuse. - - Call ensure_cute_available.cache_clear() after installing CuTeDSL - in the same interpreter to retry the import. - """ - try: - return importlib.util.find_spec("cutlass.cute") is not None - except ImportError: - return False - - -def use_blackwell_cutedsl_grouped_mm( - mat_a: Any, - mat_b: Any, - layout: Layout, - a_is_2d: bool, - b_is_2d: bool, - offs: Optional[Any], - bias: Optional[Any], - scale_result: Optional[Any], -) -> bool: - """ - Returns True if we can use the blackwell kernel for grouped mm. - Required conditions: - 1. CuTeDSL backend is enabled - 2. CuTeDSL is available - 3. We are on a blackwell arch - 4. The dtype is bf16 - 5. Max autotune or max autotune gemm is enabled - 6. A, B, and the output are 16B aligned - 7. We are not using dynamic shapes - 8. A is 2d - 9. B is 3d - 10. Offsets are provided - 11. Bias and Scale are not provided - """ - if not ensure_cute_available(): - return False - - if not _use_autotune_backend("CUTEDSL"): - return False - - from .codegen.cuda.cuda_env import is_datacenter_blackwell_arch - - if not is_gpu(layout.device.type): - return False - - if not is_datacenter_blackwell_arch(): - return False - - layout_dtypes = [torch.bfloat16] - if not _use_template_for_gpu(layout, layout_dtypes): - return False - - if not (config.max_autotune or config.max_autotune_gemm): - return False - - # Checks for 16B ptr and stride alignment - if not can_use_tma(mat_a, mat_b, output_layout=layout): - return False - - if any(is_dynamic(x) for x in [mat_a, mat_b]): - return False - - if not a_is_2d or b_is_2d: - return False - - if offs is None: - return False - - if bias is not None or scale_result is not None: - return False - - return True - - def use_cutlass_template(layout: Layout, m: int, n: int, k: int) -> bool: from .virtualized import V From 59563dfe56a086a4a95025f0ccfe373bc1fd3759 Mon Sep 17 00:00:00 2001 From: Jane Xu Date: Tue, 4 Nov 2025 15:36:00 -0800 Subject: [PATCH 060/130] Refactor out headeronly ArrayRef (#164991) Differential Revision: [D85091961](https://our.internmc.facebook.com/intern/diff/D85091961) Pull Request resolved: https://github.com/pytorch/pytorch/pull/164991 Approved by: https://github.com/swolchok --- c10/util/ArrayRef.h | 238 ++++++----------- test/cpp/aoti_abi_check/CMakeLists.txt | 1 + .../test_headeronlyarrayref.cpp | 52 ++++ torch/header_only_apis.txt | 3 + torch/headeronly/util/HeaderOnlyArrayRef.h | 247 ++++++++++++++++++ 5 files changed, 388 insertions(+), 153 deletions(-) create mode 100644 test/cpp/aoti_abi_check/test_headeronlyarrayref.cpp create mode 100644 torch/headeronly/util/HeaderOnlyArrayRef.h diff --git a/c10/util/ArrayRef.h b/c10/util/ArrayRef.h index 64605f5153595..1311867ef797e 100644 --- a/c10/util/ArrayRef.h +++ b/c10/util/ArrayRef.h @@ -18,6 +18,7 @@ #include #include #include +#include #include #include @@ -40,200 +41,99 @@ namespace c10 { /// /// This is intended to be trivially copyable, so it should be passed by /// value. +/// +/// NOTE: We have refactored out the headeronly parts of the ArrayRef struct +/// into HeaderOnlyArrayRef. As adding `virtual` would change the performance of +/// the underlying constexpr calls, we rely on apparent-type dispatch for +/// inheritance. This should be fine because their memory format is the same, +/// and it is never incorrect for ArrayRef to call HeaderOnlyArrayRef methods. +/// However, you should prefer to use ArrayRef when possible, because its use +/// of TORCH_CHECK will lead to better user-facing error messages. template -class ArrayRef final { +class ArrayRef final : public HeaderOnlyArrayRef { public: - using iterator = const T*; - using const_iterator = const T*; - using size_type = size_t; - using value_type = T; - - using reverse_iterator = std::reverse_iterator; - - private: - /// The start of the array, in an external buffer. - const T* Data; - - /// The number of elements. - size_type Length; - - void debugCheckNullptrInvariant() { - TORCH_INTERNAL_ASSERT_DEBUG_ONLY( - Data != nullptr || Length == 0, - "created ArrayRef with nullptr and non-zero length! std::optional relies on this being illegal"); - } - - public: - /// @name Constructors + /// @name Constructors, all inherited from HeaderOnlyArrayRef except for + /// SmallVector. As inherited constructors won't work with class template + /// argument deduction (CTAD) until C++23, we add deduction guides after + /// the class definition to enable CTAD. /// @{ - /// Construct an empty ArrayRef. - /* implicit */ constexpr ArrayRef() : Data(nullptr), Length(0) {} - - /// Construct an ArrayRef from a single element. - // TODO Make this explicit - constexpr ArrayRef(const T& OneElt) : Data(&OneElt), Length(1) {} - - /// Construct an ArrayRef from a pointer and length. - constexpr ArrayRef(const T* data, size_t length) - : Data(data), Length(length) { - debugCheckNullptrInvariant(); - } - - /// Construct an ArrayRef from a range. - constexpr ArrayRef(const T* begin, const T* end) - : Data(begin), Length(end - begin) { - debugCheckNullptrInvariant(); - } + using HeaderOnlyArrayRef::HeaderOnlyArrayRef; /// Construct an ArrayRef from a SmallVector. This is templated in order to /// avoid instantiating SmallVectorTemplateCommon whenever we /// copy-construct an ArrayRef. + /// NOTE: this is the only constructor that is not inherited from + /// HeaderOnlyArrayRef. template /* implicit */ ArrayRef(const SmallVectorTemplateCommon& Vec) - : Data(Vec.data()), Length(Vec.size()) { - debugCheckNullptrInvariant(); - } - - template < - typename Container, - typename U = decltype(std::declval().data()), - typename = std::enable_if_t< - (std::is_same_v || std::is_same_v)>> - /* implicit */ ArrayRef(const Container& container) - : Data(container.data()), Length(container.size()) { - debugCheckNullptrInvariant(); - } - - /// Construct an ArrayRef from a std::vector. - // The enable_if stuff here makes sure that this isn't used for - // std::vector, because ArrayRef can't work on a std::vector - // bitfield. - template - /* implicit */ ArrayRef(const std::vector& Vec) - : Data(Vec.data()), Length(Vec.size()) { - static_assert( - !std::is_same_v, - "ArrayRef cannot be constructed from a std::vector bitfield."); - } - - /// Construct an ArrayRef from a std::array - template - /* implicit */ constexpr ArrayRef(const std::array& Arr) - : Data(Arr.data()), Length(N) {} - - /// Construct an ArrayRef from a C array. - template - // NOLINTNEXTLINE(*c-arrays*) - /* implicit */ constexpr ArrayRef(const T (&Arr)[N]) : Data(Arr), Length(N) {} - - /// Construct an ArrayRef from a std::initializer_list. - /* implicit */ constexpr ArrayRef(const std::initializer_list& Vec) - : Data( - std::begin(Vec) == std::end(Vec) ? static_cast(nullptr) - : std::begin(Vec)), - Length(Vec.size()) {} + : HeaderOnlyArrayRef(Vec.data(), Vec.size()) {} /// @} - /// @name Simple Operations + /// @name Simple Operations, mostly inherited from HeaderOnlyArrayRef /// @{ - constexpr iterator begin() const { - return Data; - } - constexpr iterator end() const { - return Data + Length; - } - - // These are actually the same as iterator, since ArrayRef only - // gives you const iterators. - constexpr const_iterator cbegin() const { - return Data; - } - constexpr const_iterator cend() const { - return Data + Length; - } - - constexpr reverse_iterator rbegin() const { - return reverse_iterator(end()); - } - constexpr reverse_iterator rend() const { - return reverse_iterator(begin()); - } - - /// Check if all elements in the array satisfy the given expression - constexpr bool allMatch(const std::function& pred) const { - return std::all_of(cbegin(), cend(), pred); - } - - /// empty - Check if the array is empty. - constexpr bool empty() const { - return Length == 0; - } - - constexpr const T* data() const { - return Data; - } - - /// size - Get the array size. - constexpr size_t size() const { - return Length; - } - /// front - Get the first element. + /// We deviate from HeaderOnlyArrayRef by using TORCH_CHECK instead of + /// STD_TORCH_CHECK constexpr const T& front() const { TORCH_CHECK( - !empty(), "ArrayRef: attempted to access front() of empty list"); - return Data[0]; + !this->empty(), "ArrayRef: attempted to access front() of empty list"); + return this->Data[0]; } /// back - Get the last element. + /// We deviate from HeaderOnlyArrayRef by using TORCH_CHECK instead of + /// STD_TORCH_CHECK constexpr const T& back() const { - TORCH_CHECK(!empty(), "ArrayRef: attempted to access back() of empty list"); - return Data[Length - 1]; - } - - /// equals - Check for element-wise equality. - constexpr bool equals(ArrayRef RHS) const { - return Length == RHS.Length && std::equal(begin(), end(), RHS.begin()); + TORCH_CHECK( + !this->empty(), "ArrayRef: attempted to access back() of empty list"); + return this->Data[this->Length - 1]; } /// slice(n, m) - Take M elements of the array starting at element N + /// We deviate from HeaderOnlyArrayRef by using TORCH_CHECK instead of + /// STD_TORCH_CHECK constexpr ArrayRef slice(size_t N, size_t M) const { TORCH_CHECK( - N + M <= size(), + N + M <= this->size(), "ArrayRef: invalid slice, N = ", N, "; M = ", M, "; size = ", - size()); - return ArrayRef(data() + N, M); + this->size()); + return ArrayRef(this->data() + N, M); } /// slice(n) - Chop off the first N elements of the array. + /// We deviate from HeaderOnlyArrayRef by using TORCH_CHECK instead of + /// STD_TORCH_CHECK constexpr ArrayRef slice(size_t N) const { TORCH_CHECK( - N <= size(), "ArrayRef: invalid slice, N = ", N, "; size = ", size()); - return slice(N, size() - N); + N <= this->size(), + "ArrayRef: invalid slice, N = ", + N, + "; size = ", + this->size()); + return slice(N, this->size() - N); // should this slice be this->slice? } /// @} /// @name Operator Overloads /// @{ - constexpr const T& operator[](size_t Index) const { - return Data[Index]; - } /// Vector compatibility + /// We deviate from HeaderOnlyArrayRef by using TORCH_CHECK instead of + /// STD_TORCH_CHECK constexpr const T& at(size_t Index) const { TORCH_CHECK( - Index < Length, + Index < this->Length, "ArrayRef: invalid index Index = ", Index, "; Length = ", - Length); - return Data[Index]; + this->Length); + return this->Data[Index]; } /// Disallow accidental assignment from a temporary. @@ -253,16 +153,48 @@ class ArrayRef final { std::enable_if_t, ArrayRef>& operator=( std::initializer_list) = delete; - /// @} - /// @name Expensive Operations - /// @{ - std::vector vec() const { - return std::vector(Data, Data + Length); - } - /// @} }; +/// Deduction guides for ArrayRef to support CTAD with inherited constructors +/// These mirror the constructors inherited from HeaderOnlyArrayRef +/// @{ + +// Single element constructor +template +ArrayRef(const T&) -> ArrayRef; + +// Pointer and length constructor +template +ArrayRef(const T*, size_t) -> ArrayRef; + +// Range constructor (begin, end) +template +ArrayRef(const T*, const T*) -> ArrayRef; + +// Generic container constructor (anything with .data() and .size()) +template +ArrayRef(const Container&) -> ArrayRef< + std::remove_pointer_t().data())>>; + +// std::vector constructor +template +ArrayRef(const std::vector&) -> ArrayRef; + +// std::array constructor +template +ArrayRef(const std::array&) -> ArrayRef; + +// C array constructor +template +ArrayRef(const T (&)[N]) -> ArrayRef; + +// std::initializer_list constructor +template +ArrayRef(const std::initializer_list&) -> ArrayRef; + +/// @} + template std::ostream& operator<<(std::ostream& out, ArrayRef list) { int i = 0; diff --git a/test/cpp/aoti_abi_check/CMakeLists.txt b/test/cpp/aoti_abi_check/CMakeLists.txt index 4763621f60394..1695e65cb4a1b 100644 --- a/test/cpp/aoti_abi_check/CMakeLists.txt +++ b/test/cpp/aoti_abi_check/CMakeLists.txt @@ -12,6 +12,7 @@ set(AOTI_ABI_CHECK_TEST_SRCS ${AOTI_ABI_CHECK_TEST_ROOT}/test_devicetype.cpp ${AOTI_ABI_CHECK_TEST_ROOT}/test_dtype.cpp ${AOTI_ABI_CHECK_TEST_ROOT}/test_exception.cpp + ${AOTI_ABI_CHECK_TEST_ROOT}/test_headeronlyarrayref.cpp ${AOTI_ABI_CHECK_TEST_ROOT}/test_macros.cpp ${AOTI_ABI_CHECK_TEST_ROOT}/test_math.cpp ${AOTI_ABI_CHECK_TEST_ROOT}/test_rand.cpp diff --git a/test/cpp/aoti_abi_check/test_headeronlyarrayref.cpp b/test/cpp/aoti_abi_check/test_headeronlyarrayref.cpp new file mode 100644 index 0000000000000..184c0ade8360e --- /dev/null +++ b/test/cpp/aoti_abi_check/test_headeronlyarrayref.cpp @@ -0,0 +1,52 @@ +#include + +#include + +#include + +using torch::headeronly::HeaderOnlyArrayRef; + +TEST(TestHeaderOnlyArrayRef, TestEmpty) { + HeaderOnlyArrayRef arr; + ASSERT_TRUE(arr.empty()); +} + +TEST(TestHeaderOnlyArrayRef, TestSingleton) { + float val = 5.0f; + HeaderOnlyArrayRef arr(val); + ASSERT_FALSE(arr.empty()); + EXPECT_EQ(arr.size(), 1); + EXPECT_EQ(arr[0], val); +} + +TEST(TestHeaderOnlyArrayRef, TestAPIs) { + std::vector vec = {1, 2, 3, 4, 5, 6, 7}; + HeaderOnlyArrayRef arr(vec); + ASSERT_FALSE(arr.empty()); + EXPECT_EQ(arr.size(), 7); + for (size_t i = 0; i < arr.size(); i++) { + EXPECT_EQ(arr[i], i + 1); + EXPECT_EQ(arr.at(i), i + 1); + } + EXPECT_EQ(arr.front(), 1); + EXPECT_EQ(arr.back(), 7); + ASSERT_TRUE(arr.slice(3, 4).equals(arr.slice(3))); +} + +TEST(TestHeaderOnlyArrayRef, TestFromInitializerList) { + std::vector vec = {1, 2, 3, 4, 5, 6, 7}; + HeaderOnlyArrayRef arr({1, 2, 3, 4, 5, 6, 7}); + auto res_vec = arr.vec(); + for (size_t i = 0; i < vec.size(); i++) { + EXPECT_EQ(vec[i], res_vec[i]); + } +} + +TEST(TestHeaderOnlyArrayRef, TestFromRange) { + std::vector vec = {1, 2, 3, 4, 5, 6, 7}; + HeaderOnlyArrayRef arr(vec.data() + 3, vec.data() + 7); + auto res_vec = arr.vec(); + for (size_t i = 0; i < res_vec.size(); i++) { + EXPECT_EQ(vec[i + 3], res_vec[i]); + } +} diff --git a/torch/header_only_apis.txt b/torch/header_only_apis.txt index 70165a7493e59..c0cd5d9a2c689 100644 --- a/torch/header_only_apis.txt +++ b/torch/header_only_apis.txt @@ -42,6 +42,9 @@ fp16_ieee_to_fp32_value # fp32_from_bits called from fp16_ieee_to_fp32_value # fp32_to_bits called from fp16_ieee_from_fp32_value +# torch/headeronly/util/HeaderOnlyArrayRef.h +HeaderOnlyArrayRef + # c10/util/complex.h, torch/headeronly/util/complex.h complex diff --git a/torch/headeronly/util/HeaderOnlyArrayRef.h b/torch/headeronly/util/HeaderOnlyArrayRef.h new file mode 100644 index 0000000000000..2387578ab8f5f --- /dev/null +++ b/torch/headeronly/util/HeaderOnlyArrayRef.h @@ -0,0 +1,247 @@ +#pragma once + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace c10 { + +/// HeaderOnlyArrayRef - A subset of ArrayRef that is implemented only +/// in headers. This will be a base class from which ArrayRef inherits, so that +/// we can keep much of the implementation shared. +/// +/// [HeaderOnlyArrayRef vs ArrayRef note] +/// As HeaderOnlyArrayRef is a subset of ArrayRef, it has slightly less +/// functionality than ArrayRef. We document the minor differences below: +/// 1. ArrayRef has an extra convenience constructor for SmallVector. +/// 2. ArrayRef uses TORCH_CHECK. HeaderOnlyArrayRef uses header-only +/// STD_TORCH_CHECK, which will output a std::runtime_error vs a +/// c10::Error. Consequently, you should use ArrayRef when possible +/// and HeaderOnlyArrayRef only when necessary to support headeronly code. +/// In all other aspects, HeaderOnlyArrayRef is identical to ArrayRef, with the +/// positive benefit of being header-only and thus independent of libtorch.so. +template +class HeaderOnlyArrayRef { + public: + using iterator = const T*; + using const_iterator = const T*; + using size_type = size_t; + using value_type = T; + + using reverse_iterator = std::reverse_iterator; + + protected: + /// The start of the array, in an external buffer. + const T* Data; + + /// The number of elements. + size_type Length; + + public: + /// @name Constructors + /// @{ + + /// Construct an empty HeaderOnlyArrayRef. + /* implicit */ constexpr HeaderOnlyArrayRef() : Data(nullptr), Length(0) {} + + /// Construct a HeaderOnlyArrayRef from a single element. + // TODO Make this explicit + constexpr HeaderOnlyArrayRef(const T& OneElt) : Data(&OneElt), Length(1) {} + + /// Construct a HeaderOnlyArrayRef from a pointer and length. + constexpr HeaderOnlyArrayRef(const T* data, size_t length) + : Data(data), Length(length) {} + + /// Construct a HeaderOnlyArrayRef from a range. + constexpr HeaderOnlyArrayRef(const T* begin, const T* end) + : Data(begin), Length(end - begin) {} + + template < + typename Container, + typename U = decltype(std::declval().data()), + typename = std::enable_if_t< + (std::is_same_v || std::is_same_v)>> + /* implicit */ HeaderOnlyArrayRef(const Container& container) + : Data(container.data()), Length(container.size()) {} + + /// Construct a HeaderOnlyArrayRef from a std::vector. + // The enable_if stuff here makes sure that this isn't used for + // std::vector, because ArrayRef can't work on a std::vector + // bitfield. + template + /* implicit */ HeaderOnlyArrayRef(const std::vector& Vec) + : Data(Vec.data()), Length(Vec.size()) { + static_assert( + !std::is_same_v, + "HeaderOnlyArrayRef cannot be constructed from a std::vector bitfield."); + } + + /// Construct a HeaderOnlyArrayRef from a std::array + template + /* implicit */ constexpr HeaderOnlyArrayRef(const std::array& Arr) + : Data(Arr.data()), Length(N) {} + + /// Construct a HeaderOnlyArrayRef from a C array. + template + // NOLINTNEXTLINE(*c-arrays*) + /* implicit */ constexpr HeaderOnlyArrayRef(const T (&Arr)[N]) + : Data(Arr), Length(N) {} + + /// Construct a HeaderOnlyArrayRef from a std::initializer_list. + /* implicit */ constexpr HeaderOnlyArrayRef( + const std::initializer_list& Vec) + : Data( + std::begin(Vec) == std::end(Vec) ? static_cast(nullptr) + : std::begin(Vec)), + Length(Vec.size()) {} + + /// @} + /// @name Simple Operations + /// @{ + + constexpr iterator begin() const { + return this->Data; + } + constexpr iterator end() const { + return this->Data + this->Length; + } + + // These are actually the same as iterator, since ArrayRef only + // gives you const iterators. + constexpr const_iterator cbegin() const { + return this->Data; + } + constexpr const_iterator cend() const { + return this->Data + this->Length; + } + + constexpr reverse_iterator rbegin() const { + return reverse_iterator(end()); + } + constexpr reverse_iterator rend() const { + return reverse_iterator(begin()); + } + + /// Check if all elements in the array satisfy the given expression + constexpr bool allMatch(const std::function& pred) const { + return std::all_of(cbegin(), cend(), pred); + } + + /// empty - Check if the array is empty. + constexpr bool empty() const { + return this->Length == 0; + } + + constexpr const T* data() const { + return this->Data; + } + + /// size - Get the array size. + constexpr size_t size() const { + return this->Length; + } + + /// front - Get the first element. + constexpr const T& front() const { + STD_TORCH_CHECK( + !this->empty(), + "HeaderOnlyArrayRef: attempted to access front() of empty list"); + return this->Data[0]; + } + + /// back - Get the last element. + constexpr const T& back() const { + STD_TORCH_CHECK( + !this->empty(), + "HeaderOnlyArrayRef: attempted to access back() of empty list"); + return this->Data[this->Length - 1]; + } + + /// equals - Check for element-wise equality. + constexpr bool equals(HeaderOnlyArrayRef RHS) const { + return this->Length == RHS.Length && + std::equal(begin(), end(), RHS.begin()); + } + + /// slice(n, m) - Take M elements of the array starting at element N + constexpr HeaderOnlyArrayRef slice(size_t N, size_t M) const { + STD_TORCH_CHECK( + N + M <= this->size(), + "HeaderOnlyArrayRef: invalid slice, N = ", + N, + "; M = ", + M, + "; size = ", + this->size()); + return HeaderOnlyArrayRef(this->data() + N, M); + } + + /// slice(n) - Chop off the first N elements of the array. + constexpr HeaderOnlyArrayRef slice(size_t N) const { + STD_TORCH_CHECK( + N <= this->size(), + "HeaderOnlyArrayRef: invalid slice, N = ", + N, + "; size = ", + this->size()); + return slice(N, this->size() - N); + } + + /// @} + /// @name Operator Overloads + /// @{ + constexpr const T& operator[](size_t Index) const { + return this->Data[Index]; + } + + /// Vector compatibility + constexpr const T& at(size_t Index) const { + STD_TORCH_CHECK( + Index < this->Length, + "HeaderOnlyArrayRef: invalid index Index = ", + Index, + "; Length = ", + this->Length); + return this->Data[Index]; + } + + /// Disallow accidental assignment from a temporary. + /// + /// The declaration here is extra complicated so that "arrayRef = {}" + /// continues to select the move assignment operator. + template + std::enable_if_t, HeaderOnlyArrayRef>& operator=( + // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward) + U&& Temporary) = delete; + + /// Disallow accidental assignment from a temporary. + /// + /// The declaration here is extra complicated so that "arrayRef = {}" + /// continues to select the move assignment operator. + template + std::enable_if_t, HeaderOnlyArrayRef>& operator=( + std::initializer_list) = delete; + + /// @} + /// @name Expensive Operations + /// @{ + std::vector vec() const { + return std::vector(this->Data, this->Data + this->Length); + } + + /// @} +}; + +} // namespace c10 + +namespace torch::headeronly { +using c10::HeaderOnlyArrayRef; +using IntHeaderOnlyArrayRef = HeaderOnlyArrayRef; +} // namespace torch::headeronly From 7a6ff88196e12f9eebc8769d5fcbb8225a047e28 Mon Sep 17 00:00:00 2001 From: Jane Xu Date: Tue, 4 Nov 2025 15:36:01 -0800 Subject: [PATCH 061/130] Widen ops support to take in IntHOArrayRef vs only std::vec (#165152) Pull Request resolved: https://github.com/pytorch/pytorch/pull/165152 Approved by: https://github.com/mikaylagawarecki ghstack dependencies: #164991 --- .../libtorch_agnostic/csrc/kernel.cpp | 12 ++++++------ torch/csrc/stable/ops.h | 17 +++++++---------- 2 files changed, 13 insertions(+), 16 deletions(-) diff --git a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp index 58c812b08cccb..87aaa46e64c95 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp +++ b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp @@ -311,10 +311,9 @@ void boxed_fill_infinity( } Tensor my_pad(Tensor t) { - std::vector padding = {1, 2, 2, 1}; std::string mode = "constant"; double value = 0.0; - return pad(t, padding, mode, value); + return pad(t, {1, 2, 2, 1}, mode, value); } void boxed_my_pad( @@ -342,6 +341,9 @@ void boxed_my_narrow( } Tensor my_new_empty_dtype_variant(Tensor t) { + // Still using a std::vector below even though people can just pass in an + // initializer list (which will be implicitly converted to an HeaderOnlyArrayRef) + // directly. std::vector sizes = {2, 5}; auto dtype = std::make_optional(torch::headeronly::ScalarType::BFloat16); return new_empty(t, sizes, dtype); @@ -353,9 +355,8 @@ void boxed_my_new_empty_dtype_variant(StableIValue* stack, uint64_t num_args, ui } Tensor my_new_zeros_dtype_variant(Tensor t) { - std::vector sizes = {2, 5}; auto dtype = std::make_optional(at::ScalarType::Float); - return new_zeros(t, sizes, dtype); + return new_zeros(t, {2, 5}, dtype); } void boxed_my_new_zeros_dtype_variant(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { @@ -429,8 +430,7 @@ void boxed_my_amax(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) } Tensor my_amax_vec(Tensor t) { - std::vector v = {0,1}; - return amax(t, v, false); + return amax(t, {0,1}, false); } void boxed_my_amax_vec(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { diff --git a/torch/csrc/stable/ops.h b/torch/csrc/stable/ops.h index 5c2959e69ae0b..d5fbba9fbbfd7 100644 --- a/torch/csrc/stable/ops.h +++ b/torch/csrc/stable/ops.h @@ -5,13 +5,13 @@ #include #include #include -#include #include #include #include #include #include +#include HIDDEN_NAMESPACE_BEGIN(torch, stable) @@ -68,7 +68,7 @@ inline torch::stable::Tensor narrow( // only dtype information. inline torch::stable::Tensor new_empty( const torch::stable::Tensor& self, - std::vector size, + torch::headeronly::IntHeaderOnlyArrayRef size, std::optional dtype = std::nullopt) { int32_t device_type; TORCH_ERROR_CODE_CHECK(aoti_torch_get_device_type(self.get(), &device_type)); @@ -107,7 +107,7 @@ inline torch::stable::Tensor new_empty( // only dtype information. inline torch::stable::Tensor new_zeros( const torch::stable::Tensor& self, - std::vector size, + torch::headeronly::IntHeaderOnlyArrayRef size, std::optional dtype = std::nullopt) { int32_t device_type; TORCH_ERROR_CODE_CHECK(aoti_torch_get_device_type(self.get(), &device_type)); @@ -144,12 +144,10 @@ inline torch::stable::Tensor new_zeros( // We expect this to be the stable version of the pad.default op. // pad.default takes in a SymInt[] as the pad argument however pad is typed as -// use std::vector because -// (1) IntArrayRef is not yet header-only -// (2) SymInt is not yet header-only +// torch::headeronly::IntHeaderOnlyArrayRef as SymInt is not yet header-only. inline torch::stable::Tensor pad( const torch::stable::Tensor& self, - std::vector pad, + torch::headeronly::IntHeaderOnlyArrayRef pad, const std::string& mode = "constant", double value = 0.0) { AtenTensorHandle ret0 = nullptr; @@ -181,11 +179,10 @@ inline torch::stable::Tensor amax( // This function is an overload to compute the maximum value along each slice of // `self` reducing over all the dimensions in the vector `dims`. The // amax.default op takes in a SymInt[] as the dims argument, however dims is -// typed as use std::vector here because (1) IntArrayRef is not yet -// header-only (2) SymInt is not yet header-only +// typed as use IntHeaderOnlyArrayRef here because SymInt is not yet header-only inline torch::stable::Tensor amax( const torch::stable::Tensor& self, - std::vector dims, + torch::headeronly::IntHeaderOnlyArrayRef dims, bool keepdim = false) { AtenTensorHandle ret = nullptr; TORCH_ERROR_CODE_CHECK(aoti_torch_aten_amax( From d2d13bf62dc848348196f91d3f104f84ac1e47e7 Mon Sep 17 00:00:00 2001 From: eellison Date: Wed, 5 Nov 2025 05:54:27 -0800 Subject: [PATCH 062/130] Invert unary read and write for fusion (#161404) For [this repro](https://gist.github.com/eellison/75a99616a0fcca0436316bbfd8987fae) enables fusion of `to_blocked` with the prior `to_mx` calculation, so that there is only a single kernel per tensor, resulting in a 10% speedup of the non conversion code (need to update my local devserver to 12.9 to time the matmul as well). The `to_mx` kernel has a contiguous write: ```Py op6_op7: FusedSchedulerNode(SchedulerNode,SchedulerNode) op6_op7.writes = [MemoryDep('buf6', c0, {c0: 2097152}), MemoryDep('buf7', c0, {c0: 67108864})] op6_op7.unmet_dependencies = [] op6_op7.met_dependencies = [MemoryDep('arg1_1', c0, {c0: 67108864})] op6_op7.outputs = [ buf6: ComputedBuffer buf6.layout = FixedLayout('cuda:0', torch.float32, size=[8192, 256], stride=[256, 1]) buf6.users = [ NodeUser(node=SchedulerNode(name='op7'), can_inplace=False, is_weak=False), NodeUser(node=SchedulerNode(name='op9'), can_inplace=False, is_weak=False), ] buf7: ComputedBuffer buf7.layout = FixedLayout('cuda:0', torch.float8_e4m3fn, size=[8192, 256, 32], stride=[8192, 32, 1]) buf7.users = [NodeUser(node=ExternKernelSchedulerNode(name='op10'), can_inplace=False, is_weak=False)] ] ``` While the `to_blocked` has a single discontiguous read and a single contiguous write. ```Py op9: SchedulerNode(ComputedBuffer) op9.writes = [MemoryDep('buf9', c0, {c0: 2097152})] op9.unmet_dependencies = [ MemoryDep('buf6', 32768*((c0//32768)) + 8192*(((ModularIndexing(c0, 1, 16))//4)) + 256*(ModularIndexing(c0, 16, 32)) + 4*(ModularIndexing(c0, 512, 64)) + (ModularIndexing(ModularIndexing(c0, 1, 16), 1, 4)), {c0: 2097152})] op9.met_dependencies = [] op9.outputs = [ buf9: ComputedBuffer buf9.layout = FixedLayout('cuda:0', torch.float8_e8m0fnu, size=[2097152], stride=[1]) buf9.users = [NodeUser(node=ExternKernelSchedulerNode(name='op10'), can_inplace=False, is_weak=False)] ] ``` To enable fusion, we invert the read, giving op9 and contiguous read and discontiguous write. More explanation here: https://gist.github.com/eellison/6f9f4a7ec10a860150b15b719f9285a9 [Tlparse with this optimization](https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/eellison/custom/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000). [Tlparse without this optimization](https://manifold.edge.x2p.facebook.net/v0/read/tree/logs/eellison/custom/index.html?bucketName=tlparse_reports&apiKey=tlparse_reports-key&withPayload=1&timeoutMsec=10000). Pull Request resolved: https://github.com/pytorch/pytorch/pull/161404 Approved by: https://github.com/shunting314 --- test/inductor/test_fp8.py | 236 +++++++++++++++++++++++- test/inductor/test_loop_ordering.py | 108 +++++++++++ torch/_inductor/config.py | 11 ++ torch/_inductor/invert_expr_analysis.py | 208 +++++++++++++++++++++ torch/_inductor/loop_body.py | 4 +- torch/_inductor/scheduler.py | 158 +++++++++++++++- 6 files changed, 722 insertions(+), 3 deletions(-) create mode 100644 torch/_inductor/invert_expr_analysis.py diff --git a/test/inductor/test_fp8.py b/test/inductor/test_fp8.py index f26a2347e4e86..f1067b8ffebb3 100644 --- a/test/inductor/test_fp8.py +++ b/test/inductor/test_fp8.py @@ -6,6 +6,7 @@ import torch from torch import Tensor +from torch._C import FileCheck from torch._inductor import config, utils from torch._inductor.pattern_matcher import PatternMatcherPass from torch._inductor.test_case import run_tests, TestCase @@ -29,7 +30,6 @@ HAS_CPU, HAS_CUDA_AND_TRITON, ) -from torch.testing._internal.jit_utils import FileCheck from torch.utils._triton import has_triton_tma_device @@ -953,6 +953,240 @@ def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias): self.assertEqual(y_compiled.dtype, dtype) torch.testing.assert_close(y_eager, y_compiled, rtol=5e-2, atol=0.07) + @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) + @torch._inductor.config.patch("emulate_precision_casts", True) + def test_mx_fusion(self): + # Register fake_scaled_mm custom op scoped to this test + with torch.library._scoped_library("test_fp8", "FRAGMENT") as lib: + # Define the op schema + lib.define( + "fake_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scale_a, Tensor scale_b, " + "Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, " + "bool use_fast_accum=False) -> Tensor" + ) + input_values = [] + + # Register CUDA implementation + @torch.library.impl(lib, "fake_scaled_mm", "CUDA") + def fake_scaled_mm_impl( + mat_a, + mat_b, + scale_a, + scale_b, + bias=None, + scale_result=None, + out_dtype=None, + use_fast_accum=False, + ): + """Software-emulated scaled_mm for testing without CUDA 12.8""" + out_dtype = out_dtype or torch.bfloat16 + # just using add, because without real dtypes, + # was seeing overflow/instability + nonlocal input_values + input_values.append((mat_a, mat_b, scale_a, scale_b)) + result = mat_a.to(torch.float32) + mat_b.to(torch.float32) + if bias is not None: + result = result + bias.to(torch.float32) + return result.to(out_dtype) + + # Register fake implementation + @torch.library.impl(lib, "fake_scaled_mm", "Meta") + def fake_scaled_mm_meta( + mat_a, + mat_b, + scale_a, + scale_b, + bias=None, + scale_result=None, + out_dtype=None, + use_fast_accum=False, + ): + """FakeTensor implementation""" + out_dtype = out_dtype or torch.bfloat16 + M, K = mat_a.shape + K2, N = mat_b.shape + torch._check( + K == K2, + lambda: f"Incompatible shapes: {mat_a.shape} @ {mat_b.shape}", + ) + return torch.empty((M, N), dtype=out_dtype, device=mat_a.device) + + def forward( + arg0_1, + arg1_1, + ): + view = torch.ops.aten.reshape.default(arg0_1, [8192, 256, 32]) + abs_1 = torch.ops.aten.abs.default(view) + amax = torch.ops.aten.amax.default(abs_1, [-1]) + unsqueeze = torch.ops.aten.unsqueeze.default(amax, -1) + view_1 = torch.ops.aten.view.dtype(unsqueeze, torch.int32) + bitwise_right_shift = torch.ops.aten.bitwise_right_shift.Tensor_Scalar( + view_1, 23 + ) + bitwise_and = torch.ops.aten.bitwise_and.Scalar( + bitwise_right_shift, 255 + ) + sub = torch.ops.aten.sub.Tensor(bitwise_and, 127) + sub_1 = torch.ops.aten.sub.Tensor(sub, 8) + clamp_min = torch.ops.aten.clamp_min.default(sub_1, -127) + clamp_max = torch.ops.aten.clamp_max.default(clamp_min, 128) + add = torch.ops.aten.add.Tensor(clamp_max, 127) + convert_element_type = torch.ops.prims.convert_element_type.default( + add, torch.uint8 + ) + isnan = torch.ops.aten.isnan.default(unsqueeze) + scalar_tensor = torch.ops.aten.scalar_tensor.default( + 255, dtype=torch.uint8, layout=torch.strided, device="cuda" + ) + where = torch.ops.aten.where.self( + isnan, scalar_tensor, convert_element_type + ) + convert_element_type_1 = torch.ops.prims.convert_element_type.default( + where, torch.int32 + ) + bitwise_left_shift = torch.ops.aten.bitwise_left_shift.Tensor_Scalar( + convert_element_type_1, 23 + ) + view_2 = torch.ops.aten.view.dtype(bitwise_left_shift, torch.float32) + clamp_min_1 = torch.ops.aten.clamp_min.default( + view_2, 1.1754943508222875e-38 + ) + div = torch.ops.aten.div.Tensor(view, clamp_min_1) + clamp_min_2 = torch.ops.aten.clamp_min.default(div, -448.0) + clamp_max_1 = torch.ops.aten.clamp_max.default(clamp_min_2, 448.0) + convert_element_type_2 = torch.ops.prims.convert_element_type.default( + clamp_max_1, torch.float8_e4m3fn + ) + view_3 = torch.ops.aten.reshape.default( + convert_element_type_2, [8192, 8192] + ) + convert_element_type_2 = None + view_4 = torch.ops.aten.view.dtype(where, torch.float8_e8m0fnu) + squeeze = torch.ops.aten.squeeze.dim(view_4, -1) + + view_5 = torch.ops.aten.reshape.default(arg1_1, [8192, 256, 32]) + abs_2 = torch.ops.aten.abs.default(view_5) + amax_1 = torch.ops.aten.amax.default(abs_2, [-1]) + unsqueeze_1 = torch.ops.aten.unsqueeze.default(amax_1, -1) + view_6 = torch.ops.aten.view.dtype(unsqueeze_1, torch.int32) + bitwise_right_shift_1 = ( + torch.ops.aten.bitwise_right_shift.Tensor_Scalar(view_6, 23) + ) + bitwise_and_1 = torch.ops.aten.bitwise_and.Scalar( + bitwise_right_shift_1, 255 + ) + sub_2 = torch.ops.aten.sub.Tensor(bitwise_and_1, 127) + sub_3 = torch.ops.aten.sub.Tensor(sub_2, 8) + clamp_min_3 = torch.ops.aten.clamp_min.default(sub_3, -127) + clamp_max_2 = torch.ops.aten.clamp_max.default(clamp_min_3, 128) + add_1 = torch.ops.aten.add.Tensor(clamp_max_2, 127) + convert_element_type_3 = torch.ops.prims.convert_element_type.default( + add_1, torch.uint8 + ) + isnan_1 = torch.ops.aten.isnan.default(unsqueeze_1) + unsqueeze_1 = None + scalar_tensor_1 = torch.ops.aten.scalar_tensor.default( + 255, dtype=torch.uint8, layout=torch.strided, device="cuda" + ) + where_1 = torch.ops.aten.where.self( + isnan_1, scalar_tensor_1, convert_element_type_3 + ) + convert_element_type_4 = torch.ops.prims.convert_element_type.default( + where_1, torch.int32 + ) + bitwise_left_shift_1 = torch.ops.aten.bitwise_left_shift.Tensor_Scalar( + convert_element_type_4, 23 + ) + convert_element_type_4 = None + view_7 = torch.ops.aten.view.dtype(bitwise_left_shift_1, torch.float32) + bitwise_left_shift_1 = None + clamp_min_4 = torch.ops.aten.clamp_min.default( + view_7, 1.1754943508222875e-38 + ) + div_1 = torch.ops.aten.div.Tensor(view_5, clamp_min_4) + clamp_min_5 = torch.ops.aten.clamp_min.default(div_1, -448.0) + clamp_max_3 = torch.ops.aten.clamp_max.default(clamp_min_5, 448.0) + convert_element_type_5 = torch.ops.prims.convert_element_type.default( + clamp_max_3, torch.float8_e4m3fn + ) + view_8 = torch.ops.aten.reshape.default( + convert_element_type_5, [8192, 8192] + ) + view_9 = torch.ops.aten.view.dtype(where_1, torch.float8_e8m0fnu) + squeeze_1 = torch.ops.aten.squeeze.dim(view_9, -1) + + permute = torch.ops.aten.permute.default(view_8, [1, 0]) + + view_13 = torch.ops.aten.reshape.default(squeeze, [64, 128, 64, 4]) + permute_2 = torch.ops.aten.permute.default(view_13, [0, 2, 1, 3]) + clone = torch.ops.aten.clone.default( + permute_2, memory_format=torch.contiguous_format + ) + view_14 = torch.ops.aten.reshape.default(clone, [4096, 4, 32, 4]) + permute_3 = torch.ops.aten.permute.default(view_14, [0, 2, 1, 3]) + clone_1 = torch.ops.aten.clone.default( + permute_3, memory_format=torch.contiguous_format + ) + view_15 = torch.ops.aten.reshape.default(clone_1, [4096, 32, 16]) + + view_16 = torch.ops.aten.reshape.default(view_15, [2097152]) + + view_18 = torch.ops.aten.reshape.default(squeeze_1, [64, 128, 64, 4]) + permute_5 = torch.ops.aten.permute.default(view_18, [0, 2, 1, 3]) + clone_2 = torch.ops.aten.clone.default( + permute_5, memory_format=torch.contiguous_format + ) + view_19 = torch.ops.aten.reshape.default(clone_2, [4096, 4, 32, 4]) + permute_6 = torch.ops.aten.permute.default(view_19, [0, 2, 1, 3]) + clone_3 = torch.ops.aten.clone.default( + permute_6, memory_format=torch.contiguous_format + ) + view_20 = torch.ops.aten.reshape.default(clone_3, [4096, 32, 16]) + + view_21 = torch.ops.aten.reshape.default(view_20, [2097152]) + + _scaled_mm = torch.ops.test_fp8.fake_scaled_mm.default( + view_3, permute, view_16, view_21, None, None, torch.float32 + ) + return (_scaled_mm,) + + # Run with largest shape + M, K, N = 8192, 8192, 8192 + device = "cuda" + + A = torch.randn(M, K, dtype=torch.float32, device=device) + B = torch.randn(K, N, dtype=torch.float32, device=device) + f_c = torch.compile(fullgraph=True)(forward) + + _, code = run_and_get_code(f_c, A, B) + + FileCheck().check(".run(").check(".run(").check("fake_scaled_mm").run( + code[0] + ) + + for seed in range(5): + input_values.clear() + torch.manual_seed(seed) + # without dividing, outputs get way too large + A = torch.randn(M, K, dtype=torch.float32, device=device) + B = torch.randn(K, N, dtype=torch.float32, device=device) + + # Uses fake_scaled_mm custom op (no CUDA 12.8 needed!) + torch._dynamo.reset() + torch.compile(forward)(A, B) + + torch._dynamo.reset() + with config.patch({"loop_index_inversion_in_fusion": False}): + torch.compile(forward)(A, B) + + assert len(input_values) == 2 + for i in range(4): + self.assertEqual( + input_values[0][i], + input_values[1][i], + msg=f"idx {i} seed {seed}", + ) + @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) @parametrize("M", (1, 3, 33, 257, 1024)) @parametrize("K", (16, 32, 1024)) diff --git a/test/inductor/test_loop_ordering.py b/test/inductor/test_loop_ordering.py index c77b3574b2227..051a5f5905997 100644 --- a/test/inductor/test_loop_ordering.py +++ b/test/inductor/test_loop_ordering.py @@ -16,6 +16,7 @@ from torch._inductor import config as inductor_config, ir, metrics from torch._inductor.codegen.triton import TritonScheduling from torch._inductor.graph import GraphLowering +from torch._inductor.invert_expr_analysis import generate_inverse_formula from torch._inductor.scheduler import SchedulerNode from torch._inductor.test_case import run_tests, TestCase from torch._inductor.test_operators import realize @@ -1188,6 +1189,113 @@ def fn(nodes): torch.compile(f)(x) +class TestIndexInversion(TestCase): + @classmethod + def setUpClass(cls): + super().setUpClass() + + gm = torch.fx.symbolic_trace(lambda: 0) + graph = GraphLowering(gm) + graph.scheduler = MockScheduler + cls._exit_stack = contextlib.ExitStack() + cls._exit_stack.enter_context(V.set_graph_handler(graph)) + + def _check_expr(self, expr, reconstruction, val_range): + import numpy as np + from sympy import lambdify + + assert len(expr.free_symbols) == 1 + p0 = next(iter(expr.free_symbols)) + + def floordiv_replacement(a, b): + """Replace FloorDiv(a, b) with a // b""" + return a // b + + def modularindexing_replacement(x, base, divisor): + """Replace ModularIndexing(x, base, divisor) with (x // base) % divisor""" + return (x // base) % divisor + + # Replace custom functions with sympy equivalents + expr_numpy_ready = expr.replace(FloorDiv, floordiv_replacement).replace( + ModularIndexing, modularindexing_replacement + ) + reconstruction_numpy_ready = reconstruction.replace( + FloorDiv, floordiv_replacement + ).replace(ModularIndexing, modularindexing_replacement) + + # Now lambdify with standard numpy + forward_func = lambdify(p0, expr_numpy_ready, modules="numpy") + inverse_func = lambdify(p0, reconstruction_numpy_ready, modules="numpy") + + test_values = np.arange(0, val_range, dtype=np.int64) + forward_values = forward_func(test_values).astype(np.int64) + recovered_values = inverse_func(forward_values).astype(np.int64) + torch.testing.assert_close(test_values, recovered_values) + + @classmethod + def tearDownClass(cls): + super().tearDownClass() + cls._exit_stack.close() + + def test_original_complex_expression(self): + """Test the original motivating complex expression.""" + p0 = sympy.Symbol("p0") + expr = ( + 32768 * FloorDiv(p0, 32768) + + 8192 * FloorDiv(ModularIndexing(p0, 1, 16), 4) + + ModularIndexing(p0, 1, 4) + + 256 * ModularIndexing(p0, 16, 32) + + 4 * ModularIndexing(p0, 512, 64) + ) + + reconstruction = generate_inverse_formula(expr, p0) + self.assertIsNotNone(reconstruction) + self._check_expr(expr, reconstruction, 2097152) + + def test_inversion_cases(self): + """Test various expressions for correct inversion behavior.""" + p = sympy.Symbol("p") + + cases = [ + # (expression, should_be_invertible, test_range) + # Simple 2-term base-10 style: 10 = 1 × 10 ✓ + (10 * ModularIndexing(p, 10, 10) + ModularIndexing(p, 1, 10), True, 100), + # Simple 2-term base-2 style: 2 = 1 × 2 ✓ + (2 * ModularIndexing(p, 2, 2) + ModularIndexing(p, 1, 2), True, 4), + # 3-term decimal: 100 = 10×10, 10 = 1×10 ✓ + ( + 100 * FloorDiv(p, 100) + + 10 * FloorDiv(ModularIndexing(p, 1, 100), 10) + + ModularIndexing(p, 1, 10), + True, + 1000, + ), + (4 * p, False, 64), # expr and inverse not bijections + # when sorted, invertible + (ModularIndexing(p, 1, 10) + 10 * ModularIndexing(p, 10, 10), True, None), + # Wrong coefficient ratios: 4 ≠ 1×2 + (4 * ModularIndexing(p, 1, 8) + ModularIndexing(p, 8, 2), False, None), + ( + 100 * FloorDiv(p, 100) + 7 * ModularIndexing(p, 1, 100), + False, + None, + ), # Wrong ratios + (FloorDiv(p, 100) + FloorDiv(p, 10) + p, False, None), # Overlapping ranges + (p**2 + 10 * p + 1, False, None), # Quadratic + (sympy.sin(p) + sympy.cos(p), False, None), # Trigonometric + ] + + for expr, should_invert, test_range in cases: + reconstruction = generate_inverse_formula(expr, p) + + if should_invert: + self.assertIsNotNone(reconstruction, f"Expected invertible: {expr}") + # Test correctness on sample values + self._check_expr(expr, reconstruction, test_range) + else: + self.assertIsNone(reconstruction, f"Expected non-invertible: {expr}") + + if __name__ == "__main__": if HAS_GPU: run_tests() diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index 66eaf69dd59a8..aaf7fbd2f7f54 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -674,6 +674,17 @@ def use_autoheuristic(name: str) -> bool: == "1" ) + +# When trying to fuse two nodes, one with: +# a[contiguous_writes] = fn(...) +# and another node: +# b[contiguous_writes] = a[discontiguous_reads] +# If b is unary, and we can figure out an inverse formula for +# discontiguous writes, invert b as : +# b[inverse(discontiguous_writes)] = a[contiguous_reads] +# so that the nodes can fuse. for more details: https://gist.github.com/eellison/6f9f4a7ec10a860150b15b719f9285a9 +loop_index_inversion_in_fusion: bool = True + # If fusing two nodes only save less then score_fusion_memory_threshold memory, # we should not bother fusing the nodes. # diff --git a/torch/_inductor/invert_expr_analysis.py b/torch/_inductor/invert_expr_analysis.py new file mode 100644 index 0000000000000..816482dba020c --- /dev/null +++ b/torch/_inductor/invert_expr_analysis.py @@ -0,0 +1,208 @@ +from dataclasses import dataclass +from typing import Optional + +import sympy + +from torch._inductor.utils import _IntLike, argsort_sym +from torch.utils._sympy.functions import FloorDiv, ModularIndexing + +from .virtualized import V + + +def static_eq(a: _IntLike, b: _IntLike) -> bool: + return V.graph.sizevars.statically_known_equals(a, b) + + +@dataclass +class Term: + coefficient: _IntLike + range: Optional[_IntLike] # None for unbounded + original_expr: sympy.Expr + reconstruction_multiplier: _IntLike # The multiplier needed for reconstruction + + +def generate_inverse_formula( + expr: sympy.Expr, var: sympy.Symbol +) -> Optional[sympy.Expr]: + """ + Analyze an expression to see if it matches a specific invertible pattern that we + know how to reverse. + + We're looking for expressions that are sums of terms where each term extracts a + distinct bounded range from the input variable, like: + + y = c₀*a₀ + c₁*a₁ + c₂*a₂ + ... + cₙ*aₙ + + where each aᵢ must be one of these specific patterns: + - ModularIndexing(var, divisor, modulo) + - FloorDiv(ModularIndexing(var, 1, modulo), divisor) + - FloorDiv(var, divisor) + - var (the variable itself) + + The key pattern we need is: + - Coefficients are strictly decreasing: c₀ > c₁ > c₂ > ... > cₙ + - Each coefficient matches the product of ranges of later terms (mixed-radix property) + - Each term extracts a bounded range, creating non-overlapping "slots" + + If we find this pattern, we can generate the reconstruction transformation that + decomposes the variable and rebuilds it using the correct multipliers. + + EXAMPLE: + Input: 100*((p//100)) + 10*((p%100)//10) + (p%10) + + Returns the reconstruction expression: + remainder₀ = p + component₀ = remainder₀ // 100 # hundreds digit + remainder₁ = remainder₀ % 100 + component₁ = remainder₁ // 10 # tens digit + remainder₂ = remainder₁ % 10 + component₂ = remainder₂ # ones digit + result = component₀*100 + component₁*10 + component₂*1 + + This decomposes p into its components and rebuilds it using the original + multipliers, which should equal the input expression. + + Args: + expr: Expression to analyze (sum of terms with ModularIndexing, FloorDiv, etc.) + var: The variable being decomposed + + Returns: + None if not invertible, or the reconstruction expression + + References: + Mixed-radix systems: https://en.wikipedia.org/wiki/Mixed_radix + """ + # Step 1: Parse all terms + terms = parse_terms(expr, var) + if not terms: + return None + + # Step 2: Sort by coefficient (descending) + coeffs = [t.coefficient for t in terms] + idxs = reversed(argsort_sym(V.graph.sizevars.shape_env, coeffs)) + terms = [terms[i] for i in idxs] + + # Step 3: Check invertibility conditions + if not check_invertibility(terms): + return None + + return generate_reconstruction_expr(terms, var) + + +def parse_terms(expr: sympy.Expr, var: sympy.Symbol) -> Optional[list[Term]]: + """Parse expression into terms.""" + if not isinstance(expr, sympy.Add): + # Single term + term = parse_single_term(expr, var) + return [term] if term else [] + + terms = [] + for arg in expr.args: + term = parse_single_term(arg, var) + if term: + terms.append(term) + else: + return None # If any term fails to parse, fail completely + + return terms + + +def parse_single_term(term: sympy.Expr, var: sympy.Symbol) -> Optional[Term]: + """Parse a single term and extract coefficient, range, and reconstruction multiplier.""" + # Extract coefficient and expression parts + coefficient, expr_parts = term.as_coeff_mul() + + if len(expr_parts) == 0: + # Pure constant term + return Term( + coefficient=coefficient, + range=1, + original_expr=1, + reconstruction_multiplier=0, + ) + elif len(expr_parts) == 1: + expr = expr_parts[0] + else: + # Multiple non-constant factors, too complex + return None + + # Now determine the range and reconstruction multiplier + range_val, reconstruction_multiplier = analyze_expression_properties(expr, var) + if reconstruction_multiplier is None: + return None + + return Term( + coefficient=coefficient, + range=range_val, + original_expr=expr, + reconstruction_multiplier=reconstruction_multiplier, + ) + + +def analyze_expression_properties( + expr: sympy.Expr, var: sympy.Symbol +) -> tuple[Optional[_IntLike], Optional[_IntLike]]: + """Analyze an expression to determine its range and reconstruction multiplier.""" + # ModularIndexing(var, divisor, modulo) = (var // divisor) % modulo + if isinstance(expr, ModularIndexing): + x, div, mod = expr.args + if static_eq(x, var): + return mod, div # Range is mod, multiplier is div + + # FloorDiv cases + if isinstance(expr, FloorDiv): + base, divisor = expr.args + + # FloorDiv(ModularIndexing(var, 1, mod), div) = (var % mod) // div + if isinstance(base, ModularIndexing): + x, inner_div, mod = base.args + if static_eq(x, var) and static_eq(inner_div, 1): + range_val = FloorDiv(mod, divisor) + return range_val, divisor # Range is mod//div, multiplier is div + + # FloorDiv(var, divisor) = var // divisor (unbounded) + elif static_eq(base, var): + return None, divisor # Unbounded range, multiplier is div + + return None, None + + +def check_invertibility(terms: list[Term]) -> bool: + """Check if the terms represent an invertible transformation.""" + if not terms: + return False + + # Coefficients must be strictly decreasing + coeffs = [t.coefficient for t in terms] + if argsort_sym(V.graph.sizevars.shape_env, coeffs) != list( + reversed(range(len(coeffs))) + ): + return False + + # Check mixed-radix property: each coeff[i] = coeff[i+1] * range[i+1] + expected_coeff = 1 + for term in reversed(terms): + if not static_eq(term.coefficient, expected_coeff): + return False + if term.range is not None: + expected_coeff *= term.range + + return True + + +def generate_reconstruction_expr(terms: list[Term], var: sympy.Symbol) -> sympy.Expr: + y = var + reconstruction = sympy.S.Zero + remainder = y + + for i, term in enumerate(terms): + if i < len(terms) - 1: + component = FloorDiv(remainder, term.coefficient) + remainder = ModularIndexing(remainder, 1, term.coefficient) + else: + # Last term should also divide by its coefficient + component = FloorDiv(remainder, term.coefficient) + + reconstruction += component * term.reconstruction_multiplier + + return reconstruction diff --git a/torch/_inductor/loop_body.py b/torch/_inductor/loop_body.py index 53ae1d8f63f6b..3921aa955a836 100644 --- a/torch/_inductor/loop_body.py +++ b/torch/_inductor/loop_body.py @@ -95,7 +95,6 @@ class LoopBody: """ indexing_exprs: dict[str, sympy.Expr] - indexing_exprs_name: dict[sympy.Expr, str] submodules: dict[str, Any] subblocks: dict[str, LoopBodyBlock] indirect_vars: list[sympy.Symbol] @@ -104,6 +103,9 @@ class LoopBody: memory_usage: dict[MemoryUsageType, list[MemoryEntry]] op_counts: collections.Counter[str] + # defined only temporarily + indexing_exprs_name: dict[sympy.Expr, str] + def __init__( self, fn, diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index df1d2f729b34a..2930a33b465a6 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -3345,7 +3345,10 @@ def fuse_nodes(self, nodes: list[BaseSchedulerNode]) -> list[BaseSchedulerNode]: ) break - if config.loop_ordering_after_fusion: + if ( + config.loop_ordering_after_fusion + or config.loop_index_inversion_in_fusion + ): nodes = self.fuse_nodes_once(nodes, is_reorder_round=True) return nodes @@ -4302,6 +4305,148 @@ def decide_fusion_fail_reason( return str(reasons) + def shared_data_after_inverting_indexing( + self, node1: BaseSchedulerNode, node2: BaseSchedulerNode + ) -> int: + """ + Attempts to enable fusion between two nodes by inverting indexing patterns. + + This optimization targets cases where node1 has a contiguous write and + node2 has a contiguous write but discontiguous read. By inverting the + indexing in node2's read and write operations, we can make them compatible + with node1 for potential fusion. + + Args: + node1: First scheduler node (source) + node2: Second scheduler node (target for inversion) + + Returns: + int: Fusion score if successful, 0 if optimization not applicable + """ + + if not config.loop_index_inversion_in_fusion: + return -1 + + if any(n.is_cpu() for n in [node1, node2]): + return -1 + + # Check for shared buffers between nodes + node1_buffer_names = node1.read_writes.buffer_names() + node2_buffer_names = node2.read_writes.buffer_names() + common_buffer_names = node1_buffer_names & node2_buffer_names + + if not common_buffer_names: + return -1 + + # only invert if node1 is single unmet dep + node2_unmet_dependencies = OrderedSet( + dep.name for dep in node2.unmet_dependencies + ) + if node2_unmet_dependencies - node1_buffer_names: + return -1 + + if len(node2_unmet_dependencies) > 1: + return -1 + + # Currently only handle single read/write operations + if len(node2.read_writes.reads) > 1 or len(node2.read_writes.writes) > 1: + return -1 + + node2_read = next(iter(node2.read_writes.reads)) + node2_write = next(iter(node2.read_writes.writes)) + + if not isinstance(node2_read, MemoryDep) or not isinstance( + node2_write, MemoryDep + ): + return -1 + + node1_writes = {dep.name: dep for dep in node1.read_writes.writes} + if node2_read.name not in node1_writes: + return -1 + + node1_write = node1_writes[node2_read.name] + + if not isinstance(node1_write, MemoryDep): + return -1 + + # We are checking for compatibility with the normalized node1 write + # then modifying node2 reads/writes. since the node1 write will be just used + # for compatibility, while node2 will be used in actual modification, just + # normalize node1 not node2. + node1_write = node1_write.normalize() + + if ( + node1_write.index != node2_write.index + and node1_write.size != node2_write.size + ): + return -1 + + if node2_read.size != node2_write.size or len(node2_read.var_names) != 1: + return -1 + + # Verify we have exactly two indexing expressions (one read, one write) + if len(node2._body.indexing_exprs) != 2: # type: ignore[attr-defined] + return -1 + + # No subblocks allowed for this optimization + if node2._body.subblocks: # type: ignore[attr-defined] + return -1 + + assert ( + "index0" in node2._body.indexing_exprs # type: ignore[attr-defined] + and "index1" in node2._body.indexing_exprs # type: ignore[attr-defined] + ) + + # Extract and verify single read expression + node2_read_exprs = OrderedSet(expr for expr in node2._body.get_read_exprs()) # type: ignore[attr-defined] + if len(node2_read_exprs) != 1: + return -1 + + read_expr = next(iter(node2_read_exprs)) + + # Determine which index is for reading vs writing + if read_expr == node2._body.indexing_exprs["index0"]: # type: ignore[attr-defined] + read_expr_index = "index0" + write_expr_index = "index1" + else: + assert read_expr == node2._body.indexing_exprs["index1"] # type: ignore[attr-defined] + read_expr_index = "index1" + write_expr_index = "index0" + + from torch._inductor.invert_expr_analysis import generate_inverse_formula + + index_vars = node2._body.vars[0] # type: ignore[attr-defined] + if len(index_vars) != 1: + return -1 + + simplified_terms = [] + for term in sympy.Add.make_args(read_expr): + simplified_terms.append( + V.graph.sizevars.combine_modular_indexing_pairs(term) + ) + simplified_read_expr = sum(simplified_terms) + + inverse_formula = generate_inverse_formula(simplified_read_expr, index_vars[0]) + + # formula is not invertible + if inverse_formula is None: + return -1 + + # === Apply Inversion === + + # Swap the indexing expressions using the inverse formula + node2._body.indexing_exprs[read_expr_index] = node2._body.indexing_exprs[ # type: ignore[attr-defined] + write_expr_index + ] + node2._body.indexing_exprs[write_expr_index] = inverse_formula # type: ignore[attr-defined] + + # Refresh dependencies and calculate fusion score + node2.refresh_dependencies(True, False) # type: ignore[attr-defined] + score = self.score_fusion_memory(node1, node2) + + fusion_log.info("Shared memory after inversion: %d", score) + return score + def shared_data_after_reordering_loop( self, node1: BaseSchedulerNode, node2: BaseSchedulerNode ) -> int: @@ -4686,6 +4831,7 @@ def can_fuse( del device2 shared_data_score = self.score_fusion_memory(node1, node2) + if ( can_reorder and shared_data_score < config.score_fusion_memory_threshold @@ -4702,6 +4848,16 @@ def can_fuse( smaller_node.expand_dimension_for_pointwise_node(expand_dim, expand_size) shared_data_score = self.score_fusion_memory(node1, node2) + if ( + config.loop_index_inversion_in_fusion + and shared_data_score < config.score_fusion_memory_threshold + ): + new_shared_data_score = self.shared_data_after_inverting_indexing( + node1, node2 + ) + if new_shared_data_score >= 0: + shared_data_score = new_shared_data_score + if loop_ordering_log.isEnabledFor(logging.DEBUG): loop_ordering_log.debug( "%s and %s has %s shared data", From aba2fa32593c6d7cfa55d488814984c421eaafb7 Mon Sep 17 00:00:00 2001 From: Siddhartha Menon Date: Wed, 5 Nov 2025 16:55:51 +0000 Subject: [PATCH 063/130] Fix clang-21 warnings (#166859) Fixes compiler warnings thrown by Clang-21 Fixes #166755 Pull Request resolved: https://github.com/pytorch/pytorch/pull/166859 Approved by: https://github.com/aditew01, https://github.com/fadara01, https://github.com/malfet --- aten/src/ATen/cpu/vec/sve/vec_bfloat16.h | 84 +++++++++---------- .../src/ATen/native/cpu/GridSamplerKernel.cpp | 4 +- caffe2/CMakeLists.txt | 2 +- 3 files changed, 45 insertions(+), 45 deletions(-) diff --git a/aten/src/ATen/cpu/vec/sve/vec_bfloat16.h b/aten/src/ATen/cpu/vec/sve/vec_bfloat16.h index 9e0b189bdac89..757ef839f965a 100644 --- a/aten/src/ATen/cpu/vec/sve/vec_bfloat16.h +++ b/aten/src/ATen/cpu/vec/sve/vec_bfloat16.h @@ -191,7 +191,7 @@ class Vectorized { auto vals = svreinterpret_u16_bf16(values); vals = sveor_u16_x(ptrue, vals, mask); return svreinterpret_bf16_u16(vals); - }; + } Vectorized round() const; Vectorized tan() const; Vectorized tanh() const; @@ -349,47 +349,47 @@ Vectorized inline Vectorized::frac() const { return convert_float_bfloat16(v1, v2); \ } -DEFINE_BF16_FUNC_VIA_FLOAT(isnan); -DEFINE_BF16_FUNC_VIA_FLOAT(angle); -DEFINE_BF16_FUNC_VIA_FLOAT(acos); -DEFINE_BF16_FUNC_VIA_FLOAT(acosh); -DEFINE_BF16_FUNC_VIA_FLOAT(asin); -DEFINE_BF16_FUNC_VIA_FLOAT(atan); -DEFINE_BF16_FUNC_VIA_FLOAT(atanh); -DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(atan2); -DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(copysign); -DEFINE_BF16_FUNC_VIA_FLOAT(erf); -DEFINE_BF16_FUNC_VIA_FLOAT(erfc); -DEFINE_BF16_FUNC_VIA_FLOAT(exp); -DEFINE_BF16_FUNC_VIA_FLOAT(exp2); -DEFINE_BF16_FUNC_VIA_FLOAT(expm1); -DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(fmod); -DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(hypot); -DEFINE_BF16_FUNC_VIA_FLOAT(i0); -DEFINE_BF16_FUNC_VIA_FLOAT(i0e); -DEFINE_BF16_FUNC_VIA_FLOAT(digamma); -DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(igamma); -DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(igammac); -DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(nextafter); -DEFINE_BF16_FUNC_VIA_FLOAT(log); -DEFINE_BF16_FUNC_VIA_FLOAT(log2); -DEFINE_BF16_FUNC_VIA_FLOAT(log10); -DEFINE_BF16_FUNC_VIA_FLOAT(log1p); -DEFINE_BF16_FUNC_VIA_FLOAT(sin); -DEFINE_BF16_FUNC_VIA_FLOAT(sinh); -DEFINE_BF16_FUNC_VIA_FLOAT(cos); -DEFINE_BF16_FUNC_VIA_FLOAT(cosh); -DEFINE_BF16_FUNC_VIA_FLOAT(ceil); -DEFINE_BF16_FUNC_VIA_FLOAT(floor); -DEFINE_BF16_FUNC_VIA_FLOAT(round); -DEFINE_BF16_FUNC_VIA_FLOAT(tan); -DEFINE_BF16_FUNC_VIA_FLOAT(tanh); -DEFINE_BF16_FUNC_VIA_FLOAT(trunc); -DEFINE_BF16_FUNC_VIA_FLOAT(lgamma); -DEFINE_BF16_FUNC_VIA_FLOAT(sqrt); -DEFINE_BF16_FUNC_VIA_FLOAT(reciprocal); -DEFINE_BF16_FUNC_VIA_FLOAT(rsqrt); -DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(pow); +DEFINE_BF16_FUNC_VIA_FLOAT(isnan) +DEFINE_BF16_FUNC_VIA_FLOAT(angle) +DEFINE_BF16_FUNC_VIA_FLOAT(acos) +DEFINE_BF16_FUNC_VIA_FLOAT(acosh) +DEFINE_BF16_FUNC_VIA_FLOAT(asin) +DEFINE_BF16_FUNC_VIA_FLOAT(atan) +DEFINE_BF16_FUNC_VIA_FLOAT(atanh) +DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(atan2) +DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(copysign) +DEFINE_BF16_FUNC_VIA_FLOAT(erf) +DEFINE_BF16_FUNC_VIA_FLOAT(erfc) +DEFINE_BF16_FUNC_VIA_FLOAT(exp) +DEFINE_BF16_FUNC_VIA_FLOAT(exp2) +DEFINE_BF16_FUNC_VIA_FLOAT(expm1) +DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(fmod) +DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(hypot) +DEFINE_BF16_FUNC_VIA_FLOAT(i0) +DEFINE_BF16_FUNC_VIA_FLOAT(i0e) +DEFINE_BF16_FUNC_VIA_FLOAT(digamma) +DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(igamma) +DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(igammac) +DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(nextafter) +DEFINE_BF16_FUNC_VIA_FLOAT(log) +DEFINE_BF16_FUNC_VIA_FLOAT(log2) +DEFINE_BF16_FUNC_VIA_FLOAT(log10) +DEFINE_BF16_FUNC_VIA_FLOAT(log1p) +DEFINE_BF16_FUNC_VIA_FLOAT(sin) +DEFINE_BF16_FUNC_VIA_FLOAT(sinh) +DEFINE_BF16_FUNC_VIA_FLOAT(cos) +DEFINE_BF16_FUNC_VIA_FLOAT(cosh) +DEFINE_BF16_FUNC_VIA_FLOAT(ceil) +DEFINE_BF16_FUNC_VIA_FLOAT(floor) +DEFINE_BF16_FUNC_VIA_FLOAT(round) +DEFINE_BF16_FUNC_VIA_FLOAT(tan) +DEFINE_BF16_FUNC_VIA_FLOAT(tanh) +DEFINE_BF16_FUNC_VIA_FLOAT(trunc) +DEFINE_BF16_FUNC_VIA_FLOAT(lgamma) +DEFINE_BF16_FUNC_VIA_FLOAT(sqrt) +DEFINE_BF16_FUNC_VIA_FLOAT(reciprocal) +DEFINE_BF16_FUNC_VIA_FLOAT(rsqrt) +DEFINE_BF16_FUNC_VIA_FLOAT_W_ARG(pow) Vectorized inline Vectorized::operator==( const Vectorized& other) const { diff --git a/aten/src/ATen/native/cpu/GridSamplerKernel.cpp b/aten/src/ATen/native/cpu/GridSamplerKernel.cpp index 7587988528ebb..73f8c136794ce 100644 --- a/aten/src/ATen/native/cpu/GridSamplerKernel.cpp +++ b/aten/src/ATen/native/cpu/GridSamplerKernel.cpp @@ -293,7 +293,7 @@ struct ComputeLocationBase { , empty(size <= 0) {} inline Vec unnormalize(const Vec &in) const { - return (in + Vec(1)) * Vec(scaling_factor) - Vec(0.5); + return (in + Vec(static_cast(1))) * Vec(scaling_factor) - Vec(static_cast(0.5)); } inline Vec clip_coordinates(const Vec &in) const { @@ -831,7 +831,7 @@ struct ApplyGridSample(-0.75)); ApplyGridSample(const TensorAccessor& input) : inp_H(input.size(2)) diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 0e86e826405c6..e1cc43350b2b6 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -1307,7 +1307,7 @@ endif() if(USE_MKLDNN_ACL) find_package(ACL REQUIRED) - target_include_directories(torch_cpu PRIVATE ${ACL_INCLUDE_DIRS}) + target_include_directories(torch_cpu SYSTEM PRIVATE ${ACL_INCLUDE_DIRS}) endif() target_include_directories(torch_cpu PRIVATE ${ATen_CPU_INCLUDE}) From d4dcd0354c4affcd90417f213785fc762e1b2b2f Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Wed, 5 Nov 2025 19:43:40 +0800 Subject: [PATCH 064/130] [pytree][dynamo] add test to ensure `tree_map` preserves `dict` order (#166236) Pull Request resolved: https://github.com/pytorch/pytorch/pull/166236 Approved by: https://github.com/mlazos --- test/dynamo/test_misc.py | 24 ++++++++++++++++ test/test_pytree.py | 18 ++++++++++++ torch/_dynamo/polyfills/pytree.py | 47 ++++++++++++++++++++++++++++--- 3 files changed, 85 insertions(+), 4 deletions(-) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index b8727208a5bfa..b3e9df6a25cf3 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -13194,6 +13194,30 @@ def fn(x, y): self.assertEqual(actual, expected) + @parametrize_pytree_module + def test_pytree_tree_map_dict_order(self, pytree): + def fn(tree): + new_tree = pytree.tree_map(lambda x: x, tree) + return list(new_tree.keys()), list(new_tree.values()) + + x = torch.randn(3, 2) + fn_opt = torch.compile(fullgraph=True)(fn) + + tree1 = {"b": x + 2, "a": x, "c": x - 1} + expected1 = fn(tree1) + actual1 = fn_opt(tree1) + self.assertEqual(actual1, expected1) + + tree2 = collections.OrderedDict([("b", x + 2), ("a", x), ("c", x - 1)]) + expected2 = fn(tree2) + actual2 = fn_opt(tree2) + self.assertEqual(actual2, expected2) + + tree3 = collections.defaultdict(int, {"b": x + 2, "a": x, "c": x - 1}) + expected3 = fn(tree3) + actual3 = fn_opt(tree3) + self.assertEqual(actual3, expected3) + @parametrize_pytree_module def test_pytree_tree_map_only(self, pytree): if not callable(getattr(pytree, "tree_map_only", None)): diff --git a/test/test_pytree.py b/test/test_pytree.py index 7cc3b8affc0ef..09cf0bbd47a43 100644 --- a/test/test_pytree.py +++ b/test/test_pytree.py @@ -601,6 +601,24 @@ def f(x, y, z): for case in cases: run_test(case) + @parametrize_pytree_module + def test_tree_map_dict_order(self, pytree): + d = {"b": 2, "a": 1, "c": 3} + od = OrderedDict([("b", 2), ("a", 1), ("c", 3)]) + dd = defaultdict(int, {"b": 2, "a": 1, "c": 3}) + for tree in (d, od, dd): + result = pytree.tree_map(lambda x: x, tree) + self.assertEqual( + list(result.keys()), + list(tree.keys()), + msg=f"Dictionary keys order changed in tree_map: {tree!r} vs. {result!r}", + ) + self.assertEqual( + list(result.values()), + list(tree.values()), + msg=f"Dictionary keys order changed in tree_map: {tree!r} vs. {result!r}", + ) + @parametrize_pytree_module def test_tree_map_only(self, pytree): self.assertEqual(pytree.tree_map_only(int, lambda x: x + 2, [0, "a"]), [2, "a"]) diff --git a/torch/_dynamo/polyfills/pytree.py b/torch/_dynamo/polyfills/pytree.py index d86fe054b2ebc..b4de3200e2960 100644 --- a/torch/_dynamo/polyfills/pytree.py +++ b/torch/_dynamo/polyfills/pytree.py @@ -6,7 +6,7 @@ from collections import deque from dataclasses import dataclass, field -from typing import Any, TYPE_CHECKING +from typing import Any, TYPE_CHECKING, TypeVar from typing_extensions import TypeIs import torch.utils._pytree as python_pytree @@ -24,9 +24,15 @@ __all__: list[str] = [] +_T = TypeVar("_T") +_KT = TypeVar("_KT") +_VT = TypeVar("_VT") + + if python_pytree._cxx_pytree_dynamo_traceable: import optree import optree._C + import optree.utils import torch.utils._cxx_pytree as cxx_pytree # noqa: F401 @@ -600,14 +606,47 @@ def tree_map_( __all__ += ["tree_map_"] - _none_unflatten = optree.register_pytree_node.get(type(None)).unflatten_func # type: ignore[union-attr, attr-defined] + _none_registration = optree.register_pytree_node.get(type(None)) + assert _none_registration is not None @substitute_in_graph( # type: ignore[arg-type] - _none_unflatten, + _none_registration.unflatten_func, can_constant_fold_through=True, skip_signature_check=True, ) - def none_unflatten(_: None, children: Iterable[Any], /) -> None: + def none_unflatten(_: None, children: Iterable[_T], /) -> None: if len(list(children)) != 0: raise ValueError("Expected no children.") return None + + with optree.dict_insertion_ordered(False, namespace="torch"): + _dict_registration = optree.register_pytree_node.get(dict) + assert _dict_registration is not None + + @substitute_in_graph( # type: ignore[arg-type] + _dict_registration.flatten_func, + can_constant_fold_through=True, + skip_signature_check=True, + ) + def dict_flatten( + dct: dict[_KT, _VT], / + ) -> tuple[list[_VT], tuple[list[_KT], list[_KT]], tuple[_KT, ...]]: + sorted_keys = optree.utils.total_order_sorted(dct) + values = [dct[key] for key in sorted_keys] + original_keys = list(dct) + return values, (original_keys, sorted_keys), tuple(sorted_keys) + + @substitute_in_graph( # type: ignore[arg-type] + _dict_registration.unflatten_func, + can_constant_fold_through=True, + skip_signature_check=True, + ) + def dict_unflatten( + metadata: tuple[list[_KT], list[_KT]], + values: Iterable[_VT], + /, + ) -> dict[_KT, _VT]: + original_keys, sorted_keys = metadata + d = dict.fromkeys(original_keys) + d.update(zip(sorted_keys, values)) + return d # type: ignore[return-value] From 9c2c3dbc156a0eae1212ec3e51109d83a4922c9b Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 5 Nov 2025 17:12:30 +0000 Subject: [PATCH 065/130] Revert "Update triton to 3.5.1 release (#166968)" This reverts commit b4e4ee81d386db922d8f63359f9870eff1f44052. Reverted https://github.com/pytorch/pytorch/pull/166968 on behalf of https://github.com/malfet due to It might have caused deadlock/test timeouts, see https://hud.pytorch.org/hud/pytorch/pytorch/d4dcd0354c4affcd90417f213785fc762e1b2b2f/1?per_page=50&name_filter=trunk%20%2F%20linux-jammy-cuda12.8-py3.10-gcc11%20%2F%20test&mergeEphemeralLF=true ([comment](https://github.com/pytorch/pytorch/pull/166968#issuecomment-3492399396)) --- .ci/docker/ci_commit_pins/triton.txt | 2 +- .ci/docker/triton_version.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.ci/docker/ci_commit_pins/triton.txt b/.ci/docker/ci_commit_pins/triton.txt index 7aab8bed1c108..10f1207e60e6c 100644 --- a/.ci/docker/ci_commit_pins/triton.txt +++ b/.ci/docker/ci_commit_pins/triton.txt @@ -1 +1 @@ -bfeb066872bc1e8b2d2bc0a3b295b99dd77206e7 +7416ffcb92cdbe98d9f97e4e6f95247e46dfc9fd diff --git a/.ci/docker/triton_version.txt b/.ci/docker/triton_version.txt index d5c0c99142898..1545d966571dc 100644 --- a/.ci/docker/triton_version.txt +++ b/.ci/docker/triton_version.txt @@ -1 +1 @@ -3.5.1 +3.5.0 From f93ee16fb68b480a183348b99445fb089f4a5c30 Mon Sep 17 00:00:00 2001 From: Catherine Lee Date: Wed, 5 Nov 2025 17:19:24 +0000 Subject: [PATCH 066/130] [CI] Parse xml and upload json while running (#166988) Then we can point an ClickHouse ingestor at this s3 path and get them into ClickHouse while the job is running. use filelock to make sure each json is uploaded once so we don't end up with dups in ClickHouse Pull Request resolved: https://github.com/pytorch/pytorch/pull/166988 Approved by: https://github.com/izaitsevfb --- test/run_test.py | 18 +++++++- tools/stats/upload_test_stats.py | 6 ++- tools/testing/upload_artifacts.py | 70 ++++++++++++++++++++++++++++++- 3 files changed, 90 insertions(+), 4 deletions(-) diff --git a/test/run_test.py b/test/run_test.py index 448fbc28751f3..aa6a6d04cde3e 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -73,7 +73,22 @@ ShardedTest, THRESHOLD, ) -from tools.testing.upload_artifacts import zip_and_upload_artifacts + + +try: + from tools.testing.upload_artifacts import ( + parse_xml_and_upload_json, + zip_and_upload_artifacts, + ) +except ImportError: + # some imports in those files might fail, e.g., boto3 not installed. These + # functions are only needed under specific circumstances (CI) so we can + # define dummy functions here. + def parse_xml_and_upload_json(): + pass + + def zip_and_upload_artifacts(failed: bool): + pass # Make sure to remove REPO_ROOT after import is done @@ -1887,6 +1902,7 @@ def run_tests( def handle_complete(failure: Optional[TestFailure]): failed = failure is not None if IS_CI and options.upload_artifacts_while_running: + parse_xml_and_upload_json() zip_and_upload_artifacts(failed) if not failed: return False diff --git a/tools/stats/upload_test_stats.py b/tools/stats/upload_test_stats.py index 6c0232c5e5a17..b2b0869d48350 100644 --- a/tools/stats/upload_test_stats.py +++ b/tools/stats/upload_test_stats.py @@ -38,12 +38,14 @@ def parse_xml_report( report: Path, workflow_id: int, workflow_run_attempt: int, + job_id: int | None = None, ) -> list[dict[str, Any]]: """Convert a test report xml file into a JSON-serializable list of test cases.""" print(f"Parsing {tag}s for test report: {report}") - job_id = get_job_id(report) - print(f"Found job id: {job_id}") + if job_id is None: + job_id = get_job_id(report) + print(f"Found job id: {job_id}") test_cases: list[dict[str, Any]] = [] diff --git a/tools/testing/upload_artifacts.py b/tools/testing/upload_artifacts.py index 07b62ec9a1b74..49d68fe9959ae 100644 --- a/tools/testing/upload_artifacts.py +++ b/tools/testing/upload_artifacts.py @@ -1,11 +1,16 @@ import glob import gzip +import json import os import time import zipfile from functools import lru_cache from pathlib import Path -from typing import Any +from typing import Any, Optional + +from filelock import FileLock, Timeout + +from tools.stats.upload_test_stats import parse_xml_report REPO_ROOT = Path(__file__).resolve().parent.parent.parent @@ -140,3 +145,66 @@ def trigger_upload_test_stats_intermediate_workflow() -> None: }, ) print(x.text) + + +def parse_xml_and_upload_json() -> None: + """ + Parse xml test reports that do not yet have a corresponding json report + uploaded to s3, and upload the json reports to s3. Use filelock to avoid + uploading the same file from multiple processes. + """ + try: + job_id: Optional[int] = int(os.environ.get("JOB_ID", 0)) + if job_id == 0: + job_id = None + except (ValueError, TypeError): + job_id = None + + try: + for xml_file in glob.glob( + f"{REPO_ROOT}/test/test-reports/**/*.xml", recursive=True + ): + xml_path = Path(xml_file) + json_file = xml_path.with_suffix(".json") + lock = FileLock(str(json_file) + ".lock") + + try: + lock.acquire(timeout=0) # immediately fails if already locked + if json_file.exists(): + continue # already uploaded + test_cases = parse_xml_report( + "testcase", + xml_path, + int(os.environ.get("GITHUB_RUN_ID", "0")), + int(os.environ.get("GITHUB_RUN_ATTEMPT", "0")), + job_id, + ) + line_by_line_jsons = "\n".join([json.dumps(tc) for tc in test_cases]) + + gzipped = gzip.compress(line_by_line_jsons.encode("utf-8")) + s3_key = ( + json_file.relative_to(REPO_ROOT / "test/test-reports") + .as_posix() + .replace("/", "_") + ) + + get_s3_resource().put_object( + Body=gzipped, + Bucket="gha-artifacts", + Key=f"test_jsons_while_running/{os.environ.get('GITHUB_RUN_ID')}/{job_id}/{s3_key}", + ContentType="application/json", + ContentEncoding="gzip", + ) + + # We don't need to save the json file locally, but doing so lets us + # track which ones have been uploaded already. We could probably also + # check S3 + with open(json_file, "w") as f: + f.write(line_by_line_jsons) + except Timeout: + continue # another process is working on this file + finally: + if lock.is_locked: + lock.release() + except Exception as e: + print(f"Failed to parse and upload json test reports: {e}") From 0c7a4a6b48d49306eae8d0a9ee8d32b1899e5e23 Mon Sep 17 00:00:00 2001 From: karthickai Date: Mon, 3 Nov 2025 13:25:27 -0800 Subject: [PATCH 067/130] [Inductor] Fix unbacked float symbol handling in kernel codegen (#166890) When a fn compiled with `torch.compile` calls `.item()` on a float tensor arg (e.g., for thresholds in `torch.clamp`), the generated triton kernel references an unbacked float symbol (e.g., `zuf0`) that was never added to the kernel's parameter list, causing a compilation error. Fixes: #166888 Pull Request resolved: https://github.com/pytorch/pytorch/pull/166890 Approved by: https://github.com/eellison --- test/inductor/test_torchinductor.py | 14 ++++++++++++++ torch/_inductor/codecache.py | 6 ++++++ torch/_inductor/codegen/common.py | 11 +++++++++-- torch/_inductor/codegen/triton_utils.py | 5 +++++ 4 files changed, 34 insertions(+), 2 deletions(-) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index ed8993a1c9a39..d0ff5799ac417 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -14424,6 +14424,20 @@ def fn(x): self.common(fn, (torch.randn(6, 4, device=GPU_TYPE).t().contiguous().t(),)) + @skip_if_halide + @requires_cuda_and_triton + def test_unbacked_float_item(self): + def fn(x, max_val): + return torch.clamp(x, 0, max_val.item()) + + self.common( + fn, + ( + torch.randn(10, 20, 30, device=self.device), + torch.tensor(5.0, device=self.device), + ), + ) + # end of class CommonTemplate - add new tests here diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index cf17bf2e9478b..85702057cbb43 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -2970,6 +2970,12 @@ class CppPythonBindingsCodeCache(CppCodeCache): throw std::runtime_error("expected int arg"); return reinterpret_cast(result); }} + template <> inline float parse_arg(PyObject* args, size_t n) {{ + auto result = PyFloat_AsDouble(PyTuple_GET_ITEM(args, n)); + if(unlikely(result == -1.0 && PyErr_Occurred())) + throw std::runtime_error("expected float arg"); + return static_cast(result); + }} {extra_parse_arg} diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index 730c03f1c813c..3e9f174c810c5 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -1732,9 +1732,15 @@ def cpp_argdefs( call_args.append(self.wrap_ptr_arg(outer, dtype)) arg_types.append(f"{cpp_dtype}*") for outer, inner in self.sizevars.items(): - arg_defs.append(f"const {INDEX_TYPE} {inner}") + if isinstance(outer, sympy.Symbol) and symbol_is_type( + outer, (SymT.UNBACKED_FLOAT) + ): + arg_defs.append(f"const float {inner}") + arg_types.append("const float") + else: + arg_defs.append(f"const {INDEX_TYPE} {inner}") + arg_types.append(f"const {INDEX_TYPE}") call_args.append(self.wrap_size_arg(outer)) - arg_types.append(f"const {INDEX_TYPE}") if V.graph.wrapper_code: V.graph.wrapper_code.ensure_size_computed(outer) assert not self.workspace_args, "Workspace not supported on CPU " @@ -2353,6 +2359,7 @@ def rename_indexing( SymT.UNBACKED_INT, SymT.SIZE, SymT.PRECOMPUTED_SIZE, + SymT.UNBACKED_FLOAT, ), ) } diff --git a/torch/_inductor/codegen/triton_utils.py b/torch/_inductor/codegen/triton_utils.py index 2a2706ad5720b..75a34813c876b 100644 --- a/torch/_inductor/codegen/triton_utils.py +++ b/torch/_inductor/codegen/triton_utils.py @@ -4,6 +4,7 @@ import sympy import torch +from torch.utils._sympy.symbol import symbol_is_type, SymT from .. import config from ..runtime.hints import AttrsDescriptorWrapper @@ -71,6 +72,10 @@ def signature_of(arg: KernelArgType, *, size_dtype: Optional[str]) -> str: return "constexpr" elif isinstance(arg.expr, (float, sympy.Float)): return "fp32" + elif isinstance(arg.expr, sympy.Symbol) and symbol_is_type( + arg.expr, (SymT.UNBACKED_FLOAT) + ): + return "fp32" elif isinstance(arg.expr, bool): return "i1" From 4ff068c33a0beda5df88cd373a4fb70b5a68e554 Mon Sep 17 00:00:00 2001 From: KarhouTam Date: Wed, 5 Nov 2025 17:59:12 +0000 Subject: [PATCH 068/130] [Code Clean] Replace `assert` with if statement and raise `AssertionError` (#166935) Including: - `torch/profiler/profiler.py` Fixes part of #164878 Pull Request resolved: https://github.com/pytorch/pytorch/pull/166935 Approved by: https://github.com/fffrog, https://github.com/albanD --- torch/profiler/profiler.py | 43 ++++++++++++++++++++++++++------------ 1 file changed, 30 insertions(+), 13 deletions(-) diff --git a/torch/profiler/profiler.py b/torch/profiler/profiler.py index ee0ea85e1694b..893b4078cb9ce 100644 --- a/torch/profiler/profiler.py +++ b/torch/profiler/profiler.py @@ -210,7 +210,8 @@ def prepare_trace(self) -> None: def start_trace(self) -> None: if self.execution_trace_observer: self.execution_trace_observer.start() - assert self.profiler is not None + if self.profiler is None: + raise AssertionError("Profiler must be initialized before starting trace") self.profiler._start_trace() if self.profile_memory: @@ -256,7 +257,8 @@ def start_trace(self) -> None: def stop_trace(self) -> None: if self.execution_trace_observer: self.execution_trace_observer.stop() - assert self.profiler is not None + if self.profiler is None: + raise AssertionError("Profiler must be initialized before stopping trace") self.profiler.__exit__(None, None, None) def export_chrome_trace(self, path: str): @@ -264,7 +266,10 @@ def export_chrome_trace(self, path: str): Exports the collected trace in Chrome JSON format. If kineto is enabled, only last cycle in schedule is exported. """ - assert self.profiler + if self.profiler is None: + raise AssertionError( + "Profiler must be initialized before exporting chrome trace" + ) if path.endswith(".gz"): fp = tempfile.NamedTemporaryFile("w+b", suffix=".json", delete=False) fp.close() @@ -284,7 +289,8 @@ def export_stacks(self, path: str, metric: str = "self_cpu_time_total"): path (str): save stacks file to this location; metric (str): metric to use: "self_cpu_time_total" or "self_cuda_time_total" """ - assert self.profiler + if self.profiler is None: + raise AssertionError("Profiler must be initialized before exporting stacks") return self.profiler.export_stacks(path, metric) def toggle_collection_dynamic( @@ -316,7 +322,7 @@ def toggle_collection_dynamic( print(p.key_averages().table( sort_by="self_cuda_time_total", row_limit=-1)) """ - if not self.profiler: + if self.profiler is None: return self.profiler.toggle_collection_dynamic(enable, activities) @@ -333,7 +339,10 @@ def key_averages( To use shape/stack functionality make sure to set record_shapes/with_stack when creating profiler context manager. """ - assert self.profiler + if self.profiler is None: + raise AssertionError( + "Profiler must be initialized before getting key averages" + ) return self.profiler.key_averages( group_by_input_shape, group_by_stack_n, group_by_overload_name ) @@ -343,7 +352,8 @@ def events(self): Returns the list of unaggregated profiler events, to be used in the trace callback or after the profiling is finished """ - assert self.profiler + if self.profiler is None: + raise AssertionError("Profiler must be initialized before accessing events") return self.profiler.function_events def add_metadata(self, key: str, value: str) -> None: @@ -395,7 +405,10 @@ def _memory_profile(self) -> MemoryProfile: if missing: raise ValueError(f"{', '.join(missing)} required for memory profiling.") - assert self.profiler is not None and self.profiler.kineto_results is not None + if self.profiler is None or self.profiler.kineto_results is None: + raise AssertionError( + "Profiler and kineto_results must be initialized for memory profiling" + ) return MemoryProfile(self.profiler.kineto_results) def export_memory_timeline(self, path: str, device: Optional[str] = None) -> None: @@ -485,7 +498,8 @@ def schedule( """ def schedule_fn(step: int) -> ProfilerAction: - assert step >= 0 + if step < 0: + raise AssertionError(f"Step must be non-negative. Got {step}.") if step < skip_first: return ProfilerAction.NONE else: @@ -508,9 +522,11 @@ def schedule_fn(step: int) -> ProfilerAction: else ProfilerAction.RECORD_AND_SAVE ) - assert ( - wait >= 0 and warmup >= 0 and active > 0 and repeat >= 0 and skip_first >= 0 - ), "Invalid profiler schedule arguments" + if wait < 0 or warmup < 0 or active <= 0 or repeat < 0 or skip_first < 0: + raise AssertionError( + f"Invalid profiler schedule arguments. Got wait={wait} (need >= 0), warmup={warmup} (need >= 0), " + f"active={active} (need > 0), repeat={repeat} (need >= 0), skip_first={skip_first} (need >= 0)." + ) if warmup == 0: warn( "Profiler won't be using warmup, this can skew profiler results", @@ -717,7 +733,8 @@ def __init__( activities_set.add(ProfilerActivity.CUDA) elif ProfilerActivity.CUDA in activities_set: activities_set.remove(ProfilerActivity.CUDA) - assert len(activities_set) > 0, "No valid profiler activities found" + if len(activities_set) == 0: + raise AssertionError("No valid profiler activities found") super().__init__( activities=activities, From c17aa0f11303bcd2cf617efd0cda6f3d38a1a34b Mon Sep 17 00:00:00 2001 From: Jagadish Krishnamoorthy Date: Wed, 5 Nov 2025 18:03:59 +0000 Subject: [PATCH 069/130] [ROCm] Enable group gemm through CK (#166334) Fixes #161366 All the 4 types of dimension matrix are supported. 2d-2d, 2d-3d, 3d-3d, 3d-2d. The corresponding test cases in test_matmul_cuda are working for both forward and backward pass. The CK path is enabled for gfx942, gfx950. ToDo: Need to enable support on gfx90a since the ck kernel used in this commit produces gpu error, might require a different CK kernel config, based on the profiler result on gfx90a. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166334 Approved by: https://github.com/atalman --- aten/src/ATen/native/cuda/GroupedBlas.cpp | 10 + aten/src/ATen/native/hip/ck_group_gemm.h | 19 + aten/src/ATen/native/hip/ck_group_gemm.hip | 462 +++++++++++++++++++++ test/test_matmul_cuda.py | 2 - 4 files changed, 491 insertions(+), 2 deletions(-) create mode 100644 aten/src/ATen/native/hip/ck_group_gemm.h create mode 100644 aten/src/ATen/native/hip/ck_group_gemm.hip diff --git a/aten/src/ATen/native/cuda/GroupedBlas.cpp b/aten/src/ATen/native/cuda/GroupedBlas.cpp index f64eb317d0cca..18ae048cfc968 100644 --- a/aten/src/ATen/native/cuda/GroupedBlas.cpp +++ b/aten/src/ATen/native/cuda/GroupedBlas.cpp @@ -22,6 +22,9 @@ #include #include #include +#ifdef USE_ROCM +#include +#endif #include #ifdef USE_FBGEMM_GENAI @@ -666,12 +669,19 @@ std::optional out_dtype) { // _scaled_mm_allowed_device is used here within _grouped_mm_cuda which seems incorrect since scale is not used. // the _grouped_mm_fallback should be safe for any ROCm GPU since it's just calling typical mm/bmm bool use_fast_path = false; + if (at::detail::getCUDAHooks().isGPUArch({"gfx942", "gfx950"})) { + use_fast_path = true; + } #endif const auto out_dtype_ = _resolve_grouped_mm_out_dtype(mat_a, mat_b, out_dtype); Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_); if (use_fast_path) { // fast path, no d2h sync needed +#ifndef USE_ROCM at::cuda::detail::bf16bf16_grouped_mm(mat_a, mat_b, offs, bias, out); +#else + at::hip::detail::group_gemm_ck(mat_a, mat_b, offs, bias, out); +#endif } else { _grouped_mm_fallback(mat_a, mat_b, offs, bias, out_dtype, out); } diff --git a/aten/src/ATen/native/hip/ck_group_gemm.h b/aten/src/ATen/native/hip/ck_group_gemm.h new file mode 100644 index 0000000000000..c50307c9f8ea3 --- /dev/null +++ b/aten/src/ATen/native/hip/ck_group_gemm.h @@ -0,0 +1,19 @@ +#pragma once + +#include +#include +#include + +namespace at { +namespace hip { +namespace detail { +void group_gemm_ck( + const at::Tensor& mat_a, + const at::Tensor& mat_b, + const std::optional& offs, + const std::optional& bias, + at::Tensor& out); + +} // namespace detail +} // namespace hip +} // namespace at diff --git a/aten/src/ATen/native/hip/ck_group_gemm.hip b/aten/src/ATen/native/hip/ck_group_gemm.hip new file mode 100644 index 0000000000000..c436ad660c1c7 --- /dev/null +++ b/aten/src/ATen/native/hip/ck_group_gemm.hip @@ -0,0 +1,462 @@ +#undef __HIP_NO_HALF_CONVERSIONS__ +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +template +using S = ck::Sequence; + +namespace at { +namespace hip { +namespace detail { + +namespace CkTypes { + using BF16 = ck::bhalf_t; + using F16 = ck::half_t; + using F32 = float; + using PassThrough = ck::tensor_operation::element_wise::PassThrough; +} + +template +using GroupedGemmKernel = ck::tensor_operation::device::DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage< + ALayout, BLayout, ck::Tuple<>, ck::tensor_layout::gemm::RowMajor, + DataType, DataType, CkTypes::F32, DataType, ck::Tuple<>, DataType, + CkTypes::PassThrough, CkTypes::PassThrough, CkTypes::PassThrough, + ck::tensor_operation::device::GemmSpecialization::MNKPadding, + 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, + S<1,4,64,1>, S<0,2,1,3>, S<0,2,1,3>, + 3, 8, 8, 1, + S<1,4,64,1>, S<0,2,1,3>, S<0,2,1,3>, + 3, 8, 8, 1, + 1, 1, + S<1,32,1,8>, 4 +>; + +template +void launch_grouped_bgemm_ck_impl_dispatch( + const at::Tensor& mat_a, + const at::Tensor& mat_b, + const std::optional& offs, + at::Tensor& out) +{ + using DeviceOp = GroupedGemmKernel; + using PassThrough = CkTypes::PassThrough; + + std::vector gemm_descs; + std::vector p_a_ptrs, p_b_ptrs; + std::vector p_e_ptrs; + // Note: d_ptrs will be resized after we populate the other vectors + + const int mat_a_dim = mat_a.dim(); + const int mat_b_dim = mat_b.dim(); + + const char* a_ptr_base = reinterpret_cast(mat_a.data_ptr()); + const char* b_ptr_base = reinterpret_cast(mat_b.data_ptr()); + char* out_ptr_base = reinterpret_cast(out.data_ptr()); + const size_t a_element_size = mat_a.element_size(); + const size_t b_element_size = mat_b.element_size(); + const size_t out_element_size = out.element_size(); + + // for each group, calculate m,n,k,lda,ldb,ldc and A,B,out pointer base addresses. + if (mat_a_dim == 2 && mat_b_dim == 2) { + // 2D*2D case requires offset tensor + auto offs_accessor = offs->accessor(); + int num_groups = offs_accessor.size(0); + const int M = mat_a.size(0); // number of rows in A + const int N = mat_b.size(1); // number of columns in B + const int K = mat_a.size(1); // columns in A == rows in B + // for 2d*2d input, output is 3d. + // for each group, A columns (K) are sliced. M and N dimensions are not sliced. + for (int i = 0; i < num_groups; ++i) { + int start_k = (i == 0) ? 0 : offs_accessor[i-1]; + int end_k = offs_accessor[i]; + int k = end_k - start_k; + + //K dimension are sliced, hence select stride(1) always. + //K dimension is always dimension 1, regardless of memory layout (row/column major) + const void* group_a_ptr = a_ptr_base + start_k * mat_a.stride(1) * a_element_size; + const void* group_b_ptr; + int ldb; + + if (std::is_same::value) { + // Row-major B [K,N]: K values are horizontally adjacent, use stride(1) for K offset + group_b_ptr = b_ptr_base + start_k * mat_b.stride(1) * b_element_size; + // Leading dimension = distance between rows = stride(0) + ldb = mat_b.stride(0); + } else { + // Column-major B [K,N]: K values are vertically adjacent, use stride(0) for K offset + group_b_ptr = b_ptr_base + start_k * mat_b.stride(0) * b_element_size; + // Leading dimension = distance between columns = stride(1) + ldb = mat_b.stride(1); + } + + // Calculate output pointer for group i in 3D tensor [num_groups, M, N] + // stride(0) = M*N elements between groups, so skip i*stride(0) elements to reach group i + void* group_e_ptr = out_ptr_base + i * out.stride(0) * out_element_size; + int lda, ldc; + if (std::is_same::value) { + // Row-major A [M,K]: leading dimension = distance between rows = stride(0) + lda = mat_a.stride(0); + } else { + // Column-major A [M,K]: leading dimension = distance between columns = stride(1) + lda = mat_a.stride(1); + } + // Output is always row-major in 3D tensor [num_groups, M, N] + // Leading dimension for each group's [M,N] slice = stride(1) = N + ldc = out.stride(1); + size_t output_group_bytes = M * N * out_element_size; + void* group_e_ptr_end = (char*)group_e_ptr + output_group_bytes; + + gemm_descs.push_back({ + static_cast(M), + static_cast(N), + static_cast(k), + static_cast(lda), + static_cast(ldb), + static_cast(ldc), + {} // --> stride_Ds_ + }); + p_a_ptrs.push_back(group_a_ptr); + p_b_ptrs.push_back(group_b_ptr); + p_e_ptrs.push_back(group_e_ptr); + } + } else if (mat_a_dim == 2 && mat_b_dim == 3) { + // 2D*3D case requires offset tensor + auto offs_accessor = offs->accessor(); + int num_groups = offs_accessor.size(0); + + // 2d*3d input, output is 2d. + // A: [m * n_groups, k], B: [n_groups, n, k] or [n_groups, k, n], Output: [m * n_groups, n] + // Offset divides M dimension (rows of A), each group gets different rows of A and different batch of B + const int K = mat_a.size(1); // columns in A + // For 2D-3D case: The output determines N (result width) + const int N = out.size(1); // N is the width of the output tensor + + for (int i = 0; i < num_groups; ++i) { + int start_m = (i == 0) ? 0 : offs_accessor[i - 1]; + int end_m = offs_accessor[i]; + int m = end_m - start_m; + + // Skip zero-sized groups but continue processing subsequent groups + if (m <= 0) { + continue; + } + + // Select A rows for group i: skip start_m rows + const void* group_a_ptr; + int lda; + if (std::is_same::value) { + // Row-major A [total_m, K]: skip start_m rows, each row is stride(0) elements apart + group_a_ptr = a_ptr_base + start_m * mat_a.stride(0) * a_element_size; + lda = mat_a.stride(0); // distance between rows + } else { + // Column-major A [total_m, K]: skip start_m elements in the first dimension (stride(0) is between rows) + group_a_ptr = a_ptr_base + start_m * mat_a.stride(0) * a_element_size; + + // Detect stride pattern for A tensor to determine appropriate lda calculation + bool a_is_strided_tensor = (mat_a.stride(0) > mat_a.size(0)); + + if (a_is_strided_tensor) { + // For strided A tensors: stride(0) gives the actual leading dimension + lda = mat_a.stride(0); + } else { + // For non-strided A tensors: use the M dimension (total rows) + lda = mat_a.size(0); // Total M dimension for column-major layout + } + } + + // Select B batch for group i: B[i, :, :] + const void* group_b_ptr = b_ptr_base + i * mat_b.stride(0) * b_element_size; + int ldb; + + if (std::is_same::value) { + // Row-major GEMM: expecting B as [K, N] but we have [N, K], so transpose needed + ldb = mat_b.stride(2); // Leading dimension for accessing as [K, N] + } else { + // Detect stride pattern to determine appropriate ldb calculation + bool is_strided_tensor = (mat_b.stride(2) > mat_b.size(2)); + + if (is_strided_tensor) { + // For strided tensors: stride(2) gives the actual leading dimension + ldb = mat_b.stride(2); + } else { + // For non-strided tensors: use the N dimension + ldb = mat_b.size(1); + } + } + + // Output for this group: rows [start_m:end_m, :] in 2D output [total_m, N] + void* group_e_ptr = out_ptr_base + start_m * out.stride(0) * out_element_size; + int ldc = out.stride(0); // distance between rows in output (should be N for 2D case) + + gemm_descs.push_back({ + static_cast(m), + static_cast(N), + static_cast(K), + static_cast(lda), + static_cast(ldb), + static_cast(ldc), + {} // --> stride_Ds_ + }); + p_a_ptrs.push_back(group_a_ptr); + p_b_ptrs.push_back(group_b_ptr); + p_e_ptrs.push_back(group_e_ptr); + } + } else if (mat_a_dim == 3 && mat_b_dim == 3) { + // 3d*3d input, output is 3d - batched matrix multiplication + // A: [batch, m, k], B: [batch, k, n] or [batch, n, k] (depending on transpose), Output: [batch, m, n] + // Each batch is processed as a separate GEMM operation + const int batch_size = mat_a.size(0); + const int M = mat_a.size(1); // rows in each A matrix + const int K = mat_a.size(2); // columns in A == rows in B (or columns if B is transposed) + + // Determine N from B tensor - it could be B.size(1) or B.size(2) depending on layout + int N; + if (mat_b.size(1) == K) { + // B is [batch, k, n] - normal layout + N = mat_b.size(2); + } else if (mat_b.size(2) == K) { + // B is [batch, n, k] - transposed layout + N = mat_b.size(1); + } else { + TORCH_CHECK(false, "CK Group GEMM 3D-3D: B tensor dimensions incompatible with A. A=[", + batch_size, ",", M, ",", K, "], B=[", mat_b.size(0), ",", mat_b.size(1), ",", mat_b.size(2), "]"); + } + + for (int i = 0; i < batch_size; ++i) { + // Select A batch for group i: A[i, :, :] + const void* group_a_ptr = a_ptr_base + i * mat_a.stride(0) * a_element_size; + + // Select B batch for group i: B[i, :, :] + const void* group_b_ptr = b_ptr_base + i * mat_b.stride(0) * b_element_size; + + // Select output batch for group i: Output[i, :, :] + void* group_e_ptr = out_ptr_base + i * out.stride(0) * out_element_size; + + int lda, ldb, ldc; + + if (std::is_same::value) { + // Row-major A: leading dimension = distance between rows = stride(1) + lda = mat_a.stride(1); + } else { + // Column-major A: leading dimension = distance between columns = stride(2) + lda = mat_a.stride(2); + } + + if (std::is_same::value) { + // Row-major B: leading dimension = distance between rows + if (mat_b.size(1) == K) { + // B is [batch, k, n] - normal layout + ldb = mat_b.stride(1); // stride between K rows + } else { + // B is [batch, n, k] - transposed layout, treat as [k, n] for GEMM + ldb = mat_b.stride(2); // stride between N rows (since we're accessing as [k,n]) + } + } else { + // Column-major B: leading dimension = distance between columns + if (mat_b.size(1) == K) { + // B is [batch, k, n] - normal layout + ldb = mat_b.stride(2); // stride between N columns + } else { + // B is [batch, n, k] - transposed layout + ldb = mat_b.stride(1); // stride between K columns (since we're accessing as [n,k]→[k,n]) + } + } + + // Output is typically row-major: leading dimension = distance between rows = stride(1) + ldc = out.stride(1); + + gemm_descs.push_back({ + static_cast(M), + static_cast(N), + static_cast(K), + static_cast(lda), + static_cast(ldb), + static_cast(ldc), + {} // --> stride_Ds_ + }); + p_a_ptrs.push_back(group_a_ptr); + p_b_ptrs.push_back(group_b_ptr); + p_e_ptrs.push_back(group_e_ptr); + } + } else if (mat_a_dim == 3 && mat_b_dim == 2) { + // 3D*2D case requires offset tensor + auto offs_accessor = offs->accessor(); + int num_groups = offs_accessor.size(0); + // 3d*2d input, output is 3d. + // A: [n_groups, m, k], B: [k, total_n] (assuming row-major for both) + // Offset divides N dimension of B, each group gets different slice of B and different batch of A + const int batch_size = mat_a.size(0); // n_groups + const int M = mat_a.size(1); // rows in each A matrix + const int K = mat_a.size(2); // columns in A + + // For row-major A and B case: B should be [K, total_N] + const int total_N = mat_b.size(1); // B is [K, total_N] for row-major + + for (int i = 0; i < num_groups; ++i) { + int start_n = (i == 0) ? 0 : offs_accessor[i - 1]; + int end_n = offs_accessor[i]; + int n = end_n - start_n; + + // Skip zero-sized groups but continue processing subsequent groups + if (n <= 0) { + continue; + } + + // Select A batch for group i: A[i, :, :] + const void* group_a_ptr = a_ptr_base + i * mat_a.stride(0) * a_element_size; + + // Select B slice for group i: B[:, start_n:end_n] (B[K, total_N]) + const void* group_b_ptr; + int ldb; + + // Check if B is row-major or column-major + if (std::is_same::value) { + // Row-major B [K, total_N]: slice columns [start_n:end_n] + group_b_ptr = b_ptr_base + start_n * mat_b.stride(1) * b_element_size; + ldb = mat_b.stride(0); // distance between rows (should be total_N) + } else { + // Column-major B [K, total_N]: slice columns [start_n:end_n] + group_b_ptr = b_ptr_base + start_n * mat_b.stride(1) * b_element_size; + ldb = mat_b.stride(1); // distance between columns (should be K) + } + + // Select output slice for group i: Output[:, start_n:end_n] + void* group_e_ptr = out_ptr_base + start_n * out.stride(1) * out_element_size; + + int lda, ldc; + + // Row-major A: leading dimension = distance between rows = stride(1) + lda = mat_a.stride(1); + // Output is row-major: leading dimension = distance between rows = stride(0) + ldc = out.stride(0); + + gemm_descs.push_back({ + static_cast(M), + static_cast(n), + static_cast(K), + static_cast(lda), + static_cast(ldb), + static_cast(ldc), + {} // --> stride_Ds_ + }); + p_a_ptrs.push_back(group_a_ptr); + p_b_ptrs.push_back(group_b_ptr); + p_e_ptrs.push_back(group_e_ptr); + } + } else { + TORCH_CHECK(false, "CK Group GEMM: Unsupported dimensions, mat A dim is ", mat_a_dim, ", mat B dim is ", mat_b_dim); + } + + TORCH_INTERNAL_ASSERT(p_a_ptrs.size() > 0, "CK Group GEMM: No valid groups"); + + // Initialize d_ptrs with the correct size + std::vector> d_ptrs(p_a_ptrs.size()); + + static DeviceOp gemm_instance; + auto argument = gemm_instance.MakeArgument( + p_a_ptrs, p_b_ptrs, d_ptrs, p_e_ptrs, + gemm_descs, PassThrough{}, PassThrough{}, PassThrough{} + ); + TORCH_INTERNAL_ASSERT(gemm_instance.IsSupportedArgument(argument), + "CK Group GEMM: argument unsupported (shape/strides/type config)"); + size_t arg_buf_size = gemm_instance.GetDeviceKernelArgSize(&argument); + size_t ws_size = gemm_instance.GetWorkSpaceSize(&argument); + + void* gemm_arg_buf = nullptr; + void* ws_buf = nullptr; + + hipMalloc(&gemm_arg_buf, arg_buf_size); + hipMalloc(&ws_buf, ws_size); + + gemm_instance.SetDeviceKernelArgs(&argument, gemm_arg_buf); + gemm_instance.SetWorkSpacePointer(&argument, ws_buf); + + auto invoker = gemm_instance.MakeInvoker(); + hipStream_t stream = c10::hip::getCurrentHIPStream(); + invoker.Run(argument, {stream}); + hipFree(gemm_arg_buf); + hipFree(ws_buf); +} + +void group_gemm_ck( + const at::Tensor& input_a, + const at::Tensor& input_b_colmajor, + const std::optional& offs, + const std::optional& /*bias*/, + at::Tensor& out) +{ + // Detect if input_a is row-major based on stride pattern + bool a_row_major = (input_a.dim() == 3) ? (input_a.stride(2) == 1) : (input_a.stride(1) == 1); + bool b_col_major = (input_b_colmajor.dim() == 3) ? (input_b_colmajor.stride(1) == 1) : (input_b_colmajor.stride(0) == 1); + // Ensure tensor A is row-major and contiguous if not already + at::Tensor mat_a = input_a; + if (!a_row_major) { + // If A is not row-major, make it contiguous (row-major) + mat_a = input_a.contiguous(); + } + // Force tensor B to be column-major using double transpose trick + // This guarantees stride(0) == 1 and stride(1) == K for [K, N] shape + at::Tensor mat_b = input_b_colmajor; + if (!b_col_major) { + mat_b = input_b_colmajor.transpose(-2, -1).contiguous().transpose(-2, -1); + } + + // For 3D tensors, check the last dimension stride for row-major detection + a_row_major = (mat_a.dim() == 3) ? (mat_a.stride(2) == 1) : (mat_a.stride(1) == 1); + bool b_row_major = (mat_b.dim() == 3) ? (mat_b.stride(2) == 1) : (mat_b.stride(1) == 1); + + if (mat_a.dtype() == at::kBFloat16) { + // bf16 path + if (a_row_major && b_row_major) { + launch_grouped_bgemm_ck_impl_dispatch(mat_a, mat_b, offs, out); + } else if (a_row_major && !b_row_major) { + launch_grouped_bgemm_ck_impl_dispatch(mat_a, mat_b, offs, out); + } else if (!a_row_major && b_row_major) { + launch_grouped_bgemm_ck_impl_dispatch(mat_a, mat_b, offs, out); + } else { + launch_grouped_bgemm_ck_impl_dispatch(mat_a, mat_b, offs, out); + } + } else if (mat_a.dtype() == at::kHalf) { + // fp16 path + if (a_row_major && b_row_major) { + launch_grouped_bgemm_ck_impl_dispatch(mat_a, mat_b, offs, out); + } else if (a_row_major && !b_row_major) { + launch_grouped_bgemm_ck_impl_dispatch(mat_a, mat_b, offs, out); + } else if (!a_row_major && b_row_major) { + launch_grouped_bgemm_ck_impl_dispatch(mat_a, mat_b, offs, out); + } else { + launch_grouped_bgemm_ck_impl_dispatch(mat_a, mat_b, offs, out); + } + } else if (mat_a.dtype() == at::kFloat) { + // fp32 path + if (a_row_major && b_row_major) { + launch_grouped_bgemm_ck_impl_dispatch(mat_a, mat_b, offs, out); + } else if (a_row_major && !b_row_major) { + launch_grouped_bgemm_ck_impl_dispatch(mat_a, mat_b, offs, out); + } else if (!a_row_major && b_row_major) { + launch_grouped_bgemm_ck_impl_dispatch(mat_a, mat_b, offs, out); + } else { + launch_grouped_bgemm_ck_impl_dispatch(mat_a, mat_b, offs, out); + } + } else { + TORCH_CHECK(false, "CK Group GEMM: Unsupported mat_a dtype"); + } + +} + +} // namespace detail +} // namespace hip +} // namespace at diff --git a/test/test_matmul_cuda.py b/test/test_matmul_cuda.py index 1ba947befd9e7..10611d4f24673 100644 --- a/test/test_matmul_cuda.py +++ b/test/test_matmul_cuda.py @@ -490,8 +490,6 @@ def test_grouped_gemm_3d_3d(self, strided, a_row_major, b_row_major, dtype): @parametrize("b_row_major", [False, True]) @dtypes(torch.bfloat16, torch.float32, torch.float16) def test_grouped_gemm_3d_2d(self, strided, a_row_major, b_row_major, dtype): - if TEST_WITH_ROCM and a_row_major and b_row_major and dtype in [torch.bfloat16, torch.float16]: - self.skipTest("failed using hipblaslt on rocm 6.4.2") device = "cuda" s_int = int(strided) m, n, k, n_groups = 16, 32, 64, 4 From c86540f12038ffc4a3c9ecdbecb01ce73e0967c9 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 5 Nov 2025 18:11:11 +0000 Subject: [PATCH 070/130] Revert "Add model code stack trace to torch.profile (#166677)" This reverts commit c00696144dae1f02e04ce345480b55e46c7d32a8. Reverted https://github.com/pytorch/pytorch/pull/166677 on behalf of https://github.com/jeffdaily due to broke rocm ([comment](https://github.com/pytorch/pytorch/pull/166677#issuecomment-3492658160)) --- ...t-fx_backcompat_function_signatures.expect | 2 +- test/test_fx.py | 180 ------------------ torch/autograd/profiler_util.py | 40 ---- torch/fx/graph.py | 23 --- torch/fx/graph_module.py | 16 +- torch/profiler/_utils.py | 169 +--------------- 6 files changed, 5 insertions(+), 425 deletions(-) diff --git a/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect b/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect index 12f6ba2228db8..a404e15a977ee 100644 --- a/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect +++ b/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect @@ -23,7 +23,7 @@ torch.fx.graph.Graph.node_copy(self, node: torch.fx.node.Node, arg_transform: Ca torch.fx.graph.Graph.output(self, result: 'Argument', type_expr: Optional[Any] = None) torch.fx.graph.Graph.placeholder(self, name: str, type_expr: Optional[Any] = None, default_value: Any) -> torch.fx.node.Node torch.fx.graph.Graph.print_tabular(self) -torch.fx.graph.Graph.python_code(self, root_module: str, verbose: bool = False, include_stride: bool = False, include_device: bool = False, colored: bool = False, expanded_def: bool = False, record_func: bool = False) -> torch.fx.graph.PythonCode +torch.fx.graph.Graph.python_code(self, root_module: str, verbose: bool = False, include_stride: bool = False, include_device: bool = False, colored: bool = False, expanded_def: bool = False) -> torch.fx.graph.PythonCode torch.fx.graph_module.GraphModule.__init__(self, root: Union[torch.nn.modules.module.Module, Dict[str, Any]], graph: torch.fx.graph.Graph, class_name: str = 'GraphModule') torch.fx.graph_module.GraphModule.add_submodule(self, target: str, m: torch.nn.modules.module.Module) -> bool torch.fx.graph_module.GraphModule.delete_all_unused_submodules(self) -> None diff --git a/test/test_fx.py b/test/test_fx.py index e12189dfea461..92d35fd8f49ad 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -75,12 +75,6 @@ ) from torch.testing._internal.jit_utils import JitTestCase -import json -import tempfile -from torch.profiler import profile, ProfilerActivity -from torch.profiler._utils import map_recorded_events_to_aten_ops_with_stack_trace -from torch.autograd.profiler_util import _canonicalize_profiler_events - try: from torchvision import models as torchvision_models @@ -207,36 +201,6 @@ def side_effect_func(x: torch.Tensor): print(x) -def _enrich_profiler_traces(prof): - """ - Helper function to extract and augment profiler events with stack traces. - - Args: - prof: A torch.profiler.profile object - - Returns: - A string representing enriched events - """ - with tempfile.NamedTemporaryFile(mode='w', suffix='.json') as f: - trace_file = f.name - prof.export_chrome_trace(trace_file) - - with open(trace_file) as f: - trace_data = json.load(f) - - map_recorded_events_to_aten_ops_with_stack_trace( - trace_data - ) - - events = [] - for event in trace_data["traceEvents"]: - if "args" in event and "stack_trace" in event["args"]: - events.append(event) - - actual_traces = _canonicalize_profiler_events(events) - return actual_traces - - class TestFX(JitTestCase): def setUp(self): super().setUp() @@ -4248,150 +4212,6 @@ def fn(a, b, c, d): # recorver mutable checking flag torch.fx.proxy.TracerBase.check_mutable_operations = orig_tracer_mutable_flag - @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") - @torch._dynamo.config.patch("enrich_profiler_metadata", True) - def test_profiler_stack_trace_augmentation(self): - """ - Test that map_recorded_events_to_aten_ops_with_stack_trace correctly - augments profiler events with stack traces from FX metadata registry. - """ - - # Simple test model - class TestModel(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear1 = torch.nn.Linear(10, 16) - self.relu = torch.nn.ReLU() - self.linear2 = torch.nn.Linear(16, 10) - - def forward(self, x): - x = self.linear1(x) - x = self.relu(x) - x = self.linear2(x) - return x - - model = TestModel().cuda() - - # Compile the model - compiled_model = torch.compile(model, backend="aot_eager", fullgraph=True) - - # Warmup - for _ in range(3): - _ = compiled_model(torch.randn(10, 10, device="cuda")) - - # Profile with the compiled model - with profile( - activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], - ) as prof: - result = compiled_model(torch.randn(10, 10, device="cuda")) - - actual_traces = _enrich_profiler_traces(prof) - - self.assertExpectedInline(actual_traces, """\ -event=aten::t node=t stack_trace=x = self.linear1(x) -event=aten::transpose node=t stack_trace=x = self.linear1(x) -event=aten::as_strided node=t stack_trace=x = self.linear1(x) -event=aten::addmm node=addmm stack_trace=x = self.linear1(x) -event=cudaLaunchKernel node=addmm stack_trace=x = self.linear1(x) -event=aten::relu node=relu stack_trace=x = self.relu(x) -event=aten::clamp_min node=relu stack_trace=x = self.relu(x) -event=cudaLaunchKernel node=relu stack_trace=x = self.relu(x) -event=aten::t node=t_1 stack_trace=x = self.linear2(x) -event=aten::transpose node=t_1 stack_trace=x = self.linear2(x) -event=aten::as_strided node=t_1 stack_trace=x = self.linear2(x) -event=aten::addmm node=addmm_1 stack_trace=x = self.linear2(x) -event=cudaLaunchKernel node=addmm_1 stack_trace=x = self.linear2(x)""" - ) - - @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") - @torch._dynamo.config.patch("enrich_profiler_metadata", True) - def test_profiler_multiple_modules(self): - """ - Test that multiple compiled modules under the same profiler session - have their events correctly augmented with stack traces. - """ - - class ModelA(torch.nn.Module): - def forward(self, x): - return x + 1 - - class ModelB(torch.nn.Module): - def forward(self, x): - return x - 1 - - model_a = ModelA().cuda() - model_b = ModelB().cuda() - - # Compile both models - compiled_a = torch.compile(model_a, backend="aot_eager", fullgraph=True) - compiled_b = torch.compile(model_b, backend="aot_eager", fullgraph=True) - - # Warmup - for _ in range(3): - _ = compiled_a(torch.randn(10, 10, device="cuda")) - _ = compiled_b(torch.randn(1, 3, 8, 8, device="cuda")) - - # Profile both models in the same session - with profile( - activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], - ) as prof: - result_a = compiled_a(torch.randn(10, 10, device="cuda")) - result_b = compiled_b(torch.randn(1, 3, 8, 8, device="cuda")) - - actual_traces = _enrich_profiler_traces(prof) - self.assertExpectedInline(actual_traces, """\ -event=aten::add node=add stack_trace=return x + 1 -event=cudaLaunchKernel node=add stack_trace=return x + 1 -event=aten::sub node=sub stack_trace=return x - 1 -event=cudaLaunchKernel node=sub stack_trace=return x - 1""" - ) - - @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") - @torch._dynamo.config.patch("enrich_profiler_metadata", True) - def test_profiler_nested_graph_modules(self): - """ - Test that nested graph modules (e.g., graph modules calling subgraphs) - have their events correctly augmented with stack traces. - """ - - # Model with nested structure - class Mod(torch.nn.Module): - def __init__(self): - super().__init__() - self.c = 5 - - @torch.compiler.nested_compile_region - def forward(self, x, y): - m = torch.mul(x, y) - s = m.sin() - a = s + self.c - return a - - model = Mod().cuda() - - # Compile the model (this may create nested graph modules) - compiled_model = torch.compile(model, backend="aot_eager", fullgraph=True) - - # Warmup - for _ in range(3): - _ = compiled_model(torch.randn(10, 10, device="cuda"), torch.randn(10, 10, device="cuda")) - - # Profile - with profile( - activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], - ) as prof: - result = compiled_model(torch.randn(10, 10, device="cuda"), torch.randn(10, 10, device="cuda")) - - actual_traces = _enrich_profiler_traces(prof) - self.assertExpectedInline(actual_traces, """\ -event=aten::mul node=mul stack_trace=m = torch.mul(x, y) -event=cudaLaunchKernel node=mul stack_trace=m = torch.mul(x, y) -event=aten::sin node=sin stack_trace=s = m.sin() -event=cudaLaunchKernel node=sin stack_trace=s = m.sin() -event=aten::add node=add stack_trace=a = s + self.c -event=cudaLaunchKernel node=add stack_trace=a = s + self.c""" - ) - def run_getitem_target(): from torch.fx._symbolic_trace import _wrapped_methods_to_patch diff --git a/torch/autograd/profiler_util.py b/torch/autograd/profiler_util.py index a61aee321fcff..b2d6530049e61 100644 --- a/torch/autograd/profiler_util.py +++ b/torch/autograd/profiler_util.py @@ -1224,43 +1224,3 @@ def override_time_unit(time_us, default_str, time_unit): f"time total: {override_time_unit(sum_self_device_time_total, _format_time(sum_self_device_time_total), time_unit)}" ) return "".join(result) - - -# Collect all events with stack traces and format them canonically -def _canonicalize_profiler_events(events): - """ - Extract and format all events with stack traces in a canonical way - for deterministic testing. - """ - events_with_traces = [] - - for event in events: - # Extract relevant fields - event_name = event.get("name", "") - node_name = event["args"].get("node_name", "") - stack_trace = event["args"].get("stack_trace", "") - - # Get the last non-empty line of the stack trace - lines = [s.strip() for s in stack_trace.split("\n") if s.strip()] - stack_trace = lines[-1] if lines else "" - - events_with_traces.append( - { - "event_name": event_name[:20], - "node_name": node_name, - "stack_trace": stack_trace, - "start_time": event.get("ts", 0), - } - ) - - # Sort by node_name for deterministic ordering - events_with_traces.sort(key=lambda x: x["start_time"]) - - # Format as a string - lines: list[str] = [] - for evt in events_with_traces: - lines.append( - f"event={evt['event_name']} node={evt['node_name']} stack_trace={evt['stack_trace']}" - ) - - return "\n".join(lines) diff --git a/torch/fx/graph.py b/torch/fx/graph.py index fd6835d2b301b..697b2f4084ca5 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -443,7 +443,6 @@ def _gen_python_code( colored: bool = False, # Render each argument on its own line expanded_def: bool = False, - record_func: bool = False, ) -> PythonCode: free_vars: list[str] = [] body: list[str] = [] @@ -799,10 +798,6 @@ def _tensor_annotation(t: torch.Tensor) -> str: return raise NotImplementedError(f"node: {node.op} {node.target}") - if record_func: - body.append( - "_rf = torch._C._profiler._RecordFunctionFast('## ENTER_GRAPH_PLACEHOLDER_KEY ##'); _rf.__enter__()\n" - ) for i, node in enumerate(nodes): # NOTE: emit_node does not emit a string with newline. It depends # on delete_unused_values to append one @@ -812,22 +807,8 @@ def _tensor_annotation(t: torch.Tensor) -> str: # node index, which will be deleted later # after going through _body_transformer body.append(f"# COUNTER: {i}\n") - do_record = record_func and node.op in ( - "call_function", - "call_method", - "call_module", - ) - if do_record: - # The double hash ## convention is used by post-processing to find the fx markers - body.append( - f"_rf_{node.name} = torch._C._profiler._RecordFunctionFast('## {i} ##'); _rf_{node.name}.__enter__()\n" - ) emit_node(node) delete_unused_values(node) - if do_record: - body.append(f"_rf_{node.name}.__exit__(None, None, None)\n") - if record_func: - body.append("_rf.__exit__(None, None, None)\n") if len(body) == 0: # If the Graph has no non-placeholder nodes, no lines for the body @@ -1779,7 +1760,6 @@ def python_code( include_device: bool = False, colored: bool = False, expanded_def: bool = False, - record_func: bool = False, ) -> PythonCode: """ Turn this ``Graph`` into valid Python code. @@ -1847,7 +1827,6 @@ def override_node_repr(graph: Graph): include_device=include_device, colored=colored, expanded_def=expanded_def, - record_func=record_func, ) def _python_code( @@ -1860,7 +1839,6 @@ def _python_code( include_device: bool = False, colored: bool = False, expanded_def: bool = False, - record_func: bool = False, ) -> PythonCode: return self._codegen._gen_python_code( self.nodes, @@ -1871,7 +1849,6 @@ def _python_code( include_device=include_device, colored=colored, expanded_def=expanded_def, - record_func=record_func, ) def __str__(self) -> str: diff --git a/torch/fx/graph_module.py b/torch/fx/graph_module.py index 8360c96630d6c..297f76732584f 100644 --- a/torch/fx/graph_module.py +++ b/torch/fx/graph_module.py @@ -861,18 +861,14 @@ def recompile(self) -> PythonCode: if isinstance(self._graph._codegen, _PyTreeCodeGen): self._in_spec = self._graph._codegen.pytree_info.in_spec self._out_spec = self._graph._codegen.pytree_info.out_spec - - from torch._dynamo import config as dynamo_config - - python_code = self._graph.python_code( - root_module="self", record_func=dynamo_config.enrich_profiler_metadata - ) + python_code = self._graph.python_code(root_module="self") self._code = python_code.src self._lineno_map = python_code._lineno_map self._prologue_start = python_code._prologue_start cls = type(self) co_fields = self._graph._co_fields if hasattr(self._graph, "_co_fields") else {} + from torch._dynamo import config as dynamo_config if dynamo_config.enrich_profiler_metadata: # Generate metadata and register for profiler augmentation @@ -889,6 +885,7 @@ def recompile(self) -> PythonCode: # This ensures the same code+metadata always generates the same filename hash_value = _metadata_hash(self._code, node_metadata) file_stem = f"{FX_GRAPH_MODULE_FILE_PREFIX}_{hash_value}" + filename = f"{file_stem}.py" # Only include co_filename to use it directly as the cache key @@ -908,13 +905,6 @@ def recompile(self) -> PythonCode: _register_fx_metadata(filename, metadata) - # Replace the placeholder in generated code with actual filename - # The double hash ## convention is used by post-processing to find the fx markers - self._code = self._code.replace( - "torch._C._profiler._RecordFunctionFast('## ENTER_GRAPH_PLACEHOLDER_KEY ##')", - f"torch._C._profiler._RecordFunctionFast('## {filename} ##')", - ) - cls.forward = _forward_from_src(self._code, python_code.globals, co_fields) # Determine whether this class explicitly defines a __call__ implementation diff --git a/torch/profiler/_utils.py b/torch/profiler/_utils.py index 47df87ce1678d..2c6e06b2cb3c9 100644 --- a/torch/profiler/_utils.py +++ b/torch/profiler/_utils.py @@ -4,7 +4,7 @@ import re from collections import deque from dataclasses import dataclass -from typing import Any, Literal, Optional, TYPE_CHECKING +from typing import TYPE_CHECKING from torch.autograd.profiler import profile from torch.profiler import DeviceType @@ -400,170 +400,3 @@ def _init_for_cuda_graphs() -> None: with profile(): pass - - -@dataclass -class TimelineEvent: - """Represents an event in the profiler timeline.""" - - timestamp: int - event_type: Literal["start", "end", "regular"] - marker_type: Optional[Literal["filename", "node"]] - identifier: Optional[str | int] - event: dict[str, Any] - - -@dataclass -class ContextStackEntry: - """Represents a context (filename or node) in the stack.""" - - context_type: Literal["filename", "node"] - identifier: str | int - metadata: Optional[dict] - tid: Optional[int] = None # Thread ID associated with this context - - -def map_recorded_events_to_aten_ops_with_stack_trace(traced_data): - """ - Maps recorded profiler events to their corresponding fx nodes and adds stack traces. - - Builds a timeline of all events (regular ops and FX markers for filenames/nodes), - sorts by timestamp, then processes chronologically while maintaining a context stack of active - filename/node scopes. Regular events are augmented with stack traces and node names from the - innermost active context. Runtime is O(n log n) for n events. - - Args: - traced_data: Json of profiler events from Chrome trace - - Returns: - Dict mapping recorded event names to their aten operations with added stack traces - """ - from torch.fx.traceback import _FX_METADATA_REGISTRY - - trace_events = traced_data.get("traceEvents", []) - - # Create event timeline - event_timeline: list[TimelineEvent] = [] - - def is_fx_marker_event(event): - return ( - event.get("cat") == "cpu_op" - and event.get("name", "").startswith("## ") - and event.get("name", "").endswith(" ##") - ) - - def append_fx_marker_event(event_type, identifier, event): - start_ts = event["ts"] - end_ts = start_ts + event["dur"] - event_timeline.append( - TimelineEvent(start_ts, "start", event_type, identifier, event) - ) - event_timeline.append( - TimelineEvent(end_ts, "end", event_type, identifier, event) - ) - - for event in trace_events: - if "ts" not in event or "dur" not in event: - continue - - if is_fx_marker_event(event): - content = event["name"][3:-3] - - if content.endswith(".py"): - append_fx_marker_event("filename", content, event) - else: - try: - node_index = int(content) - except ValueError: - pass - append_fx_marker_event("node", node_index, event) # type: ignore[possibly-undefined] - - else: - # Regular event that needs augmentation - start_ts = event["ts"] - event_timeline.append(TimelineEvent(start_ts, "regular", None, None, event)) - - # Sort by timestamp - event_timeline.sort(key=lambda x: x.timestamp) - - # Process events in chronological order with a stack - context_stack: list[ContextStackEntry] = [] - - # Invariant: all start event has a corresponding end event - for timeline_event in event_timeline: - match timeline_event.event_type: - case "start": - assert timeline_event.identifier is not None - - if timeline_event.marker_type == "filename": - assert isinstance(timeline_event.identifier, str) - # Push filename context - query metadata registry on-demand - metadata = _FX_METADATA_REGISTRY.get(timeline_event.identifier) - tid = timeline_event.event.get("tid") - context_stack.append( - ContextStackEntry( - "filename", timeline_event.identifier, metadata, tid - ) - ) - elif timeline_event.marker_type == "node": - # Find the current filename from stack - current_file_metadata = None - tid = timeline_event.event.get("tid") - for ctx_entry in reversed(context_stack): - if ( - ctx_entry.context_type == "filename" - and ctx_entry.tid == tid - ): - current_file_metadata = ctx_entry.metadata - break - - if current_file_metadata: - node_metadata = current_file_metadata.get("node_metadata", {}) - if timeline_event.identifier in node_metadata: - node_meta: Optional[dict] = node_metadata[ - timeline_event.identifier - ] - context_stack.append( - ContextStackEntry( - "node", timeline_event.identifier, node_meta, tid - ) - ) - - case "end": - # Pop from stack - search backwards to find matching context - for i in range(len(context_stack) - 1, -1, -1): - ctx_entry = context_stack[i] - if ( - timeline_event.marker_type == ctx_entry.context_type - and timeline_event.identifier == ctx_entry.identifier - ): - context_stack.pop(i) - break - - case "regular": - # Apply metadata from current context stack - # Find the most specific context (node takes precedence over filename) - # Only augment events with the same tid as the file/node event matched - current_stack_trace = None - current_node_name = None - event_tid = timeline_event.event.get("tid") - - for ctx_entry in reversed(context_stack): - # Only apply metadata from contexts with matching tid - if ctx_entry.tid == event_tid: - if ctx_entry.context_type == "node" and ctx_entry.metadata: - current_stack_trace = ctx_entry.metadata.get( - "stack_trace", "No model stack trace available" - ) - current_node_name = ctx_entry.metadata.get("name", "") - # Do we want to only attach the stack trace of the lowest node or stack trace of all nodes - # if nodes are nested, e.g. in nested graph modules - break - - # Augment the event - if current_stack_trace or current_node_name: - args = timeline_event.event.setdefault("args", {}) - if current_stack_trace: - args["stack_trace"] = current_stack_trace - if current_node_name: - args["node_name"] = current_node_name From ad5c7c20e0dd55baa23a597cf10ffe7422b5cabf Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 5 Nov 2025 18:13:57 +0000 Subject: [PATCH 071/130] Revert "[cuDNN] Smoke-test runtime cuDNN version matches compile time version in CI (#165922)" This reverts commit 1d3f5e19da068ec1340db041b7105b287a513578. Reverted https://github.com/pytorch/pytorch/pull/165922 on behalf of https://github.com/atalman due to Introduces Segfault in linux-jammy-cuda12.8-py3.10-gcc11 ([comment](https://github.com/pytorch/pytorch/pull/165922#issuecomment-3492667312)) --- .ci/docker/common/install_cuda.sh | 2 +- .ci/pytorch/smoke_test/smoke_test.py | 12 ------------ 2 files changed, 1 insertion(+), 13 deletions(-) diff --git a/.ci/docker/common/install_cuda.sh b/.ci/docker/common/install_cuda.sh index fe0cb8cc79c4f..fe2f9ae3185a3 100644 --- a/.ci/docker/common/install_cuda.sh +++ b/.ci/docker/common/install_cuda.sh @@ -129,7 +129,7 @@ function install_129 { } function install_128 { - CUDNN_VERSION=9.10.2.21 + CUDNN_VERSION=9.8.0.87 echo "Installing CUDA 12.8.1 and cuDNN ${CUDNN_VERSION} and NVSHMEM and NCCL and cuSparseLt-0.7.1" # install CUDA 12.8.1 in the same container install_cuda 12.8.1 cuda_12.8.1_570.124.06_linux diff --git a/.ci/pytorch/smoke_test/smoke_test.py b/.ci/pytorch/smoke_test/smoke_test.py index 3642f29684cf0..675d58a3e283d 100644 --- a/.ci/pytorch/smoke_test/smoke_test.py +++ b/.ci/pytorch/smoke_test/smoke_test.py @@ -272,18 +272,6 @@ def smoke_test_cuda( torch_cudnn_version = cudnn_to_version_str(torch.backends.cudnn.version()) print(f"Torch cuDNN version: {torch_cudnn_version}") - torch_cudnn_compile_version = torch._C._cudnn.getCompileVersion() - print(f"Torch cuDNN compile-time version: {torch_cudnn_compile_version}") - torch_cudnn_runtime_version = tuple( - [int(x) for x in torch_cudnn_version.split(".")] - ) - if torch_cudnn_runtime_version != torch_cudnn_compile_version: - raise RuntimeError( - "cuDNN runtime version doesn't match comple version. " - f"Loaded: {torch_cudnn_runtime_version} " - f"Expected: {torch_cudnn_compile_version}" - ) - if sys.platform in ["linux", "linux2"]: torch_nccl_version = ".".join(str(v) for v in torch.cuda.nccl.version()) print(f"Torch nccl; version: {torch_nccl_version}") From dcc2ba4ca48512968e027e765695490476d717dc Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Wed, 5 Nov 2025 06:52:49 -0800 Subject: [PATCH 072/130] Add some code for exploring the space of accessible size/stride configs via plain views (#167076) We are working on a translation from as_strided to view operations, but only when the as_strided is representable as a plain view. A useful testing utility in this situation is the ability to enumerate all valid views on an original tensor. So we have a small test here that shows it is possible. To avoid an explosion of states, we don't handle permutes and size=1, which are degenerate cases (you can always do a single permute and a series of unsqueezes to get to the final desired state.) Authored with claude code assistance. Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/167076 Approved by: https://github.com/albanD ghstack dependencies: #166868, #166867 --- test/test_as_strided.py | 176 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 176 insertions(+) create mode 100644 test/test_as_strided.py diff --git a/test/test_as_strided.py b/test/test_as_strided.py new file mode 100644 index 0000000000000..a5bcb8e279247 --- /dev/null +++ b/test/test_as_strided.py @@ -0,0 +1,176 @@ +# Owner(s): ["oncall: pt2"] + +from collections import deque +from typing import Optional + +import torch +from torch.testing._internal.common_utils import run_tests, TestCase + + +def get_state(t: torch.Tensor) -> tuple[tuple[int, ...], tuple[int, ...]]: + """Extract (sizes, strides) tuple from a tensor.""" + return (tuple(t.size()), tuple(t.stride())) + + +def enumerate_reachable_states( + initial_size: int, +) -> set[tuple[tuple[int, ...], tuple[int, ...]]]: + """ + Use BFS with DP to enumerate all reachable (size, stride) states from + a 1D contiguous tensor via valid view operations. + + We only explore states with offset=0 (you can retroactively change the offset). + We reject states with size=0 or size=1 dimensions as they are degenerate. + """ + # Create initial 1D contiguous tensor + initial_tensor = torch.arange(initial_size) + + initial_state = get_state(initial_tensor) + + # Map from state to tensor for that state + state_to_tensor: dict[tuple[tuple[int, ...], tuple[int, ...]], torch.Tensor] = { + initial_state: initial_tensor + } + visited: set[tuple[tuple[int, ...], tuple[int, ...]]] = {initial_state} + queue: deque[tuple[tuple[int, ...], tuple[int, ...]]] = deque([initial_state]) + + while queue: + state = queue.popleft() + t = state_to_tensor[state] + sizes, strides = state + ndim = len(sizes) + + def add_state(new_t: torch.Tensor) -> None: + new_state = get_state(new_t) + sizes, strides = new_state + # Skip if has size-0 or size-1 dimensions + if any(s == 0 or s == 1 for s in sizes): + return + # Only accept states where strides are in descending order + if list(strides) != sorted(strides, reverse=True): + return + if new_state not in visited: + visited.add(new_state) + queue.append(new_state) + state_to_tensor[new_state] = new_t + + # 1. Unflatten: try factoring each dimension + for dim in range(ndim): + size = sizes[dim] + assert size > 1 + # Try all factorizations x * y = size where both x, y >= 2 + # We only need to check x up to size // 2 since when x > size // 2, + # y = size // x < 2, which we reject + for x in range(2, size // 2 + 1): + if size % x == 0: + y = size // x + add_state(t.unflatten(dim, (x, y))) + + # 2. Slice: exhaustively check all possible slicing parameters + for dim in range(ndim): + size = sizes[dim] + for start in range(size): + for stop in range(start + 1, size + 1): + for step in range(1, size + 1): + slices = [slice(None)] * ndim + slices[dim] = slice(start, stop, step) + add_state(t[tuple(slices)]) + + # 3. Flatten: merge adjacent dimensions + for dim in range(ndim - 1): + add_state(t.flatten(dim, dim + 1)) + + return visited + + +class TestAsStrided(TestCase): + def test_size_10_exhaustive(self) -> None: + """Test that size 10 produces exactly the expected 54 states.""" + expected_states = { + ((2,), (1,)), + ((2,), (2,)), + ((2,), (3,)), + ((2,), (4,)), + ((2,), (5,)), + ((2,), (6,)), + ((2,), (7,)), + ((2,), (8,)), + ((2,), (9,)), + ((2, 2), (2, 1)), + ((2, 2), (3, 1)), + ((2, 2), (3, 2)), + ((2, 2), (4, 1)), + ((2, 2), (4, 2)), + ((2, 2), (4, 3)), + ((2, 2), (5, 1)), + ((2, 2), (5, 2)), + ((2, 2), (5, 3)), + ((2, 2), (5, 4)), + ((2, 2), (6, 1)), + ((2, 2), (6, 2)), + ((2, 2), (6, 3)), + ((2, 2), (8, 1)), + ((2, 2, 2), (4, 2, 1)), + ((2, 2, 2), (5, 2, 1)), + ((2, 3), (3, 1)), + ((2, 3), (4, 1)), + ((2, 3), (5, 1)), + ((2, 3), (5, 2)), + ((2, 3), (6, 1)), + ((2, 4), (4, 1)), + ((2, 4), (5, 1)), + ((2, 5), (5, 1)), + ((3,), (1,)), + ((3,), (2,)), + ((3,), (3,)), + ((3,), (4,)), + ((3, 2), (2, 1)), + ((3, 2), (3, 1)), + ((3, 2), (3, 2)), + ((3, 2), (4, 1)), + ((3, 3), (3, 1)), + ((4,), (1,)), + ((4,), (2,)), + ((4,), (3,)), + ((4, 2), (2, 1)), + ((5,), (1,)), + ((5,), (2,)), + ((5, 2), (2, 1)), + ((6,), (1,)), + ((7,), (1,)), + ((8,), (1,)), + ((9,), (1,)), + ((10,), (1,)), + } + + actual_states = enumerate_reachable_states(10) + + self.assertEqual(len(actual_states), 54) + self.assertEqual(actual_states, expected_states) + + def test_subset_property(self) -> None: + """ + Test that for sizes 2..10, each smaller tensor results in a strict + subset of possible states compared to the next one. + """ + prev_states: Optional[set[tuple[tuple[int, ...], tuple[int, ...]]]] = None + for size in range(2, 11): + current_states = enumerate_reachable_states(size) + + if prev_states is not None: + # Check that prev_states is a strict subset of current_states + self.assertTrue( + prev_states.issubset(current_states), + f"States from size {size - 1} are not a subset of size {size}", + ) + # Check that it's a strict subset (not equal) + self.assertTrue( + len(prev_states) < len(current_states), + f"States from size {size - 1} should be strictly fewer than size {size}", + ) + + prev_states = current_states + + +if __name__ == "__main__": + run_tests() From 89165c0a2b5d3c147c19a492437291c8ff18aa7f Mon Sep 17 00:00:00 2001 From: Andrey Talman Date: Wed, 5 Nov 2025 18:26:31 +0000 Subject: [PATCH 073/130] Update triton to 3.5.1 release (#166968) This includes sm103 https://github.com/triton-lang/triton/pull/8485 fix Pull Request resolved: https://github.com/pytorch/pytorch/pull/166968 Approved by: https://github.com/Lucaskabela, https://github.com/njriasan --- .ci/docker/ci_commit_pins/triton.txt | 2 +- .ci/docker/triton_version.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.ci/docker/ci_commit_pins/triton.txt b/.ci/docker/ci_commit_pins/triton.txt index 10f1207e60e6c..7aab8bed1c108 100644 --- a/.ci/docker/ci_commit_pins/triton.txt +++ b/.ci/docker/ci_commit_pins/triton.txt @@ -1 +1 @@ -7416ffcb92cdbe98d9f97e4e6f95247e46dfc9fd +bfeb066872bc1e8b2d2bc0a3b295b99dd77206e7 diff --git a/.ci/docker/triton_version.txt b/.ci/docker/triton_version.txt index 1545d966571dc..d5c0c99142898 100644 --- a/.ci/docker/triton_version.txt +++ b/.ci/docker/triton_version.txt @@ -1 +1 @@ -3.5.0 +3.5.1 From 641de23c96e2c0d2848a7aa2aacb2f77843177a5 Mon Sep 17 00:00:00 2001 From: Eli Uriegas Date: Wed, 5 Nov 2025 17:05:14 +0000 Subject: [PATCH 074/130] ci: Add aarch64 docker builds for modern clang (#166416) Should enable us to build using some arm optimizations that are only available on the newest versions of clang. Signed-off-by: Eli Uriegas Pull Request resolved: https://github.com/pytorch/pytorch/pull/166416 Approved by: https://github.com/malfet --- .ci/docker/build.sh | 10 ++++++++++ .ci/docker/common/install_clang.sh | 4 ++-- .ci/docker/common/install_openblas.sh | 1 + .github/workflows/docker-builds.yml | 2 ++ 4 files changed, 15 insertions(+), 2 deletions(-) diff --git a/.ci/docker/build.sh b/.ci/docker/build.sh index d0500b89780ce..5257decb9d4d5 100755 --- a/.ci/docker/build.sh +++ b/.ci/docker/build.sh @@ -271,6 +271,16 @@ case "$tag" in # from pytorch/llvm:9.0.1 is x86 specific SKIP_LLVM_SRC_BUILD_INSTALL=yes ;; + pytorch-linux-jammy-aarch64-py3.10-clang21) + ANACONDA_PYTHON_VERSION=3.10 + CLANG_VERSION=21 + ACL=yes + VISION=yes + OPENBLAS=yes + # snadampal: skipping llvm src build install because the current version + # from pytorch/llvm:9.0.1 is x86 specific + SKIP_LLVM_SRC_BUILD_INSTALL=yes + ;; pytorch-linux-jammy-aarch64-py3.10-gcc11-inductor-benchmarks) ANACONDA_PYTHON_VERSION=3.10 GCC_VERSION=11 diff --git a/.ci/docker/common/install_clang.sh b/.ci/docker/common/install_clang.sh index 1cb216edf1b38..93daeee919b3d 100755 --- a/.ci/docker/common/install_clang.sh +++ b/.ci/docker/common/install_clang.sh @@ -8,8 +8,8 @@ if [ -n "$CLANG_VERSION" ]; then # work around ubuntu apt-get conflicts sudo apt-get -y -f install wget --no-check-certificate -O - https://apt.llvm.org/llvm-snapshot.gpg.key | sudo apt-key add - - if [[ $CLANG_VERSION == 18 ]]; then - apt-add-repository "deb http://apt.llvm.org/jammy/ llvm-toolchain-jammy-18 main" + if [[ $CLANG_VERSION -ge 18 ]]; then + apt-add-repository "deb http://apt.llvm.org/jammy/ llvm-toolchain-jammy-${CLANG_VERSION} main" fi fi diff --git a/.ci/docker/common/install_openblas.sh b/.ci/docker/common/install_openblas.sh index 2f386c6bd523a..5a28068781245 100755 --- a/.ci/docker/common/install_openblas.sh +++ b/.ci/docker/common/install_openblas.sh @@ -10,6 +10,7 @@ git clone https://github.com/OpenMathLib/OpenBLAS.git -b "${OPENBLAS_VERSION}" - OPENBLAS_CHECKOUT_DIR="OpenBLAS" OPENBLAS_BUILD_FLAGS=" +CC=gcc NUM_THREADS=128 USE_OPENMP=1 NO_SHARED=0 diff --git a/.github/workflows/docker-builds.yml b/.github/workflows/docker-builds.yml index 6fbe2e846d40b..4d0940094f541 100644 --- a/.github/workflows/docker-builds.yml +++ b/.github/workflows/docker-builds.yml @@ -79,6 +79,8 @@ jobs: include: - docker-image-name: pytorch-linux-jammy-aarch64-py3.10-gcc11 runner: linux.arm64.m7g.4xlarge + - docker-image-name: pytorch-linux-jammy-aarch64-py3.10-clang21 + runner: linux.arm64.m7g.4xlarge - docker-image-name: pytorch-linux-jammy-aarch64-py3.10-gcc11-inductor-benchmarks runner: linux.arm64.m7g.4xlarge timeout-minutes: 600 From 14b153bcf28efa7056f8b0ecf2e8c7def97aa2ea Mon Sep 17 00:00:00 2001 From: Brian Hirsh Date: Wed, 5 Nov 2025 08:00:51 -0800 Subject: [PATCH 075/130] include DTensor metadata when pretty-printing fx.Graphs (#166750) Example below. You need to trace your function with DTensor inputs in order for the graph proxies to run on DTensor (and not the inner local tensor). You also need to run with `tracing_mode="fake"`, or with your own `FakeTensorMode`, to see the nice DTensor printing. If this doesn't feel very ergonomic then maybe we can find some better UX for printing a graph with DTensor in it: image ``` import torch from torch.testing._internal.distributed.fake_pg import FakeStore from torch.distributed.tensor import distribute_tensor, Shard, Replicate from torch.utils._debug_mode import DebugMode from torch.fx.experimental.proxy_tensor import make_fx from torch.utils._python_dispatch import TorchDispatchMode from torch.utils import _pytree as pytree world_size = 8 device_type = "cpu" fake_store = FakeStore() torch.distributed.init_process_group("fake", store=fake_store, rank=0, world_size=world_size) device_mesh = torch.distributed.init_device_mesh(device_type, (world_size,)) dim = 128 A = torch.randn(8, dim) B = torch.randn(dim, dim) dA = distribute_tensor(A, device_mesh, [Shard(0)]).requires_grad_() dB = distribute_tensor(B, device_mesh, [Replicate()]).requires_grad_() def f(dA, dB): dy = dA @ dB loss = dy.sum() loss.backward() return dA.grad, dB.grad # We actually need the tracing_mode='fake' here, or to trace under a FakeTensorMode. # make_fx has some logic to ensure we don't accidentally stash real tensors in the graph # so we won't stash our DTensors properly if they don't hold Fake inner tensors gm = make_fx(f, tracing_mode='fake')(dA, dB) # DCE isn't necessary here, there were just a lot of dead detach() nodes that spammed the graph gm.graph.eliminate_dead_code() gm.recompile() gm.print_readable(colored=True) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/166750 Approved by: https://github.com/ezyang, https://github.com/wconstab, https://github.com/Skylion007 --- .../tensor/debug/test_debug_mode.py | 35 ++++++++++++++++++- torch/fx/graph.py | 19 ++++++++++ 2 files changed, 53 insertions(+), 1 deletion(-) diff --git a/test/distributed/tensor/debug/test_debug_mode.py b/test/distributed/tensor/debug/test_debug_mode.py index 9acfcb15804e5..abc37f17a74de 100644 --- a/test/distributed/tensor/debug/test_debug_mode.py +++ b/test/distributed/tensor/debug/test_debug_mode.py @@ -5,8 +5,16 @@ import torch import torch.distributed as dist from torch._subclasses.fake_tensor import FakeTensorMode -from torch.distributed.tensor import DeviceMesh, DTensor, Partial, Replicate, Shard +from torch.distributed.tensor import ( + DeviceMesh, + distribute_tensor, + DTensor, + Partial, + Replicate, + Shard, +) from torch.distributed.tensor._dtensor_spec import ShardOrderEntry +from torch.fx.experimental.proxy_tensor import make_fx from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, @@ -426,6 +434,31 @@ def forward(self, x): ][-1] self.assertTrue("self.l2(self.l1(x))" in sum_op.fwd_stack_trace) + def test_pretty_print_dtensor_make_fx(self): + mesh = DeviceMesh(self.device_type, list(range(self.world_size))) + + A = torch.randn(8, 32) + B = torch.randn(32, 32) + dA = distribute_tensor(A, mesh, [Shard(0)]).requires_grad_() + dB = distribute_tensor(B, mesh, [Replicate()]).requires_grad_() + + def f(dA, dB): + dy = dA @ dB + loss = dy.sum() + loss.backward() + return dA.grad, dB.grad + + # We actually need the tracing_mode='fake' here, or to trace under a FakeTensorMode. + # make_fx has some logic to ensure we don't accidentally stash real tensors in the graph + # so we won't stash our DTensors properly if they don't hold Fake inner tensors + gm = make_fx(f, tracing_mode="fake")(dA, dB) + # DCE isn't necessary here, there were just a lot of dead detach() nodes that spammed the graph + gm.graph.eliminate_dead_code() + gm.recompile() + # Colored is nice for actual viewing, not using in this test though + gm_str = gm.print_readable(colored=False, print_output=False) + self.assertTrue('"DTensor(f32[8, 32], S(0))" = torch.ops.aten.mm' in gm_str) + instantiate_parametrized_tests(TestDTensorDebugMode) diff --git a/torch/fx/graph.py b/torch/fx/graph.py index 697b2f4084ca5..899a50f0f4142 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -647,6 +647,15 @@ def emit_node(node: Node): if verbose: # override annotation with more detailed information + try: + from torch.distributed.tensor._api import DTensor, DTensorSpec + + dtensorspec_format_shard_order_str = ( + DTensorSpec.format_shard_order_str + ) + except ModuleNotFoundError: + DTensor = None # type: ignore[assignment,misc] + dtensorspec_format_shard_order_str = None from torch.fx.experimental.proxy_tensor import py_sym_types from torch.fx.passes.shape_prop import TensorMetadata @@ -677,6 +686,16 @@ def _tensor_annotation(t: torch.Tensor) -> str: core = _tensor_annotation(meta_val) if is_plain: maybe_type_annotation = f': "{core}"' + elif type(meta_val) is DTensor: + assert dtensorspec_format_shard_order_str is not None + dtensor_meta = dtensorspec_format_shard_order_str( + meta_val._spec.placements, # type: ignore[attr-defined] + meta_val._spec.shard_order, # type: ignore[attr-defined] + ) + cls = meta_val.__class__.__name__ + maybe_type_annotation = ( + f': "{cls}({core}, {dim_green(dtensor_meta)})"' + ) else: cls = meta_val.__class__.__name__ maybe_type_annotation = f': "{cls}({core})"' From 6052a01b71277eb767d87daf47d109f8e0edd5c0 Mon Sep 17 00:00:00 2001 From: Lucas Kabela Date: Wed, 5 Nov 2025 19:18:35 +0000 Subject: [PATCH 076/130] [BE][Typing][Dynamo] Type torch/_dynamo/variables/dicts.py (#167022) Provides type coverage to torch/_dynamo/variables/dicts.py Coverage report: `mypy torch/_dynamo/variables/dicts.py --linecount-report /tmp/coverage_log` Compare before to after - we go from 0 lines and 0 funcs covered to 1547 lines and 89 funcs covered Pull Request resolved: https://github.com/pytorch/pytorch/pull/167022 Approved by: https://github.com/Skylion007 --- torch/_dynamo/symbolic_convert.py | 6 +- torch/_dynamo/variables/dicts.py | 358 +++++++++++++++++------------- 2 files changed, 208 insertions(+), 156 deletions(-) diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 53ec0ee412849..3943f90b0020a 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -3320,7 +3320,7 @@ def SET_ADD(self, inst: Instruction) -> None: obj = self.stack[-inst.arg] assert isinstance(obj, SetVariable) assert obj.is_mutable() - obj.call_method(self, "add", [v], {}) + obj.call_method(self, "add", [v], {}) # type: ignore[arg-type] def SET_UPDATE(self, inst: Instruction) -> None: v = self.pop() @@ -3329,7 +3329,7 @@ def SET_UPDATE(self, inst: Instruction) -> None: obj = self.stack[-inst.arg] assert isinstance(obj, SetVariable) assert obj.is_mutable() - obj.call_method(self, "update", [v], {}) + obj.call_method(self, "update", [v], {}) # type: ignore[arg-type] def LIST_APPEND(self, inst: Instruction) -> None: v = self.pop() @@ -3637,7 +3637,7 @@ def DICT_MERGE(self, inst: Instruction) -> None: obj = self.stack[-inst.arg].realize() assert isinstance(obj, ConstDictVariable) assert obj.is_mutable() - obj.call_method(self, "update", [v], {}) + obj.call_method(self, "update", [v], {}) # type: ignore[arg-type] DICT_UPDATE = DICT_MERGE diff --git a/torch/_dynamo/variables/dicts.py b/torch/_dynamo/variables/dicts.py index f70ba99c0c93d..fb212c3326222 100644 --- a/torch/_dynamo/variables/dicts.py +++ b/torch/_dynamo/variables/dicts.py @@ -1,5 +1,3 @@ -# mypy: ignore-errors - """ Dictionary-related variable tracking classes for PyTorch Dynamo. @@ -26,7 +24,7 @@ import operator import types from collections.abc import Hashable as py_Hashable -from typing import Optional, TYPE_CHECKING +from typing import Any, Optional, TYPE_CHECKING, Union from torch._subclasses.fake_tensor import is_fake @@ -59,11 +57,13 @@ # - (perhaps) Define how it is compared in _HashableTracker._eq_impl -def was_instancecheck_override(obj): +def was_instancecheck_override(obj: Any) -> bool: return type(obj).__dict__.get("__instancecheck__", False) -def raise_unhashable(arg, tx=None): +def raise_unhashable( + arg: VariableTracker, tx: Optional["InstructionTranslator"] = None +) -> None: if tx is None: from torch._dynamo.symbolic_convert import InstructionTranslator @@ -75,7 +75,7 @@ def raise_unhashable(arg, tx=None): ) -def is_hashable(x): +def is_hashable(x: VariableTracker) -> bool: # NB - performing isinstance check on a LazVT realizes the VT, accidentally # inserting the guard. To avoid this, lazyVT `is_hashable` methods looks at # the underlying value without realizing the VT. Consider updating the @@ -143,7 +143,7 @@ class _HashableTracker: Note that it's also fine to put VTs into dictionaries and sets, but doing so does not take into account aliasing """ - def __init__(self, vt) -> None: + def __init__(self, vt: VariableTracker) -> None: # We specialize SymNodes vt = specialize_symnode(vt) # TODO Temporarily remove to figure out what keys are we breaking on @@ -153,7 +153,7 @@ def __init__(self, vt) -> None: self.vt = vt @property - def underlying_value(self): + def underlying_value(self) -> Any: if ( isinstance(self.vt, variables.LazyVariableTracker) and not self.vt.is_realized() @@ -178,7 +178,8 @@ def underlying_value(self): elif isinstance(self.vt, variables.FrozenDataClassVariable): Hashable = ConstDictVariable._HashableTracker fields_values = { - k: Hashable(v).underlying_value for k, v in self.vt.fields.items() + k: Hashable(v).underlying_value + for k, v in self.vt.fields.items() # type: ignore[attr-defined] } return variables.FrozenDataClassVariable.HashWrapper( self.vt.python_type(), fields_values @@ -187,16 +188,16 @@ def underlying_value(self): # The re module in Python 3.13+ has a dictionary (_cache2) with # an object as key (`class _ZeroSentinel(int): ...`): # python test/dynamo/test_unittest.py CPythonTestLongMessage.test_baseAssertEqual - return self.vt.value + return self.vt.value # type: ignore[attr-defined,union-attr] else: x = self.vt.as_python_constant() return x - def __hash__(self): + def __hash__(self) -> int: return hash(self.underlying_value) @staticmethod - def _eq_impl(a, b): + def _eq_impl(a: Any, b: Any) -> bool: # TODO: Put this in utils and share it between variables/builtin.py and here type_a, type_b = type(a), type(b) if not (issubclass(type_a, type_b) or issubclass(type_b, type_a)): @@ -212,7 +213,7 @@ def _eq_impl(a, b): else: return a == b - def __eq__(self, other: "ConstDictVariable._HashableTracker") -> bool: + def __eq__(self, other: object) -> bool: Hashable = ConstDictVariable._HashableTracker assert isinstance(other, Hashable) or ConstantVariable.is_literal(other), ( type(other) @@ -226,8 +227,8 @@ def __eq__(self, other: "ConstDictVariable._HashableTracker") -> bool: def __init__( self, items: dict[VariableTracker, VariableTracker], - user_cls=dict, - **kwargs, + user_cls: type = dict, + **kwargs: Any, ) -> None: # .clone() pass these arguments in kwargs but they're recreated a few # lines below @@ -247,18 +248,22 @@ def __init__( for x, v in items.items() ) - def make_hashable(key): + def make_hashable( + key: Union[VariableTracker, "ConstDictVariable._HashableTracker"], + ) -> "ConstDictVariable._HashableTracker": return key if isinstance(key, Hashable) else Hashable(key) dict_cls = self._get_dict_cls_from_user_cls(user_cls) self.items = dict_cls({make_hashable(x): v for x, v in items.items()}) # need to reconstruct everything if the dictionary is an intermediate value # or if a pop/delitem was executed - self.should_reconstruct_all = not is_from_local_source(self.source) + self.should_reconstruct_all = ( + not is_from_local_source(self.source) if self.source else True + ) self.original_items = items.copy() self.user_cls = user_cls - def _get_dict_cls_from_user_cls(self, user_cls): + def _get_dict_cls_from_user_cls(self, user_cls: type) -> type: accepted_dict_types = (dict, collections.OrderedDict, collections.defaultdict) # avoid executing user code if user_cls is a dict subclass @@ -277,10 +282,10 @@ def _get_dict_cls_from_user_cls(self, user_cls): dict_cls = dict return dict_cls - def as_proxy(self): + def as_proxy(self) -> dict[Any, Any]: return {k.vt.as_proxy(): v.as_proxy() for k, v in self.items.items()} - def debug_repr(self): + def debug_repr(self) -> str: return ( "{" + ", ".join( @@ -289,20 +294,20 @@ def debug_repr(self): + "}" ) - def as_python_constant(self): + def as_python_constant(self) -> dict[Any, Any]: return { k.vt.as_python_constant(): v.as_python_constant() for k, v in self.items.items() } - def keys_as_python_constant(self): + def keys_as_python_constant(self) -> dict[Any, VariableTracker]: self.install_dict_keys_match_guard() return {k.vt.as_python_constant(): v for k, v in self.items.items()} - def python_type(self): + def python_type(self) -> type: return self.user_cls - def __contains__(self, vt) -> bool: + def __contains__(self, vt: VariableTracker) -> bool: assert isinstance(vt, VariableTracker) Hashable = ConstDictVariable._HashableTracker return ( @@ -322,13 +327,15 @@ def has_new_items(self) -> bool: for key, value in self.items.items() ) - def is_new_item(self, value, other): + def is_new_item( + self, value: Optional[VariableTracker], other: VariableTracker + ) -> bool: # compare the id of the realized values if both values are not lazy VTs if value and value.is_realized() and other.is_realized(): return id(value.realize()) != id(other.realize()) return id(value) != id(other) - def reconstruct_kvs_into_new_dict(self, codegen): + def reconstruct_kvs_into_new_dict(self, codegen: "PyCodegen") -> None: # Build a dictionary that contains the keys and values. num_args = 0 for key, value in self.items.items(): @@ -340,7 +347,7 @@ def reconstruct_kvs_into_new_dict(self, codegen): num_args += 1 codegen.append_output(create_instruction("BUILD_MAP", arg=num_args)) - def reconstruct(self, codegen: "PyCodegen"): + def reconstruct(self, codegen: "PyCodegen") -> None: if self.user_cls is collections.OrderedDict: # emit `OrderedDict(constructed_dict)` codegen.add_push_null( @@ -358,19 +365,21 @@ def reconstruct(self, codegen: "PyCodegen"): def getitem_const_raise_exception_if_absent( self, tx: "InstructionTranslator", arg: VariableTracker - ): + ) -> VariableTracker: key = ConstDictVariable._HashableTracker(arg) if key not in self.items: raise_observed_exception(KeyError, tx) return self.items[key] - def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker): + def getitem_const( + self, tx: "InstructionTranslator", arg: VariableTracker + ) -> VariableTracker: key = ConstDictVariable._HashableTracker(arg) if key not in self.items: - msg = f"Dictionary key {arg.value} not found during tracing" + msg = f"Dictionary key {arg.value} not found during tracing" # type: ignore[attr-defined] unimplemented_v2( gb_type="key not found in dict", - context=f"Key {arg.value}", + context=f"Key {arg.value}", # type: ignore[attr-defined] explanation=msg, hints=[ "Check if the key exists in the dictionary before accessing it.", @@ -379,13 +388,13 @@ def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker): ) return self.items[key] - def maybe_getitem_const(self, arg: VariableTracker): + def maybe_getitem_const(self, arg: VariableTracker) -> Optional[VariableTracker]: key = ConstDictVariable._HashableTracker(arg) if key not in self.items: return None return self.items[key] - def realize_key_vt(self, arg: VariableTracker): + def realize_key_vt(self, arg: VariableTracker) -> None: # Realize the LazyVT on a particular index assert arg in self key = ConstDictVariable._HashableTracker(arg) @@ -394,11 +403,13 @@ def realize_key_vt(self, arg: VariableTracker): if isinstance(original_key_vt, variables.LazyVariableTracker): original_key_vt.realize() - def install_dict_keys_match_guard(self): + def install_dict_keys_match_guard(self) -> None: if self.source: install_guard(self.make_guard(GuardBuilder.DICT_KEYS_MATCH)) - def install_dict_contains_guard(self, tx, args): + def install_dict_contains_guard( + self, tx: "InstructionTranslator", args: list[VariableTracker] + ) -> None: # Key guarding - These are the cases to consider # 1) The dict has been mutated. In this case, we would have already # inserted a DICT_KEYS_MATCH guard, so we can skip. @@ -439,11 +450,11 @@ def install_dict_contains_guard(self, tx, args): def call_method( self, - tx, - name, - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ) -> "VariableTracker": + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: # NB - Both key and value are LazyVariableTrackers in the beginning. So, # we have to insert guards when a dict method is accessed. For this to # be simple, we are conservative and overguard. We skip guard only for @@ -462,7 +473,7 @@ def call_method( tx, *args, **kwargs ) tx.output.side_effects.mutation(self) - self.items.update(temp_dict_vt.items) + self.items.update(temp_dict_vt.items) # type: ignore[attr-defined] return ConstantVariable.create(None) elif name == "__getitem__": # Key guarding - Nothing to do. LazyVT for value will take care. @@ -526,7 +537,7 @@ def call_method( return ConstantVariable.create(len(self.items)) elif name == "__setitem__" and self.is_mutable(): if not arg_hashable: - raise_unhashable(args[0]) + raise_unhashable(args[0], tx) self.install_dict_keys_match_guard() if kwargs or len(args) != 2: @@ -550,7 +561,7 @@ def call_method( raise_args_mismatch(tx, name, "1 or 2 args", f"{len(args)} args") if not arg_hashable: - raise_unhashable(args[0]) + raise_unhashable(args[0], tx) if args[0] not in self: self.install_dict_contains_guard(tx, args) @@ -565,7 +576,7 @@ def call_method( raise_args_mismatch(tx, name, "1 or 2 args", f"{len(args)} args") if not arg_hashable: - raise_unhashable(args[0]) + raise_unhashable(args[0], tx) if args[0] not in self: # missing item, return the default value. Install no DICT_CONTAINS guard. @@ -599,7 +610,7 @@ def call_method( last = v.value else: raise_args_mismatch(tx, name) - k, v = self.items.popitem(last=last) + k, v = self.items.popitem(last=last) # type: ignore[possibly-undefined] else: k, v = self.items.popitem() @@ -632,17 +643,17 @@ def call_method( # NB - Guard on all the keys of the other dict to ensure # correctness. args[0].install_dict_keys_match_guard() - dict_vt = args[0] + dict_vt: ConstDictVariable = args[0] else: - dict_vt = BuiltinVariable.call_custom_dict(tx, dict, args[0]) - self.items.update(dict_vt.items) + dict_vt = BuiltinVariable.call_custom_dict(tx, dict, args[0]) # type: ignore[assignment] + self.items.update(dict_vt.items) # type: ignore[attr-defined] if has_kwargs: # Handle kwargs - kwargs = { + kwargs_hashable = { Hashable(ConstantVariable.create(k)): v for k, v in kwargs.items() } - self.items.update(kwargs) + self.items.update(kwargs_hashable) return ConstantVariable.create(None) else: return super().call_method(tx, name, args, kwargs) @@ -656,7 +667,7 @@ def call_method( ) if not arg_hashable: - raise_unhashable(args[0]) + raise_unhashable(args[0], tx) self.install_dict_contains_guard(tx, args) contains = args[0] in self @@ -671,7 +682,7 @@ def call_method( ) if not arg_hashable: - raise_unhashable(args[0]) + raise_unhashable(args[0], tx) self.install_dict_keys_match_guard() if kwargs or len(args) > 2: @@ -707,7 +718,7 @@ def call_method( and "last" in kwargs and isinstance(kwargs["last"], ConstantVariable) ): - last = kwargs.get("last").value + last = kwargs.get("last").value # type: ignore[union-attr] key = Hashable(args[0]) self.items.move_to_end(key, last=last) @@ -723,7 +734,7 @@ def call_method( ) elif name == "__ne__": return ConstantVariable.create( - not self.call_method(tx, "__eq__", args, kwargs).value + not self.call_method(tx, "__eq__", args, kwargs).value # type: ignore[attr-defined] ) elif name == "__or__": if len(args) != 1: @@ -750,14 +761,14 @@ def call_method( if not istype( other, (ConstDictVariable, variables.UserDefinedDictVariable) ): - msg = ( + err_msg = ( f"unsupported operand type(s) for |: '{self.python_type().__name__}'" f"and '{other.python_type().__name__}'" ) - raise_observed_exception(TypeError, tx, args=[msg]) + raise_observed_exception(TypeError, tx, args=[err_msg]) # OrderedDict overloads __ror__ - ts = {self.user_cls, other.user_cls} + ts = {self.user_cls, other.user_cls} # type: ignore[attr-defined] user_cls = ( collections.OrderedDict if any(issubclass(t, collections.OrderedDict) for t in ts) @@ -774,8 +785,8 @@ def call_method( # NB - Guard on all the keys of the other dict to ensure # correctness. - args[0].install_dict_keys_match_guard() - new_dict_vt.items.update(args[0].items) + args[0].install_dict_keys_match_guard() # type: ignore[attr-defined] + new_dict_vt.items.update(args[0].items) # type: ignore[attr-defined] return new_dict_vt elif name == "__ior__": self.call_method(tx, "update", args, kwargs) @@ -789,11 +800,13 @@ def call_method( else: return super().call_method(tx, name, args, kwargs) - def unpack_var_sequence(self, tx): + def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]: self.install_dict_keys_match_guard() return [x.vt for x in self.items.keys()] - def call_obj_hasattr(self, tx, name): + def call_obj_hasattr( + self, tx: "InstructionTranslator", name: str + ) -> VariableTracker: # dict not allow setting arbitrary attributes. OrderedDict and # defaultdict allow arbitrary setattr, but not deletion of default attrs if any( @@ -816,25 +829,25 @@ def call_obj_hasattr(self, tx, name): ], ) - def clone(self, **kwargs): + def clone(self, **kwargs: Any) -> VariableTracker: self.install_dict_keys_match_guard() return super().clone(**kwargs) class MappingProxyVariable(VariableTracker): # proxies to the original dict_vt - def __init__(self, dv_dict: ConstDictVariable, **kwargs) -> None: + def __init__(self, dv_dict: ConstDictVariable, **kwargs: Any) -> None: super().__init__(**kwargs) assert isinstance(dv_dict, ConstDictVariable) self.dv_dict = dv_dict - def python_type(self): + def python_type(self) -> type: return types.MappingProxyType - def unpack_var_sequence(self, tx): + def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]: return self.dv_dict.unpack_var_sequence(tx) - def reconstruct(self, codegen: "PyCodegen"): + def reconstruct(self, codegen: "PyCodegen") -> None: # load types.MappingProxyType if self.source: msg = ( @@ -863,11 +876,11 @@ def reconstruct(self, codegen: "PyCodegen"): def call_method( self, - tx, - name, - args: list["VariableTracker"], - kwargs: dict[str, "VariableTracker"], - ) -> "VariableTracker": + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: if self.source and tx.output.side_effects.has_existing_dict_mutation(): msg = ( "A dict has been modified while we have an existing mappingproxy object. " @@ -892,7 +905,7 @@ def call_method( def call_obj_hasattr( self, tx: "InstructionTranslator", name: str - ) -> "VariableTracker": + ) -> VariableTracker: if self.python_type() is types.MappingProxyType: return ConstantVariable.create(name in types.MappingProxyType.__dict__) return super().call_obj_hasattr(tx, name) @@ -900,35 +913,44 @@ def call_obj_hasattr( class NNModuleHooksDictVariable(ConstDictVariable): # Special class to avoid adding any guards on the nn module hook ids. - def install_dict_keys_match_guard(self): + def install_dict_keys_match_guard(self) -> None: pass - def install_dict_contains_guard(self, tx, args): + def install_dict_contains_guard( + self, tx: "InstructionTranslator", args: list[VariableTracker] + ) -> None: pass class DefaultDictVariable(ConstDictVariable): - def __init__(self, items, user_cls, default_factory=None, **kwargs) -> None: + def __init__( + self, + items: dict[VariableTracker, VariableTracker], + user_cls: type, + default_factory: Optional[VariableTracker] = None, + **kwargs: Any, + ) -> None: super().__init__(items, user_cls, **kwargs) assert user_cls is collections.defaultdict if default_factory is None: default_factory = ConstantVariable.create(None) self.default_factory = default_factory - def is_python_constant(self): + def is_python_constant(self) -> bool: # Return false for unsupported defaults. This ensures that a bad handler # path is not taken in BuiltinVariable for getitem. if self.default_factory not in [list, tuple, dict] and not self.items: return False return super().is_python_constant() - def debug_repr(self): + def debug_repr(self) -> str: + assert self.default_factory is not None return ( f"defaultdict({self.default_factory.debug_repr()}, {super().debug_repr()})" ) @staticmethod - def is_supported_arg(arg): + def is_supported_arg(arg: VariableTracker) -> bool: if isinstance(arg, variables.BuiltinVariable): return arg.fn in (list, tuple, dict, set) else: @@ -942,11 +964,11 @@ def is_supported_arg(arg): def call_method( self, - tx, - name, - args: "list[VariableTracker]", - kwargs: "dict[str, VariableTracker]", - ) -> "VariableTracker": + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: if name == "__getitem__": if len(args) != 1: raise_args_mismatch(tx, name, "1 args", f"{len(args)} args") @@ -962,13 +984,13 @@ def call_method( else: default_var = self.default_factory.call_function(tx, [], {}) super().call_method( - tx, "__setitem__", (args[0], default_var), kwargs + tx, "__setitem__", [args[0], default_var], kwargs ) return default_var else: return super().call_method(tx, name, args, kwargs) - def reconstruct(self, codegen): + def reconstruct(self, codegen: "PyCodegen") -> None: # emit `defaultdict(default_factory, new_dict)` codegen.add_push_null( lambda: codegen.extend_output( @@ -994,40 +1016,48 @@ class SetVariable(ConstDictVariable): def __init__( self, items: list[VariableTracker], - **kwargs, + **kwargs: Any, ) -> None: + # pyrefly: ignore[bad-assignment] items = dict.fromkeys(items, SetVariable._default_value()) + # pyrefly: ignore[bad-argument-type] super().__init__(items, **kwargs) - def debug_repr(self): + def debug_repr(self) -> str: if not self.items: return "set()" else: return "{" + ",".join(k.vt.debug_repr() for k in self.items.keys()) + "}" @property - def set_items(self): + def set_items(self) -> set["ConstDictVariable._HashableTracker"]: return set(self.items.keys()) @staticmethod - def _default_value(): + def _default_value() -> VariableTracker: # Variable to fill in he keys of the dictionary return ConstantVariable.create(None) - def as_proxy(self): + def as_proxy(self) -> Any: return {k.vt.as_proxy() for k in self.set_items} - def python_type(self): + def python_type(self) -> type: return set - def as_python_constant(self): + def as_python_constant(self) -> Any: return {k.vt.as_python_constant() for k in self.set_items} - def reconstruct(self, codegen: "PyCodegen"): + def reconstruct(self, codegen: "PyCodegen") -> None: codegen.foreach([x.vt for x in self.set_items]) codegen.append_output(create_instruction("BUILD_SET", arg=len(self.set_items))) - def _fast_set_method(self, tx, fn, args, kwargs): + def _fast_set_method( + self, + tx: "InstructionTranslator", + fn: Any, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: try: res = fn( *[x.as_python_constant() for x in [self, *args]], @@ -1037,15 +1067,16 @@ def _fast_set_method(self, tx, fn, args, kwargs): raise_observed_exception( type(exc), tx, args=list(map(ConstantVariable.create, exc.args)) ) + # pyrefly: ignore[unbound-name] return VariableTracker.build(tx, res) def call_method( self, - tx, - name, + tx: "InstructionTranslator", + name: str, args: list[VariableTracker], kwargs: dict[str, VariableTracker], - ) -> "VariableTracker": + ) -> VariableTracker: # We forward the calls to the dictionary model from ..utils import check_constant_args @@ -1065,10 +1096,10 @@ def call_method( return self._fast_set_method(tx, getattr(py_type, name), args, kwargs) if name == "__init__": - temp_set_vt = variables.BuiltinVariable(set).call_set(tx, *args, *kwargs) + temp_set_vt = variables.BuiltinVariable(set).call_set(tx, *args, **kwargs) tx.output.side_effects.mutation(self) self.items.clear() - self.items.update(temp_set_vt.items) + self.items.update(temp_set_vt.items) # type: ignore[attr-defined] return ConstantVariable.create(None) elif name == "add": if kwargs or len(args) != 1: @@ -1079,7 +1110,7 @@ def call_method( f"{len(args)} args and {len(kwargs)} kwargs", ) name = "__setitem__" - args = (args[0], SetVariable._default_value()) + args = [args[0], SetVariable._default_value()] elif name == "pop": if kwargs or args: raise_args_mismatch( @@ -1090,12 +1121,14 @@ def call_method( ) # Choose an item at random and pop it via the Dict.pop method try: - result = self.set_items.pop().vt + result: VariableTracker = self.set_items.pop().vt # type: ignore[assignment] except KeyError as e: raise_observed_exception( KeyError, tx, args=list(map(ConstantVariable.create, e.args)) ) - super().call_method(tx, name, (result,), kwargs) + # pyrefly: ignore[unbound-name] + super().call_method(tx, name, [result], kwargs) + # pyrefly: ignore[unbound-name] return result elif name == "isdisjoint": if kwargs or len(args) != 1: @@ -1217,6 +1250,7 @@ def call_method( f"unsupported operand type(s) for {name}: '{self.python_type_name()}' and '{args[0].python_type_name()}'" ) raise_observed_exception(TypeError, tx, args=[msg]) + assert m is not None return self.call_method(tx, m, args, kwargs) elif name in ("__iand__", "__ior__", "__ixor__", "__isub__"): if not isinstance(args[0], (SetVariable, variables.UserDefinedSetVariable)): @@ -1230,29 +1264,34 @@ def call_method( "__ixor__": "symmetric_difference_update", "__isub__": "difference_update", }.get(name) + assert m is not None self.call_method(tx, m, args, kwargs) return self elif name == "__eq__": if not isinstance(args[0], (SetVariable, variables.UserDefinedSetVariable)): return ConstantVariable.create(False) r = self.call_method(tx, "symmetric_difference", args, kwargs) - return ConstantVariable.create(len(r.set_items) == 0) + return ConstantVariable.create(len(r.set_items) == 0) # type: ignore[attr-defined] elif name in cmp_name_to_op_mapping: if not isinstance(args[0], (SetVariable, variables.UserDefinedSetVariable)): return ConstantVariable.create(NotImplemented) return ConstantVariable.create( - cmp_name_to_op_mapping[name](self.set_items, args[0].set_items) + cmp_name_to_op_mapping[name](self.set_items, args[0].set_items) # type: ignore[attr-defined] ) return super().call_method(tx, name, args, kwargs) - def getitem_const(self, tx: "InstructionTranslator", arg: VariableTracker): + def getitem_const( + self, tx: "InstructionTranslator", arg: VariableTracker + ) -> VariableTracker: raise RuntimeError("Illegal to getitem on a set") - def install_dict_keys_match_guard(self): + def install_dict_keys_match_guard(self) -> None: # Already EQUALS_MATCH guarded pass - def install_dict_contains_guard(self, tx, args): + def install_dict_contains_guard( + self, tx: "InstructionTranslator", args: list[VariableTracker] + ) -> None: super().install_dict_contains_guard(tx, args) @@ -1260,27 +1299,27 @@ class FrozensetVariable(SetVariable): def __init__( self, items: list[VariableTracker], - **kwargs, + **kwargs: Any, ) -> None: super().__init__(items, **kwargs) - def debug_repr(self): + def debug_repr(self) -> str: if not self.items: return "frozenset()" else: return "{" + ",".join(k.vt.debug_repr() for k in self.items.keys()) + "}" @property - def set_items(self): + def set_items(self) -> set["ConstDictVariable._HashableTracker"]: return self.items.keys() - def python_type(self): + def python_type(self) -> type: return frozenset - def as_python_constant(self): + def as_python_constant(self) -> Any: return frozenset({k.vt.as_python_constant() for k in self.set_items}) - def reconstruct(self, codegen: "PyCodegen"): + def reconstruct(self, codegen: "PyCodegen") -> None: codegen.foreach([x.vt for x in self.set_items]) codegen.add_push_null( lambda: codegen.extend_output( @@ -1293,11 +1332,11 @@ def reconstruct(self, codegen: "PyCodegen"): def call_method( self, - tx, - name, + tx: "InstructionTranslator", + name: str, args: list[VariableTracker], kwargs: dict[str, VariableTracker], - ) -> "VariableTracker": + ) -> VariableTracker: if name in ["add", "pop", "update", "remove", "discard", "clear"]: raise RuntimeError(f"Illegal call_method {name} on a frozenset") elif name == "__init__": @@ -1316,7 +1355,7 @@ def call_method( "symmetric_difference", ): r = super().call_method(tx, name, args, kwargs) - return FrozensetVariable(r.items) + return FrozensetVariable(r.items) # type: ignore[attr-defined] return super().call_method(tx, name, args, kwargs) @@ -1324,11 +1363,11 @@ class DictKeySetVariable(SetVariable): def __init__( self, items: list[VariableTracker], - **kwargs, + **kwargs: Any, ) -> None: super().__init__(items, **kwargs) - def debug_repr(self): + def debug_repr(self) -> str: if not self.items: return "dict_keys([])" else: @@ -1338,33 +1377,35 @@ def debug_repr(self): + "])" ) - def install_dict_keys_match_guard(self): + def install_dict_keys_match_guard(self) -> None: # Already EQUALS_MATCH guarded pass - def install_dict_contains_guard(self, tx, args): + def install_dict_contains_guard( + self, tx: "InstructionTranslator", args: list[VariableTracker] + ) -> None: # Already EQUALS_MATCH guarded pass @property - def set_items(self): + def set_items(self) -> Any: return self.items - def python_type(self): + def python_type(self) -> type: return dict_keys - def as_python_constant(self): + def as_python_constant(self) -> Any: return dict.fromkeys( {k.vt.as_python_constant() for k in self.set_items}, None ).keys() def call_method( self, - tx, - name, + tx: "InstructionTranslator", + name: str, args: list[VariableTracker], kwargs: dict[str, VariableTracker], - ) -> "VariableTracker": + ) -> VariableTracker: if name in ["add", "pop", "update", "remove", "discard", "clear"]: raise RuntimeError(f"Illegal call_method {name} on a dict_keys") return super().call_method(tx, name, args, kwargs) @@ -1379,42 +1420,47 @@ class DictViewVariable(VariableTracker): kv: Optional[str] = None - def __init__(self, dv_dict: ConstDictVariable, **kwargs) -> None: + def __init__(self, dv_dict: ConstDictVariable, **kwargs: Any) -> None: super().__init__(**kwargs) assert self.kv in ("keys", "values", "items") assert isinstance(dv_dict, ConstDictVariable) self.dv_dict = dv_dict @property - def view_items(self): + def view_items(self) -> Any: + assert self.kv is not None return getattr(self.dv_dict.items, self.kv)() @property - def view_items_vt(self): + def view_items_vt(self) -> list[VariableTracker]: # Returns an iterable of the unpacked items # Implement in the subclasses raise NotImplementedError - def unpack_var_sequence(self, tx): + def unpack_var_sequence(self, tx: "InstructionTranslator") -> list[VariableTracker]: return self.view_items_vt - def reconstruct(self, codegen: "PyCodegen"): + def reconstruct(self, codegen: "PyCodegen") -> None: + assert self.kv is not None codegen(self.dv_dict) codegen.load_method(self.kv) codegen.call_method(0) - def call_obj_hasattr(self, tx, name): + def call_obj_hasattr( + self, tx: "InstructionTranslator", name: str + ) -> VariableTracker: + assert self.kv is not None if name in self.python_type().__dict__: return ConstantVariable.create(True) return ConstantVariable.create(False) def call_method( self, - tx, - name, - args: list["VariableTracker"], - kwargs: dict[str, "VariableTracker"], - ) -> "VariableTracker": + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: if name == "__len__": return self.dv_dict.call_method(tx, name, args, kwargs) elif name == "__iter__": @@ -1428,24 +1474,24 @@ class DictKeysVariable(DictViewVariable): kv = "keys" @property - def set_items(self): + def set_items(self) -> set[VariableTracker]: return set(self.view_items) @property - def view_items_vt(self): + def view_items_vt(self) -> list[VariableTracker]: # Returns an iterable of the unpacked items return [x.vt for x in self.view_items] - def python_type(self): + def python_type(self) -> type: return dict_keys def call_method( self, - tx, - name, - args: list["VariableTracker"], - kwargs: dict[str, "VariableTracker"], - ) -> "VariableTracker": + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: if name == "__contains__": return self.dv_dict.call_method(tx, name, args, kwargs) elif name in ( @@ -1460,13 +1506,13 @@ def call_method( ): # These methods always returns a set m = getattr(self.set_items, name) - r = m(args[0].set_items) + r = m(args[0].set_items) # type: ignore[attr-defined] return SetVariable(r) if name in cmp_name_to_op_mapping: if not isinstance(args[0], (SetVariable, DictKeysVariable)): return ConstantVariable.create(NotImplemented) return ConstantVariable.create( - cmp_name_to_op_mapping[name](self.set_items, args[0].set_items) + cmp_name_to_op_mapping[name](self.set_items, args[0].set_items) # type: ignore[attr-defined] ) return super().call_method(tx, name, args, kwargs) @@ -1476,10 +1522,10 @@ class DictValuesVariable(DictViewVariable): kv = "values" @property - def view_items_vt(self): + def view_items_vt(self) -> list[VariableTracker]: return list(self.view_items) - def python_type(self): + def python_type(self) -> type: return dict_values @@ -1487,14 +1533,20 @@ class DictItemsVariable(DictViewVariable): kv = "items" @property - def view_items_vt(self): + def view_items_vt(self) -> list[VariableTracker]: # Returns an iterable of the unpacked items return [variables.TupleVariable([k.vt, v]) for k, v in self.view_items] - def python_type(self): + def python_type(self) -> type: return dict_items - def call_method(self, tx, name, args, kwargs): + def call_method( + self, + tx: "InstructionTranslator", + name: str, + args: list[VariableTracker], + kwargs: dict[str, VariableTracker], + ) -> VariableTracker: # TODO(guilhermeleobas): This should actually check if args[0] # implements the mapping protocol. if name == "__eq__": From 6c5db82584bf71f5b1db3b598bbd00f44140c28d Mon Sep 17 00:00:00 2001 From: Jack Taylor <108682042+jataylo@users.noreply.github.com> Date: Wed, 5 Nov 2025 19:27:23 +0000 Subject: [PATCH 077/130] [Inductor] Naive foreach autotune support (#162053) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Initial autotuning support for foreach kernels, 4x improvement for some kernels in internal workload. More improvements can surely be made here in the future. Removing num_warps for definition to enable autotune support in generated wrapper code. Before: triton_for_fused_18.kd 🔍 | 4.986 ms | 4.986 ms | 2.493 ms | 2 | triton_for_fused_6.kd 🔍 | 0.098 ms | 0.098 ms | 0.049 ms | 2 | triton_for_fused_7.kd 🔍 | 0.036 ms | 0.036 ms | 0.018 ms | 2 | After: triton_for_fused_18.kd 🔍 | 1.273 ms | 1.273 ms | 0.636 ms | 2 | triton_for_fused_6.kd 🔍 | 0.044 ms | 0.044 ms | 0.022 ms | 2 | triton_for_fused_7.kd 🔍 | 0.024 ms | 0.024 ms | 0.012 ms | 2 | num_warps=8 default due to https://github.com/pytorch/pytorch/blob/main/torch/_inductor/codegen/triton_combo_kernel.py#L374 Pull Request resolved: https://github.com/pytorch/pytorch/pull/162053 Approved by: https://github.com/mlazos, https://github.com/naromero77amd, https://github.com/jeffdaily Co-authored-by: Nichols A. Romero --- torch/_inductor/codegen/triton_combo_kernel.py | 2 +- torch/_inductor/runtime/triton_heuristics.py | 15 +++++++++++++-- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/torch/_inductor/codegen/triton_combo_kernel.py b/torch/_inductor/codegen/triton_combo_kernel.py index 3e58e95ef9e9c..1f531a5d99ef5 100644 --- a/torch/_inductor/codegen/triton_combo_kernel.py +++ b/torch/_inductor/codegen/triton_combo_kernel.py @@ -627,7 +627,7 @@ def jit_line( if heuristics == "foreach": heuristics_line = f""" @triton_heuristics.foreach( - num_warps={self.num_warps}, + filename=__file__, triton_meta={triton_meta!r}, inductor_meta={inductor_meta!r}, ) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index cb43d55bc86b3..cdecd50927024 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -3586,13 +3586,24 @@ def user_autotune( ) -def foreach(triton_meta, num_warps, filename=None, inductor_meta=None): +def foreach(triton_meta, filename=None, inductor_meta=None): """ Compile a triton foreach kernel """ + configs = [] + + # Naive autotuning path for num_warps + if not ( + inductor_meta.get("max_autotune") or inductor_meta.get("max_autotune_pointwise") + ): + configs.append(triton.Config({}, num_stages=1, num_warps=8)) + else: + for warps in [1, 2, 4, 8]: + configs.append(triton.Config({}, num_stages=1, num_warps=warps)) + return cached_autotune( None, - [triton.Config({}, num_stages=1, num_warps=num_warps)], + configs, triton_meta=triton_meta, inductor_meta=inductor_meta, heuristic_type=HeuristicType.TEMPLATE, From fbd70fb84e347b45db79eb24cc2c53e447a04147 Mon Sep 17 00:00:00 2001 From: Maggie Moss Date: Wed, 5 Nov 2025 19:35:34 +0000 Subject: [PATCH 078/130] Update typing docs to reference pyrefly (#166883) Replacing mypy codumentation in the CONTRIBUTING.MD file with pyrefly references. I have made initial changes to https://github.com/pytorch/pytorch/wiki/Guide-for-adding-type-annotations-to-PyTorch documentation, and will replace the script at the bottom with one tailored to the pyrefly tool as a follow-up. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166883 Approved by: https://github.com/malfet --- CONTRIBUTING.md | 31 ++++++++++++++++++++++++------- 1 file changed, 24 insertions(+), 7 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 9df55ca6acd5c..bc0b0fc9bb00f 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -18,7 +18,7 @@ aspects of contributing to PyTorch. - [Python Unit Testing](#python-unit-testing) - [Better local unit tests with `pytest`](#better-local-unit-tests-with-pytest) - [Local linting](#local-linting) - - [Running `mypy`](#running-mypy) + - [Running `pyrefly`](#running-pyrefly) - [C++ Unit Testing](#c-unit-testing) - [Run Specific CI Jobs](#run-specific-ci-jobs) - [Merging your Change](#merging-your-change) @@ -281,7 +281,7 @@ dependencies as well as the nightly binaries into the repo directory. **Prerequisites**: The following packages should be installed with `pip`: - `expecttest` and `hypothesis` - required to run tests -- `mypy` - recommended for linting +- `pyrefly` - recommended for type checking. [Pyrefly](https://pyrefly.org/) - `pytest` - recommended to run tests more selectively Running ``` @@ -350,15 +350,32 @@ make lint Learn more about the linter on the [lintrunner wiki page](https://github.com/pytorch/pytorch/wiki/lintrunner) -#### Running `mypy` +#### Running `pyrefly` -`mypy` is an optional static type checker for Python. We have multiple `mypy` -configs for the PyTorch codebase that are automatically validated against whenever the linter is run. +[Pyrefly](https://pyrefly.org/) is a high-performance static type checker for Python. It provides fast type checking along with IDE features like autocomplete and instant error feedback. + +PyTorch uses Pyrefly for type checking across the codebase. The configuration is managed in `pyrefly.toml` at the root of the repository. + +**Getting Started with Pyrefly:** + +To run type checking on the PyTorch codebase: +```bash +pyrefly check +``` + +For more detailed error information with summaries: +```bash +pyrefly check --summarize-errors +``` + +**Learn More:** +- [Pyrefly Configuration](https://pyrefly.org/en/docs/configuration/) - Detailed configuration options +- [Pyrefly IDE Features](https://pyrefly.org/en/docs/IDE-features/) - Set up Pyrefly in your editor for real-time type checking +- [Python Typing Tutorial](https://pyrefly.org/en/docs/typing-for-python-developers/) - Learn about Python type annotations See [Guide for adding type annotations to PyTorch](https://github.com/pytorch/pytorch/wiki/Guide-for-adding-type-annotations-to-PyTorch) -for more information on how to set up `mypy` and tackle type annotation -tasks. +for PyTorch-specific guidance on how to set up `pyrefly` and tackle type annotation tasks in this codebase. ### C++ Unit Testing From 8e8cbb85ee927776210f7872e3d0286d5d40dc14 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 5 Nov 2025 19:42:39 +0000 Subject: [PATCH 079/130] Revert "[Inductor] Fix unbacked float symbol handling in kernel codegen (#166890)" This reverts commit 0c7a4a6b48d49306eae8d0a9ee8d32b1899e5e23. Reverted https://github.com/pytorch/pytorch/pull/166890 on behalf of https://github.com/malfet due to Looks like it broke torchfuzz tests, see https://hud.pytorch.org/hud/pytorch/pytorch/fbd70fb84e347b45db79eb24cc2c53e447a04147/1?per_page=50&name_filter=trunk%20%2F%20linux-jammy-cuda12&mergeEphemeralLF=true and same test on slow ([comment](https://github.com/pytorch/pytorch/pull/166890#issuecomment-3493011038)) --- test/inductor/test_torchinductor.py | 14 -------------- torch/_inductor/codecache.py | 6 ------ torch/_inductor/codegen/common.py | 11 ++--------- torch/_inductor/codegen/triton_utils.py | 5 ----- 4 files changed, 2 insertions(+), 34 deletions(-) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index d0ff5799ac417..ed8993a1c9a39 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -14424,20 +14424,6 @@ def fn(x): self.common(fn, (torch.randn(6, 4, device=GPU_TYPE).t().contiguous().t(),)) - @skip_if_halide - @requires_cuda_and_triton - def test_unbacked_float_item(self): - def fn(x, max_val): - return torch.clamp(x, 0, max_val.item()) - - self.common( - fn, - ( - torch.randn(10, 20, 30, device=self.device), - torch.tensor(5.0, device=self.device), - ), - ) - # end of class CommonTemplate - add new tests here diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 85702057cbb43..cf17bf2e9478b 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -2970,12 +2970,6 @@ class CppPythonBindingsCodeCache(CppCodeCache): throw std::runtime_error("expected int arg"); return reinterpret_cast(result); }} - template <> inline float parse_arg(PyObject* args, size_t n) {{ - auto result = PyFloat_AsDouble(PyTuple_GET_ITEM(args, n)); - if(unlikely(result == -1.0 && PyErr_Occurred())) - throw std::runtime_error("expected float arg"); - return static_cast(result); - }} {extra_parse_arg} diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index 3e9f174c810c5..730c03f1c813c 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -1732,15 +1732,9 @@ def cpp_argdefs( call_args.append(self.wrap_ptr_arg(outer, dtype)) arg_types.append(f"{cpp_dtype}*") for outer, inner in self.sizevars.items(): - if isinstance(outer, sympy.Symbol) and symbol_is_type( - outer, (SymT.UNBACKED_FLOAT) - ): - arg_defs.append(f"const float {inner}") - arg_types.append("const float") - else: - arg_defs.append(f"const {INDEX_TYPE} {inner}") - arg_types.append(f"const {INDEX_TYPE}") + arg_defs.append(f"const {INDEX_TYPE} {inner}") call_args.append(self.wrap_size_arg(outer)) + arg_types.append(f"const {INDEX_TYPE}") if V.graph.wrapper_code: V.graph.wrapper_code.ensure_size_computed(outer) assert not self.workspace_args, "Workspace not supported on CPU " @@ -2359,7 +2353,6 @@ def rename_indexing( SymT.UNBACKED_INT, SymT.SIZE, SymT.PRECOMPUTED_SIZE, - SymT.UNBACKED_FLOAT, ), ) } diff --git a/torch/_inductor/codegen/triton_utils.py b/torch/_inductor/codegen/triton_utils.py index 75a34813c876b..2a2706ad5720b 100644 --- a/torch/_inductor/codegen/triton_utils.py +++ b/torch/_inductor/codegen/triton_utils.py @@ -4,7 +4,6 @@ import sympy import torch -from torch.utils._sympy.symbol import symbol_is_type, SymT from .. import config from ..runtime.hints import AttrsDescriptorWrapper @@ -72,10 +71,6 @@ def signature_of(arg: KernelArgType, *, size_dtype: Optional[str]) -> str: return "constexpr" elif isinstance(arg.expr, (float, sympy.Float)): return "fp32" - elif isinstance(arg.expr, sympy.Symbol) and symbol_is_type( - arg.expr, (SymT.UNBACKED_FLOAT) - ): - return "fp32" elif isinstance(arg.expr, bool): return "i1" From 6d30666bc1cad94295f708f74ebaf161e291c273 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 5 Nov 2025 20:02:47 +0000 Subject: [PATCH 080/130] Revert "[12/N] Apply ruff UP035 rule (#166929)" This reverts commit 5863ba1b2e4de9ea0ae16a663465ec5d3d6f9f52. Reverted https://github.com/pytorch/pytorch/pull/166929 on behalf of https://github.com/donigian due to Temporarily need to revert this to continue a revert for #165076. @cyyever Please re-merge after revert of #165076. ([comment](https://github.com/pytorch/pytorch/pull/166929#issuecomment-3493090596)) --- test/distributed/tensor/test_attention.py | 3 +-- test/higher_order_ops/test_local_map.py | 3 +-- test/inductor/test_caching.py | 3 +-- test/inductor/test_fx_fusion.py | 3 +-- test/inductor/test_native_matmul.py | 2 +- test/quantization/fx/test_quantize_fx.py | 3 +-- test/test_matmul_cuda.py | 2 +- torch/_dynamo/eval_frame.py | 3 +-- torch/_dynamo/graph_bytecode_inputs.py | 3 +-- torch/_dynamo/variables/distributed.py | 3 +-- torch/_dynamo/variables/iter.py | 4 ++-- torch/_dynamo/variables/optimizer.py | 3 +-- torch/_dynamo/variables/script_object.py | 4 ++-- torch/_dynamo/variables/sdpa.py | 3 +-- torch/_dynamo/variables/streams.py | 3 +-- torch/_dynamo/variables/torch_function.py | 4 ++-- torch/_functorch/_aot_autograd/aot_autograd_result.py | 3 +-- torch/_inductor/compile_worker/timer.py | 3 +-- torch/_inductor/fx_passes/bucketing.py | 3 +-- torch/_inductor/fx_passes/ddp_fusion.py | 4 ++-- torch/_inductor/fx_passes/fsdp.py | 2 +- torch/_inductor/fx_passes/memory_estimator.py | 2 +- torch/_inductor/fx_passes/mkldnn_fusion.py | 6 +----- torch/_inductor/fx_passes/overlap_scheduling.py | 4 ++-- torch/_inductor/fx_passes/pad_mm.py | 4 ++-- torch/_inductor/fx_passes/post_grad.py | 3 +-- torch/_inductor/fx_passes/reinplace.py | 4 ++-- torch/_inductor/fx_passes/split_cat.py | 5 +++-- torch/_inductor/kernel/custom_op.py | 3 +-- torch/_inductor/kernel/flex/flex_flash_attention.py | 3 +-- torch/_inductor/runtime/benchmarking.py | 4 ++-- torch/_inductor/runtime/caching/interfaces.py | 6 ++---- torch/_inductor/runtime/caching/locks.py | 5 ++--- torch/distributed/elastic/multiprocessing/tail_log.py | 3 +-- torch/utils/_cxx_pytree.py | 4 ++-- torch/utils/_debug_mode.py | 3 +-- torch/utils/_pytree.py | 3 +-- 37 files changed, 50 insertions(+), 76 deletions(-) diff --git a/test/distributed/tensor/test_attention.py b/test/distributed/tensor/test_attention.py index 6c3485f9d7025..eaf3a4042060d 100644 --- a/test/distributed/tensor/test_attention.py +++ b/test/distributed/tensor/test_attention.py @@ -3,8 +3,7 @@ import itertools import random import unittest -from collections.abc import Callable -from typing import Any, ClassVar, Optional +from typing import Any, Callable, ClassVar, Optional import torch import torch.distributed as dist diff --git a/test/higher_order_ops/test_local_map.py b/test/higher_order_ops/test_local_map.py index fbb21633260e7..9d2870d3b5fdd 100644 --- a/test/higher_order_ops/test_local_map.py +++ b/test/higher_order_ops/test_local_map.py @@ -4,9 +4,8 @@ import functools import unittest -from collections.abc import Callable from contextlib import contextmanager, ExitStack -from typing import Any, Optional +from typing import Any, Callable, Optional import torch import torch._dynamo diff --git a/test/inductor/test_caching.py b/test/inductor/test_caching.py index aa4c3a1f229f1..bcb66beea700c 100644 --- a/test/inductor/test_caching.py +++ b/test/inductor/test_caching.py @@ -13,7 +13,7 @@ from shutil import rmtree from threading import Lock from time import sleep, time -from typing import Any, TYPE_CHECKING, Union +from typing import Any, Generator, Sequence, TYPE_CHECKING, Union from typing_extensions import TypeVar from unittest.mock import patch @@ -37,7 +37,6 @@ if TYPE_CHECKING: - from collections.abc import Generator, Sequence from pathlib import Path diff --git a/test/inductor/test_fx_fusion.py b/test/inductor/test_fx_fusion.py index 63342502d3cd9..ebe98373e622a 100644 --- a/test/inductor/test_fx_fusion.py +++ b/test/inductor/test_fx_fusion.py @@ -1,6 +1,5 @@ # Owner(s): ["module: inductor"] -from collections.abc import Callable -from typing import Any +from typing import Any, Callable import torch from torch._inductor.fx_passes.pre_grad import ( diff --git a/test/inductor/test_native_matmul.py b/test/inductor/test_native_matmul.py index c37f844e41eae..1870a0e373be0 100644 --- a/test/inductor/test_native_matmul.py +++ b/test/inductor/test_native_matmul.py @@ -1,7 +1,7 @@ # Owner(s): ["module: inductor"] -from collections.abc import Callable +from typing import Callable import torch from torch._dynamo.testing import rand_strided diff --git a/test/quantization/fx/test_quantize_fx.py b/test/quantization/fx/test_quantize_fx.py index faba2f5edc6a7..cd922d94c60c3 100644 --- a/test/quantization/fx/test_quantize_fx.py +++ b/test/quantization/fx/test_quantize_fx.py @@ -204,8 +204,7 @@ import operator import unittest import io -from typing import Optional -from collections.abc import Callable +from typing import Callable, Optional class BinaryOp(torch.nn.Module): def __init__(self, binary_op, ibinary_op, is_inplace, is_scalar): diff --git a/test/test_matmul_cuda.py b/test/test_matmul_cuda.py index 10611d4f24673..002c34c450756 100644 --- a/test/test_matmul_cuda.py +++ b/test/test_matmul_cuda.py @@ -5,7 +5,7 @@ import unittest from itertools import product from functools import partial -from collections.abc import Callable +from typing import Callable import torch diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index 222647eeae9ab..e23e049e3bbb1 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -39,11 +39,10 @@ import unittest import warnings import weakref -from collections.abc import Sized from dataclasses import dataclass from enum import Enum from os.path import dirname, join -from typing import Any, NamedTuple, Optional, TYPE_CHECKING, Union +from typing import Any, NamedTuple, Optional, Sized, TYPE_CHECKING, Union from unittest.mock import patch import sympy diff --git a/torch/_dynamo/graph_bytecode_inputs.py b/torch/_dynamo/graph_bytecode_inputs.py index 16583b89201ec..979950cf3bd1b 100644 --- a/torch/_dynamo/graph_bytecode_inputs.py +++ b/torch/_dynamo/graph_bytecode_inputs.py @@ -1,6 +1,5 @@ import weakref -from collections.abc import Callable -from typing import Any +from typing import Any, Callable from torch._dynamo.source import Source diff --git a/torch/_dynamo/variables/distributed.py b/torch/_dynamo/variables/distributed.py index 187055c26cd00..eb39dd8fa3e07 100644 --- a/torch/_dynamo/variables/distributed.py +++ b/torch/_dynamo/variables/distributed.py @@ -20,8 +20,7 @@ import functools import inspect -from collections.abc import Sequence -from typing import Any, TYPE_CHECKING +from typing import Any, Sequence, TYPE_CHECKING import torch from torch.fx.experimental._backward_state import BackwardState diff --git a/torch/_dynamo/variables/iter.py b/torch/_dynamo/variables/iter.py index be765cbbc8bf9..5970ba0e1dda7 100644 --- a/torch/_dynamo/variables/iter.py +++ b/torch/_dynamo/variables/iter.py @@ -14,8 +14,8 @@ """ import itertools -from collections.abc import Callable, Sequence -from typing import Any, TYPE_CHECKING, Union +from collections.abc import Callable +from typing import Any, Sequence, TYPE_CHECKING, Union from .. import graph_break_hints, polyfills, variables from ..bytecode_transformation import ( diff --git a/torch/_dynamo/variables/optimizer.py b/torch/_dynamo/variables/optimizer.py index c09cc2163a5f4..289cebbe8129b 100644 --- a/torch/_dynamo/variables/optimizer.py +++ b/torch/_dynamo/variables/optimizer.py @@ -22,8 +22,7 @@ import logging import weakref -from collections.abc import Iterable -from typing import Any, Optional, TYPE_CHECKING +from typing import Any, Iterable, Optional, TYPE_CHECKING import torch from torch._dynamo.variables.tensor import TensorVariable diff --git a/torch/_dynamo/variables/script_object.py b/torch/_dynamo/variables/script_object.py index 644c269a23a34..85977104977fb 100644 --- a/torch/_dynamo/variables/script_object.py +++ b/torch/_dynamo/variables/script_object.py @@ -19,8 +19,8 @@ """ import functools -from collections.abc import Callable, Iterable -from typing import Any, TYPE_CHECKING, TypeVar +from collections.abc import Callable +from typing import Any, Iterable, TYPE_CHECKING, TypeVar from typing_extensions import ParamSpec import torch diff --git a/torch/_dynamo/variables/sdpa.py b/torch/_dynamo/variables/sdpa.py index 629bf094dc951..75928842cf297 100644 --- a/torch/_dynamo/variables/sdpa.py +++ b/torch/_dynamo/variables/sdpa.py @@ -1,6 +1,5 @@ -from collections.abc import Sequence from inspect import getattr_static -from typing import Any, TYPE_CHECKING, TypeGuard +from typing import Any, Sequence, TYPE_CHECKING, TypeGuard from torch._guards import Source from torch.backends.cuda import SDPAParams diff --git a/torch/_dynamo/variables/streams.py b/torch/_dynamo/variables/streams.py index fb5dd775bd636..c353181eb8029 100644 --- a/torch/_dynamo/variables/streams.py +++ b/torch/_dynamo/variables/streams.py @@ -1,6 +1,5 @@ import collections -from collections.abc import Callable -from typing import Any, Optional +from typing import Any, Callable, Optional import torch from torch._dynamo.variables.dicts import ConstDictVariable diff --git a/torch/_dynamo/variables/torch_function.py b/torch/_dynamo/variables/torch_function.py index 4d0f0b4fae8ab..fa8412146a427 100644 --- a/torch/_dynamo/variables/torch_function.py +++ b/torch/_dynamo/variables/torch_function.py @@ -29,9 +29,9 @@ import functools import inspect import operator -from collections.abc import Generator, Iterable, Sequence +from collections.abc import Sequence from types import TracebackType -from typing import Any, Optional, TYPE_CHECKING +from typing import Any, Generator, Iterable, Optional, TYPE_CHECKING import torch._C import torch.utils._pytree as pytree diff --git a/torch/_functorch/_aot_autograd/aot_autograd_result.py b/torch/_functorch/_aot_autograd/aot_autograd_result.py index 7e608933b34c3..ce01e37f03243 100644 --- a/torch/_functorch/_aot_autograd/aot_autograd_result.py +++ b/torch/_functorch/_aot_autograd/aot_autograd_result.py @@ -22,10 +22,9 @@ import json import logging from abc import ABC, abstractmethod -from collections.abc import Callable from copy import copy from dataclasses import dataclass -from typing import Any, Generic, Optional, TYPE_CHECKING, TypeVar +from typing import Any, Callable, Generic, Optional, TYPE_CHECKING, TypeVar import torch from torch._dynamo.precompile_context import BackendCacheArtifact diff --git a/torch/_inductor/compile_worker/timer.py b/torch/_inductor/compile_worker/timer.py index 7c495403b3a55..7cfeb4217e26b 100644 --- a/torch/_inductor/compile_worker/timer.py +++ b/torch/_inductor/compile_worker/timer.py @@ -1,7 +1,6 @@ -from collections.abc import Callable from threading import Lock, Thread from time import monotonic, sleep -from typing import Optional, Union +from typing import Callable, Optional, Union class Timer: diff --git a/torch/_inductor/fx_passes/bucketing.py b/torch/_inductor/fx_passes/bucketing.py index 29f070564349c..ab831c96c94ba 100644 --- a/torch/_inductor/fx_passes/bucketing.py +++ b/torch/_inductor/fx_passes/bucketing.py @@ -2,8 +2,7 @@ import logging import operator from collections import defaultdict -from collections.abc import Callable -from typing import Any, Literal, TypeAlias +from typing import Any, Callable, Literal, TypeAlias import torch import torch.distributed as dist diff --git a/torch/_inductor/fx_passes/ddp_fusion.py b/torch/_inductor/fx_passes/ddp_fusion.py index 44314b912786f..8a4de1a604869 100644 --- a/torch/_inductor/fx_passes/ddp_fusion.py +++ b/torch/_inductor/fx_passes/ddp_fusion.py @@ -4,10 +4,10 @@ import logging import math import operator -from collections.abc import Callable, Generator +from collections.abc import Generator from dataclasses import dataclass from functools import partial -from typing import Any, cast +from typing import Any, Callable, cast import torch import torch.fx as fx diff --git a/torch/_inductor/fx_passes/fsdp.py b/torch/_inductor/fx_passes/fsdp.py index 1e71c350ed7b6..6b0c2ad2c94a7 100644 --- a/torch/_inductor/fx_passes/fsdp.py +++ b/torch/_inductor/fx_passes/fsdp.py @@ -1,5 +1,5 @@ import logging -from collections.abc import Callable +from typing import Callable import torch from torch._inductor.fx_passes.bucketing import ( diff --git a/torch/_inductor/fx_passes/memory_estimator.py b/torch/_inductor/fx_passes/memory_estimator.py index e887d4bf62c8e..c6b7c51b948e5 100644 --- a/torch/_inductor/fx_passes/memory_estimator.py +++ b/torch/_inductor/fx_passes/memory_estimator.py @@ -1,8 +1,8 @@ import itertools import logging from collections import defaultdict -from collections.abc import Callable from dataclasses import dataclass +from typing import Callable import torch import torch.fx as fx diff --git a/torch/_inductor/fx_passes/mkldnn_fusion.py b/torch/_inductor/fx_passes/mkldnn_fusion.py index 214d3bf02f7f4..70b3a3c355dde 100644 --- a/torch/_inductor/fx_passes/mkldnn_fusion.py +++ b/torch/_inductor/fx_passes/mkldnn_fusion.py @@ -2,7 +2,7 @@ import functools import operator from functools import reduce -from typing import Any, TYPE_CHECKING +from typing import Any, Callable import torch from torch._dynamo.utils import counters @@ -35,10 +35,6 @@ ) -if TYPE_CHECKING: - from collections.abc import Callable - - if torch._C._has_mkldnn: aten = torch.ops.aten mkldnn = torch.ops.mkldnn diff --git a/torch/_inductor/fx_passes/overlap_scheduling.py b/torch/_inductor/fx_passes/overlap_scheduling.py index f383ab63dc261..a47aa960e58c5 100644 --- a/torch/_inductor/fx_passes/overlap_scheduling.py +++ b/torch/_inductor/fx_passes/overlap_scheduling.py @@ -4,9 +4,9 @@ import logging import sys from collections import Counter, defaultdict -from collections.abc import Callable, Iterable +from collections.abc import Iterable from dataclasses import dataclass -from typing import Any +from typing import Any, Callable import torch import torch.fx as fx diff --git a/torch/_inductor/fx_passes/pad_mm.py b/torch/_inductor/fx_passes/pad_mm.py index b511403d4874c..30768fda9bb72 100644 --- a/torch/_inductor/fx_passes/pad_mm.py +++ b/torch/_inductor/fx_passes/pad_mm.py @@ -2,8 +2,8 @@ import itertools import operator import typing -from collections.abc import Callable, Sequence -from typing import Any +from collections.abc import Sequence +from typing import Any, Callable import torch import torch._inductor.runtime.runtime_utils diff --git a/torch/_inductor/fx_passes/post_grad.py b/torch/_inductor/fx_passes/post_grad.py index 91b4e10bf7238..7d995adec04ef 100644 --- a/torch/_inductor/fx_passes/post_grad.py +++ b/torch/_inductor/fx_passes/post_grad.py @@ -5,8 +5,7 @@ import logging import operator from collections import Counter, defaultdict -from collections.abc import Callable -from typing import Any, TypeVar +from typing import Any, Callable, TypeVar from typing_extensions import ParamSpec import torch diff --git a/torch/_inductor/fx_passes/reinplace.py b/torch/_inductor/fx_passes/reinplace.py index e42e8a1139770..52222f3da8344 100644 --- a/torch/_inductor/fx_passes/reinplace.py +++ b/torch/_inductor/fx_passes/reinplace.py @@ -3,10 +3,10 @@ import logging import operator from collections import defaultdict -from collections.abc import Callable, Sequence +from collections.abc import Sequence from contextlib import nullcontext from dataclasses import dataclass -from typing import Any, cast +from typing import Any, Callable, cast import torch import torch.fx.node diff --git a/torch/_inductor/fx_passes/split_cat.py b/torch/_inductor/fx_passes/split_cat.py index 0bad4fa7cc635..92e1e6f375f44 100644 --- a/torch/_inductor/fx_passes/split_cat.py +++ b/torch/_inductor/fx_passes/split_cat.py @@ -4,8 +4,9 @@ import operator import os from collections import defaultdict -from collections.abc import Callable, Sequence -from typing import Any, TypeAlias +from collections.abc import Sequence +from typing import Any, Callable +from typing_extensions import TypeAlias import torch from torch._dynamo.utils import counters diff --git a/torch/_inductor/kernel/custom_op.py b/torch/_inductor/kernel/custom_op.py index d35309c01d07c..303110a561b5e 100644 --- a/torch/_inductor/kernel/custom_op.py +++ b/torch/_inductor/kernel/custom_op.py @@ -2,8 +2,7 @@ import functools import logging -from collections.abc import Callable -from typing import Any, Optional, Union +from typing import Any, Callable, Optional, Union import torch from torch._inductor.codegen.subgraph import SubgraphTemplate diff --git a/torch/_inductor/kernel/flex/flex_flash_attention.py b/torch/_inductor/kernel/flex/flex_flash_attention.py index 0d3721aa730a4..c100df84d5a73 100644 --- a/torch/_inductor/kernel/flex/flex_flash_attention.py +++ b/torch/_inductor/kernel/flex/flex_flash_attention.py @@ -3,9 +3,8 @@ import functools import importlib -from collections.abc import Callable, Sequence from contextlib import contextmanager -from typing import Any, Optional +from typing import Any, Callable, Optional, Sequence import sympy from sympy import Expr, Integer diff --git a/torch/_inductor/runtime/benchmarking.py b/torch/_inductor/runtime/benchmarking.py index d9d92e363879d..d592a8c8c00f9 100644 --- a/torch/_inductor/runtime/benchmarking.py +++ b/torch/_inductor/runtime/benchmarking.py @@ -5,8 +5,8 @@ from functools import cached_property, wraps from itertools import chain from statistics import median -from typing import Any, Concatenate, Optional, Union -from typing_extensions import ParamSpec, Self, TypeVar +from typing import Any, Optional, Union +from typing_extensions import Concatenate, ParamSpec, Self, TypeVar import torch import torch.utils._pytree as pytree diff --git a/torch/_inductor/runtime/caching/interfaces.py b/torch/_inductor/runtime/caching/interfaces.py index 03d2957493679..0758e11134018 100644 --- a/torch/_inductor/runtime/caching/interfaces.py +++ b/torch/_inductor/runtime/caching/interfaces.py @@ -12,8 +12,8 @@ from pathlib import Path from threading import Lock from time import time -from typing import Any, TYPE_CHECKING, TypeAlias -from typing_extensions import override +from typing import Any, Callable, TYPE_CHECKING +from typing_extensions import override, TypeAlias from filelock import FileLock @@ -21,8 +21,6 @@ if TYPE_CHECKING: - from collections.abc import Callable - from .utils import P, R diff --git a/torch/_inductor/runtime/caching/locks.py b/torch/_inductor/runtime/caching/locks.py index 8e8cd011e2d44..e7e1f1adc3622 100644 --- a/torch/_inductor/runtime/caching/locks.py +++ b/torch/_inductor/runtime/caching/locks.py @@ -12,8 +12,8 @@ from __future__ import annotations from contextlib import _GeneratorContextManager, contextmanager, ExitStack -from typing import TYPE_CHECKING, TypeAlias -from typing_extensions import Protocol +from typing import Generator, TYPE_CHECKING +from typing_extensions import Protocol, TypeAlias from filelock import FileLock, Timeout @@ -21,7 +21,6 @@ if TYPE_CHECKING: - from collections.abc import Generator from threading import Lock diff --git a/torch/distributed/elastic/multiprocessing/tail_log.py b/torch/distributed/elastic/multiprocessing/tail_log.py index 034740810dcdd..7ad35115cd34a 100644 --- a/torch/distributed/elastic/multiprocessing/tail_log.py +++ b/torch/distributed/elastic/multiprocessing/tail_log.py @@ -10,10 +10,9 @@ import logging import os import time -from collections.abc import Callable from concurrent.futures.thread import ThreadPoolExecutor from threading import Event -from typing import Optional, TextIO, TYPE_CHECKING, Union +from typing import Callable, Optional, TextIO, TYPE_CHECKING, Union if TYPE_CHECKING: diff --git a/torch/utils/_cxx_pytree.py b/torch/utils/_cxx_pytree.py index 897279bd39b1e..603625ed97c12 100644 --- a/torch/utils/_cxx_pytree.py +++ b/torch/utils/_cxx_pytree.py @@ -15,8 +15,8 @@ import functools import types from collections.abc import Callable, Iterable, Mapping -from typing import Any, Optional, overload, TypeAlias, TypeVar, Union -from typing_extensions import deprecated, Self, TypeIs +from typing import Any, Optional, overload, TypeVar, Union +from typing_extensions import deprecated, Self, TypeAlias, TypeIs import torch.utils._pytree as python_pytree from torch.torch_version import TorchVersion as _TorchVersion diff --git a/torch/utils/_debug_mode.py b/torch/utils/_debug_mode.py index 5a6ee246abf7e..5e24ce086e1aa 100644 --- a/torch/utils/_debug_mode.py +++ b/torch/utils/_debug_mode.py @@ -3,8 +3,7 @@ import functools import traceback import weakref -from collections.abc import Callable -from typing import Any, Optional, TYPE_CHECKING +from typing import Any, Callable, Optional, TYPE_CHECKING import torch from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode diff --git a/torch/utils/_pytree.py b/torch/utils/_pytree.py index 147340f58d66e..56704bb3f8024 100644 --- a/torch/utils/_pytree.py +++ b/torch/utils/_pytree.py @@ -36,11 +36,10 @@ Optional, overload, Protocol, - TypeAlias, TypeVar, Union, ) -from typing_extensions import deprecated, NamedTuple, Self +from typing_extensions import deprecated, NamedTuple, Self, TypeAlias from torch.torch_version import TorchVersion as _TorchVersion From a74fe75c450277eb88a95c764e8b0a664a550a86 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Wed, 5 Nov 2025 08:21:40 -0800 Subject: [PATCH 081/130] Don't hardcode double argument for reduction base (#166951) Fixes https://github.com/pytorch/pytorch/issues/43254 Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/166951 Approved by: https://github.com/ngimel, https://github.com/Skylion007 ghstack dependencies: #166813 --- aten/src/ATen/native/cpu/Reduce.h | 4 ++-- aten/src/ATen/native/cpu/ReduceOpsKernel.cpp | 22 +------------------- 2 files changed, 3 insertions(+), 23 deletions(-) diff --git a/aten/src/ATen/native/cpu/Reduce.h b/aten/src/ATen/native/cpu/Reduce.h index 6c9efbb0f6e7f..ab9051ca8d2a2 100644 --- a/aten/src/ATen/native/cpu/Reduce.h +++ b/aten/src/ATen/native/cpu/Reduce.h @@ -247,8 +247,8 @@ void binary_kernel_reduce(TensorIteratorBase& iter, ops_t ops, init_t init) { }); } -template -void binary_kernel_reduce_vec(TensorIteratorBase& iter, func_t op, vec_func_t vop, double ident = 0) { +template +void binary_kernel_reduce_vec(TensorIteratorBase& iter, func_t op, vec_func_t vop, ident_t ident = static_cast(0)) { using traits = binary_function_traits; static_assert( all_same< diff --git a/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp b/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp index 3bad49a32d98c..053db7b4eda00 100644 --- a/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp @@ -339,33 +339,13 @@ void or_kernel_impl(TensorIterator& iter) { } } -template -struct MinValuesOps: public at::native::MinOps { - using arg_t = typename MinOps::arg_t; - static scalar_t project(arg_t arg) { - return arg.first; - } -}; - void min_values_kernel_impl(TensorIterator& iter) { - // This case is special because of Vectorized does not - // handle upper_bound(). - // See: https://github.com/pytorch/pytorch/issues/43254 - if (iter.dtype() == kLong || iter.dtype() == kUInt64) { - AT_DISPATCH_V2(iter.dtype(), "min_values_cpu", AT_WRAP([&iter] { - binary_kernel_reduce( - iter, - MinValuesOps{}, - std::pair(upper_bound(), -1)); - }), kLong, kUInt64); - return; - } AT_DISPATCH_V2(iter.dtype(), "min_values_cpu", AT_WRAP([&iter] { binary_kernel_reduce_vec( iter, [](scalar_t a, scalar_t b) -> scalar_t { return min_impl(a, b); }, [](Vectorized a, Vectorized b) { return minimum(a, b); }, - static_cast(upper_bound())); + upper_bound()); }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool); } From ea44f12bce3eb05eaa9fa34943a3ffae04647fa5 Mon Sep 17 00:00:00 2001 From: Yuanyuan Chen Date: Wed, 5 Nov 2025 20:51:47 +0000 Subject: [PATCH 082/130] [13/N] Apply ruff UP035 rule (#167048) This PR continues to apply ruff UP035 rule to test code and some remaining torch files. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167048 Approved by: https://github.com/Skylion007 --- test/dynamo/test_install_free_tensors.py | 4 ++-- test/dynamo/test_python_autograd.py | 6 +++++- test/typing/pass/arithmetic_ops.py | 4 ++-- torch/_C/_distributed_c10d.pyi | 3 ++- torch/_dynamo/variables/ctx_manager.py | 4 ++-- torch/_inductor/codegen/pallas.py | 4 +++- torch/_inductor/runtime/caching/config.py | 2 +- torch/distributed/_local_tensor/_c10d.py | 3 +-- 8 files changed, 18 insertions(+), 12 deletions(-) diff --git a/test/dynamo/test_install_free_tensors.py b/test/dynamo/test_install_free_tensors.py index 3858b827bd598..fd9e14c4c3f76 100644 --- a/test/dynamo/test_install_free_tensors.py +++ b/test/dynamo/test_install_free_tensors.py @@ -1,7 +1,7 @@ # Owner(s): ["module: dynamo"] import unittest -from collections.abc import Sequence -from typing import Any, Callable, Union +from collections.abc import Callable, Sequence +from typing import Any, Union import torch import torch._dynamo diff --git a/test/dynamo/test_python_autograd.py b/test/dynamo/test_python_autograd.py index a615c653f56c3..a6117bb4093a7 100644 --- a/test/dynamo/test_python_autograd.py +++ b/test/dynamo/test_python_autograd.py @@ -1,5 +1,5 @@ # Owner(s): ["module: dynamo"] -from typing import Callable, NamedTuple, Optional +from typing import NamedTuple, Optional, TYPE_CHECKING import torch import torch._dynamo @@ -7,6 +7,10 @@ from torch._dynamo.testing import CompileCounter, same +if TYPE_CHECKING: + from collections.abc import Callable + + """ This is an example of a pure-python version of autograd implemented by @zdevito. It represents a rather challenging test case for TorchDynamo diff --git a/test/typing/pass/arithmetic_ops.py b/test/typing/pass/arithmetic_ops.py index f0d6cc6fd9f97..14dda1cf39772 100644 --- a/test/typing/pass/arithmetic_ops.py +++ b/test/typing/pass/arithmetic_ops.py @@ -1,5 +1,5 @@ -from typing import Union -from typing_extensions import assert_type, TypeAlias +from typing import TypeAlias, Union +from typing_extensions import assert_type from torch import randn, Tensor diff --git a/torch/_C/_distributed_c10d.pyi b/torch/_C/_distributed_c10d.pyi index f3d96860f5584..b659be9ee119e 100644 --- a/torch/_C/_distributed_c10d.pyi +++ b/torch/_C/_distributed_c10d.pyi @@ -1,8 +1,9 @@ # mypy: allow-untyped-defs # mypy: disable-error-code="type-arg" +from collections.abc import Callable from datetime import timedelta from enum import Enum -from typing import Any, Callable, Optional, overload, Union +from typing import Any, Optional, overload, Union import torch from torch import Tensor diff --git a/torch/_dynamo/variables/ctx_manager.py b/torch/_dynamo/variables/ctx_manager.py index 4eac189b65fdd..3f52c19ff0a90 100644 --- a/torch/_dynamo/variables/ctx_manager.py +++ b/torch/_dynamo/variables/ctx_manager.py @@ -21,9 +21,9 @@ import inspect import sys import warnings -from collections.abc import Callable, Sequence +from collections.abc import Callable, Sequence, Sized from contextlib import ExitStack -from typing import Any, ContextManager, Optional, Sized, TYPE_CHECKING, Union +from typing import Any, ContextManager, Optional, TYPE_CHECKING, Union import torch._C from torch._guards import Guard diff --git a/torch/_inductor/codegen/pallas.py b/torch/_inductor/codegen/pallas.py index 1fc8e40724bc0..da437a4e8ee3c 100644 --- a/torch/_inductor/codegen/pallas.py +++ b/torch/_inductor/codegen/pallas.py @@ -2,7 +2,7 @@ from __future__ import annotations import hashlib -from typing import Any, Callable, Optional, Sequence, TYPE_CHECKING +from typing import Any, Optional, TYPE_CHECKING import sympy # noqa: TC002 @@ -17,6 +17,8 @@ if TYPE_CHECKING: + from collections.abc import Callable, Sequence + from ..ir import IRNode from ..scheduler import BaseSchedulerNode diff --git a/torch/_inductor/runtime/caching/config.py b/torch/_inductor/runtime/caching/config.py index 748715d1631ad..14e13f937dbb7 100644 --- a/torch/_inductor/runtime/caching/config.py +++ b/torch/_inductor/runtime/caching/config.py @@ -1,6 +1,6 @@ import os +from collections.abc import Callable from functools import cache, partial -from typing import Callable import torch from torch._environment import is_fbcode diff --git a/torch/distributed/_local_tensor/_c10d.py b/torch/distributed/_local_tensor/_c10d.py index c9256543e8977..0b63330dfafce 100644 --- a/torch/distributed/_local_tensor/_c10d.py +++ b/torch/distributed/_local_tensor/_c10d.py @@ -1,9 +1,8 @@ import functools import math import operator -from collections.abc import Sequence +from collections.abc import Callable, Sequence from datetime import timedelta -from typing import Callable import torch from torch._C import ScriptObject From ef3f953966d94ce11ced06f8e468b2fa69c1b3cb Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 5 Nov 2025 20:52:41 +0000 Subject: [PATCH 083/130] Revert "[DebugMode] output, tensor id annotations for DebugMode (#165076)" This reverts commit a64c7d740428010d700b4bcd395af8a7b2d5c21f. Reverted https://github.com/pytorch/pytorch/pull/165076 on behalf of https://github.com/wdvr due to Sorry but this is breaking internally. See diff [D86245252](https://l.workplace.com/l.php?u=https%3A%2F%2Fwww.internalfb.com%2Fdiff%2FD86245252&h=AT1oPbS1XTv6HjYeYdxmDMW1-jlT0pS8yBO2iSfbPfUB9ydsEjFXBNT56QhV1v5TKc4_QaQNxykNowSKmb4fgenjOyCv20NuL7oV_Id5fhh32hhv1IpjgsDJYK-PBFfSfv_miLIWfNgj902KcgXojbBgDcDzQeS9lNt0GQ) for details. To validate your fixes internally, you can follow the instructions here: https://fburl.com/fixing-ghfirst-reverts ([comment](https://github.com/pytorch/pytorch/pull/165076#issuecomment-3493358159)) --- .../tensor/debug/test_debug_mode.py | 22 ++- torch/utils/_debug_mode.py | 126 +++--------------- 2 files changed, 31 insertions(+), 117 deletions(-) diff --git a/test/distributed/tensor/debug/test_debug_mode.py b/test/distributed/tensor/debug/test_debug_mode.py index abc37f17a74de..18cc702cbbc7a 100644 --- a/test/distributed/tensor/debug/test_debug_mode.py +++ b/test/distributed/tensor/debug/test_debug_mode.py @@ -50,24 +50,22 @@ def test_debug_mode_mm(self): x_dtensor = DTensor.from_local(x, mesh, [Shard(0)], run_check=False) y_dtensor = DTensor.from_local(y, mesh, [Shard(0)], run_check=False) - with DebugMode( - record_torchfunction=True, record_ids=True, record_output=True - ) as debug_mode: + with DebugMode(record_torchfunction=True) as debug_mode: torch.mm(x_dtensor, y_dtensor).sum() self.assertExpectedInline( debug_mode.debug_string(), """\ - torch.mm(dt$0: f32[8, 8]| S(0), dt$1: f32[8, 32]| S(0)) -> dt$6: f32[8, 32]| S(0) - aten::mm(dt$0: f32[8, 8]| S(0), dt$1: f32[8, 32]| S(0)) + torch.mm(dt: f32[8, 8]| S(0), dt: f32[8, 32]| S(0)) + aten::mm(dt: f32[8, 8]| S(0), dt: f32[8, 32]| S(0)) redistribute_input(1, S(0) -> R) - redistribute_input(t$2: f32[1, 32], trace: S(0)->R) - _c10d_functional::all_gather_into_tensor(t$2: f32[1, 32], 8, 0) -> t$3: f32[8, 32] - _c10d_functional::wait_tensor(t$3: f32[8, 32]) -> t$3: f32[8, 32] - aten::mm(t$4: f32[1, 8], t$3: f32[8, 32]) -> t$5: f32[1, 32] - (dt$6: f32[8, 32]| S(0)) -> dt$8: f32[]| P - aten::sum(dt$6: f32[8, 32]| S(0)) - aten::sum(t$5: f32[1, 32]) -> t$7: f32[]""", + redistribute_input(t: f32[1, 32], trace: S(0)->R) + _c10d_functional::all_gather_into_tensor(t: f32[1, 32], 8, 0) + _c10d_functional::wait_tensor(t: f32[8, 32]) + aten::mm(t: f32[1, 8], t: f32[8, 32]) + (dt: f32[8, 32]| S(0)) + aten::sum(dt: f32[8, 32]| S(0)) + aten::sum(t: f32[1, 32])""", ) self.assertTrue(isinstance(debug_mode.operators[0], _OpCall)) diff --git a/torch/utils/_debug_mode.py b/torch/utils/_debug_mode.py index 5e24ce086e1aa..09435aa07e68b 100644 --- a/torch/utils/_debug_mode.py +++ b/torch/utils/_debug_mode.py @@ -2,7 +2,6 @@ import contextlib import functools import traceback -import weakref from typing import Any, Callable, Optional, TYPE_CHECKING import torch @@ -15,7 +14,6 @@ ) from torch.utils._pytree import tree_all, tree_map from torch.utils._traceback import CapturedTraceback -from torch.utils.weak import WeakIdRef if TYPE_CHECKING: @@ -58,48 +56,29 @@ def _stringify_dtensor_spec(spec) -> str: return DTensorSpec.format_shard_order_str(spec.placements, spec.shard_order) -class TensorIdTracker: - def __init__(self): - self.tensor_memo: dict[WeakIdRef, int] = {} - self.next_tensor_id = 0 - - def _id(self, tensor) -> int: - with torch._C._DisablePythonDispatcher(): - o = WeakIdRef(tensor) - - def del_memo(): - self.tensor_memo.pop(o, None) - - weakref.finalize(tensor, del_memo) - if o not in self.tensor_memo: - self.tensor_memo[o] = self.next_tensor_id - self.next_tensor_id += 1 - return self.tensor_memo[o] - - -def _tensor_debug_string(tensor, attributes, tensor_memo=None) -> str: +def _tensor_debug_string(tensor, attributes) -> str: """Convert tensor to debug string representation.""" if isinstance(tensor, torch.Tensor): tensor_debug_str = f"{dtype_abbrs[tensor.dtype]}{_stringify_shape(tensor.shape)}{_stringify_attributes(tensor, attributes)}" - id_str = f"${tensor_memo._id(tensor)}" if tensor_memo is not None else "" + if isinstance(tensor, torch.distributed.tensor.DTensor): # omitted device mesh - return f"dt{id_str}: {tensor_debug_str}| {_stringify_dtensor_spec(tensor._spec)}" + return f"dt: {tensor_debug_str}| {_stringify_dtensor_spec(tensor._spec)}" elif isinstance(tensor, FakeTensor): - return f"ft{id_str}: {tensor_debug_str}" + return f"ft: {tensor_debug_str}" else: - return f"t{id_str}: {tensor_debug_str}" + return f"t: {tensor_debug_str}" else: raise RuntimeError(f"Unsupported tensor type: {type(tensor)}") -def _arg_to_str(arg, attributes, tensor_memo=None) -> str: +def _arg_to_str(arg, attributes) -> str: from torch.distributed.tensor._dtensor_spec import DTensorSpec def to_str(x): if isinstance(x, torch.Tensor): - return _tensor_debug_string(x, attributes, tensor_memo) + return _tensor_debug_string(x, attributes) elif isinstance(x, DTensorSpec): return _stringify_dtensor_spec(x) return x @@ -165,11 +144,8 @@ def __init__( # results from dispatch hooks self.record = record self.log = log - self.output_str: Optional[str] = None - def stringify_args( - self, attributes: list[str], tensor_memo: Optional[TensorIdTracker] = None - ) -> None: + def stringify_args(self, attributes: list[str]) -> None: """ To reduce memory consumption, this method stringifies args/kwargs, stores the result, and deletes original args/kwargs. """ @@ -177,18 +153,6 @@ def stringify_args( "Subclasses must implement stringify_args(), even if no-op" ) - def stringify_output( - self, - output: Any, - attributes: list[str], - tensor_memo: Optional[TensorIdTracker] = None, - ) -> None: - """Store stringified version of call output in self.output_str""" - if tree_all(lambda x: x is None, output): - return - output_str = tree_map(lambda x: _arg_to_str(x, attributes, tensor_memo), output) - self.output_str = f" -> {str(output_str)}" - def render(self, attributes: list[str]) -> str: raise NotImplementedError("Subclasses must implement string render()") @@ -215,16 +179,11 @@ def __init__( self.args_str: Optional[str] = None self.kwargs_str: Optional[str] = None - def stringify_args( - self, attributes: list[str], tensor_memo: Optional[TensorIdTracker] = None - ) -> None: - self.args_str = ", ".join( - _arg_to_str(arg, attributes, tensor_memo) for arg in self.args - ) + def stringify_args(self, attributes: list[str]) -> None: + self.args_str = ", ".join(_arg_to_str(arg, attributes) for arg in self.args) if self.kwargs: self.kwargs_str = ", " + ", ".join( - f"{k}={_arg_to_str(v, attributes, tensor_memo)}" - for k, v in self.kwargs.items() + f"{k}={_arg_to_str(v, attributes)}" for k, v in self.kwargs.items() ) else: self.kwargs_str = "" @@ -256,8 +215,6 @@ def render(self, attributes: list[str]) -> str: base_str = f"{op_name}({args_str}{kwargs_str})" - if self.output_str: - base_str += self.output_str if self.log: base_str += f" # {self.log}" return base_str @@ -290,10 +247,8 @@ def __init__( self.arg_str: Optional[str] = None - def stringify_args( - self, attributes: list[str], tensor_memo: Optional[TensorIdTracker] = None - ) -> None: - self.arg_str = f"{_arg_to_str(self.arg, attributes, tensor_memo)}" + def stringify_args(self, attributes: list[str]) -> None: + self.arg_str = f"{_arg_to_str(self.arg, attributes)}" del self.arg def render(self, attributes: list[str]) -> str: @@ -308,11 +263,7 @@ def render(self, attributes: list[str]) -> str: src_placement_str = _arg_to_str(self.src_placement, attributes) dst_placement_str = _arg_to_str(self.dst_placement, attributes) placement_str = f"{src_placement_str} -> {dst_placement_str}" - - base_str = f"{REDISTRIBUTE_FUNC}({arg_str}, {placement_str})" - if self.output_str: - base_str += self.output_str - return base_str + return f"{REDISTRIBUTE_FUNC}({arg_str}, {placement_str})" def __iter__(self): # for BC; tuple(self) returns (op, placement info, kwargs, call_depth) @@ -337,9 +288,7 @@ def __init__(self, module_name: str, call_depth: int, stack: bool = False): super().__init__(call_depth, stack=stack) self.module_name = module_name - def stringify_args( - self, attributes: list[str], tensor_memo: Optional[TensorIdTracker] = None - ) -> None: + def stringify_args(self, attributes: list[str]) -> None: pass # nothing to stringify def render(self, attributes: list[str]) -> str: @@ -392,8 +341,6 @@ def __init__( record_nn_module=False, store_original_args=False, record_stack_trace=False, - record_output=False, - record_ids=False, ): super().__init__() import torch.distributed.tensor # noqa: F401 @@ -431,24 +378,8 @@ def __init__( # e.g. via DebugMode(record_stack_trace=True), or torch.autograd.set_detect_anomaly(). self.record_stack_trace = record_stack_trace - # Records call outputs in logs (e.g. for __torch_dispatch__, __torch_function__, redistribute_input) - self.record_output: bool = record_output - - # Annotates string dumps with graph-style tensor ids, e.g. op($1, $2) -> $3. - self.record_ids: bool = record_ids - - self.reset() - - def reset(self): self.operators = [] self.call_depth = 0 - self._tensor_memo = TensorIdTracker() - self._output_info: dict[int, object] = {} - - def _track_op_output(self, op_index, result): - """Assign IDs to output tensors and store in output_info""" - # self._track_tensor_ids(result) - self._output_info[op_index] = result # Without this override, running torch.compile under DebugMode # will force torch.compile to always use the “eager” backend @@ -459,35 +390,20 @@ def ignore_compile_internals(cls): def _record_call(self, call): if not self.store_original_args: - call.stringify_args( - self.record_tensor_attributes, - self._tensor_memo if self.record_ids else None, - ) + call.stringify_args(self.record_tensor_attributes) self.operators.append(call) - def _record_call_output(self, call, output): - if not self.record_output: - return - call.stringify_output( - output, - self.record_tensor_attributes, - self._tensor_memo if self.record_ids else None, - ) - def __torch_function__(self, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} - call = _OpCall( - func, args, kwargs, self.call_depth, stack=self.record_stack_trace + self._record_call( + _OpCall(func, args, kwargs, self.call_depth, stack=self.record_stack_trace) ) - self._record_call(call) try: self.call_depth += 1 - result = func(*args, **kwargs) - self._record_call_output(call, result) - return result + return func(*args, **kwargs) finally: self.call_depth -= 1 @@ -529,13 +445,13 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): result = func(*args, **kwargs) if call: - self._record_call_output(call, result) _run_dispatch_hooks(call, func, types, args, kwargs, result) return result def __enter__(self): - self.reset() + self.operators = [] + self.call_depth = 0 if self.record_torchfunction: torch._C._push_on_torch_function_stack(self) From c6c913d18e8c40ade1523cc0dd08f095217a2fdf Mon Sep 17 00:00:00 2001 From: Jane Xu Date: Tue, 4 Nov 2025 15:36:02 -0800 Subject: [PATCH 084/130] Add torch::stable::Tensor sizes and strides (#165153) Pull Request resolved: https://github.com/pytorch/pytorch/pull/165153 Approved by: https://github.com/mikaylagawarecki ghstack dependencies: #164991, #165152 --- .../libtorch_agnostic/csrc/kernel.cpp | 18 ++++-------- torch/csrc/stable/tensor_struct.h | 28 +++++++++++++++++++ 2 files changed, 33 insertions(+), 13 deletions(-) diff --git a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp index 87aaa46e64c95..7154322641c32 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp +++ b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp @@ -47,20 +47,10 @@ Tensor sgd_out_of_place( STD_TORCH_CHECK(param.get_device() == -1, "CPU device index = -1"); STD_TORCH_CHECK(param.get_device_index() == -1, "CPU device index = -1"); - int64_t *param_sizes; - int64_t *param_strides; - aoti_torch_get_sizes(param.get(), ¶m_sizes); - aoti_torch_get_strides(param.get(), ¶m_strides); + // testing Tensor strides + stride + STD_TORCH_CHECK(param.strides()[0] == param.stride(0)); - int32_t param_dtype; - aoti_torch_get_dtype(param.get(), ¶m_dtype); - - int32_t param_device_type; - aoti_torch_get_device_type(param.get(), ¶m_device_type); - - AtenTensorHandle out_ath; - aoti_torch_empty_strided(param.dim(), param_sizes, param_strides, param_dtype, param_device_type, param.get_device(), &out_ath); - auto out = Tensor(out_ath); + auto out = new_empty(param, param.sizes()); sgd_math( reinterpret_cast(param.data_ptr()), @@ -344,6 +334,8 @@ Tensor my_new_empty_dtype_variant(Tensor t) { // Still using a std::vector below even though people can just pass in an // initializer list (which will be implicitly converted to an HeaderOnlyArrayRef) // directly. + // This is to test that passing in a std::vector works for BC. (It gets + // implicitly converted to HeaderOnlyArrayRef too!) std::vector sizes = {2, 5}; auto dtype = std::make_optional(torch::headeronly::ScalarType::BFloat16); return new_empty(t, sizes, dtype); diff --git a/torch/csrc/stable/tensor_struct.h b/torch/csrc/stable/tensor_struct.h index 88cc167e59770..0d44ffd075170 100644 --- a/torch/csrc/stable/tensor_struct.h +++ b/torch/csrc/stable/tensor_struct.h @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -13,6 +14,7 @@ HIDDEN_NAMESPACE_BEGIN(torch, stable) using accelerator::DeviceIndex; +using torch::headeronly::IntHeaderOnlyArrayRef; using torch::headeronly::ScalarType; // The torch::stable::Tensor class is a highlevel C++ wrapper around @@ -93,6 +95,32 @@ class Tensor { return numel; } + // note: this API is, for all intents and purposes, the same as the one in + // TensorBase.h: it returns a borrowed reference of the dimension sizes of + // a Tensor. + // + // The only difference is that it returns a header-only IntHeaderOnlyArrayRef, + // which has slightly less functionality than a regular IntArrayRef. See + // [HeaderOnlyArrayRef vs ArrayRef note] for more details. + IntHeaderOnlyArrayRef sizes() const { + int64_t* sizes; + TORCH_ERROR_CODE_CHECK(aoti_torch_get_sizes(ath_.get(), &sizes)); + return IntHeaderOnlyArrayRef(sizes, dim()); + } + + // note: this API is, for all intents and purposes, the same as the one in + // TensorBase.h: it returns a borrowed reference of the strides of a + // Tensor. + // + // The only difference is that it returns a header-only IntHeaderOnlyArrayRef, + // which has slightly less functionality than a regular IntArrayRef. See + // [HeaderOnlyArrayRef vs ArrayRef note] for more details. + IntHeaderOnlyArrayRef strides() const { + int64_t* strides; + TORCH_ERROR_CODE_CHECK(aoti_torch_get_strides(ath_.get(), &strides)); + return IntHeaderOnlyArrayRef(strides, dim()); + } + // note: this is a subset of the original TensorBase API. It takes no // arguments whereas the original API takes in a kwarg of memory format. // Here, we assume the default contiguous memory format. From 13d2cc7bd26e32cafff0377dda1c5ddc8d04c4ce Mon Sep 17 00:00:00 2001 From: Yuanyuan Chen Date: Wed, 5 Nov 2025 20:55:59 +0000 Subject: [PATCH 085/130] Remove python workaround for ContextDecorator (#167049) This PR removes the import workaround for ContextDecorator because the import always succeeds in Py 3.10+. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167049 Approved by: https://github.com/Skylion007 --- torch/autograd/profiler.py | 24 ++---------------------- 1 file changed, 2 insertions(+), 22 deletions(-) diff --git a/torch/autograd/profiler.py b/torch/autograd/profiler.py index fa43af2701171..9e2a7b5046dee 100644 --- a/torch/autograd/profiler.py +++ b/torch/autograd/profiler.py @@ -52,26 +52,7 @@ "MemRecordsAcc", ] -try: - # Available in Python >= 3.2 - from contextlib import ContextDecorator as _ContextDecorator -except ImportError: - import functools - - class _ContextDecorator: # type: ignore[no-redef] - def __enter__(self): - raise NotImplementedError - - def __exit__(self, exc_type, exc_val, exc_tb): - raise NotImplementedError - - def __call__(self, func): - @functools.wraps(func) - def wrapped(*args, **kwargs): - with self: - return func(*args, **kwargs) - - return wrapped +from contextlib import ContextDecorator # global python state - whether profiler is currently enabled @@ -744,8 +725,7 @@ def createFunctionEventForMemoryEvents(evt): return all_function_events -# pyrefly: ignore [invalid-inheritance] -class record_function(_ContextDecorator): +class record_function(ContextDecorator): """Context manager/function decorator that adds a label to a code block/function when running autograd profiler. Label will only appear if CPU activity tracing is enabled. From fd8f368d31d622355275cfe0283ab582cd2ee903 Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Tue, 4 Nov 2025 17:45:11 -0800 Subject: [PATCH 086/130] [user-streams] Add graph annotation checks (#167019) Pull Request resolved: https://github.com/pytorch/pytorch/pull/167019 Approved by: https://github.com/Lucaskabela --- test/dynamo/test_graph_deduplication.py | 14 +- test/dynamo/test_streams.py | 230 ++++++++++++++++++++++-- torch/_dynamo/testing.py | 6 + 3 files changed, 225 insertions(+), 25 deletions(-) diff --git a/test/dynamo/test_graph_deduplication.py b/test/dynamo/test_graph_deduplication.py index 004aee88a8633..fc9284a3c9542 100644 --- a/test/dynamo/test_graph_deduplication.py +++ b/test/dynamo/test_graph_deduplication.py @@ -8,21 +8,11 @@ from torch._dynamo.graph_utils import _detect_cycles from torch._dynamo.output_graph import FakeRootModule from torch._dynamo.test_case import TestCase -from torch._dynamo.testing import ( - AotEagerAndRecordGraphs, - extract_graph_and_tracker, - normalize_gm, -) +from torch._dynamo.testing import extract_graph, extract_graph_and_tracker, normalize_gm from torch.compiler import allow_in_graph from torch.utils._ordered_set import OrderedSet -def extract_graph(fn, *args, **kwargs): - backend = AotEagerAndRecordGraphs() - result = torch.compile(backend=backend)(fn)(*args, **kwargs) - return result, backend.graphs, backend.fw_graphs - - def graph_str(gm): return normalize_gm(gm.print_readable(print_output=False)) @@ -40,7 +30,7 @@ def tearDown(self): super().tearDown() def run_and_return_graphs(self, fn, *args, **kwargs): - return extract_graph(fn, *args, **kwargs) + return extract_graph(fn, *args, **kwargs)[0:3] def run_and_get_simple_graph(self): def fn(x, y): diff --git a/test/dynamo/test_streams.py b/test/dynamo/test_streams.py index e05e1304d2860..0a49a21cca42b 100644 --- a/test/dynamo/test_streams.py +++ b/test/dynamo/test_streams.py @@ -1,11 +1,13 @@ # Owner(s): ["module: dynamo"] import functools +import re import unittest import weakref import torch import torch._dynamo.test_case import torch._dynamo.testing +from torch._dynamo.testing import extract_graph, remove_trailing_space from torch.testing._internal.common_cuda import TEST_MULTIGPU from torch.testing._internal.common_utils import requires_cuda @@ -15,6 +17,14 @@ ) +def remove_file_comment(gm_str: str) -> str: + return remove_trailing_space(re.sub(r"File.*\n", "\n", gm_str)) + + +def print_graph(graph: torch.fx.GraphModule) -> str: + return remove_file_comment(graph.print_readable()) + + class TestStreams(torch._dynamo.test_case.TestCase): @classmethod def setUpClass(cls): @@ -36,9 +46,7 @@ def test_event_weakref(self): @requires_cuda def test_stream_enter_exit(self): - def fn(x, y): - s2 = torch.Stream() - s1 = torch.Stream() + def fn(x, y, s1, s2): with s1: z1 = torch.add(x, y) with s2: @@ -47,13 +55,36 @@ def fn(x, y): return y - inp = (torch.ones(2, 2) + 1, torch.ones(2, 2)) + inp = (torch.ones(2, 2) + 1, torch.ones(2, 2), torch.Stream(), torch.Stream()) expected = fn(*inp) - fn_opt = torch.compile(fn, fullgraph=True) - actual = fn_opt(*inp) + ( + actual, + _, + fw_graphs, + _, + ) = extract_graph(fn, *inp) + self.assertEqual(len(fw_graphs), 1) self.assertEqual(expected, actual) + self.assertExpectedInline( + print_graph(fw_graphs[0]), + """\ +class (torch.nn.Module): + def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"): + # Annotation: {'stream': None} + add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1) + + # Annotation: {'stream': None} + add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None + + # Annotation: {'stream': None} + add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, 2); add_1 = None + add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_2, add); add_2 = add = None + return (add_3,) +""", + ) @requires_cuda + @unittest.skip("Needs graph break support with annotation context") def test_stream_context_graph_break(self): def fn(x, y): s2 = torch.Stream() @@ -70,9 +101,16 @@ def fn(x, y): inp = (torch.ones(2, 2) + 1, torch.ones(2, 2)) expected = fn(*inp) - fn_opt = torch.compile(fn) - actual = fn_opt(*inp) + ( + actual, + _, + fw_graphs, + _, + ) = extract_graph(fn, *inp) self.assertEqual(expected, actual) + self.assertEqual(len(fw_graphs), 2) + self.assertExpectedInline(print_graph(fw_graphs[0]), """""") + self.assertExpectedInline(print_graph(fw_graphs[1]), """""") @requires_cuda def test_stream_input(self): @@ -155,22 +193,188 @@ def fn(x, s0, s1): self.assertEqual(s_act, s_exp) def test_nested_stream_enter_exit(self): - pass - + def fn(x, y, s0, s1, s2): + with s1: + with s2: + z1 = torch.add(x, y) + with s0: + z0 = torch.add(x, y) + with s2: + y = 2 + z1 + + return z0, y + + inp = ( + torch.ones(2, 2) + 1, + torch.ones(2, 2), + torch.Stream(), + torch.Stream(), + torch.Stream(), + ) + expected = fn(*inp) + ( + actual, + _, + fw_graphs, + _, + ) = extract_graph(fn, *inp) + self.assertEqual(len(fw_graphs), 1) + self.assertEqual(expected, actual) + self.assertExpectedInline( + print_graph(fw_graphs[0]), + """\ +class (torch.nn.Module): + def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"): + # Annotation: {'stream': None} + add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1) + + # Annotation: {'stream': None} + add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None + + # Annotation: {'stream': None} + add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add, 2); add = None + return (add_1, add_2) +""", + ) + + @unittest.skip("Needs graph break support with annotation context") def test_stream_enter_exit_graph_break(self): pass + @unittest.skip("Needs graph break support with annotation context") def test_nested_stream_enter_exit_graph_break(self): pass def test_local_stream_enter_exit(self): - pass + def fn(x, y): + s2 = torch.Stream() + s1 = torch.Stream() + with s1: + z1 = torch.add(x, y) + with s2: + z = torch.add(x, y) + y = z + 2 + z1 + + return y + + inp = (torch.ones(2, 2) + 1, torch.ones(2, 2)) + expected = fn(*inp) + ( + actual, + _, + fw_graphs, + _, + ) = extract_graph(fn, *inp) + self.assertEqual(len(fw_graphs), 1) + self.assertEqual(expected, actual) + self.assertExpectedInline( + print_graph(fw_graphs[0]), + """\ +class (torch.nn.Module): + def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"): + # Annotation: {'stream': 1} + add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1) + + # Annotation: {'stream': 0} + add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None + + # Annotation: {'stream': 0} + add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, 2); add_1 = None + add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_2, add); add_2 = add = None + return (add_3,) +""", + ) def test_local_stream_nested_enter_exit(self): - pass + def fn(x, y): + s2 = torch.Stream() + s1 = torch.Stream() + s0 = torch.Stream() + with s1: + with s2: + z1 = torch.add(x, y) + with s0: + z0 = torch.add(x, y) + with s2: + y = 2 + z1 + + return z0, y + + inp = (torch.ones(2, 2) + 1, torch.ones(2, 2)) + expected = fn(*inp) + ( + actual, + _, + fw_graphs, + _, + ) = extract_graph(fn, *inp) + self.assertEqual(len(fw_graphs), 1) + self.assertEqual(expected, actual) + self.assertExpectedInline( + print_graph(fw_graphs[0]), + """\ +class (torch.nn.Module): + def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"): + # Annotation: {'stream': 0} + add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1) + + # Annotation: {'stream': 2} + add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None + + # Annotation: {'stream': 0} + add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add, 2); add = None + return (add_1, add_2) +""", + ) def test_stream_with_mutation(self): - pass + def fn(x, y): + s2 = torch.Stream() + s1 = torch.Stream() + s0 = torch.Stream() + with s1: + with s2: + x.add_(y) + with s0: + z1 = torch.add(y, y) + z0 = torch.add(z1, y) + with s2: + y = 2 + z1 + + return z0, y + + inp = (torch.ones(2, 2) + 1, torch.ones(2, 2)) + expected = fn(*inp) + ( + actual, + _, + fw_graphs, + _, + ) = extract_graph(fn, *inp) + self.assertEqual(len(fw_graphs), 1) + self.assertEqual(expected, actual) + self.assertExpectedInline( + print_graph(fw_graphs[0]), + """\ +class (torch.nn.Module): + def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"): + # Annotation: {'stream': 0} + add: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg0_1, arg1_1) + + # Annotation: {'stream': 2} + add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(arg1_1, arg1_1) + + # Annotation: {'stream': 2} + add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, arg1_1); arg1_1 = None + + # Annotation: {'stream': 0} + add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(add_1, 2); add_1 = None + + # + copy_: "f32[2, 2]" = torch.ops.aten.copy_.default(arg0_1, add); arg0_1 = add = copy_ = None + return (add_2, add_3) +""", + ) @requires_cuda def test_run_opcheck(self): diff --git a/torch/_dynamo/testing.py b/torch/_dynamo/testing.py index 9206f2598afc2..3eeedfb65da20 100644 --- a/torch/_dynamo/testing.py +++ b/torch/_dynamo/testing.py @@ -87,6 +87,12 @@ def extract_graph_backend(_gm, *args, **kwargs): # type: ignore[no-untyped-def] return gm.graph, region_tracker # type: ignore[union-attr] +def extract_graph(fn, *args, **kwargs): # type: ignore[no-untyped-def] + backend = AotEagerAndRecordGraphs() + result = torch.compile(backend=backend)(fn)(*args, **kwargs) + return result, backend.graphs, backend.fw_graphs, backend.bw_graphs + + def collect_results( model: torch.nn.Module, prediction: Any, loss: Any, example_inputs: Any ) -> list[Any]: From e69aaaf45a8018004aa91d58bef77199acbb888e Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Tue, 4 Nov 2025 17:45:12 -0800 Subject: [PATCH 087/130] [user-streams] Add backward test (#167021) Pull Request resolved: https://github.com/pytorch/pytorch/pull/167021 Approved by: https://github.com/Lucaskabela ghstack dependencies: #167019 --- test/dynamo/test_streams.py | 60 +++++++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/test/dynamo/test_streams.py b/test/dynamo/test_streams.py index 0a49a21cca42b..b9a3855f6ddbb 100644 --- a/test/dynamo/test_streams.py +++ b/test/dynamo/test_streams.py @@ -376,6 +376,66 @@ def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2, 2]"): """, ) + def test_stream_backward(self) -> None: + def fn(x, y): + s2 = torch.Stream() + s0 = torch.Stream() + with s0: + y0 = 2 * x + y + with s2: + z = 2 * x + y + + return y0, z + + inp = ( + torch.ones(2, 2, requires_grad=True) + 1, + torch.ones(2, 2, requires_grad=True), + ) + expected = fn(*inp) + ( + actual, + _, + fw_graphs, + bw_graphs, + ) = extract_graph(fn, *inp) + self.assertEqual(len(fw_graphs), 1) + self.assertEqual(expected, actual) + self.assertExpectedInline( + print_graph(fw_graphs[0]), + """\ +class GraphModule(torch.nn.Module): + def forward(self, primals_1: "f32[2, 2]", primals_2: "f32[2, 2]"): + # Annotation: {'stream': 1} + mul: "f32[2, 2]" = torch.ops.aten.mul.Tensor(primals_1, 2); primals_1 = None + add: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul, primals_2) + + # Annotation: {'stream': 0} + add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul, primals_2); mul = primals_2 = None + return (add, add_1) +""", + ) + + actual[1].sum().backward() + self.assertExpectedInline( + print_graph(bw_graphs[0]), + """\ +class GraphModule(torch.nn.Module): + def forward(self, tangents_1: "f32[2, 2]", tangents_2: "f32[2, 2]"): + # Annotation: {'stream': 0} + mul_2: "f32[2, 2]" = torch.ops.aten.mul.Tensor(tangents_2, 2) + + # + add_2: "f32[2, 2]" = torch.ops.aten.add.Tensor(tangents_2, tangents_1); tangents_2 = None + + # Annotation: {'stream': 1} + mul_3: "f32[2, 2]" = torch.ops.aten.mul.Tensor(tangents_1, 2); tangents_1 = None + + # + add_3: "f32[2, 2]" = torch.ops.aten.add.Tensor(mul_2, mul_3); mul_2 = mul_3 = None + return (add_3, add_2) +""", + ) + @requires_cuda def test_run_opcheck(self): from torch._dynamo.variables.streams import fork_stream, join_stream From e9a688f02ee742af2c1e24d7b2109beced35465f Mon Sep 17 00:00:00 2001 From: Pian Pawakapan Date: Wed, 5 Nov 2025 22:00:11 +0000 Subject: [PATCH 088/130] [DebugMode] output, tensor id annotations for DebugMode (#165076) Adds optional "node" id for tensors, output info annotations to DebugMode, with `DebugMode(record_output=True, record_ids=True)` Example output for `test_debug_mode_mm`, with both enabled: ``` torch.mm(dt$0: f32[8, 8]| S(0), dt$1: f32[8, 32]| S(0)) -> dt$12: f32[8, 32]| S(0) aten::mm(dt$2: f32[8, 8]| S(0), dt$3: f32[8, 32]| S(0)) redistribute_input(1, S(0) -> R) redistribute_input(t$4: f32[1, 32], trace: S(0)->R) _c10d_functional::all_gather_into_tensor(t$5: f32[1, 32], 8, 0) -> t$6: f32[8, 32] _c10d_functional::wait_tensor(t$7: f32[8, 32]) -> t$8: f32[8, 32] aten::mm(t$9: f32[1, 8], t$10: f32[8, 32]) -> t$11: f32[1, 32] (dt$13: f32[8, 32]| S(0)) -> dt$17: f32[]| P aten::sum(dt$14: f32[8, 32]| S(0)) aten::sum(t$15: f32[1, 32]) -> t$16: f32[]""" ``` Sadly the only way to get DTensor op outputs is to set `record_torchfunction=True`, as dispatch calls just defer to DTensor's dispatch logic. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165076 Approved by: https://github.com/zpcore --- .../tensor/debug/test_debug_mode.py | 22 +-- torch/utils/_debug_mode.py | 126 +++++++++++++++--- 2 files changed, 117 insertions(+), 31 deletions(-) diff --git a/test/distributed/tensor/debug/test_debug_mode.py b/test/distributed/tensor/debug/test_debug_mode.py index 18cc702cbbc7a..abc37f17a74de 100644 --- a/test/distributed/tensor/debug/test_debug_mode.py +++ b/test/distributed/tensor/debug/test_debug_mode.py @@ -50,22 +50,24 @@ def test_debug_mode_mm(self): x_dtensor = DTensor.from_local(x, mesh, [Shard(0)], run_check=False) y_dtensor = DTensor.from_local(y, mesh, [Shard(0)], run_check=False) - with DebugMode(record_torchfunction=True) as debug_mode: + with DebugMode( + record_torchfunction=True, record_ids=True, record_output=True + ) as debug_mode: torch.mm(x_dtensor, y_dtensor).sum() self.assertExpectedInline( debug_mode.debug_string(), """\ - torch.mm(dt: f32[8, 8]| S(0), dt: f32[8, 32]| S(0)) - aten::mm(dt: f32[8, 8]| S(0), dt: f32[8, 32]| S(0)) + torch.mm(dt$0: f32[8, 8]| S(0), dt$1: f32[8, 32]| S(0)) -> dt$6: f32[8, 32]| S(0) + aten::mm(dt$0: f32[8, 8]| S(0), dt$1: f32[8, 32]| S(0)) redistribute_input(1, S(0) -> R) - redistribute_input(t: f32[1, 32], trace: S(0)->R) - _c10d_functional::all_gather_into_tensor(t: f32[1, 32], 8, 0) - _c10d_functional::wait_tensor(t: f32[8, 32]) - aten::mm(t: f32[1, 8], t: f32[8, 32]) - (dt: f32[8, 32]| S(0)) - aten::sum(dt: f32[8, 32]| S(0)) - aten::sum(t: f32[1, 32])""", + redistribute_input(t$2: f32[1, 32], trace: S(0)->R) + _c10d_functional::all_gather_into_tensor(t$2: f32[1, 32], 8, 0) -> t$3: f32[8, 32] + _c10d_functional::wait_tensor(t$3: f32[8, 32]) -> t$3: f32[8, 32] + aten::mm(t$4: f32[1, 8], t$3: f32[8, 32]) -> t$5: f32[1, 32] + (dt$6: f32[8, 32]| S(0)) -> dt$8: f32[]| P + aten::sum(dt$6: f32[8, 32]| S(0)) + aten::sum(t$5: f32[1, 32]) -> t$7: f32[]""", ) self.assertTrue(isinstance(debug_mode.operators[0], _OpCall)) diff --git a/torch/utils/_debug_mode.py b/torch/utils/_debug_mode.py index 09435aa07e68b..5e24ce086e1aa 100644 --- a/torch/utils/_debug_mode.py +++ b/torch/utils/_debug_mode.py @@ -2,6 +2,7 @@ import contextlib import functools import traceback +import weakref from typing import Any, Callable, Optional, TYPE_CHECKING import torch @@ -14,6 +15,7 @@ ) from torch.utils._pytree import tree_all, tree_map from torch.utils._traceback import CapturedTraceback +from torch.utils.weak import WeakIdRef if TYPE_CHECKING: @@ -56,29 +58,48 @@ def _stringify_dtensor_spec(spec) -> str: return DTensorSpec.format_shard_order_str(spec.placements, spec.shard_order) -def _tensor_debug_string(tensor, attributes) -> str: +class TensorIdTracker: + def __init__(self): + self.tensor_memo: dict[WeakIdRef, int] = {} + self.next_tensor_id = 0 + + def _id(self, tensor) -> int: + with torch._C._DisablePythonDispatcher(): + o = WeakIdRef(tensor) + + def del_memo(): + self.tensor_memo.pop(o, None) + + weakref.finalize(tensor, del_memo) + if o not in self.tensor_memo: + self.tensor_memo[o] = self.next_tensor_id + self.next_tensor_id += 1 + return self.tensor_memo[o] + + +def _tensor_debug_string(tensor, attributes, tensor_memo=None) -> str: """Convert tensor to debug string representation.""" if isinstance(tensor, torch.Tensor): tensor_debug_str = f"{dtype_abbrs[tensor.dtype]}{_stringify_shape(tensor.shape)}{_stringify_attributes(tensor, attributes)}" - + id_str = f"${tensor_memo._id(tensor)}" if tensor_memo is not None else "" if isinstance(tensor, torch.distributed.tensor.DTensor): # omitted device mesh - return f"dt: {tensor_debug_str}| {_stringify_dtensor_spec(tensor._spec)}" + return f"dt{id_str}: {tensor_debug_str}| {_stringify_dtensor_spec(tensor._spec)}" elif isinstance(tensor, FakeTensor): - return f"ft: {tensor_debug_str}" + return f"ft{id_str}: {tensor_debug_str}" else: - return f"t: {tensor_debug_str}" + return f"t{id_str}: {tensor_debug_str}" else: raise RuntimeError(f"Unsupported tensor type: {type(tensor)}") -def _arg_to_str(arg, attributes) -> str: +def _arg_to_str(arg, attributes, tensor_memo=None) -> str: from torch.distributed.tensor._dtensor_spec import DTensorSpec def to_str(x): if isinstance(x, torch.Tensor): - return _tensor_debug_string(x, attributes) + return _tensor_debug_string(x, attributes, tensor_memo) elif isinstance(x, DTensorSpec): return _stringify_dtensor_spec(x) return x @@ -144,8 +165,11 @@ def __init__( # results from dispatch hooks self.record = record self.log = log + self.output_str: Optional[str] = None - def stringify_args(self, attributes: list[str]) -> None: + def stringify_args( + self, attributes: list[str], tensor_memo: Optional[TensorIdTracker] = None + ) -> None: """ To reduce memory consumption, this method stringifies args/kwargs, stores the result, and deletes original args/kwargs. """ @@ -153,6 +177,18 @@ def stringify_args(self, attributes: list[str]) -> None: "Subclasses must implement stringify_args(), even if no-op" ) + def stringify_output( + self, + output: Any, + attributes: list[str], + tensor_memo: Optional[TensorIdTracker] = None, + ) -> None: + """Store stringified version of call output in self.output_str""" + if tree_all(lambda x: x is None, output): + return + output_str = tree_map(lambda x: _arg_to_str(x, attributes, tensor_memo), output) + self.output_str = f" -> {str(output_str)}" + def render(self, attributes: list[str]) -> str: raise NotImplementedError("Subclasses must implement string render()") @@ -179,11 +215,16 @@ def __init__( self.args_str: Optional[str] = None self.kwargs_str: Optional[str] = None - def stringify_args(self, attributes: list[str]) -> None: - self.args_str = ", ".join(_arg_to_str(arg, attributes) for arg in self.args) + def stringify_args( + self, attributes: list[str], tensor_memo: Optional[TensorIdTracker] = None + ) -> None: + self.args_str = ", ".join( + _arg_to_str(arg, attributes, tensor_memo) for arg in self.args + ) if self.kwargs: self.kwargs_str = ", " + ", ".join( - f"{k}={_arg_to_str(v, attributes)}" for k, v in self.kwargs.items() + f"{k}={_arg_to_str(v, attributes, tensor_memo)}" + for k, v in self.kwargs.items() ) else: self.kwargs_str = "" @@ -215,6 +256,8 @@ def render(self, attributes: list[str]) -> str: base_str = f"{op_name}({args_str}{kwargs_str})" + if self.output_str: + base_str += self.output_str if self.log: base_str += f" # {self.log}" return base_str @@ -247,8 +290,10 @@ def __init__( self.arg_str: Optional[str] = None - def stringify_args(self, attributes: list[str]) -> None: - self.arg_str = f"{_arg_to_str(self.arg, attributes)}" + def stringify_args( + self, attributes: list[str], tensor_memo: Optional[TensorIdTracker] = None + ) -> None: + self.arg_str = f"{_arg_to_str(self.arg, attributes, tensor_memo)}" del self.arg def render(self, attributes: list[str]) -> str: @@ -263,7 +308,11 @@ def render(self, attributes: list[str]) -> str: src_placement_str = _arg_to_str(self.src_placement, attributes) dst_placement_str = _arg_to_str(self.dst_placement, attributes) placement_str = f"{src_placement_str} -> {dst_placement_str}" - return f"{REDISTRIBUTE_FUNC}({arg_str}, {placement_str})" + + base_str = f"{REDISTRIBUTE_FUNC}({arg_str}, {placement_str})" + if self.output_str: + base_str += self.output_str + return base_str def __iter__(self): # for BC; tuple(self) returns (op, placement info, kwargs, call_depth) @@ -288,7 +337,9 @@ def __init__(self, module_name: str, call_depth: int, stack: bool = False): super().__init__(call_depth, stack=stack) self.module_name = module_name - def stringify_args(self, attributes: list[str]) -> None: + def stringify_args( + self, attributes: list[str], tensor_memo: Optional[TensorIdTracker] = None + ) -> None: pass # nothing to stringify def render(self, attributes: list[str]) -> str: @@ -341,6 +392,8 @@ def __init__( record_nn_module=False, store_original_args=False, record_stack_trace=False, + record_output=False, + record_ids=False, ): super().__init__() import torch.distributed.tensor # noqa: F401 @@ -378,8 +431,24 @@ def __init__( # e.g. via DebugMode(record_stack_trace=True), or torch.autograd.set_detect_anomaly(). self.record_stack_trace = record_stack_trace + # Records call outputs in logs (e.g. for __torch_dispatch__, __torch_function__, redistribute_input) + self.record_output: bool = record_output + + # Annotates string dumps with graph-style tensor ids, e.g. op($1, $2) -> $3. + self.record_ids: bool = record_ids + + self.reset() + + def reset(self): self.operators = [] self.call_depth = 0 + self._tensor_memo = TensorIdTracker() + self._output_info: dict[int, object] = {} + + def _track_op_output(self, op_index, result): + """Assign IDs to output tensors and store in output_info""" + # self._track_tensor_ids(result) + self._output_info[op_index] = result # Without this override, running torch.compile under DebugMode # will force torch.compile to always use the “eager” backend @@ -390,20 +459,35 @@ def ignore_compile_internals(cls): def _record_call(self, call): if not self.store_original_args: - call.stringify_args(self.record_tensor_attributes) + call.stringify_args( + self.record_tensor_attributes, + self._tensor_memo if self.record_ids else None, + ) self.operators.append(call) + def _record_call_output(self, call, output): + if not self.record_output: + return + call.stringify_output( + output, + self.record_tensor_attributes, + self._tensor_memo if self.record_ids else None, + ) + def __torch_function__(self, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} - self._record_call( - _OpCall(func, args, kwargs, self.call_depth, stack=self.record_stack_trace) + call = _OpCall( + func, args, kwargs, self.call_depth, stack=self.record_stack_trace ) + self._record_call(call) try: self.call_depth += 1 - return func(*args, **kwargs) + result = func(*args, **kwargs) + self._record_call_output(call, result) + return result finally: self.call_depth -= 1 @@ -445,13 +529,13 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None): result = func(*args, **kwargs) if call: + self._record_call_output(call, result) _run_dispatch_hooks(call, func, types, args, kwargs, result) return result def __enter__(self): - self.operators = [] - self.call_depth = 0 + self.reset() if self.record_torchfunction: torch._C._push_on_torch_function_stack(self) From 711a7758788ccdbb85bc20e9dd8146f5a7bafb24 Mon Sep 17 00:00:00 2001 From: IvanKobzarev Date: Wed, 5 Nov 2025 08:54:22 -0800 Subject: [PATCH 089/130] fix nccl estimations (#167093) Pull Request resolved: https://github.com/pytorch/pytorch/pull/167093 Approved by: https://github.com/kwen2501, https://github.com/eellison --- torch/_inductor/comm_analysis.py | 2 +- torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/torch/_inductor/comm_analysis.py b/torch/_inductor/comm_analysis.py index 61af576772c16..74a58acb84ff3 100644 --- a/torch/_inductor/comm_analysis.py +++ b/torch/_inductor/comm_analysis.py @@ -360,7 +360,7 @@ def estimate_nccl_collective_runtime_from_fx_node( fx_node: torch.fx.Node, override_size: Optional[int] = None, # TODO(ivankobzarev): NCCL estimator sometimes fail unexpectedly, enable back after fix. - use_nccl_estimator: bool = False, + use_nccl_estimator: bool = True, ) -> float: """ Returns estimated NCCL collective runtime in nanoseconds (ns). diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index d051803aa7376..3416bc336d34a 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -3593,6 +3593,7 @@ float ProcessGroupNCCL::endTimeEstimate() { #ifdef NCCL_SIM_INFO_INITIALIZER ncclSimInfo_t simInfo = NCCL_SIM_INFO_INITIALIZER; C10D_NCCL_CHECK(ncclGroupSimulateEnd(&simInfo), std::nullopt); + --ncclActiveGroupCounter_; return simInfo.estimatedTime; #else TORCH_CHECK( From ad7a57262c8f3ce6a2d724af533f09437495100f Mon Sep 17 00:00:00 2001 From: Yuanyuan Chen Date: Wed, 5 Nov 2025 22:06:19 +0000 Subject: [PATCH 090/130] [12/N] Apply ruff UP035 rule (#166929) This PR continues to apply ruff UP035 rule to test code and some remaining torch files. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166929 Approved by: https://github.com/Lucaskabela --- test/distributed/tensor/test_attention.py | 3 ++- test/higher_order_ops/test_local_map.py | 3 ++- test/inductor/test_caching.py | 3 ++- test/inductor/test_fx_fusion.py | 3 ++- test/inductor/test_native_matmul.py | 2 +- test/quantization/fx/test_quantize_fx.py | 3 ++- test/test_matmul_cuda.py | 2 +- torch/_dynamo/eval_frame.py | 3 ++- torch/_dynamo/graph_bytecode_inputs.py | 3 ++- torch/_dynamo/variables/distributed.py | 3 ++- torch/_dynamo/variables/iter.py | 4 ++-- torch/_dynamo/variables/optimizer.py | 3 ++- torch/_dynamo/variables/script_object.py | 4 ++-- torch/_dynamo/variables/sdpa.py | 3 ++- torch/_dynamo/variables/streams.py | 3 ++- torch/_dynamo/variables/torch_function.py | 4 ++-- torch/_functorch/_aot_autograd/aot_autograd_result.py | 3 ++- torch/_inductor/compile_worker/timer.py | 3 ++- torch/_inductor/fx_passes/bucketing.py | 3 ++- torch/_inductor/fx_passes/ddp_fusion.py | 4 ++-- torch/_inductor/fx_passes/fsdp.py | 2 +- torch/_inductor/fx_passes/memory_estimator.py | 2 +- torch/_inductor/fx_passes/mkldnn_fusion.py | 6 +++++- torch/_inductor/fx_passes/overlap_scheduling.py | 4 ++-- torch/_inductor/fx_passes/pad_mm.py | 4 ++-- torch/_inductor/fx_passes/post_grad.py | 3 ++- torch/_inductor/fx_passes/reinplace.py | 4 ++-- torch/_inductor/fx_passes/split_cat.py | 5 ++--- torch/_inductor/kernel/custom_op.py | 3 ++- torch/_inductor/kernel/flex/flex_flash_attention.py | 3 ++- torch/_inductor/runtime/benchmarking.py | 4 ++-- torch/_inductor/runtime/caching/interfaces.py | 6 ++++-- torch/_inductor/runtime/caching/locks.py | 5 +++-- torch/distributed/elastic/multiprocessing/tail_log.py | 3 ++- torch/utils/_cxx_pytree.py | 4 ++-- torch/utils/_debug_mode.py | 3 ++- torch/utils/_pytree.py | 3 ++- 37 files changed, 76 insertions(+), 50 deletions(-) diff --git a/test/distributed/tensor/test_attention.py b/test/distributed/tensor/test_attention.py index eaf3a4042060d..6c3485f9d7025 100644 --- a/test/distributed/tensor/test_attention.py +++ b/test/distributed/tensor/test_attention.py @@ -3,7 +3,8 @@ import itertools import random import unittest -from typing import Any, Callable, ClassVar, Optional +from collections.abc import Callable +from typing import Any, ClassVar, Optional import torch import torch.distributed as dist diff --git a/test/higher_order_ops/test_local_map.py b/test/higher_order_ops/test_local_map.py index 9d2870d3b5fdd..fbb21633260e7 100644 --- a/test/higher_order_ops/test_local_map.py +++ b/test/higher_order_ops/test_local_map.py @@ -4,8 +4,9 @@ import functools import unittest +from collections.abc import Callable from contextlib import contextmanager, ExitStack -from typing import Any, Callable, Optional +from typing import Any, Optional import torch import torch._dynamo diff --git a/test/inductor/test_caching.py b/test/inductor/test_caching.py index bcb66beea700c..aa4c3a1f229f1 100644 --- a/test/inductor/test_caching.py +++ b/test/inductor/test_caching.py @@ -13,7 +13,7 @@ from shutil import rmtree from threading import Lock from time import sleep, time -from typing import Any, Generator, Sequence, TYPE_CHECKING, Union +from typing import Any, TYPE_CHECKING, Union from typing_extensions import TypeVar from unittest.mock import patch @@ -37,6 +37,7 @@ if TYPE_CHECKING: + from collections.abc import Generator, Sequence from pathlib import Path diff --git a/test/inductor/test_fx_fusion.py b/test/inductor/test_fx_fusion.py index ebe98373e622a..63342502d3cd9 100644 --- a/test/inductor/test_fx_fusion.py +++ b/test/inductor/test_fx_fusion.py @@ -1,5 +1,6 @@ # Owner(s): ["module: inductor"] -from typing import Any, Callable +from collections.abc import Callable +from typing import Any import torch from torch._inductor.fx_passes.pre_grad import ( diff --git a/test/inductor/test_native_matmul.py b/test/inductor/test_native_matmul.py index 1870a0e373be0..c37f844e41eae 100644 --- a/test/inductor/test_native_matmul.py +++ b/test/inductor/test_native_matmul.py @@ -1,7 +1,7 @@ # Owner(s): ["module: inductor"] -from typing import Callable +from collections.abc import Callable import torch from torch._dynamo.testing import rand_strided diff --git a/test/quantization/fx/test_quantize_fx.py b/test/quantization/fx/test_quantize_fx.py index cd922d94c60c3..faba2f5edc6a7 100644 --- a/test/quantization/fx/test_quantize_fx.py +++ b/test/quantization/fx/test_quantize_fx.py @@ -204,7 +204,8 @@ import operator import unittest import io -from typing import Callable, Optional +from typing import Optional +from collections.abc import Callable class BinaryOp(torch.nn.Module): def __init__(self, binary_op, ibinary_op, is_inplace, is_scalar): diff --git a/test/test_matmul_cuda.py b/test/test_matmul_cuda.py index 002c34c450756..10611d4f24673 100644 --- a/test/test_matmul_cuda.py +++ b/test/test_matmul_cuda.py @@ -5,7 +5,7 @@ import unittest from itertools import product from functools import partial -from typing import Callable +from collections.abc import Callable import torch diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index e23e049e3bbb1..222647eeae9ab 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -39,10 +39,11 @@ import unittest import warnings import weakref +from collections.abc import Sized from dataclasses import dataclass from enum import Enum from os.path import dirname, join -from typing import Any, NamedTuple, Optional, Sized, TYPE_CHECKING, Union +from typing import Any, NamedTuple, Optional, TYPE_CHECKING, Union from unittest.mock import patch import sympy diff --git a/torch/_dynamo/graph_bytecode_inputs.py b/torch/_dynamo/graph_bytecode_inputs.py index 979950cf3bd1b..16583b89201ec 100644 --- a/torch/_dynamo/graph_bytecode_inputs.py +++ b/torch/_dynamo/graph_bytecode_inputs.py @@ -1,5 +1,6 @@ import weakref -from typing import Any, Callable +from collections.abc import Callable +from typing import Any from torch._dynamo.source import Source diff --git a/torch/_dynamo/variables/distributed.py b/torch/_dynamo/variables/distributed.py index eb39dd8fa3e07..187055c26cd00 100644 --- a/torch/_dynamo/variables/distributed.py +++ b/torch/_dynamo/variables/distributed.py @@ -20,7 +20,8 @@ import functools import inspect -from typing import Any, Sequence, TYPE_CHECKING +from collections.abc import Sequence +from typing import Any, TYPE_CHECKING import torch from torch.fx.experimental._backward_state import BackwardState diff --git a/torch/_dynamo/variables/iter.py b/torch/_dynamo/variables/iter.py index 5970ba0e1dda7..be765cbbc8bf9 100644 --- a/torch/_dynamo/variables/iter.py +++ b/torch/_dynamo/variables/iter.py @@ -14,8 +14,8 @@ """ import itertools -from collections.abc import Callable -from typing import Any, Sequence, TYPE_CHECKING, Union +from collections.abc import Callable, Sequence +from typing import Any, TYPE_CHECKING, Union from .. import graph_break_hints, polyfills, variables from ..bytecode_transformation import ( diff --git a/torch/_dynamo/variables/optimizer.py b/torch/_dynamo/variables/optimizer.py index 289cebbe8129b..c09cc2163a5f4 100644 --- a/torch/_dynamo/variables/optimizer.py +++ b/torch/_dynamo/variables/optimizer.py @@ -22,7 +22,8 @@ import logging import weakref -from typing import Any, Iterable, Optional, TYPE_CHECKING +from collections.abc import Iterable +from typing import Any, Optional, TYPE_CHECKING import torch from torch._dynamo.variables.tensor import TensorVariable diff --git a/torch/_dynamo/variables/script_object.py b/torch/_dynamo/variables/script_object.py index 85977104977fb..644c269a23a34 100644 --- a/torch/_dynamo/variables/script_object.py +++ b/torch/_dynamo/variables/script_object.py @@ -19,8 +19,8 @@ """ import functools -from collections.abc import Callable -from typing import Any, Iterable, TYPE_CHECKING, TypeVar +from collections.abc import Callable, Iterable +from typing import Any, TYPE_CHECKING, TypeVar from typing_extensions import ParamSpec import torch diff --git a/torch/_dynamo/variables/sdpa.py b/torch/_dynamo/variables/sdpa.py index 75928842cf297..629bf094dc951 100644 --- a/torch/_dynamo/variables/sdpa.py +++ b/torch/_dynamo/variables/sdpa.py @@ -1,5 +1,6 @@ +from collections.abc import Sequence from inspect import getattr_static -from typing import Any, Sequence, TYPE_CHECKING, TypeGuard +from typing import Any, TYPE_CHECKING, TypeGuard from torch._guards import Source from torch.backends.cuda import SDPAParams diff --git a/torch/_dynamo/variables/streams.py b/torch/_dynamo/variables/streams.py index c353181eb8029..fb5dd775bd636 100644 --- a/torch/_dynamo/variables/streams.py +++ b/torch/_dynamo/variables/streams.py @@ -1,5 +1,6 @@ import collections -from typing import Any, Callable, Optional +from collections.abc import Callable +from typing import Any, Optional import torch from torch._dynamo.variables.dicts import ConstDictVariable diff --git a/torch/_dynamo/variables/torch_function.py b/torch/_dynamo/variables/torch_function.py index fa8412146a427..4d0f0b4fae8ab 100644 --- a/torch/_dynamo/variables/torch_function.py +++ b/torch/_dynamo/variables/torch_function.py @@ -29,9 +29,9 @@ import functools import inspect import operator -from collections.abc import Sequence +from collections.abc import Generator, Iterable, Sequence from types import TracebackType -from typing import Any, Generator, Iterable, Optional, TYPE_CHECKING +from typing import Any, Optional, TYPE_CHECKING import torch._C import torch.utils._pytree as pytree diff --git a/torch/_functorch/_aot_autograd/aot_autograd_result.py b/torch/_functorch/_aot_autograd/aot_autograd_result.py index ce01e37f03243..7e608933b34c3 100644 --- a/torch/_functorch/_aot_autograd/aot_autograd_result.py +++ b/torch/_functorch/_aot_autograd/aot_autograd_result.py @@ -22,9 +22,10 @@ import json import logging from abc import ABC, abstractmethod +from collections.abc import Callable from copy import copy from dataclasses import dataclass -from typing import Any, Callable, Generic, Optional, TYPE_CHECKING, TypeVar +from typing import Any, Generic, Optional, TYPE_CHECKING, TypeVar import torch from torch._dynamo.precompile_context import BackendCacheArtifact diff --git a/torch/_inductor/compile_worker/timer.py b/torch/_inductor/compile_worker/timer.py index 7cfeb4217e26b..7c495403b3a55 100644 --- a/torch/_inductor/compile_worker/timer.py +++ b/torch/_inductor/compile_worker/timer.py @@ -1,6 +1,7 @@ +from collections.abc import Callable from threading import Lock, Thread from time import monotonic, sleep -from typing import Callable, Optional, Union +from typing import Optional, Union class Timer: diff --git a/torch/_inductor/fx_passes/bucketing.py b/torch/_inductor/fx_passes/bucketing.py index ab831c96c94ba..29f070564349c 100644 --- a/torch/_inductor/fx_passes/bucketing.py +++ b/torch/_inductor/fx_passes/bucketing.py @@ -2,7 +2,8 @@ import logging import operator from collections import defaultdict -from typing import Any, Callable, Literal, TypeAlias +from collections.abc import Callable +from typing import Any, Literal, TypeAlias import torch import torch.distributed as dist diff --git a/torch/_inductor/fx_passes/ddp_fusion.py b/torch/_inductor/fx_passes/ddp_fusion.py index 8a4de1a604869..44314b912786f 100644 --- a/torch/_inductor/fx_passes/ddp_fusion.py +++ b/torch/_inductor/fx_passes/ddp_fusion.py @@ -4,10 +4,10 @@ import logging import math import operator -from collections.abc import Generator +from collections.abc import Callable, Generator from dataclasses import dataclass from functools import partial -from typing import Any, Callable, cast +from typing import Any, cast import torch import torch.fx as fx diff --git a/torch/_inductor/fx_passes/fsdp.py b/torch/_inductor/fx_passes/fsdp.py index 6b0c2ad2c94a7..1e71c350ed7b6 100644 --- a/torch/_inductor/fx_passes/fsdp.py +++ b/torch/_inductor/fx_passes/fsdp.py @@ -1,5 +1,5 @@ import logging -from typing import Callable +from collections.abc import Callable import torch from torch._inductor.fx_passes.bucketing import ( diff --git a/torch/_inductor/fx_passes/memory_estimator.py b/torch/_inductor/fx_passes/memory_estimator.py index c6b7c51b948e5..e887d4bf62c8e 100644 --- a/torch/_inductor/fx_passes/memory_estimator.py +++ b/torch/_inductor/fx_passes/memory_estimator.py @@ -1,8 +1,8 @@ import itertools import logging from collections import defaultdict +from collections.abc import Callable from dataclasses import dataclass -from typing import Callable import torch import torch.fx as fx diff --git a/torch/_inductor/fx_passes/mkldnn_fusion.py b/torch/_inductor/fx_passes/mkldnn_fusion.py index 70b3a3c355dde..214d3bf02f7f4 100644 --- a/torch/_inductor/fx_passes/mkldnn_fusion.py +++ b/torch/_inductor/fx_passes/mkldnn_fusion.py @@ -2,7 +2,7 @@ import functools import operator from functools import reduce -from typing import Any, Callable +from typing import Any, TYPE_CHECKING import torch from torch._dynamo.utils import counters @@ -35,6 +35,10 @@ ) +if TYPE_CHECKING: + from collections.abc import Callable + + if torch._C._has_mkldnn: aten = torch.ops.aten mkldnn = torch.ops.mkldnn diff --git a/torch/_inductor/fx_passes/overlap_scheduling.py b/torch/_inductor/fx_passes/overlap_scheduling.py index a47aa960e58c5..f383ab63dc261 100644 --- a/torch/_inductor/fx_passes/overlap_scheduling.py +++ b/torch/_inductor/fx_passes/overlap_scheduling.py @@ -4,9 +4,9 @@ import logging import sys from collections import Counter, defaultdict -from collections.abc import Iterable +from collections.abc import Callable, Iterable from dataclasses import dataclass -from typing import Any, Callable +from typing import Any import torch import torch.fx as fx diff --git a/torch/_inductor/fx_passes/pad_mm.py b/torch/_inductor/fx_passes/pad_mm.py index 30768fda9bb72..b511403d4874c 100644 --- a/torch/_inductor/fx_passes/pad_mm.py +++ b/torch/_inductor/fx_passes/pad_mm.py @@ -2,8 +2,8 @@ import itertools import operator import typing -from collections.abc import Sequence -from typing import Any, Callable +from collections.abc import Callable, Sequence +from typing import Any import torch import torch._inductor.runtime.runtime_utils diff --git a/torch/_inductor/fx_passes/post_grad.py b/torch/_inductor/fx_passes/post_grad.py index 7d995adec04ef..91b4e10bf7238 100644 --- a/torch/_inductor/fx_passes/post_grad.py +++ b/torch/_inductor/fx_passes/post_grad.py @@ -5,7 +5,8 @@ import logging import operator from collections import Counter, defaultdict -from typing import Any, Callable, TypeVar +from collections.abc import Callable +from typing import Any, TypeVar from typing_extensions import ParamSpec import torch diff --git a/torch/_inductor/fx_passes/reinplace.py b/torch/_inductor/fx_passes/reinplace.py index 52222f3da8344..e42e8a1139770 100644 --- a/torch/_inductor/fx_passes/reinplace.py +++ b/torch/_inductor/fx_passes/reinplace.py @@ -3,10 +3,10 @@ import logging import operator from collections import defaultdict -from collections.abc import Sequence +from collections.abc import Callable, Sequence from contextlib import nullcontext from dataclasses import dataclass -from typing import Any, Callable, cast +from typing import Any, cast import torch import torch.fx.node diff --git a/torch/_inductor/fx_passes/split_cat.py b/torch/_inductor/fx_passes/split_cat.py index 92e1e6f375f44..0bad4fa7cc635 100644 --- a/torch/_inductor/fx_passes/split_cat.py +++ b/torch/_inductor/fx_passes/split_cat.py @@ -4,9 +4,8 @@ import operator import os from collections import defaultdict -from collections.abc import Sequence -from typing import Any, Callable -from typing_extensions import TypeAlias +from collections.abc import Callable, Sequence +from typing import Any, TypeAlias import torch from torch._dynamo.utils import counters diff --git a/torch/_inductor/kernel/custom_op.py b/torch/_inductor/kernel/custom_op.py index 303110a561b5e..d35309c01d07c 100644 --- a/torch/_inductor/kernel/custom_op.py +++ b/torch/_inductor/kernel/custom_op.py @@ -2,7 +2,8 @@ import functools import logging -from typing import Any, Callable, Optional, Union +from collections.abc import Callable +from typing import Any, Optional, Union import torch from torch._inductor.codegen.subgraph import SubgraphTemplate diff --git a/torch/_inductor/kernel/flex/flex_flash_attention.py b/torch/_inductor/kernel/flex/flex_flash_attention.py index c100df84d5a73..0d3721aa730a4 100644 --- a/torch/_inductor/kernel/flex/flex_flash_attention.py +++ b/torch/_inductor/kernel/flex/flex_flash_attention.py @@ -3,8 +3,9 @@ import functools import importlib +from collections.abc import Callable, Sequence from contextlib import contextmanager -from typing import Any, Callable, Optional, Sequence +from typing import Any, Optional import sympy from sympy import Expr, Integer diff --git a/torch/_inductor/runtime/benchmarking.py b/torch/_inductor/runtime/benchmarking.py index d592a8c8c00f9..d9d92e363879d 100644 --- a/torch/_inductor/runtime/benchmarking.py +++ b/torch/_inductor/runtime/benchmarking.py @@ -5,8 +5,8 @@ from functools import cached_property, wraps from itertools import chain from statistics import median -from typing import Any, Optional, Union -from typing_extensions import Concatenate, ParamSpec, Self, TypeVar +from typing import Any, Concatenate, Optional, Union +from typing_extensions import ParamSpec, Self, TypeVar import torch import torch.utils._pytree as pytree diff --git a/torch/_inductor/runtime/caching/interfaces.py b/torch/_inductor/runtime/caching/interfaces.py index 0758e11134018..03d2957493679 100644 --- a/torch/_inductor/runtime/caching/interfaces.py +++ b/torch/_inductor/runtime/caching/interfaces.py @@ -12,8 +12,8 @@ from pathlib import Path from threading import Lock from time import time -from typing import Any, Callable, TYPE_CHECKING -from typing_extensions import override, TypeAlias +from typing import Any, TYPE_CHECKING, TypeAlias +from typing_extensions import override from filelock import FileLock @@ -21,6 +21,8 @@ if TYPE_CHECKING: + from collections.abc import Callable + from .utils import P, R diff --git a/torch/_inductor/runtime/caching/locks.py b/torch/_inductor/runtime/caching/locks.py index e7e1f1adc3622..8e8cd011e2d44 100644 --- a/torch/_inductor/runtime/caching/locks.py +++ b/torch/_inductor/runtime/caching/locks.py @@ -12,8 +12,8 @@ from __future__ import annotations from contextlib import _GeneratorContextManager, contextmanager, ExitStack -from typing import Generator, TYPE_CHECKING -from typing_extensions import Protocol, TypeAlias +from typing import TYPE_CHECKING, TypeAlias +from typing_extensions import Protocol from filelock import FileLock, Timeout @@ -21,6 +21,7 @@ if TYPE_CHECKING: + from collections.abc import Generator from threading import Lock diff --git a/torch/distributed/elastic/multiprocessing/tail_log.py b/torch/distributed/elastic/multiprocessing/tail_log.py index 7ad35115cd34a..034740810dcdd 100644 --- a/torch/distributed/elastic/multiprocessing/tail_log.py +++ b/torch/distributed/elastic/multiprocessing/tail_log.py @@ -10,9 +10,10 @@ import logging import os import time +from collections.abc import Callable from concurrent.futures.thread import ThreadPoolExecutor from threading import Event -from typing import Callable, Optional, TextIO, TYPE_CHECKING, Union +from typing import Optional, TextIO, TYPE_CHECKING, Union if TYPE_CHECKING: diff --git a/torch/utils/_cxx_pytree.py b/torch/utils/_cxx_pytree.py index 603625ed97c12..897279bd39b1e 100644 --- a/torch/utils/_cxx_pytree.py +++ b/torch/utils/_cxx_pytree.py @@ -15,8 +15,8 @@ import functools import types from collections.abc import Callable, Iterable, Mapping -from typing import Any, Optional, overload, TypeVar, Union -from typing_extensions import deprecated, Self, TypeAlias, TypeIs +from typing import Any, Optional, overload, TypeAlias, TypeVar, Union +from typing_extensions import deprecated, Self, TypeIs import torch.utils._pytree as python_pytree from torch.torch_version import TorchVersion as _TorchVersion diff --git a/torch/utils/_debug_mode.py b/torch/utils/_debug_mode.py index 5e24ce086e1aa..5a6ee246abf7e 100644 --- a/torch/utils/_debug_mode.py +++ b/torch/utils/_debug_mode.py @@ -3,7 +3,8 @@ import functools import traceback import weakref -from typing import Any, Callable, Optional, TYPE_CHECKING +from collections.abc import Callable +from typing import Any, Optional, TYPE_CHECKING import torch from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode diff --git a/torch/utils/_pytree.py b/torch/utils/_pytree.py index 56704bb3f8024..147340f58d66e 100644 --- a/torch/utils/_pytree.py +++ b/torch/utils/_pytree.py @@ -36,10 +36,11 @@ Optional, overload, Protocol, + TypeAlias, TypeVar, Union, ) -from typing_extensions import deprecated, NamedTuple, Self, TypeAlias +from typing_extensions import deprecated, NamedTuple, Self from torch.torch_version import TorchVersion as _TorchVersion From 08200280ce3c7b5bfbf3997517254565b2d6f162 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Tue, 4 Nov 2025 14:37:54 -0800 Subject: [PATCH 091/130] [CP][BE][3/N] Add _templated_ring_attention to the backward compatility stub (#166991) While `_templated_ring_attention` is a private API, it is unfortunatelly used by some packages. Add it to __all__ so that people can still use it. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166991 Approved by: https://github.com/XilunWu ghstack dependencies: #166456, #166501 --- torch/distributed/tensor/experimental/_attention.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torch/distributed/tensor/experimental/_attention.py b/torch/distributed/tensor/experimental/_attention.py index 2444467a3595f..f238739ddd5cf 100644 --- a/torch/distributed/tensor/experimental/_attention.py +++ b/torch/distributed/tensor/experimental/_attention.py @@ -10,6 +10,7 @@ _enable_context_parallel_dispatcher, _is_causal_behavior, _RotateMethod, + _templated_ring_attention, context_parallel, context_parallel_unshard, set_rotate_method, @@ -22,6 +23,7 @@ ) +# TODO(fegin): add deprecation message once the final interfaces are concluded. __all__ = [ "_CausalBehavior", "_context_parallel_shard", @@ -31,6 +33,7 @@ "_enable_context_parallel_dispatcher", "_is_causal_behavior", "_RotateMethod", + "_templated_ring_attention", "context_parallel", "context_parallel_unshard", "set_rotate_method", From 47eb34b7ac4359d281d1bfc3626feec184aec8b6 Mon Sep 17 00:00:00 2001 From: YyWangCS Date: Wed, 5 Nov 2025 22:34:16 +0000 Subject: [PATCH 092/130] [ATEN][CUDA] Reduce register pressure in radix_sort_pairs to improve torch.sort performance (#167094) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # Summary This PR improves `torch.sort` and `torch.unique` performance by **15% to 50%** on NVIDIA GPUs by optimizing CUDA register allocation in radix sort operations. The key change: specialize `OpaqueType` to use native integer types (uint8_t, uint16_t, uint32_t, uint64_t) for common sizes (1, 2, 4, 8 bytes) instead of `char data[N]`. This enables more efficient register allocation while preserving the template deduplication strategy. The following table shows the speedup on various input shapes and GPUs. Sorting is performed on the last dimension, and baseline torch version is 2.9.0. | GPU | input shape | input dtype | **Before** **(ms)** | After (ms) | Speedup | | ---- | ----------- | ----------- | ------------------- | ---------- | ------- | | H100 | (16, 1e6) | int32 | 1.61 | 1.37 | 1.18× | | H100 | (1, 1e8) | int32 | 6.6 | 5.0 | 1.3× | | H20 | (16, 1e6) | int64 | 3.57 | 3.03 | 1.18× | | H20 | (1, 1e8) | int64 | 19.3 | 13.0 | 1.48× | # Analysis `torch.sort` and `torch.unique` use `radix_sort_pairs`, which internally calls `cub::DeviceRadixSort::SortPairs`. Since values are only copied (never compared), we cast them to `OpaqueType` to minimize template instantiations. For example, both `int32` and `float32` values map to the same `OpaqueType<4>.` ## The Problem The previous `char data[N]` implementation causes inefficient register allocation. Here is one reason I find from SASS code. For 8-byte types: - `char data[8]:` Compiler may allocate 8 registers (one per byte) - `uint64_t data`: Compiler allocates 2 registers (standard 64-bit handling) This happens because the compiler doesn't recognize char[8] as a cohesive 64-bit value, treating each byte independently, which increases register pressure and reduces GPU occupancy. From Nsight Compute, when using `char data[8]`, the registers per thread is 166, and corresponding theoretical occupancy is 18.75%. When using native `uint64_t`, the registers per thread is 80, and corresponding theoretical occupancy is 37.5%. ## The Solution Specialize `OpaqueType` for common sizes using native integer types: ``` // Before template struct alignas(N) OpaqueType { char data[N]; }; // After template struct alignas(N) OpaqueType { char data[N]; }; // fallback template <> struct alignas(1) OpaqueType<1> { uint8_t data; }; template <> struct alignas(2) OpaqueType<2> { uint16_t data; }; template <> struct alignas(4) OpaqueType<4> { uint32_t data; }; template <> struct alignas(8) OpaqueType<8> { uint64_t data; }; ``` This preserves the template deduplication strategy (all 8-byte types still use the same `OpaqueType<8>` instantiation) while enabling better register allocation. # Testing & Compatibility ## Testing: ✅ Correctness tests pass for various input types (bfloat16, int32, float32, int64), shapes, and dimensions (1, 2, 3) ✅ Register usage reduction verified with NSight Compute ✅ Linter passes ## Compatibility: ✅ No API/ABI changes ✅ Template instantiation count unchanged # Reference For detailed analysis, please refere to my previous blog: [Performance Optimization of torch.sort on GPU](https://yywangcs.notion.site/Performance-Optimization-of-torch-sort-on-GPU-192fc9f5d8058018a1bec1efa35da3f9) Pull Request resolved: https://github.com/pytorch/pytorch/pull/167094 Approved by: https://github.com/ngimel, https://github.com/Skylion007 --- aten/src/ATen/cuda/cub.h | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/aten/src/ATen/cuda/cub.h b/aten/src/ATen/cuda/cub.h index 7430edaf8a3dc..bca9b1faff523 100644 --- a/aten/src/ATen/cuda/cub.h +++ b/aten/src/ATen/cuda/cub.h @@ -24,7 +24,13 @@ namespace detail { // radix_sort_pairs doesn't interact with value_t other than to copy // the data, so we can save template instantiations by reinterpreting // it as an opaque type. +// We use native integer types for 1/2/4/8-byte values to reduce +// register usage in CUDA kernels. For sizes > 8 fall back to char array. template struct alignas(N) OpaqueType { char data[N]; }; +template <> struct alignas(1) OpaqueType<1> { uint8_t data; }; +template <> struct alignas(2) OpaqueType<2> { uint16_t data; }; +template <> struct alignas(4) OpaqueType<4> { uint32_t data; }; +template <> struct alignas(8) OpaqueType<8> { uint64_t data; }; template void radix_sort_pairs_impl( From 3869aa115b1d513cb83ad89889f8c3af7921b0ce Mon Sep 17 00:00:00 2001 From: Tushar Jain Date: Wed, 5 Nov 2025 23:05:56 +0000 Subject: [PATCH 093/130] fix fr reset api (#166970) Summary: - there are various places that access fr's `entries_` field - if we empty the entries_ on reset, the accesses can result in an error - so we only perform a soft delete instead of clearing out the entries copletely - only reset id_ on the reset - keep track of a reset_epoch which increments everytime reset is called - dump_entries only returns entries from the latest epoch - api's that access entries also check if the reset epoch matches - make the `next_` always track the index in the circular buffer - this change was needed to make the soft delete's implementation easier --- [//]: # (BEGIN SAPLING FOOTER) Stack created with [Sapling](https://sapling-scm.com). Best reviewed with [ReviewStack](https://reviewstack.dev/pytorch/pytorch/pull/166970). * #166972 * #166971 * __->__ #166970 Pull Request resolved: https://github.com/pytorch/pytorch/pull/166970 Approved by: https://github.com/fduwjj --- test/distributed/test_c10d_nccl.py | 223 ++++++++++++++++++ .../csrc/distributed/c10d/FlightRecorder.hpp | 42 +++- .../distributed/c10d/FlightRecorderDetail.hpp | 136 ++++++++--- .../distributed/c10d/ProcessGroupGloo.cpp | 7 +- .../distributed/c10d/ProcessGroupGloo.hpp | 1 + .../distributed/c10d/ProcessGroupNCCL.cpp | 18 +- .../distributed/c10d/ProcessGroupNCCL.hpp | 1 + 7 files changed, 389 insertions(+), 39 deletions(-) diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index cf53896187c20..d764dfbbebbb1 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -5789,6 +5789,229 @@ def test_coalescing_manager_collective(self, timing_enabled): else: self.assertTrue("duration_ms" not in t["entries"][0]) + @requires_nccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + @parametrize("timing_enabled", [True, False]) + def test_fr_record_reset_circular_buffer_full(self, timing_enabled): + """ + Test that when the circular buffer in entries_ is full and we call reset, + then fill the buffer with new entries, dump_entries returns only the new + entries and not the old ones. + """ + if self.rank == self.MAIN_PROCESS_RANK: + return + + # Override buffer size to 10 for faster testing + os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "10" + + pg = self._create_process_group_nccl() + if timing_enabled: + pg._enable_collectives_timing() + device = self.local_device + self.set_thread_name("fr_test_thread") + a = torch.full((3, 4), float(self.rank), device=device) + + # Fill the buffer completely with 10 entries + for _ in range(10): + f = pg.allreduce(a) + f.wait() + torch.cuda.synchronize(device=device) + time.sleep(1) + + # Verify buffer is full with 10 entries + t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace()) + self.assertEqual(len(t["entries"]), 10) + + # Now reset the flight recorder + torch._C._distributed_c10d._reset_fr_recording_nccl() + + # Add new entries after reset - fill the buffer completely again + for _ in range(10): + f = pg.allreduce(a) + f.wait() + torch.cuda.synchronize(device=device) + time.sleep(1) + + # Verify we get exactly 10 new entries, not 20 + t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace()) + self.assertEqual(len(t["entries"]), 10) + + # Verify all entries have the expected properties (from after reset) + # After reset, record IDs should start from 0 again + for i, entry in enumerate(t["entries"]): + self.assertIn("profiling_name", entry) + self.assertEqual(entry["profiling_name"], "nccl:all_reduce") + self.assertIn("record_id", entry) + # Record IDs should be sequential starting from 0 after reset + self.assertEqual(entry["record_id"], i) + + dist.destroy_process_group() + + @requires_nccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + @parametrize("timing_enabled", [True, False]) + def test_fr_record_reset_partial_overwrite(self, timing_enabled): + """ + Test that when the circular buffer is full, we reset, and then add fewer + entries than the buffer size, we only get the new entries. + This tests that old entries at the end of the circular buffer are properly + filtered out based on reset_epoch. + """ + if self.rank == self.MAIN_PROCESS_RANK: + return + + # Override buffer size to 10 for faster testing + os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "10" + + pg = self._create_process_group_nccl() + if timing_enabled: + pg._enable_collectives_timing() + device = self.local_device + self.set_thread_name("fr_test_thread") + a = torch.full((3, 4), float(self.rank), device=device) + + # Fill the buffer completely + for _ in range(10): + f = pg.allreduce(a) + f.wait() + torch.cuda.synchronize(device=device) + time.sleep(1) + + # Reset the flight recorder + torch._C._distributed_c10d._reset_fr_recording_nccl() + + # Add only 3 new entries (much less than buffer size) + for _ in range(3): + f = pg.allreduce(a) + f.wait() + torch.cuda.synchronize(device=device) + time.sleep(1) + + # Verify we only get the 3 new entries, not 10 + t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace()) + self.assertEqual(len(t["entries"]), 3) + + # Verify record IDs start from 0 after reset + for i, entry in enumerate(t["entries"]): + self.assertIn("record_id", entry) + self.assertEqual(entry["record_id"], i) + + dist.destroy_process_group() + + @requires_nccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + @parametrize("timing_enabled", [True, False]) + def test_fr_record_reset_wraparound(self, timing_enabled): + """ + Test that when we reset in the middle of the circular buffer and then + wrap around, dump_entries correctly returns only entries from the current + epoch in the correct order. + """ + if self.rank == self.MAIN_PROCESS_RANK: + return + + # Override buffer size to 10 for faster testing + os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "10" + + pg = self._create_process_group_nccl() + if timing_enabled: + pg._enable_collectives_timing() + device = self.local_device + self.set_thread_name("fr_test_thread") + a = torch.full((3, 4), float(self.rank), device=device) + + # Fill half the buffer + for _ in range(5): + f = pg.allreduce(a) + f.wait() + torch.cuda.synchronize(device=device) + time.sleep(1) + + # Reset at this point (reset happens at index 5) + torch._C._distributed_c10d._reset_fr_recording_nccl() + + # Now add 8 entries, which will wrap around + # (5->9 fills rest of buffer, then 0->2 wraps around) + for _ in range(8): + f = pg.allreduce(a) + f.wait() + torch.cuda.synchronize(device=device) + time.sleep(1) + + # Should get exactly 8 entries, properly ordered + t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace()) + self.assertEqual(len(t["entries"]), 8) + + # Entries should be in chronological order + # The dump_entries() method returns entries from next_ to end, then 0 to next_ + # After filtering old entries, we should have 8 entries in order + # Verify record IDs start from 0 after reset (id_ is reset in reset_all()) + for i, entry in enumerate(t["entries"]): + self.assertIn("profiling_name", entry) + self.assertIn("record_id", entry) + self.assertEqual(entry["record_id"], i) + + dist.destroy_process_group() + + @requires_nccl() + @skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs") + @parametrize("timing_enabled", [True, False]) + def test_fr_record_multiple_resets(self, timing_enabled): + """ + Test multiple consecutive resets to ensure each reset properly increments + the epoch and filters out entries from previous epochs. + """ + if self.rank == self.MAIN_PROCESS_RANK: + return + + # Override buffer size to 10 for faster testing + os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "10" + + pg = self._create_process_group_nccl() + if timing_enabled: + pg._enable_collectives_timing() + device = self.local_device + self.set_thread_name("fr_test_thread") + a = torch.full((3, 4), float(self.rank), device=device) + + # First batch: 2 entries + for _ in range(2): + f = pg.allreduce(a) + f.wait() + torch.cuda.synchronize(device=device) + time.sleep(1) + + # First reset + torch._C._distributed_c10d._reset_fr_recording_nccl() + + # Second batch: 3 entries + for _ in range(3): + f = pg.allreduce(a) + f.wait() + torch.cuda.synchronize(device=device) + time.sleep(1) + + # Second reset + torch._C._distributed_c10d._reset_fr_recording_nccl() + + # Third batch: 4 entries + for _ in range(4): + f = pg.allreduce(a) + f.wait() + torch.cuda.synchronize(device=device) + time.sleep(1) + + # Should only see the last 4 entries + t = pickle.loads(torch._C._distributed_c10d._dump_nccl_trace()) + self.assertEqual(len(t["entries"]), 4) + + # Verify record IDs start from 0 after the last reset + for i, entry in enumerate(t["entries"]): + self.assertIn("record_id", entry) + self.assertEqual(entry["record_id"], i) + + dist.destroy_process_group() + def check_if_test_is_skipped(fn): def wrapper(self, *args, **kwargs): diff --git a/torch/csrc/distributed/c10d/FlightRecorder.hpp b/torch/csrc/distributed/c10d/FlightRecorder.hpp index 23b8893c54f2c..bdb4ad045ff2a 100644 --- a/torch/csrc/distributed/c10d/FlightRecorder.hpp +++ b/torch/csrc/distributed/c10d/FlightRecorder.hpp @@ -108,12 +108,14 @@ struct FlightRecorder { capture_cpp_stack_ = getCvarBool( {"TORCH_FR_CPP_STACK", "TORCH_NCCL_TRACE_CPP_STACK"}, false); enabled_ = max_entries_ > 0; + reset_epoch_start_idx_[0] = 0; } struct Entry { size_t id_; // incremented id in the trace buffer // used to figure out where in the circular entries // buffer this entry will be located to // update state information + size_t reset_epoch_; // epoch when this entry was created size_t pg_id_; std::tuple pg_name_; // @@ -183,11 +185,34 @@ struct FlightRecorder { size_t max_entries_ = 0; size_t next_ = 0; size_t id_ = 0; + size_t reset_epoch_ = 0; + std::unordered_map + reset_epoch_start_idx_; // maps reset_epoch to the idx where it starts std::map> all_pg_status_; std::map, std::vector> pg_name_to_ranks_; std::string comm_lib_version_; + struct TraceIdentifier { + std::optional id; + std::optional reset_epoch; + }; + + TraceIdentifier recordWithResetEnabled( + size_t pg_id, + const std::tuple& pg_name, + size_t collective_seq_id, + size_t p2p_seq_id, + size_t op_id, + std::string profiling_name, + const std::vector& inputs, + const std::vector& outputs, + EventType* start, + EventType* end, + std::chrono::milliseconds timeout_ms, + std::shared_ptr pg_status, + bool isP2P); + std::optional record( size_t pg_id, const std::tuple& pg_name, @@ -213,8 +238,16 @@ struct FlightRecorder { std::vector dump_entries(); - // Returns the entry with the given id, if it exists. Otherwise, returns - // std::nullopt. + // Returns the index in entries_ for the given id and reset_epoch. + // Caller must hold mutex_lock before calling this method. + size_t getIdxFromId(size_t id, size_t reset_epoch) const; + + // Returns the entry with the given id and reset_epoch, if it exists. + // Otherwise, returns std::nullopt. + TORCH_API std::optional getEntry( + std::optional id, + std::optional reset_epoch); + TORCH_API std::optional getEntry(std::optional id); /* @@ -227,6 +260,11 @@ struct FlightRecorder { never hang. (timing must also be enabled for compute_duration - see TORCH_NCCL_ENABLE_TIMING). */ + TORCH_API void retire_id( + std::optional id, + std::optional reset_epoch, + bool compute_duration = true); + TORCH_API void retire_id( std::optional id, bool compute_duration = true); diff --git a/torch/csrc/distributed/c10d/FlightRecorderDetail.hpp b/torch/csrc/distributed/c10d/FlightRecorderDetail.hpp index 8813c95158460..88205c171941c 100644 --- a/torch/csrc/distributed/c10d/FlightRecorderDetail.hpp +++ b/torch/csrc/distributed/c10d/FlightRecorderDetail.hpp @@ -53,8 +53,41 @@ std::optional FlightRecorder::record( std::chrono::milliseconds timeout_ms, std::shared_ptr pg_status, bool isP2P) { + auto result = recordWithResetEnabled( + pg_id, + pg_name, + collective_seq_id, + p2p_seq_id, + op_id, + std::move(profiling_name), + inputs, + outputs, + start, + end, + timeout_ms, + std::move(pg_status), + isP2P); + return result.id; +} + +template +typename FlightRecorder::TraceIdentifier FlightRecorder:: + recordWithResetEnabled( + size_t pg_id, + const std::tuple& pg_name, + size_t collective_seq_id, + size_t p2p_seq_id, + size_t op_id, + std::string profiling_name, + const std::vector& inputs, + const std::vector& outputs, + EventType* start, + EventType* end, + std::chrono::milliseconds timeout_ms, + std::shared_ptr pg_status, + bool isP2P) { if (!enabled_) { - return std::nullopt; + return TraceIdentifier{std::nullopt, std::nullopt}; } if (all_pg_status_.find(pg_id) == all_pg_status_.end()) { // Current pg_status is not in FR. @@ -64,8 +97,13 @@ std::optional FlightRecorder::record( torch::CapturedTraceback::gather(true, true, capture_cpp_stack_); std::lock_guard guard(mutex_); + TORCH_CHECK( + reset_epoch_start_idx_.find(reset_epoch_) != + reset_epoch_start_idx_.end()); + auto te = Entry{ id_, + reset_epoch_, pg_id, pg_name, collective_seq_id, @@ -104,15 +142,20 @@ std::optional FlightRecorder::record( te.sizes_.insert(te.sizes_.end(), sizes.begin(), sizes.end()); } + const auto next = next_++; + if (entries_.size() < max_entries_) { entries_.emplace_back(std::move(te)); } else { - entries_[next_++] = std::move(te); - if (next_ == max_entries_) { - next_ = 0; - } + entries_[next] = std::move(te); } - return id_++; + + if (next_ == max_entries_) { + next_ = 0; + } + + const auto id = id_++; + return TraceIdentifier{id, reset_epoch_}; } template @@ -163,15 +206,20 @@ std::vector::Entry> FlightRecorder< std::vector result; { std::lock_guard guard(mutex_); - result.reserve(entries_.size()); - result.insert( - result.end(), + // Filter entries during insertion - only keep entries from current epoch + auto filter = [this](const Entry& e) { + return e.reset_epoch_ == reset_epoch_; + }; + std::copy_if( entries_.begin() + static_cast(next_), - entries_.end()); - result.insert( - result.end(), + entries_.end(), + std::back_inserter(result), + filter); + std::copy_if( entries_.begin(), - entries_.begin() + static_cast(next_)); + entries_.begin() + static_cast(next_), + std::back_inserter(result), + filter); } // query any remaining events for (auto& r : result) { @@ -182,28 +230,47 @@ std::vector::Entry> FlightRecorder< } template -// Returns the entry with the given id, if it exists. Otherwise, returns -// std::nullopt. +// Returns the index in entries_ for the given id and reset_epoch. +// Caller must hold mutex_lock before calling this method. +size_t FlightRecorder::getIdxFromId(size_t id, size_t reset_epoch) + const { + // Look up the starting idx for the given reset epoch + auto it = reset_epoch_start_idx_.find(reset_epoch); + TORCH_CHECK(it != reset_epoch_start_idx_.end()); + // Calculate idx based on where the epoch started + return (it->second + id) % max_entries_; +} + +template +// Returns the entry with the given id and reset_epoch, if it exists. Otherwise, +// returns std::nullopt. std::optional::Entry> FlightRecorder< - EventType>::getEntry(std::optional id) { - if (!enabled_ || !id) { + EventType>:: + getEntry(std::optional id, std::optional reset_epoch) { + if (!enabled_ || !id || !reset_epoch) { return std::nullopt; } std::unique_lock guard(mutex_); - Entry entry = entries_.at(*id % max_entries_); - if (entry.id_ == *id) { + Entry entry = entries_.at(getIdxFromId(*id, *reset_epoch)); + if (entry.id_ == *id && entry.reset_epoch_ == *reset_epoch) { return entry; - } else { - return std::nullopt; } + return std::nullopt; +} + +template +std::optional::Entry> FlightRecorder< + EventType>::getEntry(std::optional id) { + return getEntry(id, 0); } template void FlightRecorder::retire_id( std::optional id, + std::optional reset_epoch, bool compute_duration) { - if (!enabled_ || !id) { + if (!enabled_ || !id || !reset_epoch) { return; } @@ -214,8 +281,8 @@ void FlightRecorder::retire_id( std::unique_lock guard(mutex_); - Entry* entry = &entries_.at(*id % max_entries_); - if (entry->id_ == *id) { + Entry* entry = &entries_.at(getIdxFromId(*id, *reset_epoch)); + if (entry->id_ == *id && entry->reset_epoch_ == *reset_epoch) { update_state(*entry); if (compute_duration) { @@ -237,8 +304,8 @@ void FlightRecorder::retire_id( guard.lock(); // Refresh the entry pointer, see if the entry has been overwritten - entry = &entries_.at(*id % max_entries_); - if (entry->id_ != *id) { + entry = &entries_.at(getIdxFromId(*id, *reset_epoch)); + if (!(entry->id_ == *id && entry->reset_epoch_ == *reset_epoch)) { LOG(INFO) << "retire_id abandoned for id " << *id << ", event was overwritten while waiting to compute duration."; return; @@ -249,12 +316,23 @@ void FlightRecorder::retire_id( } } +template +void FlightRecorder::retire_id( + std::optional id, + bool compute_duration) { + retire_id(id, 0, compute_duration); +} + template void FlightRecorder::reset_all() { std::lock_guard guard(mutex_); - next_ = 0; - id_ = 0; - entries_.clear(); + if (!entries_.empty()) { + // Soft delete: increment epoch to mark all existing entries as old + // Store where the new epoch starts in the circular buffer + reset_epoch_++; + reset_epoch_start_idx_[reset_epoch_] = next_; + id_ = 0; + } } template diff --git a/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp b/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp index a9612ce759733..c1d28b2787cda 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp @@ -708,7 +708,8 @@ void ProcessGroupGloo::runLoop(int workerIndex) { // TODO: We need to have numel of tensors for gloo as well. pgStatus_->lastCompletedNumelIn = 0; pgStatus_->lastCompletedNumelOut = 0; - FlightRecorder::get()->retire_id(work->trace_id_, false); + FlightRecorder::get()->retire_id( + work->trace_id_, work->trace_reset_epoch_, false); lock.lock(); workInProgress_[workerIndex].reset(); } @@ -780,7 +781,7 @@ void ProcessGroupGloo::enqueue(c10::intrusive_ptr work) { pgStatus_->lastEnqueuedNumelOut = 0; // using c10d::FlightRecorder; // TODO: We need to have a way to use c10::Event inside gloo as well. - work->trace_id_ = FlightRecorder::get()->record( + auto traceId = FlightRecorder::get()->recordWithResetEnabled( local_id_, std::make_tuple(pg_uid_, pg_desc_), collectiveCounter_, @@ -795,6 +796,8 @@ void ProcessGroupGloo::enqueue(c10::intrusive_ptr work) { work->getTimeout(), pgStatus_, false); + work->trace_id_ = traceId.id; + work->trace_reset_epoch_ = traceId.reset_epoch; workQueue_.push_back(std::move(work)); lock.unlock(); diff --git a/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp b/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp index b2cc6993528bf..1a0b7c41b3857 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupGloo.hpp @@ -99,6 +99,7 @@ class TORCH_API ProcessGroupGloo : public Backend { // unique id used to tell the trace buffer that this // work has completed std::optional trace_id_; + std::optional trace_reset_epoch_; std::shared_ptr context_; const std::chrono::milliseconds timeout_; diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index 3416bc336d34a..29ccc115cc94d 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -575,6 +575,7 @@ ProcessGroupNCCL::WorkNCCL::WorkNCCL(const WorkNCCL& w) futureWorkResult_(w.futureWorkResult_), timingEnabled_(w.timingEnabled_), trace_id_(w.trace_id_), + trace_reset_epoch_(w.trace_reset_epoch_), distDebugLevel_(w.distDebugLevel_) { exception_ = w.exception_; } @@ -704,9 +705,9 @@ bool ProcessGroupNCCL::WorkNCCL::checkTimeout( // Print the traceback of the collective at call time std::string ProcessGroupNCCL::WorkNCCL::getTraceback() const { // First step we get the corresponding record entry from FR, based on work's - // trace_id_ + // trace_id_ and trace_reset_epoch_ std::optional entry = - FlightRecorderCUDA::get()->getEntry(trace_id_); + FlightRecorderCUDA::get()->getEntry(trace_id_, trace_reset_epoch_); if (entry.has_value()) { auto entryVal = entry.value(); // Get stack trace from FR entry, in string format @@ -2394,7 +2395,8 @@ void ProcessGroupNCCL::Watchdog::runLoop() { pg_->pgStatus_->lastCompletedWorkName = opTypeToString(work.opType_); pg_->pgStatus_->lastCompletedNumelIn = work.numelIn_; pg_->pgStatus_->lastCompletedNumelOut = work.numelOut_; - FlightRecorderCUDA::get()->retire_id(work.trace_id_, true); + FlightRecorderCUDA::get()->retire_id( + work.trace_id_, work.trace_reset_epoch_, true); if (pg_->onCompletionHook_) { // Move Work object to completedWorkList_ to be consumed by the hook // thread @@ -3360,7 +3362,7 @@ c10::intrusive_ptr ProcessGroupNCCL::initWork( // these objects to the Work because it has implications for keeping those // tensors alive longer and adds overhead when copying Work objects // between threads - r->trace_id_ = FlightRecorderCUDA::get()->record( + auto traceId = FlightRecorderCUDA::get()->recordWithResetEnabled( local_id_, std::make_tuple(pg_uid_, pg_desc_), seqCollective_, @@ -3374,6 +3376,8 @@ c10::intrusive_ptr ProcessGroupNCCL::initWork( options_->timeout, pgStatus_, isP2P); + r->trace_id_ = traceId.id; + r->trace_reset_epoch_ = traceId.reset_epoch; } return r; } @@ -3677,7 +3681,7 @@ c10::intrusive_ptr ProcessGroupNCCL::collective( // later in endCoalescing we record a 'coalesced' Work which has // timing/state updates via watchdog thread, but lacks op metadata such as // input/output sizes and profilingTitle per-op in the group. - FlightRecorderCUDA::get()->record( + FlightRecorderCUDA::get()->recordWithResetEnabled( local_id_, std::make_tuple(pg_uid_, pg_desc_), seqCollective_, @@ -4169,7 +4173,7 @@ c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( // TODO(whc) because we don't pass output {tensor} to initWork, we tell // initWork to not record, and then we manually call record passing all the // information it wants. - work->trace_id_ = FlightRecorderCUDA::get()->record( + auto traceId = FlightRecorderCUDA::get()->recordWithResetEnabled( local_id_, std::make_tuple(pg_uid_, pg_desc_), seqCollective_, @@ -4183,6 +4187,8 @@ c10::intrusive_ptr ProcessGroupNCCL::pointToPoint( options_->timeout, pgStatus_, /*isP2P=*/true); + work->trace_id_ = traceId.id; + work->trace_reset_epoch_ = traceId.reset_epoch; } // Only check for NaN for send ops, for recv ops `tensor` can be a random diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp index 2ead1a107394d..d8f324dbd8edf 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp @@ -505,6 +505,7 @@ class TORCH_API ProcessGroupNCCL : public Backend { // unique id used to tell the trace buffer that this // work has completed std::optional trace_id_; + std::optional trace_reset_epoch_; DebugLevel distDebugLevel_; friend class ProcessGroupNCCL; }; From af829c0dade306762213d25506a83da850e30a3c Mon Sep 17 00:00:00 2001 From: Jagadish Krishnamoorthy Date: Wed, 5 Nov 2025 23:15:17 +0000 Subject: [PATCH 094/130] [ROCm] Skip nvfp4 tests on ROCm (#167066) Pull Request resolved: https://github.com/pytorch/pytorch/pull/167066 Approved by: https://github.com/jeffdaily, https://github.com/slayton58 --- test/test_scaled_matmul_cuda.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/test_scaled_matmul_cuda.py b/test/test_scaled_matmul_cuda.py index 9738ac4ac6fbf..fd09afc11cecf 100644 --- a/test/test_scaled_matmul_cuda.py +++ b/test/test_scaled_matmul_cuda.py @@ -1864,6 +1864,8 @@ def test_blockwise_nvfp4_with_global_scale(self, mkn) -> None: ], name_fn=lambda mkn: f"{mkn[0]}_{mkn[1]}_{mkn[2]}") @parametrize("recipe", ["mxfp8", "mxfp4", "nvfp4"]) def test_blockwise_mxfp8_nvfp4_mxfp4_numerics(self, test_case_name, fast_accum, mkn, recipe) -> None: + if torch.version.hip and recipe == "nvfp4": + raise unittest.SkipTest("nvfp4 not supported on ROCm, skipping") if (recipe == "nvfp4" or recipe == "mxfp4") and fast_accum: raise unittest.SkipTest("fast_accum not supported in nvfp4/mxfp4 cublas gemm, skipping") From a344069f2aba6a87c0ab3fab488e83b80691457f Mon Sep 17 00:00:00 2001 From: Xinya Zhang Date: Wed, 5 Nov 2025 23:16:48 +0000 Subject: [PATCH 095/130] Add missing skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION) to test/test_transformers.py (#166969) This PR adds missing skips for efficient attention tests. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166969 Approved by: https://github.com/jeffdaily --- test/test_transformers.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/test/test_transformers.py b/test/test_transformers.py index 56e1365d33c44..cc82cbff2a46f 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -1914,6 +1914,7 @@ def test_flash_attention_fail_with_non_square_causal_attention(self, device): q, k, v, None, 0.0, is_causal=True)) @onlyCUDA + @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support Efficient Attention") def test_mem_eff_attention_fail_with_batch_size_geq_65536(self): batch_size = 2**16 query = torch.rand([batch_size, 2, 2, 8], device='cuda', dtype=torch.float16, requires_grad=True) @@ -1935,6 +1936,7 @@ def test_mem_eff_attention_fail_with_batch_size_geq_65536(self): self.assertEqual(value.grad, v_cpu.grad, atol=2e-3, rtol=1e-4) @onlyCUDA + @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support Efficient Attention") def test_mem_eff_attention_fail_with_batch_size_geq_65536_error(self): query = torch.rand([2**16, 2, 2, 8], device='cuda', dtype=torch.float16) key = torch.rand([2**16, 2, 2, 8], device='cuda', dtype=torch.float16) @@ -1948,6 +1950,7 @@ def test_mem_eff_attention_fail_with_batch_size_geq_65536_error(self): @largeTensorTest("15GB", "cuda") @onlyCUDA + @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Does not support Efficient Attention") def test_mem_eff_attention_large_seq_len_uniform_attention(self): device = torch.device("cuda") dtype = torch.bfloat16 From d29efba8fa83215465c5dd8914769593b69ed304 Mon Sep 17 00:00:00 2001 From: atalman Date: Thu, 6 Nov 2025 00:34:40 +0000 Subject: [PATCH 096/130] Move almalinux docker image to DEVTOOLSET 13 (#167018) 1. Update general Almalinux image to Devtoolset 13. 2. Fix ROCm images, missing devtoolset-13 This image used by Linux Job in test-infra Pull Request resolved: https://github.com/pytorch/pytorch/pull/167018 Approved by: https://github.com/sudharssun, https://github.com/d4l3k --- .ci/docker/almalinux/Dockerfile | 25 +++++++++++++++++++++---- .ci/docker/almalinux/build.sh | 2 +- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/.ci/docker/almalinux/Dockerfile b/.ci/docker/almalinux/Dockerfile index ce7803cf9acd2..3bc3fd8badc6d 100644 --- a/.ci/docker/almalinux/Dockerfile +++ b/.ci/docker/almalinux/Dockerfile @@ -7,13 +7,13 @@ ENV LC_ALL en_US.UTF-8 ENV LANG en_US.UTF-8 ENV LANGUAGE en_US.UTF-8 -ARG DEVTOOLSET_VERSION=11 +ARG DEVTOOLSET_VERSION=13 RUN yum -y update RUN yum -y install epel-release # install glibc-langpack-en make sure en_US.UTF-8 locale is available RUN yum -y install glibc-langpack-en -RUN yum install -y sudo wget curl perl util-linux xz bzip2 git patch which perl zlib-devel openssl-devel yum-utils autoconf automake make gcc-toolset-${DEVTOOLSET_VERSION}-toolchain +RUN yum install -y sudo wget curl perl util-linux xz bzip2 git patch which perl zlib-devel openssl-devel yum-utils autoconf automake make gcc-toolset-${DEVTOOLSET_VERSION}-gcc gcc-toolset-${DEVTOOLSET_VERSION}-gcc-c++ gcc-toolset-${DEVTOOLSET_VERSION}-gcc-gfortran gcc-toolset-${DEVTOOLSET_VERSION}-gdb # Just add everything as a safe.directory for git since these will be used in multiple places with git RUN git config --global --add safe.directory '*' ENV PATH=/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/bin:$PATH @@ -41,6 +41,7 @@ RUN bash ./install_conda.sh && rm install_conda.sh # Install CUDA FROM base as cuda ARG CUDA_VERSION=12.6 +ARG DEVTOOLSET_VERSION=13 RUN rm -rf /usr/local/cuda-* ADD ./common/install_cuda.sh install_cuda.sh COPY ./common/install_nccl.sh install_nccl.sh @@ -50,7 +51,8 @@ ENV CUDA_HOME=/usr/local/cuda-${CUDA_VERSION} # Preserve CUDA_VERSION for the builds ENV CUDA_VERSION=${CUDA_VERSION} # Make things in our path by default -ENV PATH=/usr/local/cuda-${CUDA_VERSION}/bin:$PATH +ENV PATH=/usr/local/cuda-${CUDA_VERSION}/bin:/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/bin:$PATH + FROM cuda as cuda12.6 RUN bash ./install_cuda.sh 12.6 @@ -68,8 +70,22 @@ FROM cuda as cuda13.0 RUN bash ./install_cuda.sh 13.0 ENV DESIRED_CUDA=13.0 -FROM ${ROCM_IMAGE} as rocm +FROM ${ROCM_IMAGE} as rocm_base +ARG DEVTOOLSET_VERSION=13 +ENV LC_ALL en_US.UTF-8 +ENV LANG en_US.UTF-8 +ENV LANGUAGE en_US.UTF-8 +# Install devtoolset on ROCm base image +RUN yum -y update && \ + yum -y install epel-release && \ + yum -y install glibc-langpack-en && \ + yum install -y sudo wget curl perl util-linux xz bzip2 git patch which perl zlib-devel openssl-devel yum-utils autoconf automake make gcc-toolset-${DEVTOOLSET_VERSION}-gcc gcc-toolset-${DEVTOOLSET_VERSION}-gcc-c++ gcc-toolset-${DEVTOOLSET_VERSION}-gcc-gfortran gcc-toolset-${DEVTOOLSET_VERSION}-gdb +RUN git config --global --add safe.directory '*' +ENV PATH=/opt/rh/gcc-toolset-${DEVTOOLSET_VERSION}/root/usr/bin:$PATH + +FROM rocm_base as rocm ARG PYTORCH_ROCM_ARCH +ARG DEVTOOLSET_VERSION=13 ENV PYTORCH_ROCM_ARCH ${PYTORCH_ROCM_ARCH} ADD ./common/install_mkl.sh install_mkl.sh RUN bash ./install_mkl.sh && rm install_mkl.sh @@ -88,6 +104,7 @@ COPY --from=cuda13.0 /usr/local/cuda-13.0 /usr/local/cuda-13.0 # Final step FROM ${BASE_TARGET} as final +ARG DEVTOOLSET_VERSION=13 COPY --from=openssl /opt/openssl /opt/openssl COPY --from=patchelf /patchelf /usr/local/bin/patchelf COPY --from=conda /opt/conda /opt/conda diff --git a/.ci/docker/almalinux/build.sh b/.ci/docker/almalinux/build.sh index ad234ce1ffb93..885c4440e0e6f 100755 --- a/.ci/docker/almalinux/build.sh +++ b/.ci/docker/almalinux/build.sh @@ -63,7 +63,7 @@ docker build \ --target final \ --progress plain \ --build-arg "BASE_TARGET=${BASE_TARGET}" \ - --build-arg "DEVTOOLSET_VERSION=11" \ + --build-arg "DEVTOOLSET_VERSION=13" \ ${EXTRA_BUILD_ARGS} \ -t ${tmp_tag} \ $@ \ From 6cd57e6fc275e8d53665aab4d4fbaa71e29eb9ea Mon Sep 17 00:00:00 2001 From: eqy Date: Thu, 6 Nov 2025 00:50:42 +0000 Subject: [PATCH 097/130] [cuBLAS] Force tensor-core-no-reduction algo in `cuBLASLt` for `n=1` cases (#166735) Ostensibly useful for batch-invariance purposes Pull Request resolved: https://github.com/pytorch/pytorch/pull/166735 Approved by: https://github.com/ngimel --- aten/src/ATen/cuda/CUDABlas.cpp | 51 ++++++++++++++++++++++++--------- test/test_matmul_cuda.py | 23 +++++++++++++++ 2 files changed, 61 insertions(+), 13 deletions(-) diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp index aaed431064611..20f235076220f 100644 --- a/aten/src/ATen/cuda/CUDABlas.cpp +++ b/aten/src/ATen/cuda/CUDABlas.cpp @@ -388,6 +388,7 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D #ifndef USE_ROCM at::Half halpha; at::Half hbeta; + uint32_t mask = -1; #endif void * alpha_ptr = α void * beta_ptr = β @@ -427,7 +428,7 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D auto fp16_reduction = at::globalContext().allowFP16ReductionCuBLAS(); if (fp16_reduction != at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) { - uint32_t mask = + mask = fp16_reduction == at::CuBLASReductionOption::DisallowReducedPrecisionAllowSplitK ? (CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE | @@ -444,7 +445,7 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D auto bf16_reduction = at::globalContext().allowBF16ReductionCuBLAS(); if (bf16_reduction != at::CuBLASReductionOption::AllowReducedPrecisionWithSplitK) { - uint32_t mask = + mask = bf16_reduction == at::CuBLASReductionOption::DisallowReducedPrecisionAllowSplitK ? (CUBLASLT_REDUCTION_SCHEME_COMPUTE_TYPE | @@ -511,17 +512,41 @@ static inline bool bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES_AND_C_DTYPE(D cublasStatus_t cublasStatus = CUBLAS_STATUS_SUCCESS; cublasLtMatmulHeuristicResult_t heuristicResult = {}; int returnedResult = 0; - TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic( - ltHandle, - computeDesc.descriptor(), - Adesc.descriptor(), - Bdesc.descriptor(), - Cdesc.descriptor(), - Cdesc.descriptor(), - preference.descriptor(), - 1, - &heuristicResult, - &returnedResult)); + // on Blackwell+, we fake a n > 1 matmul when querying heuristics + // to prevent cuBLASLt from dispatching to a GEMV kernel for batch-invariance +#ifndef USE_ROCM + const bool lie_to_cublaslt = mask == CUBLASLT_REDUCTION_SCHEME_NONE && n == 1 && at::cuda::getCurrentDeviceProperties()->major >= 10; +#else + const bool lie_to_cublaslt = false; +#endif + if (lie_to_cublaslt) { + CuBlasLtMatrixLayout FakeBdesc(abType, k, 2, ldb, opb == CUBLAS_OP_T); + CuBlasLtMatrixLayout FakeCdesc(cType, m, 2, ldc); + + TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic( + ltHandle, + computeDesc.descriptor(), + Adesc.descriptor(), + FakeBdesc.descriptor(), + FakeCdesc.descriptor(), + FakeCdesc.descriptor(), + preference.descriptor(), + 1, + &heuristicResult, + &returnedResult)); + } else { + TORCH_CUDABLAS_CHECK(cublasLtMatmulAlgoGetHeuristic( + ltHandle, + computeDesc.descriptor(), + Adesc.descriptor(), + Bdesc.descriptor(), + Cdesc.descriptor(), + Cdesc.descriptor(), + preference.descriptor(), + 1, + &heuristicResult, + &returnedResult)); + } if (returnedResult == 0) { cublasStatus = CUBLAS_STATUS_NOT_SUPPORTED; } diff --git a/test/test_matmul_cuda.py b/test/test_matmul_cuda.py index 10611d4f24673..a8e9be4c972a1 100644 --- a/test/test_matmul_cuda.py +++ b/test/test_matmul_cuda.py @@ -359,6 +359,29 @@ def grouped_mm_helper(self, alist, blist, gOlist, agradlist, bgradlist, outlist) self.assertEqual(agrad, a.grad) self.assertEqual(bgrad, b.grad) + @onlyCUDA + @skipIfRocm + @dtypes(torch.half, torch.bfloat16) + @unittest.skipIf(not SM100OrLater, "cuBLAS integration for batch invariance is only on Blackwell") + @serialTest() + def test_cublas_batch_invariance_blackwell(self, device, dtype): + orig_bf16 = torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction + orig_fp16 = torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction + torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = (False, False) + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = (False, False) + with blas_library_context('cublaslt'): + N = 2048 + K = 6144 + M_max = 32 + x = torch.randn(M_max, K, device="cuda", dtype=torch.bfloat16) + w = torch.randn(N, K, device="cuda", dtype=torch.bfloat16).t() + full = x @ w + xx = x[:1] + out = xx @ w + self.assertEqual(full[:1], out, atol=0., rtol=0.) + torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = orig_bf16 + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = orig_fp16 + @unittest.skipIf(not SM80OrLater, "Grouped gemm supported only on SM80 or greater") @parametrize("strided", [False, True]) @parametrize("a_row_major", [False, True]) From 872d1daec2726e8915a4d38427fa0d1c938e5905 Mon Sep 17 00:00:00 2001 From: Laith Sakka Date: Wed, 5 Nov 2025 12:50:02 -0800 Subject: [PATCH 098/130] Avoid DDE in narrow with unbacked start (#166361) Slice knows how to handle unbacked start, we do not need to offset start before calling slice, we can leave it for slice. The only edge case is when start<0 and start+length ==0 in that case slice and narrow would deviate, for that case we shall pass dim_size instead of start+length Pull Request resolved: https://github.com/pytorch/pytorch/pull/166361 Approved by: https://github.com/aorenste --- aten/src/ATen/native/TensorShape.cpp | 58 ++++++++++++++++++++++-- c10/core/SymBool.cpp | 14 ++++++ c10/core/SymBool.h | 6 +++ test/export/test_export.py | 31 ++++++++----- test/test_dynamic_shapes.py | 51 +++++++++++++++++++++ test/test_torchfuzz_repros.py | 28 ------------ torch/_inductor/codegen/wrapper.py | 3 +- torch/fx/experimental/symbolic_shapes.py | 19 +++++++- torch/utils/_sympy/printers.py | 36 +++++++++++++++ 9 files changed, 200 insertions(+), 46 deletions(-) diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index 6df7761d822db..daa8a86da253b 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -1,5 +1,6 @@ #include #include +#include #define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include #include @@ -1710,11 +1711,37 @@ Tensor narrow_symint( "], but got ", start, ")") - if (start < 0) { - start = start + cur_size; - } + + auto cond1 = TORCH_GUARD_OR_FALSE(start.sym_lt(0)); + auto cond2 = TORCH_GUARD_OR_FALSE(start.sym_ge(0)); + + if (cond1 || cond2) { + if (cond1) { + start = start + cur_size; + } + + TORCH_SYM_CHECK( + start.sym_le(cur_size - length), + "start (", + start, + ") + length (", + length, + ") exceeds dimension size (", + cur_size, + ")."); + return at::slice_symint(self, dim, start, start + length, 1); + } + + // Unbacked start handling! + + // Bounds check without converting start: + // - If start < 0: need (start + cur_size) + length <= cur_size, i.e., start + + // length <= 0 + // - If start >= 0: need start + length <= cur_size + auto end = start + length; TORCH_SYM_CHECK( - start.sym_le(cur_size - length), + (start.sym_lt(0).sym_and((end).sym_le(0))) + .sym_or(start.sym_ge(0).sym_and((end).sym_le(cur_size))), "start (", start, ") + length (", @@ -1722,7 +1749,28 @@ Tensor narrow_symint( ") exceeds dimension size (", cur_size, ")."); - return at::slice_symint(self, dim, start, start + length, 1); + + if (TORCH_GUARD_OR_FALSE(end.sym_ne(0))) { + return at::slice_symint(self, dim, start, end, 1); + } else { + // Cannot statically determine the condition due to unbacked. + // This is an interesting situation; when start is negative and + // start + length == 0, slice and narrow do different things. + // i.e., x.narrow(0, -2, 2) != x[-2:0]; in that case, we want to + // pass curr_size instead of 0. Otherwise, they would do the same thing. + // This says at runtime: if start < 0 and end == 0, then pass curr_size + // instead of 0. + + auto use_different = start.sym_lt(0).sym_and(end.sym_eq(0)).toSymInt(); + auto result = + at::slice_symint(self, dim, start, end + use_different * cur_size, 1); + + // Ensure slice allocated unbacked size is specialized to length. + SymInt new_size = result.sym_size(dim); + TORCH_SYM_CHECK(new_size.sym_eq(length), "") + + return result; + } } // This overload exists purely for XLA, because they wanted to pass in diff --git a/c10/core/SymBool.cpp b/c10/core/SymBool.cpp index d804eb9d27409..48c407b8b069c 100644 --- a/c10/core/SymBool.cpp +++ b/c10/core/SymBool.cpp @@ -1,4 +1,5 @@ #include +#include #include namespace c10 { @@ -111,4 +112,17 @@ bool SymBool::has_hint() const { return toSymNodeImpl()->has_hint(); } +SymInt SymBool::toSymInt() const { + // If concrete bool, return concrete SymInt + if (auto ma = maybe_as_bool()) { + return SymInt(*ma ? 1 : 0); + } + + // Symbolic case: use sym_ite to convert bool to int (0 or 1) + auto node = toSymNodeImpl(); + auto one_node = node->wrap_int(1); + auto zero_node = node->wrap_int(0); + return SymInt(node->sym_ite(one_node, zero_node)); +} + } // namespace c10 diff --git a/c10/core/SymBool.h b/c10/core/SymBool.h index d5d509e239b1d..a27a28a5bf8a3 100644 --- a/c10/core/SymBool.h +++ b/c10/core/SymBool.h @@ -12,6 +12,8 @@ namespace c10 { +class SymInt; + class C10_API SymBool { public: /*implicit*/ SymBool(bool b) : data_(b) {} @@ -80,6 +82,10 @@ class C10_API SymBool { return toSymNodeImplUnowned()->constant_bool(); } + // Convert SymBool to SymInt (0 or 1) + // This is the C++ equivalent of Python's cast_symbool_to_symint_guardless + SymInt toSymInt() const; + bool is_heap_allocated() const { return ptr_; } diff --git a/test/export/test_export.py b/test/export/test_export.py index 3908f03b11e55..cdc18b1d4c564 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -6093,26 +6093,19 @@ def forward(self, x, y, fixes): retry_export( cf_implicitsize(), (torch.tensor(2), torch.randn(10)), - fixes=[ - # Could not guard on data-dependent expression u0 < 0 - "torch._check(i >= 0)", - ], + fixes=[], ) class cf_stacklist(torch.nn.Module): def forward(self, xs, y, fixes): i = y.item() eval(fixes) - # instead of xs[i] return torch.stack(xs, 0).narrow(0, i, 1).squeeze() retry_export( cf_stacklist(), ([torch.ones(5) * i for i in range(10)], torch.tensor(2)), - fixes=[ - # Could not guard on data-dependent expression u0 < 0 - "torch._check(i >= 0)", - ], + fixes=[], ) class cf_tensorsplit(torch.nn.Module): @@ -6166,7 +6159,12 @@ def test_no_suggested_fixes_for_data_dependent_errors(self): class cf_stacklist(torch.nn.Module): def forward(self, xs, y): # y.item() is not a local, so we can't suggest a fix - return torch.stack(xs, 0).narrow(0, y.item(), 1).squeeze() + if y.item() < 0: + return ( + torch.stack(xs, 0).narrow(0, y.item() + xs.size(), 1).squeeze() + ) + else: + return torch.stack(xs, 0).narrow(0, y.item(), 1).squeeze() with self.assertRaisesRegex( error_type, @@ -6196,7 +6194,18 @@ class cf_stacklist_udd(torch.nn.Module): def forward(self, xs, y): box = Box(y.item()) # box.content is not a local, so we can't suggest a fix - return torch.stack(xs, 0).narrow(0, box.content, 1).squeeze() + if box.content < 0: + return ( + torch.stack(xs, 0) + .narrow(0, box.content + xs.size(), 1) + .squeeze() + ) + else: + return ( + torch.stack(xs, 0) + .narrow(0, box.content + xs.size(), 1) + .squeeze() + ) with self.assertRaisesRegex( error_type, diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index fb1d22805d50a..b63e0427c26c3 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -4401,6 +4401,57 @@ def func(x, y): self.assertEqual(compiled(a, b), func(a, b)) + @fresh_cache() + @torch._dynamo.config.patch("capture_scalar_outputs", True) + def test_narrow_unbacked_start(self): + def func(x, start, length): + # unbacked start + u0 = start.item() + return torch.narrow(x, 0, u0, length) + + compiled_func = torch.compile(func, fullgraph=True, backend="inductor") + + x = torch.tensor([1, 2, 3, 4, 5, 6]) + + # Test cases: (start, length) + test_cases = [ + # Negative starts + (-2, 2), # Start from second-to-last element + (-1, 1), # Start from last element + (-3, 3), # Start from third-to-last element + (-6, 2), # Start from beginning (negative) + (-4, 1), # Start from fourth-to-last element + # Positive starts + (0, 2), # Start from beginning + (1, 3), # Start from second element + (2, 2), # Start from third element + (4, 2), # Start near end + # Edge cases + (0, 6), # Full tensor + (0, 1), # Single element from start + (5, 1), # Single element from end + ] + + for start_val, length in test_cases: + with self.subTest(start=start_val, length=length): + start = torch.tensor([start_val]) + + # Test with compiled function + result_compiled = compiled_func(x, start, length) + + # Test with eager function (expected behavior) + result_eager = func(x, start, length) + + # Compare results + self.assertEqual(result_compiled, result_eager) + + @fresh_cache() + @torch._dynamo.config.patch("capture_scalar_outputs", True) + @torch._inductor.config.patch("cpp_wrapper", True) + def test_narrow_unbacked_start_cpp_wrapper(self): + """Test narrow with unbacked start with cpp_wrapper""" + self.test_narrow_unbacked_start() + instantiate_parametrized_tests(TestUnbacked) diff --git a/test/test_torchfuzz_repros.py b/test/test_torchfuzz_repros.py index 3b864aae4f477..988bcf8de273c 100644 --- a/test/test_torchfuzz_repros.py +++ b/test/test_torchfuzz_repros.py @@ -257,34 +257,6 @@ def foo(arg0, arg1): out_compiled.sum().backward() print("Compile Success! ✅") - @pytest.mark.xfail(reason="Issue #163971") - def test_fuzzer_issue_163971(self): - torch.manual_seed(0) - - def foo(arg0): - t0 = arg0 # size=(), stride=(), dtype=bfloat16, device=cuda - t1 = torch.softmax( - t0, dim=0 - ) # size=(), stride=(), dtype=bfloat16, device=cuda - t2 = torch.nn.functional.gelu( - t1 - ) # size=(), stride=(), dtype=bfloat16, device=cuda - t3 = torch.softmax( - t2, dim=0 - ) # size=(), stride=(), dtype=bfloat16, device=cuda - output = t3 - return output - - arg0 = torch.rand([], dtype=torch.bfloat16, device="cuda", requires_grad=True) - - out_eager = foo(arg0) - out_eager.sum().backward() - print("Eager Success! ✅") - compiled_foo = torch.compile(foo, fullgraph=True, dynamic=True) - out_compiled = compiled_foo(arg0) - out_compiled.sum().backward() - print("Compile Success! ✅") - @pytest.mark.xfail(reason="Issue #164059") def test_fuzzer_issue_164059(self): torch.manual_seed(0) diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index e629d9c7bdebd..947166cf216cd 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -2063,7 +2063,8 @@ def clamp_index(x): neg = self.codegen_sizevar( sympy.Max(0, sympy.Min(x + node.size, node.size)) ) - return f"{pos} if {x} >= 0 else {neg}" + x_cond = self.codegen_sizevar(x) + return f"{pos} if {x_cond} >= 0 else {neg}" def codegen_with_step(start_var, end_var, step): if step == 1: diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index aeccdfbe000db..693d25aea6130 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -547,6 +547,7 @@ def rebind_unbacked( assert shape_env is not None for raw_u0, path in bindings.items(): u1 = pytree.key_get(result, path) + # Sometimes, things were previously unbacked bindings become constants. # There are two situations this can happen. # @@ -602,7 +603,23 @@ def rebind_unbacked( if u1.node.hint is not None: continue - raw_u1 = u1.node.expr + # unbacked symbols bindings might be replaced to other backed or + # unbacked replacements. + # + # Example: + # u = x.item() + # torch._check(u == 5) + # + # The safest approach is to retrieve raw_u1 from u1.node._expr + # and perform the rebinding on the original unbacked symbol, + # even if it’s no longer directly referenced. + # + # In other words, we should always rebind the original symbol + # before any replacements are applied. + # u0 -> u0 == s1 + raw_u1 = u1.node._expr + + # TODO Do we still need this logic below? # Simplify SymBool binding if ( isinstance(raw_u1, sympy.Piecewise) diff --git a/torch/utils/_sympy/printers.py b/torch/utils/_sympy/printers.py index 526443577b3f8..915d0e5461f1e 100644 --- a/torch/utils/_sympy/printers.py +++ b/torch/utils/_sympy/printers.py @@ -306,6 +306,24 @@ def _print_RoundDecimal(self, expr: sympy.Expr) -> str: raise TypeError("ndigits must be an instance of sympy.Integer") return f"round({self._print(number)}, {ndigits})" + def _print_Piecewise(self, expr: sympy.Expr) -> str: + # Convert Piecewise(expr_cond_pairs) to nested ternary expressions + # Piecewise((e1, c1), (e2, c2), ..., (eN, cN)) + # becomes: e1 if c1 else (e2 if c2 else (... else eN)) + result: Optional[str] = None + for expr_i, cond_i in reversed(expr.args): + expr_str = self._print(expr_i) + if cond_i == True: # noqa: E712 + # This is the default case + result = expr_str + else: + cond_str = self._print(cond_i) + if result is None: + result = expr_str + else: + result = f"({expr_str} if {cond_str} else {result})" + return result if result else "0" + class CppPrinter(ExprPrinter): def _print_Integer(self, expr: sympy.Expr) -> str: @@ -327,6 +345,24 @@ def _print_Where(self, expr: sympy.Expr) -> str: ) return f"{c} ? {p} : {q}" + def _print_Piecewise(self, expr: sympy.Expr) -> str: + # Convert Piecewise(expr_cond_pairs) to nested ternary operators + # Piecewise((e1, c1), (e2, c2), ..., (eN, cN)) + # becomes: c1 ? e1 : (c2 ? e2 : (... : eN)) + result: Optional[str] = None + for expr_i, cond_i in reversed(expr.args): + expr_str = self.parenthesize(expr_i, PRECEDENCE["Atom"] - 0.5) + if cond_i == True: # noqa: E712 + # This is the default case + result = expr_str + else: + cond_str = self.parenthesize(cond_i, PRECEDENCE["Atom"] - 0.5) + if result is None: + result = expr_str + else: + result = f"{cond_str} ? {expr_str} : {result}" + return f"({result})" if result else "0" + def _print_ModularIndexing(self, expr: sympy.Expr) -> str: x, div, mod = expr.args x = self.doprint(x) From fd5edda1edd3f2c7ad555a626351e359af164fb4 Mon Sep 17 00:00:00 2001 From: Shangdi Yu Date: Thu, 6 Nov 2025 01:14:25 +0000 Subject: [PATCH 099/130] Reland "Add model code stack trace to torch.profile (#166677)" (#167110) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ```python python test/test_fx.py -k profiler ``` Insert `torch._C._profiler._RecordFunctionFast` to fx graph codegen. We post-process the profiler dump using `map_recorded_events_to_aten_ops_with_stack_trace` to add the stack trace to the dump'd trace. `map_recorded_events_to_aten_ops_with_stack_trace` queries `fx.traceback._FX_METADATA_REGISTRY` for node metadata. Each graph module has a hash'd fake file name (e.g. `fx_generated__iv4zodvbcmdkhx77jrg7h2f2opebujhfmc6tf6nx7vioq244baw.py`), which is the key to the registry. One can do `fx_g.enrich_profiler_metadata()` to add debugging info. Or `fx_g.enrich_profiler_metadata(enable=False)` to remove. `aot_eager` makes calls `fx_g.enrich_profiler_metadata()` if TORCH_ENRICH_RPOFILER_STACK_TRACE is set or _dynamo.config.enrich_profiler_metadata=True. Screenshot 2025-10-31 at 4 40 52 PM Example code gen'd. ``` def forward(self, args_list): args_iter = iter(args_list) arg0_1 = next(args_iter) arg1_1 = next(args_iter) args_list.clear() _rf = torch._C._profiler._RecordFunctionFast('## fx_generated__iv4zodvbcmdkhx77jrg7h2f2opebujhfmc6tf6nx7vioq244baw.py ##'); _rf.__enter__() repeated_subgraph0 = self.repeated_subgraph0 _rf_invoke_subgraph = torch._C._profiler._RecordFunctionFast('## 3 ##'); _rf_invoke_subgraph.__enter__() invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', arg0_1, arg1_1); repeated_subgraph0 = arg0_1 = arg1_1 = None _rf_invoke_subgraph.__exit__(None, None, None) _rf_getitem = torch._C._profiler._RecordFunctionFast('## 4 ##'); _rf_getitem.__enter__() getitem = invoke_subgraph[0]; invoke_subgraph = None _rf_getitem.__exit__(None, None, None) return (getitem,) _rf.__exit__(None, None, None) def forward(self, arg0_1, arg1_1): _rf = torch._C._profiler._RecordFunctionFast('## fx_generated__ozpadpj5cxoalxeyopej33g2vvtvhxg4xsk7bhx7ldmcibtybyn.py ##'); _rf.__enter__() _rf_mul = torch._C._profiler._RecordFunctionFast('## 2 ##'); _rf_mul.__enter__() mul = torch.ops.aten.mul.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None _rf_mul.__exit__(None, None, None) _rf_sin = torch._C._profiler._RecordFunctionFast('## 3 ##'); _rf_sin.__enter__() sin = torch.ops.aten.sin.default(mul); mul = None _rf_sin.__exit__(None, None, None) _rf_add = torch._C._profiler._RecordFunctionFast('## 4 ##'); _rf_add.__enter__() add = torch.ops.aten.add.Tensor(sin, 5); sin = None _rf_add.__exit__(None, None, None) return (add,) _rf.__exit__(None, None, None) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/167110 Approved by: https://github.com/pianpwk --- ...t-fx_backcompat_function_signatures.expect | 2 +- test/test_fx.py | 184 ++++++++++++++++++ torch/autograd/profiler_util.py | 40 ++++ torch/fx/graph.py | 23 +++ torch/fx/graph_module.py | 16 +- torch/profiler/_utils.py | 169 +++++++++++++++- 6 files changed, 429 insertions(+), 5 deletions(-) diff --git a/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect b/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect index a404e15a977ee..12f6ba2228db8 100644 --- a/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect +++ b/test/expect/TestFXAPIBackwardCompatibility.test_function_back_compat-fx_backcompat_function_signatures.expect @@ -23,7 +23,7 @@ torch.fx.graph.Graph.node_copy(self, node: torch.fx.node.Node, arg_transform: Ca torch.fx.graph.Graph.output(self, result: 'Argument', type_expr: Optional[Any] = None) torch.fx.graph.Graph.placeholder(self, name: str, type_expr: Optional[Any] = None, default_value: Any) -> torch.fx.node.Node torch.fx.graph.Graph.print_tabular(self) -torch.fx.graph.Graph.python_code(self, root_module: str, verbose: bool = False, include_stride: bool = False, include_device: bool = False, colored: bool = False, expanded_def: bool = False) -> torch.fx.graph.PythonCode +torch.fx.graph.Graph.python_code(self, root_module: str, verbose: bool = False, include_stride: bool = False, include_device: bool = False, colored: bool = False, expanded_def: bool = False, record_func: bool = False) -> torch.fx.graph.PythonCode torch.fx.graph_module.GraphModule.__init__(self, root: Union[torch.nn.modules.module.Module, Dict[str, Any]], graph: torch.fx.graph.Graph, class_name: str = 'GraphModule') torch.fx.graph_module.GraphModule.add_submodule(self, target: str, m: torch.nn.modules.module.Module) -> bool torch.fx.graph_module.GraphModule.delete_all_unused_submodules(self) -> None diff --git a/test/test_fx.py b/test/test_fx.py index 92d35fd8f49ad..f728187fd85f5 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -72,9 +72,16 @@ IS_WINDOWS, run_tests, skipIfTorchDynamo, + skipIfRocm, ) from torch.testing._internal.jit_utils import JitTestCase +import json +import tempfile +from torch.profiler import profile, ProfilerActivity +from torch.profiler._utils import map_recorded_events_to_aten_ops_with_stack_trace +from torch.autograd.profiler_util import _canonicalize_profiler_events + try: from torchvision import models as torchvision_models @@ -201,6 +208,36 @@ def side_effect_func(x: torch.Tensor): print(x) +def _enrich_profiler_traces(prof): + """ + Helper function to extract and augment profiler events with stack traces. + + Args: + prof: A torch.profiler.profile object + + Returns: + A string representing enriched events + """ + with tempfile.NamedTemporaryFile(mode='w', suffix='.json') as f: + trace_file = f.name + prof.export_chrome_trace(trace_file) + + with open(trace_file) as f: + trace_data = json.load(f) + + map_recorded_events_to_aten_ops_with_stack_trace( + trace_data + ) + + events = [] + for event in trace_data["traceEvents"]: + if "args" in event and "stack_trace" in event["args"]: + events.append(event) + + actual_traces = _canonicalize_profiler_events(events) + return actual_traces + + class TestFX(JitTestCase): def setUp(self): super().setUp() @@ -4212,6 +4249,153 @@ def fn(a, b, c, d): # recorver mutable checking flag torch.fx.proxy.TracerBase.check_mutable_operations = orig_tracer_mutable_flag + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + @skipIfRocm + @torch._dynamo.config.patch("enrich_profiler_metadata", True) + def test_profiler_stack_trace_augmentation(self): + """ + Test that map_recorded_events_to_aten_ops_with_stack_trace correctly + augments profiler events with stack traces from FX metadata registry. + """ + + # Simple test model + class TestModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(10, 16) + self.relu = torch.nn.ReLU() + self.linear2 = torch.nn.Linear(16, 10) + + def forward(self, x): + x = self.linear1(x) + x = self.relu(x) + x = self.linear2(x) + return x + + model = TestModel().cuda() + + # Compile the model + compiled_model = torch.compile(model, backend="aot_eager", fullgraph=True) + + # Warmup + for _ in range(3): + _ = compiled_model(torch.randn(10, 10, device="cuda")) + + # Profile with the compiled model + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + ) as prof: + result = compiled_model(torch.randn(10, 10, device="cuda")) + + actual_traces = _enrich_profiler_traces(prof) + + self.assertExpectedInline(actual_traces, """\ +event=aten::t node=t stack_trace=x = self.linear1(x) +event=aten::transpose node=t stack_trace=x = self.linear1(x) +event=aten::as_strided node=t stack_trace=x = self.linear1(x) +event=aten::addmm node=addmm stack_trace=x = self.linear1(x) +event=cudaLaunchKernel node=addmm stack_trace=x = self.linear1(x) +event=aten::relu node=relu stack_trace=x = self.relu(x) +event=aten::clamp_min node=relu stack_trace=x = self.relu(x) +event=cudaLaunchKernel node=relu stack_trace=x = self.relu(x) +event=aten::t node=t_1 stack_trace=x = self.linear2(x) +event=aten::transpose node=t_1 stack_trace=x = self.linear2(x) +event=aten::as_strided node=t_1 stack_trace=x = self.linear2(x) +event=aten::addmm node=addmm_1 stack_trace=x = self.linear2(x) +event=cudaLaunchKernel node=addmm_1 stack_trace=x = self.linear2(x)""" + ) + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + @skipIfRocm + @torch._dynamo.config.patch("enrich_profiler_metadata", True) + def test_profiler_multiple_modules(self): + """ + Test that multiple compiled modules under the same profiler session + have their events correctly augmented with stack traces. + """ + + class ModelA(torch.nn.Module): + def forward(self, x): + return x + 1 + + class ModelB(torch.nn.Module): + def forward(self, x): + return x - 1 + + model_a = ModelA().cuda() + model_b = ModelB().cuda() + + # Compile both models + compiled_a = torch.compile(model_a, backend="aot_eager", fullgraph=True) + compiled_b = torch.compile(model_b, backend="aot_eager", fullgraph=True) + + # Warmup + for _ in range(3): + _ = compiled_a(torch.randn(10, 10, device="cuda")) + _ = compiled_b(torch.randn(1, 3, 8, 8, device="cuda")) + + # Profile both models in the same session + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + ) as prof: + result_a = compiled_a(torch.randn(10, 10, device="cuda")) + result_b = compiled_b(torch.randn(1, 3, 8, 8, device="cuda")) + + actual_traces = _enrich_profiler_traces(prof) + self.assertExpectedInline(actual_traces, """\ +event=aten::add node=add stack_trace=return x + 1 +event=cudaLaunchKernel node=add stack_trace=return x + 1 +event=aten::sub node=sub stack_trace=return x - 1 +event=cudaLaunchKernel node=sub stack_trace=return x - 1""" + ) + + @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") + @skipIfRocm + @torch._dynamo.config.patch("enrich_profiler_metadata", True) + def test_profiler_nested_graph_modules(self): + """ + Test that nested graph modules (e.g., graph modules calling subgraphs) + have their events correctly augmented with stack traces. + """ + + # Model with nested structure + class Mod(torch.nn.Module): + def __init__(self): + super().__init__() + self.c = 5 + + @torch.compiler.nested_compile_region + def forward(self, x, y): + m = torch.mul(x, y) + s = m.sin() + a = s + self.c + return a + + model = Mod().cuda() + + # Compile the model (this may create nested graph modules) + compiled_model = torch.compile(model, backend="aot_eager", fullgraph=True) + + # Warmup + for _ in range(3): + _ = compiled_model(torch.randn(10, 10, device="cuda"), torch.randn(10, 10, device="cuda")) + + # Profile + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + ) as prof: + result = compiled_model(torch.randn(10, 10, device="cuda"), torch.randn(10, 10, device="cuda")) + + actual_traces = _enrich_profiler_traces(prof) + self.assertExpectedInline(actual_traces, """\ +event=aten::mul node=mul stack_trace=m = torch.mul(x, y) +event=cudaLaunchKernel node=mul stack_trace=m = torch.mul(x, y) +event=aten::sin node=sin stack_trace=s = m.sin() +event=cudaLaunchKernel node=sin stack_trace=s = m.sin() +event=aten::add node=add stack_trace=a = s + self.c +event=cudaLaunchKernel node=add stack_trace=a = s + self.c""" + ) + def run_getitem_target(): from torch.fx._symbolic_trace import _wrapped_methods_to_patch diff --git a/torch/autograd/profiler_util.py b/torch/autograd/profiler_util.py index b2d6530049e61..a61aee321fcff 100644 --- a/torch/autograd/profiler_util.py +++ b/torch/autograd/profiler_util.py @@ -1224,3 +1224,43 @@ def override_time_unit(time_us, default_str, time_unit): f"time total: {override_time_unit(sum_self_device_time_total, _format_time(sum_self_device_time_total), time_unit)}" ) return "".join(result) + + +# Collect all events with stack traces and format them canonically +def _canonicalize_profiler_events(events): + """ + Extract and format all events with stack traces in a canonical way + for deterministic testing. + """ + events_with_traces = [] + + for event in events: + # Extract relevant fields + event_name = event.get("name", "") + node_name = event["args"].get("node_name", "") + stack_trace = event["args"].get("stack_trace", "") + + # Get the last non-empty line of the stack trace + lines = [s.strip() for s in stack_trace.split("\n") if s.strip()] + stack_trace = lines[-1] if lines else "" + + events_with_traces.append( + { + "event_name": event_name[:20], + "node_name": node_name, + "stack_trace": stack_trace, + "start_time": event.get("ts", 0), + } + ) + + # Sort by node_name for deterministic ordering + events_with_traces.sort(key=lambda x: x["start_time"]) + + # Format as a string + lines: list[str] = [] + for evt in events_with_traces: + lines.append( + f"event={evt['event_name']} node={evt['node_name']} stack_trace={evt['stack_trace']}" + ) + + return "\n".join(lines) diff --git a/torch/fx/graph.py b/torch/fx/graph.py index 899a50f0f4142..d924eac24d3c2 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -443,6 +443,7 @@ def _gen_python_code( colored: bool = False, # Render each argument on its own line expanded_def: bool = False, + record_func: bool = False, ) -> PythonCode: free_vars: list[str] = [] body: list[str] = [] @@ -817,6 +818,10 @@ def _tensor_annotation(t: torch.Tensor) -> str: return raise NotImplementedError(f"node: {node.op} {node.target}") + if record_func: + body.append( + "_rf = torch._C._profiler._RecordFunctionFast('## ENTER_GRAPH_PLACEHOLDER_KEY ##'); _rf.__enter__()\n" + ) for i, node in enumerate(nodes): # NOTE: emit_node does not emit a string with newline. It depends # on delete_unused_values to append one @@ -826,8 +831,22 @@ def _tensor_annotation(t: torch.Tensor) -> str: # node index, which will be deleted later # after going through _body_transformer body.append(f"# COUNTER: {i}\n") + do_record = record_func and node.op in ( + "call_function", + "call_method", + "call_module", + ) + if do_record: + # The double hash ## convention is used by post-processing to find the fx markers + body.append( + f"_rf_{node.name} = torch._C._profiler._RecordFunctionFast('## {i} ##'); _rf_{node.name}.__enter__()\n" + ) emit_node(node) delete_unused_values(node) + if do_record: + body.append(f"_rf_{node.name}.__exit__(None, None, None)\n") + if record_func: + body.append("_rf.__exit__(None, None, None)\n") if len(body) == 0: # If the Graph has no non-placeholder nodes, no lines for the body @@ -1779,6 +1798,7 @@ def python_code( include_device: bool = False, colored: bool = False, expanded_def: bool = False, + record_func: bool = False, ) -> PythonCode: """ Turn this ``Graph`` into valid Python code. @@ -1846,6 +1866,7 @@ def override_node_repr(graph: Graph): include_device=include_device, colored=colored, expanded_def=expanded_def, + record_func=record_func, ) def _python_code( @@ -1858,6 +1879,7 @@ def _python_code( include_device: bool = False, colored: bool = False, expanded_def: bool = False, + record_func: bool = False, ) -> PythonCode: return self._codegen._gen_python_code( self.nodes, @@ -1868,6 +1890,7 @@ def _python_code( include_device=include_device, colored=colored, expanded_def=expanded_def, + record_func=record_func, ) def __str__(self) -> str: diff --git a/torch/fx/graph_module.py b/torch/fx/graph_module.py index 297f76732584f..8360c96630d6c 100644 --- a/torch/fx/graph_module.py +++ b/torch/fx/graph_module.py @@ -861,14 +861,18 @@ def recompile(self) -> PythonCode: if isinstance(self._graph._codegen, _PyTreeCodeGen): self._in_spec = self._graph._codegen.pytree_info.in_spec self._out_spec = self._graph._codegen.pytree_info.out_spec - python_code = self._graph.python_code(root_module="self") + + from torch._dynamo import config as dynamo_config + + python_code = self._graph.python_code( + root_module="self", record_func=dynamo_config.enrich_profiler_metadata + ) self._code = python_code.src self._lineno_map = python_code._lineno_map self._prologue_start = python_code._prologue_start cls = type(self) co_fields = self._graph._co_fields if hasattr(self._graph, "_co_fields") else {} - from torch._dynamo import config as dynamo_config if dynamo_config.enrich_profiler_metadata: # Generate metadata and register for profiler augmentation @@ -885,7 +889,6 @@ def recompile(self) -> PythonCode: # This ensures the same code+metadata always generates the same filename hash_value = _metadata_hash(self._code, node_metadata) file_stem = f"{FX_GRAPH_MODULE_FILE_PREFIX}_{hash_value}" - filename = f"{file_stem}.py" # Only include co_filename to use it directly as the cache key @@ -905,6 +908,13 @@ def recompile(self) -> PythonCode: _register_fx_metadata(filename, metadata) + # Replace the placeholder in generated code with actual filename + # The double hash ## convention is used by post-processing to find the fx markers + self._code = self._code.replace( + "torch._C._profiler._RecordFunctionFast('## ENTER_GRAPH_PLACEHOLDER_KEY ##')", + f"torch._C._profiler._RecordFunctionFast('## {filename} ##')", + ) + cls.forward = _forward_from_src(self._code, python_code.globals, co_fields) # Determine whether this class explicitly defines a __call__ implementation diff --git a/torch/profiler/_utils.py b/torch/profiler/_utils.py index 2c6e06b2cb3c9..47df87ce1678d 100644 --- a/torch/profiler/_utils.py +++ b/torch/profiler/_utils.py @@ -4,7 +4,7 @@ import re from collections import deque from dataclasses import dataclass -from typing import TYPE_CHECKING +from typing import Any, Literal, Optional, TYPE_CHECKING from torch.autograd.profiler import profile from torch.profiler import DeviceType @@ -400,3 +400,170 @@ def _init_for_cuda_graphs() -> None: with profile(): pass + + +@dataclass +class TimelineEvent: + """Represents an event in the profiler timeline.""" + + timestamp: int + event_type: Literal["start", "end", "regular"] + marker_type: Optional[Literal["filename", "node"]] + identifier: Optional[str | int] + event: dict[str, Any] + + +@dataclass +class ContextStackEntry: + """Represents a context (filename or node) in the stack.""" + + context_type: Literal["filename", "node"] + identifier: str | int + metadata: Optional[dict] + tid: Optional[int] = None # Thread ID associated with this context + + +def map_recorded_events_to_aten_ops_with_stack_trace(traced_data): + """ + Maps recorded profiler events to their corresponding fx nodes and adds stack traces. + + Builds a timeline of all events (regular ops and FX markers for filenames/nodes), + sorts by timestamp, then processes chronologically while maintaining a context stack of active + filename/node scopes. Regular events are augmented with stack traces and node names from the + innermost active context. Runtime is O(n log n) for n events. + + Args: + traced_data: Json of profiler events from Chrome trace + + Returns: + Dict mapping recorded event names to their aten operations with added stack traces + """ + from torch.fx.traceback import _FX_METADATA_REGISTRY + + trace_events = traced_data.get("traceEvents", []) + + # Create event timeline + event_timeline: list[TimelineEvent] = [] + + def is_fx_marker_event(event): + return ( + event.get("cat") == "cpu_op" + and event.get("name", "").startswith("## ") + and event.get("name", "").endswith(" ##") + ) + + def append_fx_marker_event(event_type, identifier, event): + start_ts = event["ts"] + end_ts = start_ts + event["dur"] + event_timeline.append( + TimelineEvent(start_ts, "start", event_type, identifier, event) + ) + event_timeline.append( + TimelineEvent(end_ts, "end", event_type, identifier, event) + ) + + for event in trace_events: + if "ts" not in event or "dur" not in event: + continue + + if is_fx_marker_event(event): + content = event["name"][3:-3] + + if content.endswith(".py"): + append_fx_marker_event("filename", content, event) + else: + try: + node_index = int(content) + except ValueError: + pass + append_fx_marker_event("node", node_index, event) # type: ignore[possibly-undefined] + + else: + # Regular event that needs augmentation + start_ts = event["ts"] + event_timeline.append(TimelineEvent(start_ts, "regular", None, None, event)) + + # Sort by timestamp + event_timeline.sort(key=lambda x: x.timestamp) + + # Process events in chronological order with a stack + context_stack: list[ContextStackEntry] = [] + + # Invariant: all start event has a corresponding end event + for timeline_event in event_timeline: + match timeline_event.event_type: + case "start": + assert timeline_event.identifier is not None + + if timeline_event.marker_type == "filename": + assert isinstance(timeline_event.identifier, str) + # Push filename context - query metadata registry on-demand + metadata = _FX_METADATA_REGISTRY.get(timeline_event.identifier) + tid = timeline_event.event.get("tid") + context_stack.append( + ContextStackEntry( + "filename", timeline_event.identifier, metadata, tid + ) + ) + elif timeline_event.marker_type == "node": + # Find the current filename from stack + current_file_metadata = None + tid = timeline_event.event.get("tid") + for ctx_entry in reversed(context_stack): + if ( + ctx_entry.context_type == "filename" + and ctx_entry.tid == tid + ): + current_file_metadata = ctx_entry.metadata + break + + if current_file_metadata: + node_metadata = current_file_metadata.get("node_metadata", {}) + if timeline_event.identifier in node_metadata: + node_meta: Optional[dict] = node_metadata[ + timeline_event.identifier + ] + context_stack.append( + ContextStackEntry( + "node", timeline_event.identifier, node_meta, tid + ) + ) + + case "end": + # Pop from stack - search backwards to find matching context + for i in range(len(context_stack) - 1, -1, -1): + ctx_entry = context_stack[i] + if ( + timeline_event.marker_type == ctx_entry.context_type + and timeline_event.identifier == ctx_entry.identifier + ): + context_stack.pop(i) + break + + case "regular": + # Apply metadata from current context stack + # Find the most specific context (node takes precedence over filename) + # Only augment events with the same tid as the file/node event matched + current_stack_trace = None + current_node_name = None + event_tid = timeline_event.event.get("tid") + + for ctx_entry in reversed(context_stack): + # Only apply metadata from contexts with matching tid + if ctx_entry.tid == event_tid: + if ctx_entry.context_type == "node" and ctx_entry.metadata: + current_stack_trace = ctx_entry.metadata.get( + "stack_trace", "No model stack trace available" + ) + current_node_name = ctx_entry.metadata.get("name", "") + # Do we want to only attach the stack trace of the lowest node or stack trace of all nodes + # if nodes are nested, e.g. in nested graph modules + break + + # Augment the event + if current_stack_trace or current_node_name: + args = timeline_event.event.setdefault("args", {}) + if current_stack_trace: + args["stack_trace"] = current_stack_trace + if current_node_name: + args["node_name"] = current_node_name From 7432676187178fcdb41a0685b078e97e436fc561 Mon Sep 17 00:00:00 2001 From: inventshah <39803835+inventshah@users.noreply.github.com> Date: Thu, 6 Nov 2025 01:55:38 +0000 Subject: [PATCH 100/130] [MPS] Fix crash in BCELoss backwards with reduction="none" and inputs with trailing 1s in shape (#166786) Fixes #166746 by removing squeezes that caused shape mismatches when calling backwards through `BCELoss(reduction='none')`. Based on running these tests, it seems MPSGraph can handle inputs without squeezing. ``` python test/test_mps.py TestMPS -k test_bce python test/test_mps.py TestConsistency -k binary_cross ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/166786 Approved by: https://github.com/malfet --- .../src/ATen/native/mps/operations/LossOps.mm | 19 +++++++------------ test/test_mps.py | 8 ++++++++ 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/aten/src/ATen/native/mps/operations/LossOps.mm b/aten/src/ATen/native/mps/operations/LossOps.mm index c995b8fc237f3..f0bbcdabfa5cd 100644 --- a/aten/src/ATen/native/mps/operations/LossOps.mm +++ b/aten/src/ATen/native/mps/operations/LossOps.mm @@ -212,17 +212,12 @@ loss.resize_((reduction == Reduction::None || grad_output.defined()) ? target.sizes() : IntArrayRef({})); TORCH_CHECK(loss.is_mps()); - Tensor loss_squeezed = loss.squeeze(); - Tensor input_squeezed = input.squeeze(); - Tensor target_squeezed = target.squeeze(); - @autoreleasepool { - std::string key = - op_name + reductionToString(reduction) + getTensorsStringKey({input_squeezed, target_squeezed, weight}); + std::string key = op_name + reductionToString(reduction) + getTensorsStringKey({input, target, weight}); auto cachedGraph = LookUpOrCreateCachedGraph(key, [&](auto mpsGraph, auto newCachedGraph) { - newCachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input_squeezed); - newCachedGraph->targetTensor = mpsGraphRankedPlaceHolder(mpsGraph, target_squeezed); + newCachedGraph->inputTensor = mpsGraphRankedPlaceHolder(mpsGraph, input); + newCachedGraph->targetTensor = mpsGraphRankedPlaceHolder(mpsGraph, target); MPSGraphTensor* bceLossUnweighted = nil; // if grad_output is defined, then it's a backward pass @@ -252,12 +247,12 @@ newCachedGraph->gradInputTensor = bceLoss; } } else { - newCachedGraph->lossTensor = reduceTensor(bceLoss, reduction, mpsGraph, input_squeezed.sizes().size()); + newCachedGraph->lossTensor = reduceTensor(bceLoss, reduction, mpsGraph, input.sizes().size()); } }); - Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor, input_squeezed); - Placeholder targetPlaceholder = Placeholder(cachedGraph->targetTensor, target_squeezed); - Placeholder lossPlaceholder = Placeholder(cachedGraph->lossTensor, loss_squeezed); + Placeholder inputPlaceholder = Placeholder(cachedGraph->inputTensor, input); + Placeholder targetPlaceholder = Placeholder(cachedGraph->targetTensor, target); + Placeholder lossPlaceholder = Placeholder(cachedGraph->lossTensor, loss); NSMutableDictionary* feeds = [[NSMutableDictionary new] autorelease]; diff --git a/test/test_mps.py b/test/test_mps.py index fad09c2f5eb28..cb0db4d96d334 100644 --- a/test/test_mps.py +++ b/test/test_mps.py @@ -4470,6 +4470,14 @@ def test_bce_loss_broadcasts_weights(self): self.assertEqual(out1, out2) + def test_bce_backward_with_no_reduction_and_one_in_shape(self): + # Regression test for https://github.com/pytorch/pytorch/issues/166746 + output = torch.zeros(3, 2, 1, requires_grad=True, device='mps') + target = torch.zeros(3, 2, 1, device='mps') + torch.sum(nn.BCELoss(reduction='none')(output, target)).backward() + expected_grad = torch.zeros(3, 2, 1, device='mps') + self.assertEqual(output.grad, expected_grad) + def test_cross_entropy_loss(self): # Regression test for https://github.com/pytorch/pytorch/issues/116095 loss = nn.CrossEntropyLoss() From 69af74972b30d748266323c8099be5743b4c0b72 Mon Sep 17 00:00:00 2001 From: Samuel Park Date: Thu, 6 Nov 2025 01:59:48 +0000 Subject: [PATCH 101/130] Bugfix to forward autodiff causing different datatype 2 (#165784) Fixes #160513 ## The Problem Summary The issue boiled down to data type promotion logic. The code base has two different functions that deal with dtype promotion logic. If it is purely multi-dimensional tensor operations, the cpp code gets triggered and that follows the numpy dtype promotion logic. That is why in #160513 NDim tensors are fine as NDim dtypes gets precedence. The issue came with python scalars and 0Dim tensors. When it detects "scalars", a python implementation of dtype promotion logic gets triggered (torch/_prims_common/__init__.py:1544). Since this is in python, the implementation can't distinguish what is from a wrapped tensor and a 0Dim tensor and thus will just take the highest dtype which is the python double wrapped number. ## The Fix The python implementation for dtype promotion had to know where the scalar came from. Once the scalar can be distinguished then the appropriate dtype can be set. The first approach was to try and expose the `is_wrapped_number` method but this came with a big issue. During the `forward_ad` the derivative of those scalars turned out to be `ZeroTensor`s. The `ZeroTensor` internally uses a hack to initialize a meta dtype tensor which skips expensive dispatch operations. But the copy would not grab everything especially the `is_number_wrapped_` property. I thought about modifying the copy but that seemed to go away from the spirit of what the copy was intended for and plus the tests for `is_wrapped_number_` requires `dim > 0` and a scalar `ZeroTensor` is a meta dtype tensor which complicates things. So I chose the route of creating a new property called `was_wrapped_number` and exposed this property to the python tensor API. I had to modify the autograd code generation to set `was_wrapped_number` in the mul, add, and div operations in `VariableType.cpp`. Once this property was set, the dtype promotion logic could be updated to consider wrapped numbers and 0Dim numbers. Once that hierarchy was taken care of, the buggy behavior was fixed. I wrote a new ops testing module `TestForwardADWithScalars`. I saw that this bug was unique and required new testing paradigm. This only tests the multiply, add, and divide and I chose this because all operations boil down to these three operations. [edit]: Just used `efficientzerotensor` meta and converted that to a python number. Since wrapped number is converted back to a python number, dtype promotion is preserved. The constraint to achieve this happened by setting the forward grad zero tensor of a wrapped number with a wrapped number flag since the tangent of the wrapped number should still be a wrapped number. After that this specific zerotensor was then sent through as a meta type in the `BinaryOps.cpp` to get appropriate dtype for resulting arithmetic. @ezyang @OihanJoyot Pull Request resolved: https://github.com/pytorch/pytorch/pull/165784 Approved by: https://github.com/ezyang --- aten/src/ATen/native/BinaryOps.cpp | 24 +++++++++++++--- test/test_ops.py | 38 +++++++++++++++++++++++++ tools/autograd/gen_variable_type.py | 13 +++++++++ torch/csrc/autograd/FunctionsManual.cpp | 6 ++++ torch/csrc/autograd/FunctionsManual.h | 1 + torch/csrc/jit/python/pybind_utils.cpp | 26 ++++++++++++----- 6 files changed, 97 insertions(+), 11 deletions(-) diff --git a/aten/src/ATen/native/BinaryOps.cpp b/aten/src/ATen/native/BinaryOps.cpp index f5d5edb6439a6..2fa6bcc6dc9ac 100644 --- a/aten/src/ATen/native/BinaryOps.cpp +++ b/aten/src/ATen/native/BinaryOps.cpp @@ -1009,12 +1009,25 @@ static Device correct_out_device(const Tensor& self, const Tensor& other) { } } +static Tensor send_to_meta(const Tensor& self, const Device& device) { + Tensor out_meta; + if (self._is_zerotensor() && self.unsafeGetTensorImpl()->is_wrapped_number()) { + out_meta = at::_efficientzerotensor(self.sizes(), self.options().device(device)); + out_meta.unsafeGetTensorImpl()->set_wrapped_number(true); + } else { + out_meta = self.to(device); + } + return out_meta; +} + Tensor mul_zerotensor(const Tensor& self, const Tensor& other) { auto out_device = correct_out_device(self, other); // hack to use the TensorIterator to get the correct broadcasting and type promotion logic auto device_ = Device(DeviceType::Meta); constexpr c10::DispatchKeySet meta_dks(at::DispatchKey::Meta); - auto meta_out = at::_ops::mul_Tensor::redispatch(meta_dks, self.to(device_), other.to(device_)); + auto self_meta = send_to_meta(self, device_); + auto other_meta = send_to_meta(other, device_); + auto meta_out = at::_ops::mul_Tensor::redispatch(meta_dks, self_meta, other_meta); return at::_efficientzerotensor(meta_out.sizes(), meta_out.options().device(out_device)); } @@ -1023,7 +1036,9 @@ Tensor div_zerotensor(const Tensor& self, const Tensor& other) { // hack to use the TensorIterator to get the correct broadcasting and type promotion logic auto device_ = Device(DeviceType::Meta); constexpr c10::DispatchKeySet meta_dks(at::DispatchKey::Meta); - auto meta_out = at::_ops::div_Tensor::redispatch(meta_dks, self.to(device_), other.to(device_)); + auto self_meta = send_to_meta(self, device_); + auto other_meta = send_to_meta(other, device_); + auto meta_out = at::_ops::div_Tensor::redispatch(meta_dks, self_meta, other_meta); if (self._is_zerotensor()) { if (other._is_zerotensor()) { @@ -1052,8 +1067,9 @@ static Tensor maybe_add_maybe_sub(const Tensor& self, const Tensor& other, const // hack to use the TensorIterator to get the correct broadcasting and type promotion logic auto device_ = Device(DeviceType::Meta); constexpr c10::DispatchKeySet meta_dks(at::DispatchKey::Meta); - auto meta_out = at::_ops::add_Tensor::redispatch( - meta_dks, self.to(device_), other.to(device_), alpha); + auto self_meta = send_to_meta(self, device_); + auto other_meta = send_to_meta(other, device_); + auto meta_out = at::_ops::add_Tensor::redispatch(meta_dks, self_meta, other_meta, alpha); auto get_out_like = [&] (const Tensor& tensor) { diff --git a/test/test_ops.py b/test/test_ops.py index 165b284b76d5c..5f44a3ba0841b 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -2992,12 +2992,50 @@ def test_strided_layout(self, device, dtype, op): self.assertEqual(strided_result.layout, torch.strided) +class TestForwardADWithScalars(TestCase): + @ops( + [op for op in op_db if op.name in ["mul", "add", "div"]], + allowed_dtypes=(torch.float32,), + ) + def test_0d_tensor_with_python_scalar(self, device, dtype, op): + """Test that forward AD preserves dtype when combining 0D tensors with Python scalars.""" + if torch.float not in op.supported_backward_dtypes(device): + raise unittest.SkipTest("Does not support autograd") + + # skip if operator doesnt support forward AD + if not op.supports_forward_ad: + raise unittest.SkipTest("Does not support forward_ad") + + # create 0D tensors + primal0d = torch.ones((), device=device, dtype=dtype) + tangent0d = torch.ones((), device=device, dtype=dtype) + + with torch.autograd.forward_ad.dual_level(): + dual0d = torch.autograd.forward_ad.make_dual(primal0d, tangent0d) + + # Test with scalar on RHS + if op.supports_rhs_python_scalar: + result = op(dual0d, 2.0) + p, t = torch.autograd.forward_ad.unpack_dual(result) + self.assertEqual( + p.dtype, t.dtype, f"{op.name} and scalar on RHS - dtype mismatch" + ) + # Test with scalar on LHS + if op.supports_one_python_scalar: + result = op(2.0, dual0d) + p, t = torch.autograd.forward_ad.unpack_dual(result) + self.assertEqual( + p.dtype, t.dtype, f"{op.name} and scalar on LHS - dtype mismatch" + ) + + instantiate_device_type_tests(TestCommon, globals(), allow_xpu=True) instantiate_device_type_tests(TestCompositeCompliance, globals()) instantiate_device_type_tests(TestMathBits, globals()) instantiate_device_type_tests(TestRefsOpsInfo, globals(), only_for="cpu") instantiate_device_type_tests(TestFakeTensor, globals()) instantiate_device_type_tests(TestTags, globals()) +instantiate_device_type_tests(TestForwardADWithScalars, globals()) if __name__ == "__main__": TestCase._default_dtype_check_enabled = True diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index 13ca3e1389ac1..4796153f24f05 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -763,6 +763,12 @@ """ ) +FW_DERIVATIVE_UPDATE_WRAPPED_NUM_TEMPLATE = CodeTemplate( + """\ +update_wrapped_number(${inp_name}_tensor, ${inp_name}_t); +""" +) + FW_DERIVATIVE_DEFINED_PRIMAL_TEMPLATE = CodeTemplate( """\ auto ${inp_name}_p = toNonOptPrimal(${inp}); @@ -1911,6 +1917,13 @@ def emit_fw_derivatives() -> list[str]: zeros_fn=zeros_fn, ) ) + if zeros_fn == "_efficientzerotensor_symint": + unpacked_arguments += ( + FW_DERIVATIVE_UPDATE_WRAPPED_NUM_TEMPLATE.substitute( + inp_name=inp.name + ) + ) + if inp.name in (derivative.required_inputs_primal or []): unpacked_arguments += ( FW_DERIVATIVE_DEFINED_PRIMAL_TEMPLATE.substitute( diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index 42d701298b0d1..b3cb07ac1cf9f 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -79,6 +79,12 @@ Tensor toNonOptPrimal(const std::optional& t) { return Tensor(); } +void update_wrapped_number(Tensor& input, Tensor& output) { + if (input.unsafeGetTensorImpl()->is_wrapped_number()) { + output.unsafeGetTensorImpl()->set_wrapped_number(true); + } +} + void copy_range(variable_list& out, IndexRange range, const Tensor& t) { TORCH_CHECK(range.second <= out.size()); TORCH_CHECK( diff --git a/torch/csrc/autograd/FunctionsManual.h b/torch/csrc/autograd/FunctionsManual.h index 4dc0425d426ec..ee0f919c44012 100644 --- a/torch/csrc/autograd/FunctionsManual.h +++ b/torch/csrc/autograd/FunctionsManual.h @@ -43,6 +43,7 @@ inline std::optional wrap_opt_if(const Tensor& t, const bool cond) { TORCH_API Tensor apply_loss_reduction(const Tensor& unreduced, int64_t reduction); TORCH_API bool any_variable_defined(const variable_list& variables); +TORCH_API void update_wrapped_number(Tensor& input, Tensor& output); TORCH_API void copy_range( variable_list& out, IndexRange range, diff --git a/torch/csrc/jit/python/pybind_utils.cpp b/torch/csrc/jit/python/pybind_utils.cpp index d60a6a0990082..9f7c2756d0d73 100644 --- a/torch/csrc/jit/python/pybind_utils.cpp +++ b/torch/csrc/jit/python/pybind_utils.cpp @@ -587,7 +587,9 @@ py::object toPyObject(IValue ivalue) { } else if (ivalue.isTensor()) { auto tensor = std::move(ivalue).toTensor(); if (tensor.unsafeGetTensorImpl()->is_wrapped_number()) { - TORCH_INTERNAL_ASSERT(tensor.device().is_cpu()); + TORCH_INTERNAL_ASSERT( + tensor.device().is_cpu() || + (tensor._is_zerotensor() && tensor.dim() == 0)); auto py_tensor = py::cast(tensor); if (PyObject_HasAttrString(py_tensor.ptr(), "_wrapped_number")) { return py_tensor.attr("_wrapped_number"); @@ -595,17 +597,27 @@ py::object toPyObject(IValue ivalue) { auto scalar_type = tensor.scalar_type(); switch (scalar_type) { case at::ScalarType::Bool: - return py::cast(*tensor.const_data_ptr()); + return (tensor._is_zerotensor()) + ? py::cast(false) + : py::cast(*tensor.const_data_ptr()); case at::ScalarType::Long: - return py::cast(*tensor.const_data_ptr()); + return (tensor._is_zerotensor()) + ? py::cast(int64_t(0)) + : py::cast(*tensor.const_data_ptr()); case at::ScalarType::UInt64: - return py::cast(*tensor.const_data_ptr()); + return (tensor._is_zerotensor()) + ? py::cast(uint64_t(0)) + : py::cast(*tensor.const_data_ptr()); case at::ScalarType::Double: - return py::cast(*tensor.const_data_ptr()); + return (tensor._is_zerotensor()) + ? py::cast(0.0) + : py::cast(*tensor.const_data_ptr()); case at::ScalarType::ComplexDouble: // TODO: https://github.com/pytorch/pytorch/issues/77134 - return py::cast(static_cast>( - *tensor.const_data_ptr>())); + return (tensor._is_zerotensor()) + ? py::cast(std::complex(0.0, 0.0)) + : py::cast(static_cast>( + *tensor.const_data_ptr>())); default: TORCH_CHECK( false, From 3a2d75a0869b3ef2344ab0501c19787924442c3e Mon Sep 17 00:00:00 2001 From: Andrey Talman Date: Thu, 6 Nov 2025 02:01:57 +0000 Subject: [PATCH 102/130] Change template 'Release highlight for proposed Feature'->'New Feature for Release' (#167145) Makes it simpler and more clear Pull Request resolved: https://github.com/pytorch/pytorch/pull/167145 Approved by: https://github.com/huydhn --- .github/ISSUE_TEMPLATE/release-feature-request.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/ISSUE_TEMPLATE/release-feature-request.yml b/.github/ISSUE_TEMPLATE/release-feature-request.yml index 80f10807ae56b..090a41d1942f6 100644 --- a/.github/ISSUE_TEMPLATE/release-feature-request.yml +++ b/.github/ISSUE_TEMPLATE/release-feature-request.yml @@ -1,11 +1,11 @@ -name: 🚀 Release highlight for proposed Feature +name: 🚀 New Feature for Release description: Submit a Release highlight for proposed Feature labels: ["release-feature-request"] body: - type: textarea attributes: - label: Release highlight for proposed Feature + label: New Feature for Release description: > Example: “A torch.special module, analogous to SciPy's special module.” - type: input From 943227f57bcd638ab288331442748769f907d8c1 Mon Sep 17 00:00:00 2001 From: "Junjie Wang (PyTorch)" Date: Thu, 6 Nov 2025 02:08:01 +0000 Subject: [PATCH 103/130] [c10d] Fix split_group bug by having the parent pg option deep copied (#167125) Summary: Inside group_split api, we share the reference of PG option with parent PG if a PG option is not explicitly specified. This is bad because if we split parent pg multiple times, we will run into errors. Test Plan: UT + internal test. Differential Revision: D86225394 Pull Request resolved: https://github.com/pytorch/pytorch/pull/167125 Approved by: https://github.com/Skylion007 --- test/distributed/test_c10d_nccl.py | 8 ++++++++ torch/distributed/distributed_c10d.py | 5 ++++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index d764dfbbebbb1..ef7ed5282816f 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -6343,6 +6343,14 @@ def test_comm_recursive_split_group(self): if self.rank == 6 or self.rank == 7: dist.broadcast(tensor2, 6, group=ng2) self.assertEqual(tensor2, torch.full((1,), 6)) + + # Test the case when the split changes the pg option of split group + # while the parent pg option is not changed. + new_pg = c10d.new_group([0, 1, 2, 3, 4, 5, 6, 7], device_id=device) + backend_new_pg = new_pg._get_backend(torch.device(device)) + self.assertEqual(len(backend_new_pg.options.global_ranks_in_group), 8) + c10d.split_group(new_pg, [[0, 2, 4, 6], [1, 3, 5, 7]]) + self.assertEqual(len(backend_new_pg.options.global_ranks_in_group), 8) # a barrier and a cuda sync before destroying all pgs. dist.barrier(pg) torch.cuda.synchronize() diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index 9e4ec1483e960..415cbacc177a8 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -3,6 +3,7 @@ import collections.abc import contextlib +import copy import ctypes import hashlib import io @@ -5212,7 +5213,9 @@ def split_group( if pg_options is None: # default pg_options same as the parent process group - pg_options = parent_backend.options + # A deep copy is needed because if the option will be modified inside split + # and if we split parent pg multiple times, we will run into device out of bound error. + pg_options = copy.deepcopy(parent_backend.options) # this timeout defaulting/validation is used for all the new_groups/new_subgroups variants, # which may just pass their timeout value (or None) From e1a1aeaf5b951e4eb9ce49756311e8f59cf29eb8 Mon Sep 17 00:00:00 2001 From: Yuanyuan Chen Date: Thu, 6 Nov 2025 02:25:10 +0000 Subject: [PATCH 104/130] [1/N] Use `key in dict` for existence checks (#167035) This PR uses `key in dict` expressions for existence checks of dict elements in Python code. This operation is more efficient than `key in dict.keys()`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167035 Approved by: https://github.com/janeyx99 --- torch/_dynamo/backends/registry.py | 2 +- torch/_dynamo/output_graph.py | 2 +- torch/_export/converter.py | 2 +- torch/_guards.py | 2 +- torch/_inductor/augmented_graph_helper.py | 2 +- torch/_inductor/bounds.py | 2 +- torch/_inductor/codecache.py | 2 +- torch/_inductor/codegen/cpp_wrapper_gpu.py | 4 ++-- torch/_inductor/codegen/triton_combo_kernel.py | 2 +- torch/_inductor/cpu_vec_isa.py | 2 +- torch/_inductor/cudagraph_utils.py | 2 +- torch/_inductor/graph.py | 2 +- torch/_inductor/ir.py | 2 +- torch/_inductor/memory.py | 2 +- torch/_inductor/select_algorithm.py | 4 +--- torch/_inductor/tiling_utils.py | 2 +- torch/_library/infer_schema.py | 2 +- torch/_numpy/_dtypes.py | 2 +- torch/_numpy/_ndarray.py | 6 +++--- .../backend_config/_common_operator_config_utils.py | 2 +- torch/ao/quantization/pt2e/prepare.py | 4 ++-- torch/autograd/profiler_legacy.py | 4 ++-- torch/cuda/_device_limits.py | 4 ++-- torch/distributed/_serialization.py | 2 +- torch/distributed/checkpoint/_consolidate_hf_safetensors.py | 2 +- torch/distributed/checkpoint/quantized_hf_storage.py | 2 +- torch/distributed/checkpoint/state_dict.py | 2 +- torch/distributed/elastic/multiprocessing/tail_log.py | 2 +- torch/distributed/fsdp/_optim_utils.py | 2 +- torch/distributed/pipelining/schedules.py | 4 ++-- torch/distributed/tensor/_ops/_pointwise_ops.py | 2 +- torch/export/dynamic_shapes.py | 2 +- torch/fx/experimental/sym_node.py | 2 +- torch/fx/node.py | 2 +- torch/nested/_internal/ops.py | 2 +- torch/optim/swa_utils.py | 4 ++-- torch/serialization.py | 4 ++-- torch/testing/_internal/distributed/distributed_test.py | 6 +++--- torch/testing/_internal/distributed/rpc/rpc_test.py | 2 +- 39 files changed, 50 insertions(+), 52 deletions(-) diff --git a/torch/_dynamo/backends/registry.py b/torch/_dynamo/backends/registry.py index 706ec1768cd35..1469ca478a386 100644 --- a/torch/_dynamo/backends/registry.py +++ b/torch/_dynamo/backends/registry.py @@ -146,7 +146,7 @@ def list_backends(exclude_tags=("debug", "experimental")) -> list[str]: # type: backends = [ name - for name in _BACKENDS.keys() + for name in _BACKENDS if name not in _COMPILER_FNS or not exclude_tags_set.intersection(_COMPILER_FNS[name]._tags) # type: ignore[attr-defined] ] diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index 77f5d6cb05a01..50a2667c12a25 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -2587,7 +2587,7 @@ def update_used_symbols( real_script_obj ): flat_dict = dict(real_script_obj.__obj_flatten__()) # type: ignore[attr-defined] - for attr in flat_dict.keys(): + for attr in flat_dict: fake_attr_val = getattr( fake_script_obj.wrapped_obj, attr ) diff --git a/torch/_export/converter.py b/torch/_export/converter.py index 89b6e3297933f..58de4fd20c953 100644 --- a/torch/_export/converter.py +++ b/torch/_export/converter.py @@ -443,7 +443,7 @@ def __init__( self.blocks_to_lifted_attrs = blocks_to_lifted_attrs # Populate methods for the standard operators. - for k in kind_to_standard_operators.keys(): + for k in kind_to_standard_operators: handler_func_name = ir_name_to_func_name(k) # Create an indirect function call: # convert__ --> lambda node: _convert_standard_operator(node) diff --git a/torch/_guards.py b/torch/_guards.py index b321c5f968b16..32b796d71eea7 100644 --- a/torch/_guards.py +++ b/torch/_guards.py @@ -904,7 +904,7 @@ def patch(**kwargs: Any) -> Generator[None, None, None]: prior = {} ctx = TracingContext.get() - for key in kwargs.keys(): + for key in kwargs: # KeyError on invalid entry prior[key] = getattr(ctx, key) for key, val in kwargs.items(): diff --git a/torch/_inductor/augmented_graph_helper.py b/torch/_inductor/augmented_graph_helper.py index 81dca605940e5..5a70a34f7b64b 100644 --- a/torch/_inductor/augmented_graph_helper.py +++ b/torch/_inductor/augmented_graph_helper.py @@ -164,7 +164,7 @@ def transfer_erased_node_deps(self, erased_to_new: dict[fx.Node, fx.Node]) -> No self.extra_uses[new_node].add(updated_use) # Clean up erased nodes - for old_node in erased_merge_sets.keys(): + for old_node in erased_merge_sets: self.extra_deps[old_node].clear() self.extra_uses[old_node].clear() del self.merge_sets[old_node] diff --git a/torch/_inductor/bounds.py b/torch/_inductor/bounds.py index a227239356a61..bc8dba5119252 100644 --- a/torch/_inductor/bounds.py +++ b/torch/_inductor/bounds.py @@ -86,7 +86,7 @@ def swap_submodules( self, submodules: dict[str, Callable[..., Any]] ) -> dict[str, Callable[..., ValueRanges[Expr]]]: result: dict[str, Callable[..., ValueRanges[Expr]]] = {} - for key in submodules.keys(): + for key in submodules: if key == "get_index": result[key] = self.get_index elif "masked_subblock" in key: diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index cf17bf2e9478b..9583494299265 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -1681,7 +1681,7 @@ def set( if config.aot_inductor.emit_multi_arch_kernel: bin_type_to_ext = {"cubin": ".fatbin", "spv": ".spv"} - assert bin_type in bin_type_to_ext.keys(), ( + assert bin_type in bin_type_to_ext, ( "multi_arch_kernel_binary only supported in CUDA/XPU" ) base_path, _ = os.path.splitext(bin_path) diff --git a/torch/_inductor/codegen/cpp_wrapper_gpu.py b/torch/_inductor/codegen/cpp_wrapper_gpu.py index 02129fff24160..fad4ce84f2971 100644 --- a/torch/_inductor/codegen/cpp_wrapper_gpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_gpu.py @@ -337,7 +337,7 @@ def process_args_for_input_shape(arg, arg_type, arg_signature=None): elif ( isinstance(arg_type, type(SymbolicCallArg)) and arg_signature is not None - and arg_signature in signature2dtype.keys() + and arg_signature in signature2dtype ) or arg_type in (sympy.Integer, int, sympy.Float, float): write_dummy_scalar_ivalue(arg_name) elif arg_signature and arg_signature.startswith("tensordesc<"): @@ -719,7 +719,7 @@ def process_args(arg, arg_type, arg_signature=None): elif ( isinstance(arg_type, type(SymbolicCallArg)) and arg_signature is not None - and arg_signature in signature2dtype.keys() + and arg_signature in signature2dtype ): code.writeline( f"{signature2dtype[arg_signature]} {var_name} = {cexpr(arg)};" diff --git a/torch/_inductor/codegen/triton_combo_kernel.py b/torch/_inductor/codegen/triton_combo_kernel.py index 1f531a5d99ef5..41b12d05cd32e 100644 --- a/torch/_inductor/codegen/triton_combo_kernel.py +++ b/torch/_inductor/codegen/triton_combo_kernel.py @@ -699,7 +699,7 @@ def get_block_args(self) -> list[ConstexprArg]: block_names[f"{tree.prefix.upper()}BLOCK"] = tree.prefix self.block_args = list(block_names.keys()) - return [ConstexprArg(x) for x in block_names.keys()] + return [ConstexprArg(x) for x in block_names] def add_numel_to_args( self, argdefs: list[ArgName], signature: list[Any] diff --git a/torch/_inductor/cpu_vec_isa.py b/torch/_inductor/cpu_vec_isa.py index 515f628c9938c..1c4a394d1eb28 100644 --- a/torch/_inductor/cpu_vec_isa.py +++ b/torch/_inductor/cpu_vec_isa.py @@ -430,7 +430,7 @@ def get_isa_from_cpu_capability( "avx2": "avx2", "avx512": "avx512", } - if capability in capability_to_isa_str.keys(): + if capability in capability_to_isa_str: # pyrefly: ignore [index-error] isa_str = capability_to_isa_str[capability] if isa_str == "INVALID_VEC_ISA": diff --git a/torch/_inductor/cudagraph_utils.py b/torch/_inductor/cudagraph_utils.py index 668becdded469..50d986d48e6c2 100644 --- a/torch/_inductor/cudagraph_utils.py +++ b/torch/_inductor/cudagraph_utils.py @@ -192,7 +192,7 @@ def check_multiple_devices_or_any_cpu_nodes( ): return None - keys_repr = (repr(key) for key in device_node_mapping.keys()) + keys_repr = (repr(key) for key in device_node_mapping) return format_default_skip_message(f"multiple devices: {', '.join(keys_repr)}") diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index 2e89ea5ca461b..28e7f88d33986 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -1590,7 +1590,7 @@ def maybe_propagate( schema_kwargs = {arg.name: arg for arg in schema.arguments} - for key in old_kwargs.keys(): + for key in old_kwargs: old_arg = old_kwargs[key] new_arg = new_kwargs[key] schema_arg = schema_kwargs[key] diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index b1a3071cb7ba4..53c12d0726044 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -1534,7 +1534,7 @@ def py_cnst(val: object) -> Union[bool, float, int]: # "all" is desugared to `!any(!val)` } - assert reduction_type in rtypes_to_inits.keys(), ( + assert reduction_type in rtypes_to_inits, ( f"{reduction_type} not supported for zero-dimension tensors!" ) diff --git a/torch/_inductor/memory.py b/torch/_inductor/memory.py index 6f58b683ac22b..ed223de71c079 100644 --- a/torch/_inductor/memory.py +++ b/torch/_inductor/memory.py @@ -229,7 +229,7 @@ def assign_memory_planning_info_for_scheduler_buffers( # populate the MemoryPlanningInfoForBuffer attribute to each scheduler buffer # note: there are scheduler buffers not in dep_name_to_succ_nodes (e.g., graph outputs) - for buf_name in name_to_buf.keys(): + for buf_name in name_to_buf: name_to_buf[buf_name].mpi_buffer = MemoryPlanningInfoForBuffer( size_alloc=sched_buf_to_size[buf_name][0], size_free=sched_buf_to_size[buf_name][1], diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index dc4be650eccb4..41021b0fc8ed1 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -3719,9 +3719,7 @@ def get_choice_info(choice): M, K = input_nodes[-2].get_size()[:2] N = input_nodes[-1].get_size()[-1] - out_dict = { - str((M, K, N)): [get_choice_info(choice) for choice in timings.keys()] - } + out_dict = {str((M, K, N)): [get_choice_info(choice) for choice in timings]} append_to_log(mm_filename, out_dict) diff --git a/torch/_inductor/tiling_utils.py b/torch/_inductor/tiling_utils.py index 0c9305dc721dd..5b394b9ea9914 100644 --- a/torch/_inductor/tiling_utils.py +++ b/torch/_inductor/tiling_utils.py @@ -165,7 +165,7 @@ def find_coalesced_var( variables[v] = get_hint(v) zero_index = sympy_subs(index, variables) - for v in var_ranges.keys(): + for v in var_ranges: variables[v] = 1 try: new_val = sympy_subs(index, variables) diff --git a/torch/_library/infer_schema.py b/torch/_library/infer_schema.py index 62bd70f65a510..cb3cfd1d6029f 100644 --- a/torch/_library/infer_schema.py +++ b/torch/_library/infer_schema.py @@ -291,7 +291,7 @@ def parse_return(annotation, error_fn): origin = typing.get_origin(annotation) if origin is not tuple: - if annotation not in SUPPORTED_RETURN_TYPES.keys(): + if annotation not in SUPPORTED_RETURN_TYPES: error_fn( f"Return has unsupported type {annotation}. " f"The valid types are: {SUPPORTED_RETURN_TYPES}." diff --git a/torch/_numpy/_dtypes.py b/torch/_numpy/_dtypes.py index a429d28f30cc3..134f7617b758a 100644 --- a/torch/_numpy/_dtypes.py +++ b/torch/_numpy/_dtypes.py @@ -248,7 +248,7 @@ def sctype_from_string(s): """Normalize a string value: a type 'name' or a typecode or a width alias.""" if s in _names: return _names[s] - if s in _name_aliases.keys(): + if s in _name_aliases: return _name_aliases[s] if s in _typecodes: return _typecodes[s] diff --git a/torch/_numpy/_ndarray.py b/torch/_numpy/_ndarray.py index f192a39dd0296..e3f3836754017 100644 --- a/torch/_numpy/_ndarray.py +++ b/torch/_numpy/_ndarray.py @@ -49,7 +49,7 @@ class Flags: def __init__(self, flag_to_value: dict): - assert all(k in FLAGS for k in flag_to_value.keys()) # sanity check + assert all(k in FLAGS for k in flag_to_value) # sanity check self._flag_to_value = flag_to_value def __getattr__(self, attr: str): @@ -59,7 +59,7 @@ def __getattr__(self, attr: str): raise AttributeError(f"No flag attribute '{attr}'") def __getitem__(self, key): - if key in SHORTHAND_TO_FLAGS.keys(): + if key in SHORTHAND_TO_FLAGS: key = SHORTHAND_TO_FLAGS[key] if key in FLAGS: try: @@ -76,7 +76,7 @@ def __setattr__(self, attr, value): super().__setattr__(attr, value) def __setitem__(self, key, value): - if key in FLAGS or key in SHORTHAND_TO_FLAGS.keys(): + if key in FLAGS or key in SHORTHAND_TO_FLAGS: raise NotImplementedError("Modifying flags is not implemented") else: raise KeyError(f"No flag key '{key}'") diff --git a/torch/ao/quantization/backend_config/_common_operator_config_utils.py b/torch/ao/quantization/backend_config/_common_operator_config_utils.py index ab44cfa09197d..25672e7e6ced9 100644 --- a/torch/ao/quantization/backend_config/_common_operator_config_utils.py +++ b/torch/ao/quantization/backend_config/_common_operator_config_utils.py @@ -678,7 +678,7 @@ def _get_bn_configs(dtype_configs: list[DTypeConfig]) -> list[BackendPatternConf torch.nn.BatchNorm2d: nni.BNReLU2d, torch.nn.BatchNorm3d: nni.BNReLU3d, } - for bn in bn_to_fused_bn.keys(): + for bn in bn_to_fused_bn: fused_bn = bn_to_fused_bn[bn] # bn module + relu module fusion config bn_configs.append( diff --git a/torch/ao/quantization/pt2e/prepare.py b/torch/ao/quantization/pt2e/prepare.py index 6eac69a96ba42..9f7767101aba6 100644 --- a/torch/ao/quantization/pt2e/prepare.py +++ b/torch/ao/quantization/pt2e/prepare.py @@ -217,7 +217,7 @@ def _get_edge_or_node_to_group_id( # means the observer of key should be shared with observer with value, by default it will # be shared with itself shared_with_map: dict[EdgeOrNode, EdgeOrNode] = { - k: k for k in edge_or_node_to_qspec.keys() + k: k for k in edge_or_node_to_qspec } for edge_or_node, qspec in edge_or_node_to_qspec.items(): if isinstance(edge_or_node, torch.fx.Node): @@ -282,7 +282,7 @@ def _get_edge_or_node_to_group_id( # now that we get the sharing relations between all edges and nodes, we can assign group ids cur_group_id = 0 edge_or_node_to_group_id: dict[EdgeOrNode, int] = {} - for edge_or_node in shared_with_map.keys(): + for edge_or_node in shared_with_map: root = _find_root_edge_or_node(edge_or_node, shared_with_map) if root not in edge_or_node_to_group_id: edge_or_node_to_group_id[root] = cur_group_id diff --git a/torch/autograd/profiler_legacy.py b/torch/autograd/profiler_legacy.py index 9f60295655ddb..5dd26c0881370 100644 --- a/torch/autograd/profiler_legacy.py +++ b/torch/autograd/profiler_legacy.py @@ -296,9 +296,9 @@ def _get_record_key(record): f"Expected CPU and CUDA memory allocation handles to match, " f"but got {num_open_handles_cpu} CPU and {num_open_handles_cuda} CUDA" ) - for handle in cpu_memory_allocs.keys(): + for handle in cpu_memory_allocs: cpu_memory_allocs[handle] += record.cpu_memory_usage() - for handle in cuda_memory_allocs.keys(): + for handle in cuda_memory_allocs: cuda_memory_allocs[handle] += record.cuda_memory_usage() if num_open_handles_cpu == 0: # output event as a top-level memory event diff --git a/torch/cuda/_device_limits.py b/torch/cuda/_device_limits.py index 808d748c8f6eb..60aeedc8053ab 100644 --- a/torch/cuda/_device_limits.py +++ b/torch/cuda/_device_limits.py @@ -53,7 +53,7 @@ def get_fma_per_cycle_per_sm_cuda_cores(self, data_type: dtype) -> int: else: dict_key = "unknown" - if dict_key not in hardcoded_device_values.keys(): + if dict_key not in hardcoded_device_values: raise RuntimeError( f"No data for sm_{self.compute_capability} and {data_type}." ) @@ -96,7 +96,7 @@ def get_fma_per_cycle_per_sm_tensor_cores(self, data_type: dtype) -> int: else: dict_key = "unknown" - if dict_key not in hardcoded_device_values.keys(): + if dict_key not in hardcoded_device_values: raise RuntimeError( f"No data for sm_{self.compute_capability} and {data_type}." ) diff --git a/torch/distributed/_serialization.py b/torch/distributed/_serialization.py index c13ba46ba5757..8f7043453be76 100644 --- a/torch/distributed/_serialization.py +++ b/torch/distributed/_serialization.py @@ -145,7 +145,7 @@ def _streaming_load( if pickle_module is None: pickle_module = pickle - if "encoding" not in pickle_load_args.keys(): + if "encoding" not in pickle_load_args: pickle_load_args["encoding"] = "utf-8" zip_file = _PseudoZipFile() diff --git a/torch/distributed/checkpoint/_consolidate_hf_safetensors.py b/torch/distributed/checkpoint/_consolidate_hf_safetensors.py index 9db89d038658a..9d70ab7c7400d 100644 --- a/torch/distributed/checkpoint/_consolidate_hf_safetensors.py +++ b/torch/distributed/checkpoint/_consolidate_hf_safetensors.py @@ -257,7 +257,7 @@ def _process_output_file( ) # Process each input safetensors file - for safetensors_file in input_files_data.keys(): + for safetensors_file in input_files_data: file_metadata = input_files_data[safetensors_file].metadata input_metadata_size = input_files_data[safetensors_file].metadata_size diff --git a/torch/distributed/checkpoint/quantized_hf_storage.py b/torch/distributed/checkpoint/quantized_hf_storage.py index 2cb189d515a8a..36f4ddf937fee 100644 --- a/torch/distributed/checkpoint/quantized_hf_storage.py +++ b/torch/distributed/checkpoint/quantized_hf_storage.py @@ -82,7 +82,7 @@ def _build_weight_scale_mapping(self, weight_map: dict[str, str]): # Store the complete weight map for file location lookups self._weight_map = weight_map - for tensor_name in weight_map.keys(): + for tensor_name in weight_map: if tensor_name.endswith(".weight_scale_inv"): weight_name = tensor_name.replace(".weight_scale_inv", ".weight") if weight_name in weight_map: diff --git a/torch/distributed/checkpoint/state_dict.py b/torch/distributed/checkpoint/state_dict.py index 9202851537fba..54a29c0bb3588 100644 --- a/torch/distributed/checkpoint/state_dict.py +++ b/torch/distributed/checkpoint/state_dict.py @@ -443,7 +443,7 @@ def _verify_state_dict( f"or load but optim state_dict is empty. {optim_state_dict}" ) - for key in model_state_dict.keys(): + for key in model_state_dict: if _FLAT_PARAM in key: raise RuntimeError( f"{key} contains {_FLAT_PARAM}. This can happen if the model " diff --git a/torch/distributed/elastic/multiprocessing/tail_log.py b/torch/distributed/elastic/multiprocessing/tail_log.py index 034740810dcdd..a34ec1408be57 100644 --- a/torch/distributed/elastic/multiprocessing/tail_log.py +++ b/torch/distributed/elastic/multiprocessing/tail_log.py @@ -130,7 +130,7 @@ def __init__( self._log_line_prefixes = log_line_prefixes self._log_line_filter = log_line_filter self._finished_events: dict[int, Event] = { - local_rank: Event() for local_rank in log_files.keys() + local_rank: Event() for local_rank in log_files } self._futs: list[Future] = [] self._interval_sec = interval_sec diff --git a/torch/distributed/fsdp/_optim_utils.py b/torch/distributed/fsdp/_optim_utils.py index 60e3f37a99919..96657eeea4106 100644 --- a/torch/distributed/fsdp/_optim_utils.py +++ b/torch/distributed/fsdp/_optim_utils.py @@ -1549,7 +1549,7 @@ def _allgather_orig_param_states( fsdp_state._device_handle.memory_summary(), ) - output_states: dict[str, dict[str, Any]] = {fqn: {} for fqn in input_states.keys()} + output_states: dict[str, dict[str, Any]] = {fqn: {} for fqn in input_states} dtype, state_buffers = _convert_all_state_info( fsdp_param_info, gathered_state_info, input_states, output_states diff --git a/torch/distributed/pipelining/schedules.py b/torch/distributed/pipelining/schedules.py index 39da483fe002b..e60ae3b93ba63 100644 --- a/torch/distributed/pipelining/schedules.py +++ b/torch/distributed/pipelining/schedules.py @@ -1637,7 +1637,7 @@ def _step_microbatches( # the stages in the pipeline_order all_prev_ranks: set[int] = set() all_next_ranks: set[int] = set() - for stage_index in stage_index_to_stage.keys(): + for stage_index in stage_index_to_stage: # TODO: assumption that stages only communicate from distances of +1/-1 (no skip connections) if stage_index > 0: all_prev_ranks.add(self.stage_index_to_group_rank[stage_index - 1]) @@ -3176,7 +3176,7 @@ def get_schedule_class(schedule_name: str): "ZBVZeroBubble": ScheduleZBVZeroBubble, "DualPipeV": ScheduleDualPipeV, } - lowercase_keys = {k.lower(): k for k in schedule_map.keys()} + lowercase_keys = {k.lower(): k for k in schedule_map} lowercase_schedule_name = schedule_name.lower() if lowercase_schedule_name not in lowercase_keys: raise ValueError( diff --git a/torch/distributed/tensor/_ops/_pointwise_ops.py b/torch/distributed/tensor/_ops/_pointwise_ops.py index 084fa62706e0d..53b759e993c0d 100644 --- a/torch/distributed/tensor/_ops/_pointwise_ops.py +++ b/torch/distributed/tensor/_ops/_pointwise_ops.py @@ -618,7 +618,7 @@ def common_pointwise_strategy( return pointwise_strategy -for op in linear_pointwise_ops.keys(): +for op in linear_pointwise_ops: register_op_strategy(op, schema_info=RuntimeSchemaInfo(static_kwargkey=["out"]))( linear_pointwise_strategy ) diff --git a/torch/export/dynamic_shapes.py b/torch/export/dynamic_shapes.py index 1e1f1f409857b..a9a018468cef1 100644 --- a/torch/export/dynamic_shapes.py +++ b/torch/export/dynamic_shapes.py @@ -1333,7 +1333,7 @@ def refine_dynamic_shapes_from_suggested_fixes( roots.add(c.root.__name__) # type: ignore[attr-defined] # check keys are existing dims or new roots - for k in shape_fixes.keys(): + for k in shape_fixes: assert k in name_to_dim or k in roots # cache so we don't produce multiple derived dim objects diff --git a/torch/fx/experimental/sym_node.py b/torch/fx/experimental/sym_node.py index d07d235e51321..e01cab57775c5 100644 --- a/torch/fx/experimental/sym_node.py +++ b/torch/fx/experimental/sym_node.py @@ -1871,7 +1871,7 @@ def round_magic_impl(self, ndigits=None): setattrs(user_type, f"__r{method_name}__", rbinary_magic_impl) -for method in magic_methods.keys(): # type: ignore[assignment] +for method in magic_methods: # type: ignore[assignment] if method in only_bool_magic_methods: _make_user_magic(method, SymBool) continue diff --git a/torch/fx/node.py b/torch/fx/node.py index 1d72a75a6ccf4..272676a4e3a94 100644 --- a/torch/fx/node.py +++ b/torch/fx/node.py @@ -496,7 +496,7 @@ def insert_arg(self, idx: int, arg: Argument) -> None: _new_input_nodes: dict[Node, None] = {} _fx_map_arg(arg, _new_input_nodes.setdefault) - for new_use in _new_input_nodes.keys(): + for new_use in _new_input_nodes: if new_use not in self._input_nodes: self._input_nodes.setdefault(new_use) new_use.users.setdefault(self) diff --git a/torch/nested/_internal/ops.py b/torch/nested/_internal/ops.py index a84a5b681d638..69c324ab726ec 100644 --- a/torch/nested/_internal/ops.py +++ b/torch/nested/_internal/ops.py @@ -143,7 +143,7 @@ def check_schema(schema_str: str, func, *args, **kwargs) -> None: name, arg_type = named_arg_type.split(": ") is_optional = arg_type.endswith("?") normalized_arg_type = arg_type[:-1] if is_optional else arg_type - if normalized_arg_type not in arg_type_check_fns.keys(): + if normalized_arg_type not in arg_type_check_fns: raise AssertionError(f"Unknown arg type: {normalized_arg_type}") if i >= len(args): diff --git a/torch/optim/swa_utils.py b/torch/optim/swa_utils.py index 1ab915d27cd66..254560d8751ce 100644 --- a/torch/optim/swa_utils.py +++ b/torch/optim/swa_utils.py @@ -367,7 +367,7 @@ def update_bn( was_training = model.training model.train() - for module in momenta.keys(): + for module in momenta: module.momentum = None for input in loader: @@ -378,7 +378,7 @@ def update_bn( model(input) - for bn_module in momenta.keys(): + for bn_module in momenta: bn_module.momentum = momenta[bn_module] model.train(was_training) diff --git a/torch/serialization.py b/torch/serialization.py index ce5a74d92384e..ffa77cec732ed 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -1250,7 +1250,7 @@ def persistent_id(self, obj): zip_file.write_record("byteorder", sys.byteorder, len(sys.byteorder)) # Write each tensor to a file named tensor/the_tensor_key in the zip archive - for key in serialized_storages.keys(): + for key in serialized_storages: name = f"data/{key}" storage = serialized_storages[key] num_bytes = storage.nbytes() @@ -1494,7 +1494,7 @@ def _get_wo_message(message: str) -> str: _check_dill_version(pickle_module) - if "encoding" not in pickle_load_args.keys(): + if "encoding" not in pickle_load_args: pickle_load_args["encoding"] = "utf-8" with _open_file_like(f, "rb") as opened_file: diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index a14f670d788be..8cb9c929d8545 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -7050,8 +7050,8 @@ def _validate_execution_trace_nccl(self, et_file: str) -> None: self.assertGreaterEqual(attrs.get("in_msg_nelems", -1), 0) self.assertGreaterEqual(attrs.get("out_msg_nelems", -1), 0) - self.assertTrue("in_split_size" in attrs.keys()) - self.assertTrue("out_split_size" in attrs.keys()) + self.assertTrue("in_split_size" in attrs) + self.assertTrue("out_split_size" in attrs) self.assertEqual(attrs.get("global_rank_start", -1), 0) self.assertEqual(attrs.get("global_rank_stride", -1), 1) @@ -9306,7 +9306,7 @@ def get_loss(model_output): "tuple": tuple, "dict": dict, } - for output_type in type_mapping.keys(): + for output_type in type_mapping: for _ in range(6): out = model(inp, output_type=output_type) loss = get_loss(out) diff --git a/torch/testing/_internal/distributed/rpc/rpc_test.py b/torch/testing/_internal/distributed/rpc/rpc_test.py index b7c0dd17a1164..21464e514742c 100644 --- a/torch/testing/_internal/distributed/rpc/rpc_test.py +++ b/torch/testing/_internal/distributed/rpc/rpc_test.py @@ -3282,7 +3282,7 @@ def test_debug_info(self): expected.update(autograd_info) # NB: Key ordering is only preserved in python 3.6+. So here, we # manually check keys are equal. - for key in expected.keys(): + for key in expected: self.assertIn(key, info.keys()) for key in info.keys(): From c08ce30d18303ff4e43d53ccb0c0c6e6b8bd1dae Mon Sep 17 00:00:00 2001 From: Fadi Arafeh Date: Wed, 5 Nov 2025 22:37:35 +0000 Subject: [PATCH 105/130] [ci][cpu] Update compiler to GCC-13 in jammy-aarch64 (#166849) This is needed because manylinux uses GCC-13 since #152825 As a result of the current compiler version mismatches, we've seen tests passing jammy-aarch64 pre-commit CI, but failing for wheels built in manylinux Related to: #166736 Pull Request resolved: https://github.com/pytorch/pytorch/pull/166849 Approved by: https://github.com/robert-hardwick, https://github.com/malfet, https://github.com/Skylion007, https://github.com/atalman --- .ci/docker/build.sh | 8 ++++---- .ci/docker/common/install_gcc.sh | 4 ++-- .github/workflows/docker-builds.yml | 4 ++-- .github/workflows/inductor-perf-test-nightly-aarch64.yml | 2 +- .github/workflows/linux-aarch64.yml | 2 +- .github/workflows/operator_benchmark.yml | 2 +- test/cpp/aoti_abi_check/CMakeLists.txt | 4 ++++ test/cpp/api/CMakeLists.txt | 7 +++++++ 8 files changed, 22 insertions(+), 11 deletions(-) diff --git a/.ci/docker/build.sh b/.ci/docker/build.sh index 5257decb9d4d5..f0b9a788758ca 100755 --- a/.ci/docker/build.sh +++ b/.ci/docker/build.sh @@ -261,9 +261,9 @@ case "$tag" in PYTHON_VERSION=3.10 CUDA_VERSION=12.8.1 ;; - pytorch-linux-jammy-aarch64-py3.10-gcc11) + pytorch-linux-jammy-aarch64-py3.10-gcc13) ANACONDA_PYTHON_VERSION=3.10 - GCC_VERSION=11 + GCC_VERSION=13 ACL=yes VISION=yes OPENBLAS=yes @@ -281,9 +281,9 @@ case "$tag" in # from pytorch/llvm:9.0.1 is x86 specific SKIP_LLVM_SRC_BUILD_INSTALL=yes ;; - pytorch-linux-jammy-aarch64-py3.10-gcc11-inductor-benchmarks) + pytorch-linux-jammy-aarch64-py3.10-gcc13-inductor-benchmarks) ANACONDA_PYTHON_VERSION=3.10 - GCC_VERSION=11 + GCC_VERSION=13 ACL=yes VISION=yes OPENBLAS=yes diff --git a/.ci/docker/common/install_gcc.sh b/.ci/docker/common/install_gcc.sh index 3b96bf6e0ed2f..df1c059bc3869 100644 --- a/.ci/docker/common/install_gcc.sh +++ b/.ci/docker/common/install_gcc.sh @@ -7,11 +7,11 @@ if [ -n "$GCC_VERSION" ]; then # Need the official toolchain repo to get alternate packages add-apt-repository ppa:ubuntu-toolchain-r/test apt-get update - apt-get install -y g++-$GCC_VERSION + apt-get install -y g++-$GCC_VERSION gfortran-$GCC_VERSION update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-"$GCC_VERSION" 50 update-alternatives --install /usr/bin/g++ g++ /usr/bin/g++-"$GCC_VERSION" 50 update-alternatives --install /usr/bin/gcov gcov /usr/bin/gcov-"$GCC_VERSION" 50 - + update-alternatives --install /usr/bin/gfortran gfortran /usr/bin/gfortran-"$GCC_VERSION" 50 # Cleanup package manager apt-get autoclean && apt-get clean diff --git a/.github/workflows/docker-builds.yml b/.github/workflows/docker-builds.yml index 4d0940094f541..941a045649f3a 100644 --- a/.github/workflows/docker-builds.yml +++ b/.github/workflows/docker-builds.yml @@ -77,11 +77,11 @@ jobs: pytorch-linux-noble-riscv64-py3.12-gcc14 ] include: - - docker-image-name: pytorch-linux-jammy-aarch64-py3.10-gcc11 + - docker-image-name: pytorch-linux-jammy-aarch64-py3.10-gcc13 runner: linux.arm64.m7g.4xlarge - docker-image-name: pytorch-linux-jammy-aarch64-py3.10-clang21 runner: linux.arm64.m7g.4xlarge - - docker-image-name: pytorch-linux-jammy-aarch64-py3.10-gcc11-inductor-benchmarks + - docker-image-name: pytorch-linux-jammy-aarch64-py3.10-gcc13-inductor-benchmarks runner: linux.arm64.m7g.4xlarge timeout-minutes: 600 # Docker uploads fail from LF runners, see https://github.com/pytorch/pytorch/pull/137358 diff --git a/.github/workflows/inductor-perf-test-nightly-aarch64.yml b/.github/workflows/inductor-perf-test-nightly-aarch64.yml index e16c8be79130d..46a1966570c63 100644 --- a/.github/workflows/inductor-perf-test-nightly-aarch64.yml +++ b/.github/workflows/inductor-perf-test-nightly-aarch64.yml @@ -72,7 +72,7 @@ jobs: runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" runner: linux.arm64.m7g.4xlarge build-environment: linux-jammy-aarch64-py3.10 - docker-image-name: ci-image:pytorch-linux-jammy-aarch64-py3.10-gcc11-inductor-benchmarks + docker-image-name: ci-image:pytorch-linux-jammy-aarch64-py3.10-gcc13-inductor-benchmarks test-matrix: | { include: [ { config: "inductor_huggingface_perf_cpu_aarch64", shard: 1, num_shards: 9, runner: "linux.arm64.m7g.metal" }, diff --git a/.github/workflows/linux-aarch64.yml b/.github/workflows/linux-aarch64.yml index 2b840a39a5c21..e6690b1043006 100644 --- a/.github/workflows/linux-aarch64.yml +++ b/.github/workflows/linux-aarch64.yml @@ -33,7 +33,7 @@ jobs: with: runner_prefix: ${{ needs.get-label-type.outputs.label-type }} build-environment: linux-jammy-aarch64-py3.10 - docker-image-name: ci-image:pytorch-linux-jammy-aarch64-py3.10-gcc11 + docker-image-name: ci-image:pytorch-linux-jammy-aarch64-py3.10-gcc13 runner: linux.arm64.m7g.4xlarge test-matrix: | { include: [ diff --git a/.github/workflows/operator_benchmark.yml b/.github/workflows/operator_benchmark.yml index 40fb3b8d0c85f..758147f5fe18e 100644 --- a/.github/workflows/operator_benchmark.yml +++ b/.github/workflows/operator_benchmark.yml @@ -60,7 +60,7 @@ jobs: with: build-environment: linux-jammy-aarch64-py3.10 runner: linux.arm64.m7g.4xlarge - docker-image-name: ci-image:pytorch-linux-jammy-aarch64-py3.10-gcc11 + docker-image-name: ci-image:pytorch-linux-jammy-aarch64-py3.10-gcc13 test-matrix: | { include: [ { config: "cpu_operator_benchmark_short", shard: 1, num_shards: 1, runner: "linux.arm64.m8g.4xlarge" }, diff --git a/test/cpp/aoti_abi_check/CMakeLists.txt b/test/cpp/aoti_abi_check/CMakeLists.txt index 1695e65cb4a1b..f1747acc31fc8 100644 --- a/test/cpp/aoti_abi_check/CMakeLists.txt +++ b/test/cpp/aoti_abi_check/CMakeLists.txt @@ -45,6 +45,10 @@ endif() # Disable unused-variable warnings for variables that are only used to test compilation target_compile_options_if_supported(test_aoti_abi_check -Wno-unused-variable) target_compile_options_if_supported(test_aoti_abi_check -Wno-unused-but-set-variable) +# Add -Wno-dangling-pointer for GCC 13 +if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 13) + target_compile_options_if_supported(test_aoti_abi_check -Wno-dangling-pointer) +endif() foreach(test_src ${AOTI_ABI_CHECK_VEC_TEST_SRCS}) foreach(i RANGE ${NUM_CPU_CAPABILITY_NAMES}) diff --git a/test/cpp/api/CMakeLists.txt b/test/cpp/api/CMakeLists.txt index 8261aae3b5607..a92832a4d04c9 100644 --- a/test/cpp/api/CMakeLists.txt +++ b/test/cpp/api/CMakeLists.txt @@ -70,6 +70,13 @@ if(NOT MSVC) if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12) target_compile_options_if_supported(test_api "-Wno-error=nonnull") endif() + + # Add -Wno-error=array-bounds for GCC 13+ + # See: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=113239 + if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 13) + target_compile_options_if_supported(test_api "-Wno-error=array-bounds") + endif() + endif() if(INSTALL_TEST) From 85fab6c9b00bb6ba3a0d5e72596dfa4bf39fc998 Mon Sep 17 00:00:00 2001 From: Apurva Jain Date: Thu, 6 Nov 2025 03:24:59 +0000 Subject: [PATCH 106/130] Fix duplicate benchmarking entries for addmm (#166652) There have been duplicate entries for addmm in dashboard. This PR fixes the duplicate entries issues Pull Request resolved: https://github.com/pytorch/pytorch/pull/166652 Approved by: https://github.com/yangw-dev --- benchmarks/operator_benchmark/pt/addmm_test.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/benchmarks/operator_benchmark/pt/addmm_test.py b/benchmarks/operator_benchmark/pt/addmm_test.py index a98628944b3e8..3e94a9cd7f3dc 100644 --- a/benchmarks/operator_benchmark/pt/addmm_test.py +++ b/benchmarks/operator_benchmark/pt/addmm_test.py @@ -53,10 +53,8 @@ def forward(self, input_one, mat1, mat2): return torch.addmm(input_one, mat1, mat2) -op_bench.generate_pt_test(addmm_long_configs + addmm_long_configs, AddmmBenchmark) -op_bench.generate_pt_gradient_test( - addmm_long_configs + addmm_long_configs, AddmmBenchmark -) +op_bench.generate_pt_test(addmm_short_configs + addmm_long_configs, AddmmBenchmark) +op_bench.generate_pt_gradient_test(addmm_long_configs, AddmmBenchmark) """Mircobenchmark for addbmm operator.""" @@ -107,9 +105,7 @@ def forward(self, input_one, batch1, batch2): ) op_bench.generate_pt_test(addbmm_long_configs + addbmm_short_configs, AddbmmBenchmark) -op_bench.generate_pt_gradient_test( - addbmm_long_configs + addbmm_short_configs, AddbmmBenchmark -) +op_bench.generate_pt_gradient_test(addbmm_long_configs, AddbmmBenchmark) if __name__ == "__main__": op_bench.benchmark_runner.main() From d31599f40bc580c170553ca6766163b41c427ed9 Mon Sep 17 00:00:00 2001 From: Yuanyuan Chen Date: Thu, 6 Nov 2025 03:36:56 +0000 Subject: [PATCH 107/130] [7/N] Fix unused loop variables in tests (#167043) This PR continues to fix or remove unused loop variables in tests. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167043 Approved by: https://github.com/Lucaskabela --- test/dynamo/test_compiler_bisector.py | 2 +- test/functorch/test_control_flow.py | 14 ++++---------- test/inductor/test_flex_attention.py | 4 ++-- .../eager/test_bias_correction_eager.py | 2 +- test/quantization/fx/test_quantize_fx.py | 6 +++--- test/test_nn.py | 2 +- test/test_transformers.py | 2 +- 7 files changed, 13 insertions(+), 19 deletions(-) diff --git a/test/dynamo/test_compiler_bisector.py b/test/dynamo/test_compiler_bisector.py index 8810a30aaf3b7..8ebf35f3f0d3f 100644 --- a/test/dynamo/test_compiler_bisector.py +++ b/test/dynamo/test_compiler_bisector.py @@ -283,7 +283,7 @@ def test_fn(): ) def test_bisect_pre_grad_graph(self): def f(x): - for i in range(5): + for _ in range(5): x = x + 1 return x.relu() diff --git a/test/functorch/test_control_flow.py b/test/functorch/test_control_flow.py index 5034661fa3e05..f83f059663149 100644 --- a/test/functorch/test_control_flow.py +++ b/test/functorch/test_control_flow.py @@ -942,9 +942,7 @@ def false_fn(x): b = torch.randn(4, requires_grad=True) c = torch.randn(4, requires_grad=True) - for pred, fn in zip( - [torch.tensor(False), torch.tensor(True)], [false_fn, true_fn] - ): + for pred in [torch.tensor(False), torch.tensor(True)]: with self.assertRaisesRegex( torch._dynamo.exc.UncapturedHigherOrderOpError, "Cond doesn't work unless it is captured completely with torch.compile", @@ -3066,13 +3064,9 @@ def run_test_and_get_grads_loss(model, initial_hs, inputs): ).to(DEVICE) # Test 3 models: RNNScanList, RNNScanTensor, RNNLoop - models = [ - ("ScanList", RNNScanList), - ("ScanTensor", RNNScanTensor), - ("Loop", RNNLoop), - ] + models = [RNNScanList, RNNScanTensor, RNNLoop] - for model_name, model_class in models: + for model_class in models: # Create uncompiled model model_uc = model_class().to(DEVICE) uncompiled_grads, uncompiled_loss = run_test_and_get_grads_loss( @@ -7538,7 +7532,7 @@ def foo(x): inps = (torch.ones(3, 4), torch.ones(3, 5), torch.ones(5, 4), torch.ones(5, 3)) for inp in inps: - gm = make_fx(foo, tracing_mode="symbolic")(torch.ones(3, 4)) + gm = make_fx(foo, tracing_mode="symbolic")(inp) self.assertExpectedInline( gm.code.strip(), """\ diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index a1e5aa3cebc45..816d3b93ecfef 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -5807,11 +5807,11 @@ def causal_mask(b, h, q_idx, kv_idx): from torch.utils._pytree import GetAttrKey - for key, tensor in tensors_with_keys: + for key, _tensor in tensors_with_keys: self.assertIsInstance(key, GetAttrKey) self.assertIsNotNone(key) - for key, value in context_with_keys: + for key, _value in context_with_keys: self.assertIsInstance(key, GetAttrKey) self.assertIsNotNone(key) diff --git a/test/quantization/eager/test_bias_correction_eager.py b/test/quantization/eager/test_bias_correction_eager.py index 5f0c475f934dd..071ea6e2a768f 100644 --- a/test/quantization/eager/test_bias_correction_eager.py +++ b/test/quantization/eager/test_bias_correction_eager.py @@ -39,7 +39,7 @@ def correct_artificial_bias_quantize(self, float_model, img_data): torch.ao.quantization.convert(artificial_model, inplace=True) # manually changing bias - for name, submodule in artificial_model.named_modules(): + for submodule in artificial_model.modules(): if type(submodule) in _supported_modules: x = get_param(submodule, "bias") weight = get_param(submodule, "weight") diff --git a/test/quantization/fx/test_quantize_fx.py b/test/quantization/fx/test_quantize_fx.py index faba2f5edc6a7..b33afc7a80363 100644 --- a/test/quantization/fx/test_quantize_fx.py +++ b/test/quantization/fx/test_quantize_fx.py @@ -9663,10 +9663,10 @@ def forward(self, input: torch.Tensor, offsets: Optional[torch.Tensor] = None, .set_global(get_default_qat_qconfig(qengine)) \ .set_object_type(torch.nn.EmbeddingBag, default_embedding_qat_qconfig) - train_indices = [[torch.randint(0, 10, (12, 12)), torch.randn((12, 1))] for _ in range(2)] - eval_output = [[torch.randint(0, 10, (12, 1))]] + train_indices = [[torch.randint(0, 10, (12, 12), device=device), torch.randn((12, 1), device=device)] for _ in range(2)] + eval_output = [[torch.randint(0, 10, (12, 1), device=device)]] - model = EmbeddingBagLinear().train() + model = EmbeddingBagLinear().to(device).train() prepared_fx_model = prepare_qat_fx(model, qconfig_dict, example_inputs=(train_indices[0][0],)) test_only_train_fn(prepared_fx_model, train_indices) quant_model = convert_fx(prepared_fx_model, diff --git a/test/test_nn.py b/test/test_nn.py index 034cf51d49ff0..bedb4b22a01bd 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -13516,7 +13516,7 @@ def compare_scaling(grads): # Should warning when parameters generator exhausted params = l.parameters() - for p in params: + for _p in params: pass with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") diff --git a/test/test_transformers.py b/test/test_transformers.py index cc82cbff2a46f..ad7ae56307eb1 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -2857,7 +2857,7 @@ def test_cudnn_attention_broken_166211(self): # https://github.com/pytorch/pytorch/issues/166211#issue-3551350377 shape = (20, 4, 4, 32) scale = 10 - for i in range(100): + for _ in range(100): q = torch.randn(*shape, device='cuda', dtype=torch.bfloat16) * scale k = torch.randn(*shape, device='cuda', dtype=torch.bfloat16) * scale v = torch.randn(*shape, device='cuda', dtype=torch.bfloat16) * scale From 981dd718939ae2413c217c071e364715dbdbf8d6 Mon Sep 17 00:00:00 2001 From: Scott Wolchok Date: Sat, 1 Nov 2025 15:49:12 -0700 Subject: [PATCH 108/130] Refactor: extract OperatorArgsKwargsView from parseIValuesToPyArgsKwargs (#166368) Intended to make it easier to reuse this logic for processing operator arguments as IValues in following PR(s). Testing: python test/test_python_dispatch.py (broke during development, seems to work now) Pull Request resolved: https://github.com/pytorch/pytorch/pull/166368 Approved by: https://github.com/albanD --- torch/csrc/autograd/python_variable.cpp | 193 ++++++++++++++++++------ 1 file changed, 151 insertions(+), 42 deletions(-) diff --git a/torch/csrc/autograd/python_variable.cpp b/torch/csrc/autograd/python_variable.cpp index 946a8d5f1d367..837ba93d1cc28 100644 --- a/torch/csrc/autograd/python_variable.cpp +++ b/torch/csrc/autograd/python_variable.cpp @@ -51,14 +51,101 @@ using namespace at; using namespace torch; using namespace torch::autograd; -std::pair parseIValuesToPyArgsKwargs( - const c10::OperatorHandle& op, - const std::vector& arguments) { - TORCH_CHECK( - PyGILState_Check(), - "GIL must be held before you call parseIValuesToPyArgsKwargs"); - const auto& schema = op.schema(); - py::dict kwargs; +namespace { +class OperatorArgsKwargsView { + public: + OperatorArgsKwargsView( + const c10::OperatorHandle& op, + const std::vector& arguments); + using args_iterator = const c10::IValue*; + + args_iterator args_begin() const { + return arguments_.data(); + } + + args_iterator args_end() const { + return arguments_.data() + positional_default_start_; + } + + auto num_positional_args() const { + return positional_default_start_; + } + + auto kwarg_start_index() const { + return first_non_default_kwarg_; + } + + struct kwargs_iterator { + kwargs_iterator() = default; + kwargs_iterator(const OperatorArgsKwargsView* parent, size_t current) + : parent_(parent), current_(current) {} + + kwargs_iterator(const kwargs_iterator&) = default; + kwargs_iterator& operator=(const kwargs_iterator&) = default; + + kwargs_iterator& operator++() { + do { + current_++; + } while (current_ < parent_->arguments_.size() && + parent_->is_default(current_)); + return *this; + } + + kwargs_iterator operator++(int) { + auto copy = *this; + ++(*this); + return copy; + } + + const c10::IValue& operator*() const { + return parent_->arguments_[current_]; + } + + const c10::IValue* operator->() const { + return &operator*(); + } + + int64_t underlying_index() const { + return current_; + } + + bool operator==(const kwargs_iterator& rhs) const { + return parent_ == rhs.parent_ && current_ == rhs.current_; + } + + bool operator!=(const kwargs_iterator& rhs) { + return !(*this == rhs); + } + + private: + const OperatorArgsKwargsView* parent_ = nullptr; + size_t current_ = 0; + }; + + kwargs_iterator kwargs_begin() const { + return kwargs_iterator(this, first_non_default_kwarg_); + } + + kwargs_iterator kwargs_end() const { + return kwargs_iterator(this, arguments_.size()); + } + + private: + bool is_default(size_t idx) const { + const auto& arg = op_.schema().arguments()[idx]; + if (!arg.default_value().has_value()) { + return false; + } + const auto& default_ivalue = *arg.default_value(); + const auto& ivalue = arguments_[idx]; + if (default_ivalue != ivalue) { + return false; + } + return true; + } + + const c10::OperatorHandle& op_; + c10::ArrayRef arguments_; // About all the pointers: // // f(int x, int y = 0, *, int z = 0) @@ -66,45 +153,63 @@ std::pair parseIValuesToPyArgsKwargs( // ^- kwarg_only_start // ^- positional_default_start // ^- 0 + int64_t positional_default_start_; + int64_t first_non_default_kwarg_; +}; +OperatorArgsKwargsView::OperatorArgsKwargsView( + const c10::OperatorHandle& op, + const std::vector& arguments) + : op_(op), arguments_(arguments) { // Find the split point between kwarg-only and regular. Since most functions // don't have kwarg-only arguments, it is more efficient to scan from the // right (but ideally, this would just be precomputed in FunctionSchema // itself). (NB: minus one in the loop is because we're testing if the // *next* argument is kwarg-only before we advance the starting index) - int64_t kwarg_only_start = static_cast(arguments.size()); + const int64_t signed_arguments_size = static_cast(arguments.size()); + int64_t kwarg_only_start = signed_arguments_size; for (; kwarg_only_start > 0; kwarg_only_start--) { - const auto& arg = schema.arguments()[kwarg_only_start - 1]; + const auto& arg = op.schema().arguments()[kwarg_only_start - 1]; if (!arg.kwarg_only()) { break; } } // Find the first positional argument that isn't defaulted - auto is_default = [&](size_t idx) -> bool { - const auto& arg = schema.arguments()[idx]; - if (!arg.default_value().has_value()) { - return false; - } - const auto& default_ivalue = *arg.default_value(); - const auto& ivalue = arguments[idx]; - if (default_ivalue != ivalue) { - return false; + positional_default_start_ = kwarg_only_start; + for (; positional_default_start_ > 0; positional_default_start_--) { + if (!is_default(positional_default_start_ - 1)) { + break; } - return true; - }; + } - int64_t positional_default_start = kwarg_only_start; - for (; positional_default_start > 0; positional_default_start--) { - if (!is_default(positional_default_start - 1)) { + // kwargs_iterator will skip default kwargs when incremented, but we + // need to skip any initial run of default kwargs ourselves. + first_non_default_kwarg_ = kwarg_only_start; + for (; first_non_default_kwarg_ < signed_arguments_size; + ++first_non_default_kwarg_) { + if (!is_default(first_non_default_kwarg_)) { break; } } +} +} // namespace - auto args = - py::reinterpret_steal(PyTuple_New(positional_default_start)); +std::pair parseIValuesToPyArgsKwargs( + const c10::OperatorHandle& op, + const std::vector& arguments) { + TORCH_CHECK( + PyGILState_Check(), + "GIL must be held before you call parseIValuesToPyArgsKwargs"); + const auto& schema = op.schema(); + py::dict kwargs; - auto schemaAwareToPyObject = [&](size_t idx) -> py::object { + OperatorArgsKwargsView args_kwargs(op, arguments); + auto args = py::reinterpret_steal( + PyTuple_New(args_kwargs.num_positional_args())); + + auto schemaAwareToPyObject = + [&schema](size_t idx, const c10::IValue& argument) -> py::object { const auto& arg = schema.arguments()[idx]; auto match = [&](c10::TypeKind kind) { const auto& t = arg.real_type(); @@ -116,38 +221,42 @@ std::pair parseIValuesToPyArgsKwargs( } return false; }; - if (arguments[idx].isNone()) { + if (argument.isNone()) { return py::none(); } else if (match(c10::ScalarTypeType::Kind)) { - auto* obj = - getTHPDtype(static_cast(arguments[idx].toInt())); + auto* obj = getTHPDtype(static_cast(argument.toInt())); return py::reinterpret_borrow( reinterpret_cast(obj)); } else if (match(c10::LayoutType::Kind)) { - auto* obj = - getTHPLayout(static_cast(arguments[idx].toInt())); + auto* obj = getTHPLayout(static_cast(argument.toInt())); return py::reinterpret_borrow( reinterpret_cast(obj)); } else if (match(c10::MemoryFormatType::Kind)) { - return py::cast(static_cast(arguments[idx].toInt())); + return py::cast(static_cast(argument.toInt())); } else { - return torch::jit::toPyObject(arguments[idx]); + return torch::jit::toPyObject(argument); } }; // Populate positional arguments - for (const auto idx : c10::irange(positional_default_start)) { + size_t idx = 0; + for (auto argument_it = args_kwargs.args_begin(); + argument_it != args_kwargs.args_end(); + ++argument_it) { PyTuple_SET_ITEM( - args.ptr(), idx, schemaAwareToPyObject(idx).release().ptr()); + args.ptr(), + idx, + schemaAwareToPyObject(idx, *argument_it).release().ptr()); + idx++; } // Populate keyword arguments - for (const auto idx : c10::irange(kwarg_only_start, arguments.size())) { - // But don't populate default keyword arguments - if (is_default(idx)) - continue; - const auto& arg = schema.arguments()[idx]; - kwargs[py::cast(arg.name())] = schemaAwareToPyObject(idx); + for (auto argument_it = args_kwargs.kwargs_begin(); + argument_it != args_kwargs.kwargs_end(); + ++argument_it) { + const auto& arg = schema.arguments()[argument_it.underlying_index()]; + kwargs[py::cast(arg.name())] = + schemaAwareToPyObject(argument_it.underlying_index(), *argument_it); } return std::make_pair(std::move(args), std::move(kwargs)); } From f72772b184ffbe82bba2412787955587a66233de Mon Sep 17 00:00:00 2001 From: Will Constable Date: Wed, 5 Nov 2025 11:41:31 -0800 Subject: [PATCH 109/130] [PP] make runtime dbg log print custom actions (#167113) Previously the log only printed if the default implementation for an action was used, now it prints before dispatching to custom registered actions. Tested by running on autoparallel graph runner and observing forward pass action logged Pull Request resolved: https://github.com/pytorch/pytorch/pull/167113 Approved by: https://github.com/sanketpurandare, https://github.com/Skylion007 --- torch/distributed/pipelining/schedules.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/torch/distributed/pipelining/schedules.py b/torch/distributed/pipelining/schedules.py index e60ae3b93ba63..44569427f8db2 100644 --- a/torch/distributed/pipelining/schedules.py +++ b/torch/distributed/pipelining/schedules.py @@ -2033,12 +2033,6 @@ def _perform_action(action: _Action) -> None: is_next_stage_on_this_rank = stage_idx + 1 in stage_index_to_stage is_prev_stage_on_this_rank = stage_idx - 1 in stage_index_to_stage - logger.debug( - "_PipelineScheduleRuntime running time_step %d, action %s", - time_step, - action, - ) - # TODO(whc) it's not actually safe to use _batch_p2p here in the uncommon case the model has skip-connections, # since we do not want to batch up ops between more than a pair of ranks. _sorted_batch_p2p would be # safe to use instead. @@ -2191,6 +2185,11 @@ def _perform_action(action: _Action) -> None: # count either full_backward or backward_weight together, to determine when to sync DP grads self.backward_counter.clear() for time_step, action in enumerate(self.pipeline_order_with_comms[self.rank]): + logger.debug( + "_PipelineScheduleRuntime running time_step %d, action %s", + time_step, + action, + ) try: with record_function(_get_profiler_function_name(action)): if action.computation_type in self._comp_type_to_function_map: From c3c36534187d49da7ca2c680a641430eb9cfc404 Mon Sep 17 00:00:00 2001 From: Yuanyuan Chen Date: Thu, 6 Nov 2025 04:32:14 +0000 Subject: [PATCH 110/130] [1/N] Add return types of Python functions (#167162) This PR adds return types of some Python functions. Most of them return `None`. The types were added automatically by ruff `ANN` rules. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167162 Approved by: https://github.com/Lucaskabela --- torch/nn/attention/__init__.py | 2 +- torch/nn/attention/_utils.py | 2 +- torch/nn/attention/bias.py | 4 +-- .../experimental/_paged_attention.py | 2 +- torch/nn/attention/flex_attention.py | 12 +++---- torch/nn/backends/thnn.py | 2 +- torch/nn/cpp.py | 12 +++---- torch/nn/modules/module.py | 2 +- torch/nn/parallel/data_parallel.py | 2 +- torch/nn/parameter.py | 10 +++--- torch/nn/parameter.pyi | 4 +-- .../expanded_weights_impl.py | 14 ++++---- .../expanded_weights_utils.py | 2 +- torch/nn/utils/parametrizations.py | 2 +- torch/nn/utils/parametrize.py | 2 +- torch/nn/utils/prune.py | 32 ++++++++++--------- torch/optim/_adafactor.py | 10 +++--- torch/optim/_functional.py | 2 +- torch/optim/_muon.py | 4 +-- torch/optim/adadelta.py | 8 ++--- torch/optim/adagrad.py | 10 +++--- torch/optim/adam.py | 8 ++--- torch/optim/adamax.py | 8 ++--- torch/optim/adamw.py | 4 +-- torch/optim/asgd.py | 8 ++--- torch/optim/lbfgs.py | 6 ++-- torch/optim/lr_scheduler.py | 28 ++++++++-------- torch/optim/nadam.py | 8 ++--- torch/optim/optimizer.py | 2 +- torch/optim/radam.py | 8 ++--- torch/optim/rmsprop.py | 8 ++--- torch/optim/rprop.py | 8 ++--- torch/optim/sgd.py | 8 ++--- torch/optim/sparse_adam.py | 2 +- torch/optim/swa_utils.py | 16 ++++++---- 35 files changed, 134 insertions(+), 128 deletions(-) diff --git a/torch/nn/attention/__init__.py b/torch/nn/attention/__init__.py index 5e6e0fa5fae3b..a115d32c6e2c8 100644 --- a/torch/nn/attention/__init__.py +++ b/torch/nn/attention/__init__.py @@ -90,7 +90,7 @@ def _cur_sdpa_kernel_backends(with_priority: bool = False): return backends -def _sdpa_kernel(backends: Iterable, set_priority: bool = False): +def _sdpa_kernel(backends: Iterable, set_priority: bool = False) -> None: for name, val in _backend_names.items(): enabled = getattr(SDPBackend, val) in backends getattr(torch._C, f"_set_sdp_use_{name}")(enabled) diff --git a/torch/nn/attention/_utils.py b/torch/nn/attention/_utils.py index a91045b92c13e..86f7c29f5313a 100644 --- a/torch/nn/attention/_utils.py +++ b/torch/nn/attention/_utils.py @@ -40,7 +40,7 @@ def _validate_sdpa_input( dropout_p=0.0, is_causal=False, scale=None, -): +) -> None: if query.dtype != key.dtype or query.dtype != value.dtype: raise ValueError( f"Expected query, key, and value to have the same dtype, " diff --git a/torch/nn/attention/bias.py b/torch/nn/attention/bias.py index 551a57e6963e0..0cb256ad36f7f 100644 --- a/torch/nn/attention/bias.py +++ b/torch/nn/attention/bias.py @@ -117,7 +117,7 @@ class CausalBias(torch.Tensor): .. warning:: This class is a prototype and subject to change. """ - def __init__(self, variant: CausalVariant, seq_len_q: int, seq_len_kv: int): + def __init__(self, variant: CausalVariant, seq_len_q: int, seq_len_kv: int) -> None: """ Initializes the CausalBias instance with a specified variant and sequence lengths. @@ -296,7 +296,7 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): return cls._dispatch(*args, **kwargs) return super().__torch_function__(func, types, args, kwargs) - def __repr__(self): # type:ignore[override] + def __repr__(self) -> str: # type:ignore[override] return self._materialize().__repr__() diff --git a/torch/nn/attention/experimental/_paged_attention.py b/torch/nn/attention/experimental/_paged_attention.py index 70eadcdadfaa0..2e0ded6063aef 100644 --- a/torch/nn/attention/experimental/_paged_attention.py +++ b/torch/nn/attention/experimental/_paged_attention.py @@ -40,7 +40,7 @@ def __init__( page_size: int, max_batch_size: int, device: str = "cuda", - ): + ) -> None: # number of pages self.n_pages = n_pages diff --git a/torch/nn/attention/flex_attention.py b/torch/nn/attention/flex_attention.py index b79b86a29afb6..be49549e5740e 100644 --- a/torch/nn/attention/flex_attention.py +++ b/torch/nn/attention/flex_attention.py @@ -550,7 +550,7 @@ def __init__( full_q_indices: Optional[Tensor], BLOCK_SIZE: tuple[int, int], mask_mod: _mask_mod_signature, - ): + ) -> None: if kv_indices.dim() < 2: raise RuntimeError("BlockMask must have at least 2 dimensions") assert kv_num_blocks is not None, "kv_num_blocks must be provided" @@ -682,7 +682,7 @@ def shape(self): *batch_dims, _, _ = self.kv_indices.shape return tuple(batch_dims) + self.seq_lengths - def __str__(self): + def __str__(self) -> str: s = f"BlockMask(shape={self.shape}, sparsity={self.sparsity():.2f}%, \n" mask_str = self.to_string().strip() s += mask_str @@ -760,7 +760,7 @@ def causal_mask(b, h, q_idx, kv_idx): compute_q_blocks=self.q_indices is not None, ) - def __repr__(self): + def __repr__(self) -> str: def shape_or_none(x: Optional[torch.Tensor]): return x.shape if x is not None else None @@ -864,7 +864,7 @@ def create_block_vis(*batch_idx): vis = ", ".join(reversed(descriptors)) + "\n" - def summarize_section(section): + def summarize_section(section) -> str: percentage = section.float().mean().item() if percentage == 1: return "█" @@ -1289,7 +1289,7 @@ def _apply_kernel_options( return kernel_options -def _validate_embed_dim(query: Tensor, key: Tensor, value: Tensor): +def _validate_embed_dim(query: Tensor, key: Tensor, value: Tensor) -> None: if query.size(-1) != key.size(-1): raise ValueError( f"Expect query and key/value to have the same embedding dimension " @@ -1297,7 +1297,7 @@ def _validate_embed_dim(query: Tensor, key: Tensor, value: Tensor): ) -def _validate_device(query: Tensor, key: Tensor, value: Tensor): +def _validate_device(query: Tensor, key: Tensor, value: Tensor) -> None: """TODO: Remove once non cuda/cpu devices support is added We only need to check query since we have already that q,k,v are on the same device """ diff --git a/torch/nn/backends/thnn.py b/torch/nn/backends/thnn.py index 8564153ece233..c56e923a84383 100644 --- a/torch/nn/backends/thnn.py +++ b/torch/nn/backends/thnn.py @@ -2,5 +2,5 @@ # this is for historical pickle deserialization, it is not used otherwise -def _get_thnn_function_backend(): +def _get_thnn_function_backend() -> None: pass diff --git a/torch/nn/cpp.py b/torch/nn/cpp.py index e447284ad82ba..b4ffd188cd39a 100644 --- a/torch/nn/cpp.py +++ b/torch/nn/cpp.py @@ -14,7 +14,7 @@ class OrderedDictWrapper: so using properties does not work. """ - def __init__(self, cpp_module, attr): + def __init__(self, cpp_module, attr) -> None: self.cpp_module = cpp_module self.attr = attr @@ -37,10 +37,10 @@ def values(self): def __iter__(self): return self.cpp_dict.__iter__() - def __len__(self): + def __len__(self) -> int: return self.cpp_dict.__len__() - def __contains__(self, key): + def __contains__(self, key) -> bool: return self.cpp_dict.__contains__(key) def __getitem__(self, key): @@ -50,7 +50,7 @@ def __getitem__(self, key): class ModuleWrapper(nn.Module): """A subclass of ``torch.nn.Module`` that wraps a C++ frontend module and delegates all access.""" - def __init__(self, cpp_module): + def __init__(self, cpp_module) -> None: # Assign before the super class constructor so ``self.training`` can be # assigned to in the super class constructor. self.cpp_module = cpp_module @@ -83,8 +83,8 @@ def training(self): return self.cpp_module.training @training.setter - def training(self, mode): + def training(self, mode) -> None: self.cpp_module.train(mode) - def __repr__(self): + def __repr__(self) -> str: return self.cpp_module.__repr__() diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py index f7e3d2f262def..10a240e3a9cf7 100644 --- a/torch/nn/modules/module.py +++ b/torch/nn/modules/module.py @@ -3040,7 +3040,7 @@ def _replicate_for_data_parallel(self): return replica - def compile(self, *args, **kwargs): + def compile(self, *args, **kwargs) -> None: """ Compile this Module's forward using :func:`torch.compile`. diff --git a/torch/nn/parallel/data_parallel.py b/torch/nn/parallel/data_parallel.py index 9a0f4973d31b2..9aaa9b4a92e6d 100644 --- a/torch/nn/parallel/data_parallel.py +++ b/torch/nn/parallel/data_parallel.py @@ -30,7 +30,7 @@ def _check_balance(device_ids: Sequence[Union[int, torch.device]]) -> None: device_ids = [_get_device_index(x, True) for x in device_ids] dev_props = _get_devices_properties(device_ids) - def warn_imbalance(get_prop): + def warn_imbalance(get_prop) -> bool: values = [get_prop(props) for props in dev_props] min_pos, min_val = min(enumerate(values), key=operator.itemgetter(1)) max_pos, max_val = max(enumerate(values), key=operator.itemgetter(1)) diff --git a/torch/nn/parameter.py b/torch/nn/parameter.py index c03c85f48fc35..64e9d8c2d80f2 100644 --- a/torch/nn/parameter.py +++ b/torch/nn/parameter.py @@ -18,7 +18,7 @@ # Metaclass to combine _TensorMeta and the instance check override for Parameter. class _ParameterMeta(torch._C._TensorMeta): # Make `isinstance(t, Parameter)` return True for custom tensor instances that have the _is_param flag. - def __instancecheck__(self, instance): + def __instancecheck__(self, instance) -> bool: if self is Parameter: if isinstance(instance, torch.Tensor) and getattr( instance, "_is_param", False @@ -82,7 +82,7 @@ def __deepcopy__(self, memo): return result # pyrefly: ignore [bad-override] - def __repr__(self): + def __repr__(self) -> str: return "Parameter containing:\n" + super().__repr__() def __reduce_ex__(self, proto): @@ -125,7 +125,7 @@ class UninitializedTensorMixin: torch._has_compatible_shallow_copy_type, ] - def materialize(self, shape, device=None, dtype=None): + def materialize(self, shape, device=None, dtype=None) -> None: r"""Create a Parameter or Tensor with the same properties of the uninitialized one. Given a shape, it materializes a parameter in the same device @@ -163,7 +163,7 @@ def share_memory_(self): "`module.share_memory()`." ) - def __repr__(self): + def __repr__(self) -> str: return f"<{self.__class__.__name__}>" def __reduce_ex__(self, proto): @@ -235,7 +235,7 @@ def __deepcopy__(self, memo): # Metaclass to combine _TensorMeta and the instance check override for Buffer. class _BufferMeta(torch._C._TensorMeta): # Make `isinstance(t, Buffer)` return True for custom tensor instances that have the _is_buffer flag. - def __instancecheck__(self, instance): + def __instancecheck__(self, instance) -> bool: if self is Buffer: if isinstance(instance, torch.Tensor) and getattr( instance, "_is_buffer", False diff --git a/torch/nn/parameter.pyi b/torch/nn/parameter.pyi index a17821c2b16c1..3d1cddb7e8b8b 100644 --- a/torch/nn/parameter.pyi +++ b/torch/nn/parameter.pyi @@ -25,7 +25,7 @@ class Buffer(Tensor): data: Tensor = ..., requires_grad: bool = ..., persistent: bool = ..., - ): ... + ) -> None: ... class UninitializedBuffer(Tensor): persistent: bool @@ -34,7 +34,7 @@ class UninitializedBuffer(Tensor): data: Tensor = ..., requires_grad: bool = ..., persistent: bool = ..., - ): ... + ) -> None: ... def materialize( self, shape: tuple[int, ...], diff --git a/torch/nn/utils/_expanded_weights/expanded_weights_impl.py b/torch/nn/utils/_expanded_weights/expanded_weights_impl.py index cfb1d99ac30ec..58ef67e06148a 100644 --- a/torch/nn/utils/_expanded_weights/expanded_weights_impl.py +++ b/torch/nn/utils/_expanded_weights/expanded_weights_impl.py @@ -37,10 +37,10 @@ # all of the RNN decomps run linear with the batch dimension second, even if batch_first was set @contextmanager def batch_second(args, kwargs): - def set_batch_second(ew): + def set_batch_second(ew) -> None: ew.set_batch_first(False) - def reset_batch_first(ew): + def reset_batch_first(ew) -> None: ew.set_batch_first(True) tree_map_only(ExpandedWeight, set_batch_second, args) @@ -55,10 +55,10 @@ def reset_batch_first(ew): # to support packed sequences, we need to allow for smaller batches. Expanded weights represents the largest batch @contextmanager def allow_smaller_batches(args, kwargs): - def allow(ew): + def allow(ew) -> None: ew.set_allow_smaller_batches(True) - def reset(ew): + def reset(ew) -> None: ew.set_allow_smaller_batches(False) tree_map_only(ExpandedWeight, allow, args) @@ -102,7 +102,7 @@ def decorator(autograd_func): # # Needs to be a tensor subclass to allow reparameterization class ExpandedWeight(torch.Tensor): - def __init__(self, orig_weight, batch_size, loss_reduction): + def __init__(self, orig_weight, batch_size, loss_reduction) -> None: self.batch_size = batch_size self.batch_first = True self.allow_smaller_batches = False @@ -179,8 +179,8 @@ def data_ptr(self): def get_device(self): return self.orig_weight.get_device() - def set_allow_smaller_batches(self, is_allow_smaller_batches): + def set_allow_smaller_batches(self, is_allow_smaller_batches) -> None: self.allow_smaller_batches = is_allow_smaller_batches - def set_batch_first(self, is_batch_first=True): + def set_batch_first(self, is_batch_first=True) -> None: self.batch_first = is_batch_first diff --git a/torch/nn/utils/_expanded_weights/expanded_weights_utils.py b/torch/nn/utils/_expanded_weights/expanded_weights_utils.py index ec6d55305fb46..eacd717873ec2 100644 --- a/torch/nn/utils/_expanded_weights/expanded_weights_utils.py +++ b/torch/nn/utils/_expanded_weights/expanded_weights_utils.py @@ -123,7 +123,7 @@ def maybe_scale_by_batch_size(grad_sample, expanded_weight): return grad_sample -def set_grad_sample_if_exists(maybe_expanded_weight, per_sample_grad_fn): +def set_grad_sample_if_exists(maybe_expanded_weight, per_sample_grad_fn) -> None: unpacked = unpack_expanded_weight_or_tensor(maybe_expanded_weight) if isinstance(maybe_expanded_weight, ExpandedWeight): grad_sample_contribution = maybe_scale_by_batch_size( diff --git a/torch/nn/utils/parametrizations.py b/torch/nn/utils/parametrizations.py index 7706be61e39f1..59044b72b96cd 100644 --- a/torch/nn/utils/parametrizations.py +++ b/torch/nn/utils/parametrizations.py @@ -388,7 +388,7 @@ def _weight_norm_compat_hook( missing_keys, unexpected_keys, error_msgs, - ): + ) -> None: g_key = f"{prefix}{name}_g" v_key = f"{prefix}{name}_v" if g_key in state_dict and v_key in state_dict: diff --git a/torch/nn/utils/parametrize.py b/torch/nn/utils/parametrize.py index 88eeb3aaf50c3..b9a1140e43f71 100644 --- a/torch/nn/utils/parametrize.py +++ b/torch/nn/utils/parametrize.py @@ -72,7 +72,7 @@ def cached(): _cache = {} -def _register_parameter_or_buffer(module, name, X): +def _register_parameter_or_buffer(module, name, X) -> None: if isinstance(X, Parameter): module.register_parameter(name, X) else: diff --git a/torch/nn/utils/prune.py b/torch/nn/utils/prune.py index 3c1a800085951..827bf19ed4bea 100644 --- a/torch/nn/utils/prune.py +++ b/torch/nn/utils/prune.py @@ -231,7 +231,7 @@ def prune(self, t, default_mask=None, importance_scores=None): default_mask = default_mask if default_mask is not None else torch.ones_like(t) return t * self.compute_mask(importance_scores, default_mask=default_mask) - def remove(self, module): + def remove(self, module) -> None: r"""Remove the pruning reparameterization from a module. The pruned parameter named ``name`` remains permanently pruned, @@ -269,7 +269,7 @@ class PruningContainer(BasePruningMethod): them. """ - def __init__(self, *args): + def __init__(self, *args) -> None: self._pruning_methods: tuple[BasePruningMethod, ...] = () if not isinstance(args, Iterable): # only 1 item self._tensor_name = args._tensor_name @@ -284,7 +284,7 @@ def __init__(self, *args): for method in args: self.add_pruning_method(method) - def add_pruning_method(self, method): + def add_pruning_method(self, method) -> None: r"""Add a child pruning ``method`` to the container. Args: @@ -303,7 +303,7 @@ def add_pruning_method(self, method): # if all checks passed, add to _pruning_methods tuple self._pruning_methods += (method,) # type: ignore[operator] - def __len__(self): + def __len__(self) -> int: return len(self._pruning_methods) def __iter__(self): @@ -449,7 +449,7 @@ class RandomUnstructured(BasePruningMethod): PRUNING_TYPE = "unstructured" - def __init__(self, amount): + def __init__(self, amount) -> None: # Check range of validity of pruning amount _validate_pruning_amount_init(amount) self.amount = amount @@ -506,7 +506,7 @@ class L1Unstructured(BasePruningMethod): PRUNING_TYPE = "unstructured" - def __init__(self, amount): + def __init__(self, amount) -> None: # Check range of validity of pruning amount _validate_pruning_amount_init(amount) self.amount = amount @@ -574,7 +574,7 @@ class RandomStructured(BasePruningMethod): PRUNING_TYPE = "structured" - def __init__(self, amount, dim=-1): + def __init__(self, amount, dim=-1) -> None: # Check range of validity of amount _validate_pruning_amount_init(amount) self.amount = amount @@ -682,7 +682,7 @@ class LnStructured(BasePruningMethod): PRUNING_TYPE = "structured" - def __init__(self, amount, n, dim=-1): + def __init__(self, amount, n, dim=-1) -> None: # Check range of validity of amount _validate_pruning_amount_init(amount) self.amount = amount @@ -799,7 +799,7 @@ def apply(cls, module, name, amount, n, dim, importance_scores=None): # type: i class CustomFromMask(BasePruningMethod): PRUNING_TYPE = "global" - def __init__(self, mask): + def __init__(self, mask) -> None: self.mask = mask def compute_mask(self, t, default_mask): @@ -1025,7 +1025,9 @@ def ln_structured(module, name, amount, n, dim, importance_scores=None): return module -def global_unstructured(parameters, pruning_method, importance_scores=None, **kwargs): +def global_unstructured( + parameters, pruning_method, importance_scores=None, **kwargs +) -> None: r""" Globally prunes tensors corresponding to all parameters in ``parameters`` by applying the specified ``pruning_method``. @@ -1212,7 +1214,7 @@ def remove(module, name): ) -def is_pruned(module): +def is_pruned(module) -> bool: r"""Check if a module is pruned by looking for pruning pre-hooks. Check whether ``module`` is pruned by looking for @@ -1241,7 +1243,7 @@ def is_pruned(module): return False -def _validate_pruning_amount_init(amount): +def _validate_pruning_amount_init(amount) -> None: r"""Validate helper to check the range of amount at init. Args: @@ -1271,7 +1273,7 @@ def _validate_pruning_amount_init(amount): ) -def _validate_pruning_amount(amount, tensor_size): +def _validate_pruning_amount(amount, tensor_size) -> None: r"""Validate that the pruning amount is meaningful wrt to the size of the data. Validation helper to check that the amount of parameters to prune @@ -1295,7 +1297,7 @@ def _validate_pruning_amount(amount, tensor_size): ) -def _validate_structured_pruning(t): +def _validate_structured_pruning(t) -> None: r"""Validate that the tensor to be pruned is at least 2-Dimensional. Validation helper to check that the tensor to be pruned is multi- @@ -1342,7 +1344,7 @@ def _compute_nparams_toprune(amount, tensor_size): return round(amount * tensor_size) -def _validate_pruning_dim(t, dim): +def _validate_pruning_dim(t, dim) -> None: r"""Validate that the pruning dimension is within the bounds of the tensor dimension. Args: diff --git a/torch/optim/_adafactor.py b/torch/optim/_adafactor.py index 4def193daf190..c417b354429b5 100644 --- a/torch/optim/_adafactor.py +++ b/torch/optim/_adafactor.py @@ -32,7 +32,7 @@ def __init__( *, foreach: Optional[bool] = None, maximize: bool = False, - ): + ) -> None: if isinstance(lr, Tensor) and lr.numel() != 1: raise ValueError("Tensor lr must be 1-element") if not 0.0 <= lr: @@ -77,7 +77,7 @@ def _init_group( col_vars, variances, state_steps, - ): + ) -> bool: for p in group["params"]: if p.grad is None: continue @@ -349,7 +349,7 @@ def _single_tensor_adafactor( eps2: float, maximize: bool, has_complex: bool, -): +) -> None: if grad_scale is not None or found_inf is not None: raise AssertionError("Grad scaling should occur outside of optimizer.step()") @@ -473,7 +473,7 @@ def _multi_tensor_adafactor( eps2: float, maximize: bool, has_complex: bool, -): +) -> None: if len(params) == 0: return @@ -624,7 +624,7 @@ def adafactor( eps1: float, eps2: float, maximize: bool, -): +) -> None: r"""Functional API that performs Adafactor algorithm computation. See :class:`~torch.optim.Adafactor` for details. diff --git a/torch/optim/_functional.py b/torch/optim/_functional.py index 9b2c76700b356..ba97bc9979378 100644 --- a/torch/optim/_functional.py +++ b/torch/optim/_functional.py @@ -33,7 +33,7 @@ def sparse_adam( beta2: float, lr: float, maximize: bool, -): +) -> None: r"""Functional API that performs Sparse Adam algorithm computation. See :class:`~torch.optim.SparseAdam` for details. diff --git a/torch/optim/_muon.py b/torch/optim/_muon.py index 7b7167a40fc1c..5b7b9892daf3a 100644 --- a/torch/optim/_muon.py +++ b/torch/optim/_muon.py @@ -141,7 +141,7 @@ def _init_group( params_with_grad: list[Tensor], grads: list[Tensor], muon_momentum_bufs: list[Tensor], - ): + ) -> bool: for p in group["params"]: if p.grad is None: continue @@ -337,7 +337,7 @@ def muon( eps: float, adjust_lr_fn: Optional[str], has_complex: bool, -): +) -> None: r"""Functional API that performs Muon algorithm computation. See :class:`~torch.optim.Muon` for details. diff --git a/torch/optim/adadelta.py b/torch/optim/adadelta.py index 4a893026451ae..75ac77790e309 100644 --- a/torch/optim/adadelta.py +++ b/torch/optim/adadelta.py @@ -38,7 +38,7 @@ def __init__( capturable: bool = False, maximize: bool = False, differentiable: bool = False, - ): + ) -> None: if isinstance(lr, Tensor) and lr.numel() != 1: raise ValueError("Tensor lr must be 1-element") if not 0.0 <= lr: @@ -257,7 +257,7 @@ def _single_tensor_adadelta( differentiable: bool, capturable: bool, has_complex: bool, -): +) -> None: # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] if not torch.compiler.is_compiling() and capturable: capturable_supported_devices = _get_capturable_supported_devices( @@ -317,7 +317,7 @@ def _multi_tensor_adadelta( differentiable: bool, capturable: bool, has_complex: bool, -): +) -> None: if differentiable: raise AssertionError("_foreach ops don't support autograd") @@ -427,7 +427,7 @@ def adadelta( eps: float, weight_decay: float, maximize: bool, -): +) -> None: r"""Functional API that performs Adadelta algorithm computation. See :class:`~torch.optim.Adadelta` for details. diff --git a/torch/optim/adagrad.py b/torch/optim/adagrad.py index 4d2523b2a16af..519900ab5da63 100644 --- a/torch/optim/adagrad.py +++ b/torch/optim/adagrad.py @@ -38,7 +38,7 @@ def __init__( maximize: bool = False, differentiable: bool = False, fused: Optional[bool] = None, - ): + ) -> None: if isinstance(lr, Tensor) and lr.numel() != 1: raise ValueError("Tensor lr must be 1-element") if not 0.0 <= lr: @@ -116,7 +116,7 @@ def __setstate__(self, state): float(s["step"]), dtype=_get_scalar_dtype(is_fused=fused) ) - def share_memory(self): + def share_memory(self) -> None: """Calls tensor.share_memory_() on the state sum tensors.""" for group in self.param_groups: for p in group["params"]: @@ -261,7 +261,7 @@ def adagrad( lr_decay: float, eps: float, maximize: bool, -): +) -> None: r"""Functional API that performs Adagrad algorithm computation. See :class:`~torch.optim.Adagrad` for details. @@ -336,7 +336,7 @@ def _single_tensor_adagrad( maximize: bool, differentiable: bool, has_complex: bool, -): +) -> None: if grad_scale is not None or found_inf is not None: raise AssertionError("Expected grad_scale and found_inf to be None") @@ -404,7 +404,7 @@ def _multi_tensor_adagrad( maximize: bool, differentiable: bool, has_complex: bool, -): +) -> None: if differentiable: raise AssertionError("_foreach ops don't support autograd") if grad_scale is not None or found_inf is not None: diff --git a/torch/optim/adam.py b/torch/optim/adam.py index 5ceadccce86a5..6b8fd5b7e70f6 100644 --- a/torch/optim/adam.py +++ b/torch/optim/adam.py @@ -47,7 +47,7 @@ def __init__( differentiable: bool = False, fused: Optional[bool] = None, decoupled_weight_decay: bool = False, - ): + ) -> None: if isinstance(lr, Tensor): if foreach and not capturable: raise ValueError( @@ -365,7 +365,7 @@ def _single_tensor_adam( capturable: bool, differentiable: bool, decoupled_weight_decay: bool, -): +) -> None: if grad_scale is not None or found_inf is not None: raise AssertionError("Expected grad_scale and found_inf to be None") @@ -572,7 +572,7 @@ def _multi_tensor_adam( capturable: bool, differentiable: bool, decoupled_weight_decay: bool, -): +) -> None: if len(params) == 0: return @@ -925,7 +925,7 @@ def adam( weight_decay: float, eps: float, maximize: bool, -): +) -> None: r"""Functional API that performs Adam algorithm computation. See :class:`~torch.optim.Adam` for details. diff --git a/torch/optim/adamax.py b/torch/optim/adamax.py index 76d784d6ea764..264451dbb4091 100644 --- a/torch/optim/adamax.py +++ b/torch/optim/adamax.py @@ -39,7 +39,7 @@ def __init__( maximize: bool = False, differentiable: bool = False, capturable: bool = False, - ): + ) -> None: if isinstance(lr, Tensor) and lr.numel() != 1: raise ValueError("Tensor lr must be 1-element") if not 0.0 <= lr: @@ -239,7 +239,7 @@ def _single_tensor_adamax( differentiable: bool, capturable: bool, has_complex: bool, -): +) -> None: if not torch.jit.is_scripting(): lr = _to_scalar(lr) @@ -319,7 +319,7 @@ def _multi_tensor_adamax( differentiable: bool, capturable: bool, has_complex: bool, -): +) -> None: if differentiable: raise AssertionError("_foreach ops don't support autograd") @@ -441,7 +441,7 @@ def adamax( beta2: float, lr: float, weight_decay: float, -): +) -> None: r"""Functional API that performs adamax algorithm computation. See :class:`~torch.optim.Adamax` for details. diff --git a/torch/optim/adamw.py b/torch/optim/adamw.py index 0558cbddd883b..2c968fabb698c 100644 --- a/torch/optim/adamw.py +++ b/torch/optim/adamw.py @@ -33,7 +33,7 @@ def __init__( capturable: bool = False, differentiable: bool = False, fused: Optional[bool] = None, - ): + ) -> None: super().__init__( params, lr, @@ -152,7 +152,7 @@ def adamw( weight_decay: float, eps: float, maximize: bool, -): +) -> None: r"""Functional API that performs AdamW algorithm computation. See :class:`~torch.optim.AdamW` for details. diff --git a/torch/optim/asgd.py b/torch/optim/asgd.py index 0008694bda18b..0af7f9b4e6f6d 100644 --- a/torch/optim/asgd.py +++ b/torch/optim/asgd.py @@ -39,7 +39,7 @@ def __init__( maximize: bool = False, differentiable: bool = False, capturable: bool = False, - ): + ) -> None: if isinstance(lr, Tensor) and lr.numel() != 1: raise ValueError("Tensor lr must be 1-element") if not 0.0 <= lr: @@ -211,7 +211,7 @@ def _single_tensor_asgd( differentiable: bool, capturable: bool, has_complex: bool, -): +) -> None: if not torch.jit.is_scripting(): lr = _to_scalar(lr) @@ -292,7 +292,7 @@ def _multi_tensor_asgd( differentiable: bool, capturable: bool, has_complex: bool, -): +) -> None: if len(params) == 0: return @@ -442,7 +442,7 @@ def asgd( t0: float, alpha: float, weight_decay: float, -): +) -> None: r"""Functional API that performs asgd algorithm computation. See :class:`~torch.optim.ASGD` for details. diff --git a/torch/optim/lbfgs.py b/torch/optim/lbfgs.py index ae4b286ffa225..3d138f6a43f76 100644 --- a/torch/optim/lbfgs.py +++ b/torch/optim/lbfgs.py @@ -254,7 +254,7 @@ def __init__( tolerance_change: float = 1e-9, history_size: int = 100, line_search_fn: Optional[str] = None, - ): + ) -> None: if isinstance(lr, Tensor) and lr.numel() != 1: raise ValueError("Tensor lr must be 1-element") if not 0.0 <= lr: @@ -304,7 +304,7 @@ def _gather_flat_grad(self): views.append(view) return torch.cat(views, 0) - def _add_grad(self, step_size, update): + def _add_grad(self, step_size, update) -> None: offset = 0 for p in self._params: if torch.is_complex(p): @@ -319,7 +319,7 @@ def _add_grad(self, step_size, update): def _clone_param(self): return [p.clone(memory_format=torch.contiguous_format) for p in self._params] - def _set_param(self, params_data): + def _set_param(self, params_data) -> None: for p, pdata in zip(self._params, params_data, strict=True): p.copy_(pdata) diff --git a/torch/optim/lr_scheduler.py b/torch/optim/lr_scheduler.py index 71dcb6129a8ec..6426283e6542c 100644 --- a/torch/optim/lr_scheduler.py +++ b/torch/optim/lr_scheduler.py @@ -89,7 +89,9 @@ def _param_groups_val_list(optimizer: Optimizer, key: str) -> list[Any]: ] -def _update_param_group_val(param_group: dict[str, Any], key: str, val: float | Tensor): +def _update_param_group_val( + param_group: dict[str, Any], key: str, val: float | Tensor +) -> None: """Set param_group[key] to val without aliasing or assignment when they're both tensors. Raises a KeyError if param_group[key] does not exist. """ @@ -196,7 +198,7 @@ def state_dict(self) -> dict[str, Any]: key: value for key, value in self.__dict__.items() if key != "optimizer" } - def load_state_dict(self, state_dict: dict[str, Any]): + def load_state_dict(self, state_dict: dict[str, Any]) -> None: """Load the scheduler's state. Args: @@ -288,7 +290,7 @@ def step(self, epoch: Optional[int] = None) -> None: warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning, stacklevel=2) self._update_lr(epoch) - def _update_lr(self, epoch: Optional[int] = None): + def _update_lr(self, epoch: Optional[int] = None) -> None: with _enable_get_lr_call(self): if epoch is None: self.last_epoch += 1 @@ -339,7 +341,7 @@ def __exit__(self, type, value, traceback) -> None: class _initial_mode: - def __init__(self, o: LRScheduler): + def __init__(self, o: LRScheduler) -> None: self.o = o def __enter__(self): @@ -1180,7 +1182,7 @@ def __init__( self._last_lr = schedulers[0].get_last_lr() - def recursive_undo(self, sched=None): + def recursive_undo(self, sched=None) -> None: """ Recursively undo any step performed by the initialisation of schedulers. @@ -1659,7 +1661,7 @@ def __init__( cooldown: int = 0, min_lr: Union[list[float], float] = 0, eps: float = 1e-8, - ): # noqa: D107 + ) -> None: # noqa: D107 if factor >= 1.0: raise ValueError("Factor should be < 1.0.") self.factor = factor @@ -1691,7 +1693,7 @@ def __init__( ) self._reset() - def _reset(self): + def _reset(self) -> None: """Reset num_bad_epochs counter and cooldown counter.""" self.best = self.mode_worse self.cooldown_counter = 0 @@ -1724,7 +1726,7 @@ def step(self, metrics: SupportsFloat, epoch=None) -> None: # type: ignore[over self._last_lr = _param_groups_val_list(self.optimizer, "lr") - def _reduce_lr(self, epoch): + def _reduce_lr(self, epoch) -> None: if len(self.optimizer.param_groups) != len(self.min_lrs): if self.default_min_lr is None: raise RuntimeError( @@ -1765,7 +1767,7 @@ def _is_better(self, a, best): # noqa: D102 else: # mode == 'max' and epsilon_mode == 'abs': return a > best + self.threshold - def _init_is_better(self, mode, threshold, threshold_mode): + def _init_is_better(self, mode, threshold, threshold_mode) -> None: if mode not in {"min", "max"}: raise ValueError("mode " + mode + " is unknown!") if threshold_mode not in {"rel", "abs"}: @@ -1904,7 +1906,7 @@ def __init__( base_momentum: float = 0.8, max_momentum: float = 0.9, last_epoch: int = -1, - ): # noqa: D107 + ) -> None: # noqa: D107 # Attach optimizer if not isinstance(optimizer, Optimizer): raise TypeError(f"{type(optimizer).__name__} is not an Optimizer") @@ -1970,7 +1972,7 @@ def __init__( super().__init__(optimizer, last_epoch) self.base_lrs = base_lrs - def _init_scale_fn(self): + def _init_scale_fn(self) -> None: if self._scale_fn_custom is not None: return if self.mode == "triangular": @@ -2155,7 +2157,7 @@ def __init__( T_mult: int = 1, eta_min: float = 0.0, last_epoch: int = -1, - ): # noqa: D107 + ) -> None: # noqa: D107 if T_0 <= 0 or not isinstance(T_0, int): raise ValueError(f"Expected positive integer T_0, but got {T_0}") if T_mult < 1 or not isinstance(T_mult, int): @@ -2407,7 +2409,7 @@ def __init__( final_div_factor: float = 1e4, three_phase: bool = False, last_epoch: int = -1, - ): # noqa: D107 + ) -> None: # noqa: D107 # Validate optimizer if not isinstance(optimizer, Optimizer): raise TypeError(f"{type(optimizer).__name__} is not an Optimizer") diff --git a/torch/optim/nadam.py b/torch/optim/nadam.py index 508648a65c14a..f83cd4b85d02f 100644 --- a/torch/optim/nadam.py +++ b/torch/optim/nadam.py @@ -44,7 +44,7 @@ def __init__( maximize: bool = False, capturable: bool = False, differentiable: bool = False, - ): # noqa: D107 + ) -> None: # noqa: D107 if isinstance(lr, Tensor) and lr.numel() != 1: raise ValueError("Tensor lr must be 1-element") if not 0.0 <= lr: @@ -297,7 +297,7 @@ def _single_tensor_nadam( capturable: bool, differentiable: bool, has_complex: bool, -): +) -> None: if not torch.jit.is_scripting(): lr = _to_scalar(lr) @@ -397,7 +397,7 @@ def _multi_tensor_nadam( capturable: bool, differentiable: bool, has_complex: bool, -): +) -> None: if len(params) == 0: return @@ -624,7 +624,7 @@ def nadam( weight_decay: float, momentum_decay: float, eps: float, -): +) -> None: r"""Functional API that performs NAdam algorithm computation. See :class:`~torch.optim.NAdam` for details. diff --git a/torch/optim/optimizer.py b/torch/optim/optimizer.py index 6a336fa5bab70..c42ea3cfb02d5 100644 --- a/torch/optim/optimizer.py +++ b/torch/optim/optimizer.py @@ -204,7 +204,7 @@ def _device_dtype_check_for_fused( ) -def _view_as_real(params, *state_and_grads): +def _view_as_real(params, *state_and_grads) -> None: for i, p in enumerate(params): if torch.is_complex(p): params[i] = torch.view_as_real(params[i]) diff --git a/torch/optim/radam.py b/torch/optim/radam.py index e13e6806e43a7..db69bbb01a042 100644 --- a/torch/optim/radam.py +++ b/torch/optim/radam.py @@ -42,7 +42,7 @@ def __init__( maximize: bool = False, capturable: bool = False, differentiable: bool = False, - ): # noqa: D107 + ) -> None: # noqa: D107 if isinstance(lr, Tensor) and lr.numel() != 1: raise ValueError("Tensor lr must be 1-element") if not 0.0 <= lr: @@ -270,7 +270,7 @@ def _single_tensor_radam( maximize: bool, capturable: bool, has_complex: bool, -): +) -> None: if not torch.jit.is_scripting(): lr = _to_scalar(lr) @@ -377,7 +377,7 @@ def _multi_tensor_radam( maximize: bool, capturable: bool, has_complex: bool, -): +) -> None: if len(params) == 0: return @@ -586,7 +586,7 @@ def radam( lr: float, weight_decay: float, eps: float, -): +) -> None: r"""Functional API that performs RAdam algorithm computation. See :class:`~torch.optim.RAdam` for details. diff --git a/torch/optim/rmsprop.py b/torch/optim/rmsprop.py index 04981d517d1ef..364068ecc9ab3 100644 --- a/torch/optim/rmsprop.py +++ b/torch/optim/rmsprop.py @@ -41,7 +41,7 @@ def __init__( foreach: Optional[bool] = None, maximize: bool = False, differentiable: bool = False, - ): # noqa: D107 + ) -> None: # noqa: D107 if isinstance(lr, Tensor) and lr.numel() != 1: raise ValueError("Tensor lr must be 1-element") if not 0.0 <= lr: @@ -280,7 +280,7 @@ def _single_tensor_rmsprop( differentiable: bool, capturable: bool, has_complex: bool, -): +) -> None: if not torch.jit.is_scripting(): lr = _to_scalar(lr) @@ -357,7 +357,7 @@ def _multi_tensor_rmsprop( differentiable: bool, capturable: bool, has_complex: bool, -): +) -> None: if len(params) == 0: return @@ -495,7 +495,7 @@ def rmsprop( weight_decay: float, momentum: float, centered: bool, -): +) -> None: r"""Functional API that performs rmsprop algorithm computation. See :class:`~torch.optim.RMSProp` for details. diff --git a/torch/optim/rprop.py b/torch/optim/rprop.py index 8ad7faf130e39..c9e1d5eabaeee 100644 --- a/torch/optim/rprop.py +++ b/torch/optim/rprop.py @@ -39,7 +39,7 @@ def __init__( foreach: Optional[bool] = None, maximize: bool = False, differentiable: bool = False, - ): # noqa: D107 + ) -> None: # noqa: D107 if isinstance(lr, Tensor) and lr.numel() != 1: raise ValueError("Tensor lr must be 1-element") if not 0.0 <= lr: @@ -235,7 +235,7 @@ def _single_tensor_rprop( capturable: bool, differentiable: bool, has_complex: bool, -): +) -> None: for i, param in enumerate(params): grad = grads[i] grad = grad if not maximize else -grad @@ -306,7 +306,7 @@ def _multi_tensor_rprop( capturable: bool, differentiable: bool, has_complex: bool, -): +) -> None: if len(params) == 0: return @@ -428,7 +428,7 @@ def rprop( step_size_max: float, etaminus: float, etaplus: float, -): +) -> None: r"""Functional API that performs rprop algorithm computation. See :class:`~torch.optim.Rprop` for details. diff --git a/torch/optim/sgd.py b/torch/optim/sgd.py index 9c2c5a0eab3d0..63c80d645cd08 100644 --- a/torch/optim/sgd.py +++ b/torch/optim/sgd.py @@ -39,7 +39,7 @@ def __init__( foreach: Optional[bool] = None, differentiable: bool = False, fused: Optional[bool] = None, - ): # noqa: D107 + ) -> None: # noqa: D107 if isinstance(lr, Tensor) and lr.numel() != 1: raise ValueError("Tensor lr must be 1-element") if lr < 0.0: @@ -267,7 +267,7 @@ def sgd( dampening: float, nesterov: bool, maximize: bool, -): +) -> None: r"""Functional API that performs SGD algorithm computation. See :class:`~torch.optim.SGD` for details. @@ -333,7 +333,7 @@ def _single_tensor_sgd( nesterov: bool, maximize: bool, has_sparse_grad: bool, -): +) -> None: if grad_scale is not None or found_inf is not None: raise AssertionError("Expected grad_scale and found_inf to be None") @@ -394,7 +394,7 @@ def _multi_tensor_sgd( nesterov: bool, maximize: bool, has_sparse_grad: bool, -): +) -> None: if grad_scale is not None or found_inf is not None: raise AssertionError("Expected grad_scale and found_inf to be None") diff --git a/torch/optim/sparse_adam.py b/torch/optim/sparse_adam.py index ca87e87ce8674..ed58c93181ae2 100644 --- a/torch/optim/sparse_adam.py +++ b/torch/optim/sparse_adam.py @@ -19,7 +19,7 @@ def __init__( betas: tuple[float, float] = (0.9, 0.999), eps: float = 1e-8, maximize: bool = False, - ): + ) -> None: if isinstance(lr, Tensor) and lr.numel() != 1: raise ValueError("Tensor lr must be 1-element") if not 0.0 < lr: diff --git a/torch/optim/swa_utils.py b/torch/optim/swa_utils.py index 254560d8751ce..ebe3e07025957 100644 --- a/torch/optim/swa_utils.py +++ b/torch/optim/swa_utils.py @@ -43,7 +43,9 @@ def get_ema_multi_avg_fn(decay=0.999): ) @torch.no_grad() - def ema_update(ema_param_list: PARAM_LIST, current_param_list: PARAM_LIST, _): + def ema_update( + ema_param_list: PARAM_LIST, current_param_list: PARAM_LIST, _ + ) -> None: # foreach lerp only handles float and complex if torch.is_floating_point(ema_param_list[0]) or torch.is_complex( ema_param_list[0] @@ -64,7 +66,7 @@ def swa_update( averaged_param_list: PARAM_LIST, current_param_list: PARAM_LIST, num_averaged: Union[Tensor, int], - ): + ) -> None: # foreach lerp only handles float and complex if torch.is_floating_point(averaged_param_list[0]) or torch.is_complex( averaged_param_list[0] @@ -227,7 +229,7 @@ def __init__( Callable[[PARAM_LIST, PARAM_LIST, Union[Tensor, int]], None] ] = None, use_buffers=False, - ): # noqa: D107 + ) -> None: # noqa: D107 super().__init__() if avg_fn is not None and multi_avg_fn is not None: raise AssertionError( @@ -247,7 +249,7 @@ def forward(self, *args, **kwargs): """Forward pass.""" return self.module(*args, **kwargs) - def update_parameters(self, model: Module): + def update_parameters(self, model: Module) -> None: """Update model parameters.""" self_param = ( # pyrefly: ignore [bad-argument-type] @@ -329,7 +331,7 @@ def update_bn( loader: Iterable[Any], model: Module, device: Optional[Union[int, torch.device]] = None, -): +) -> None: r"""Update BatchNorm running_mean, running_var buffers in the model. It performs one pass over data in `loader` to estimate the activation @@ -434,7 +436,7 @@ def __init__( anneal_epochs=10, anneal_strategy: Literal["cos", "linear"] = "cos", last_epoch=-1, - ): # noqa: D107 + ) -> None: # noqa: D107 swa_lrs = _format_param("swa_lr", optimizer, swa_lr) for swa_lr, group in zip(swa_lrs, optimizer.param_groups, strict=True): group["swa_lr"] = swa_lr @@ -516,7 +518,7 @@ def get_lr(self): for group, lr in zip(self.optimizer.param_groups, prev_lrs, strict=True) ] - def _set_anneal_func(self, anneal_strategy: Literal["cos", "linear"]): + def _set_anneal_func(self, anneal_strategy: Literal["cos", "linear"]) -> None: self._anneal_strategy = anneal_strategy if anneal_strategy == "cos": self.anneal_func = self._cosine_anneal From 3feea296a59c2dfc1d2f4b7e0e5d3f61fd4bf7ea Mon Sep 17 00:00:00 2001 From: Mark Barnes Date: Thu, 6 Nov 2025 04:33:05 +0000 Subject: [PATCH 111/130] torch.fx: add debug-level logging to Interpreter.run_node (#117351) (#166622) ### Summary Adds a debug-level logging statement to torch.fx.Interpreter.run_node, as proposed in [#117351](https://github.com/pytorch/pytorch/issues/117351), to make FX graph execution traceable when debugging or instrumenting model transformations. When debug logging is enabled, each executed node emits a single structured log line formatted via `LazyString(lambda: n.format_node())`, deferring string construction unless logging is active. ### Example Output With `logging.DEBUG` enabled: ``` run_node x = x() run_node add = _operator.add(x, 1) run_node clamp = torch.clamp(add, min=0.0, max=5.0) run_node output = output(clamp) ``` With `logging.DEBUG` disabled no additional output is produced (unchanged default behavior). ### Test Plan Verified locally with Python 3.11 on macOS using a PyTorch build from source. - With `logging.DEBUG` enabled: each node emits a debug log via LazyString. - With `logging.DEBUG` disabled: no additional output. - Confirmed all `Interpreter` tests pass locally: `pytest test/test_fx.py -k "Interpreter"` Updated the example output to reflect the new `_format_fx_node` helper and inclusion of `kwargs`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166622 Approved by: https://github.com/aorenste --- torch/fx/interpreter.py | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/torch/fx/interpreter.py b/torch/fx/interpreter.py index 5ad1424c4e489..5b40e8a66147f 100644 --- a/torch/fx/interpreter.py +++ b/torch/fx/interpreter.py @@ -1,11 +1,12 @@ # mypy: allow-untyped-defs import inspect +import logging from contextlib import contextmanager from typing import Any, Optional, TYPE_CHECKING, Union import torch import torch.fx.traceback as fx_traceback -from torch._logging import trace_structured +from torch._logging import LazyString, trace_structured from torch.hub import tqdm from . import config @@ -21,10 +22,35 @@ if TYPE_CHECKING: from collections.abc import Iterator +log = logging.getLogger(__name__) __all__ = ["Interpreter", "Transformer"] +def _format_fx_node(n): + """ + Format a torch.fx.Node into a human-readable string for debug logging. + + Args: + n (torch.fx.Node): The FX node being executed. + + Returns: + str: A formatted string describing the node operation, including its + name, target, positional arguments, and keyword arguments. + """ + module_prefix = getattr(n.target, "__module__", "") + module_prefix = f"{module_prefix}." if module_prefix else "" + + # Handle positional and keyword arguments + args = ", ".join(map(str, n.args)) + kwargs = ", ".join(f"{k}={v}" for k, v in n.kwargs.items()) + joined = ", ".join(filter(None, [args, kwargs])) + + return ( + f"{n.name} = {module_prefix}{getattr(n.target, '__name__', n.target)}({joined})" + ) + + @compatibility(is_backward_compatible=True) class Interpreter: """ @@ -261,6 +287,7 @@ def run_node(self, n: Node) -> Any: Returns: Any: The result of executing ``n`` """ + log.debug("run_node %s", LazyString(lambda: _format_fx_node(n))) with self._set_current_node(n): args, kwargs = self.fetch_args_kwargs_from_env(n) assert isinstance(args, tuple) From eea951758fcb71ed544bee9f83e67913dec26aaf Mon Sep 17 00:00:00 2001 From: William Wen Date: Tue, 4 Nov 2025 16:49:31 -0800 Subject: [PATCH 112/130] [dynamo, 3.14] disable dynamo cpython tests in 3.14 (again) (#167000) The previous PR was not enough to prevent errors caused by cpython dynamo tests in 3.14 Pull Request resolved: https://github.com/pytorch/pytorch/pull/167000 Approved by: https://github.com/mlazos, https://github.com/guilhermeleobas --- test/run_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/run_test.py b/test/run_test.py index aa6a6d04cde3e..764b20dc9adc2 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -1687,7 +1687,7 @@ def get_selected_tests(options) -> list[str]: ] ) - if sys.version_info[:2] < (3, 13): + if sys.version_info[:2] < (3, 13) or sys.version_info[:2] >= (3, 14): # Skip tests for older Python versions as they may use syntax or features # not supported in those versions options.exclude.extend( From 91337ae3ffd5e3a5204c9e47aeaa2d093710a46c Mon Sep 17 00:00:00 2001 From: PyTorch UpdateBot Date: Thu, 6 Nov 2025 04:57:01 +0000 Subject: [PATCH 113/130] [audio hash update] update the pinned audio hash (#167031) This PR is auto-generated nightly by [this action](https://github.com/pytorch/pytorch/blob/main/.github/workflows/nightly.yml). Update the pinned audio hash. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167031 Approved by: https://github.com/pytorchbot --- .github/ci_commit_pins/audio.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/ci_commit_pins/audio.txt b/.github/ci_commit_pins/audio.txt index 966f6bcfc0d94..14144f3c11e2d 100644 --- a/.github/ci_commit_pins/audio.txt +++ b/.github/ci_commit_pins/audio.txt @@ -1 +1 @@ -3b0e7a6f192ca2715e7e6cbe5db007aea7165fe2 +ad5816f0eee1c873df1b7d371c69f1f811a89387 From f7b7f40a6fed52a7190301b8dfebc528b349c8d4 Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Wed, 5 Nov 2025 17:09:12 -0800 Subject: [PATCH 114/130] [user-streams] Enable stream ops to work in eager (#167141) Pull Request resolved: https://github.com/pytorch/pytorch/pull/167141 Approved by: https://github.com/Lucaskabela --- test/dynamo/test_streams.py | 27 ++++++++++++++++++++------- torch/_dynamo/variables/streams.py | 30 ++++++++++++++---------------- 2 files changed, 34 insertions(+), 23 deletions(-) diff --git a/test/dynamo/test_streams.py b/test/dynamo/test_streams.py index b9a3855f6ddbb..6b7ad5ce0ce96 100644 --- a/test/dynamo/test_streams.py +++ b/test/dynamo/test_streams.py @@ -7,6 +7,10 @@ import torch import torch._dynamo.test_case import torch._dynamo.testing +from torch._dynamo.graph_bytecode_inputs import ( + reset_user_object_tracking, + store_user_object_weakrefs, +) from torch._dynamo.testing import extract_graph, remove_trailing_space from torch.testing._internal.common_cuda import TEST_MULTIGPU from torch.testing._internal.common_utils import requires_cuda @@ -441,13 +445,22 @@ def test_run_opcheck(self): from torch._dynamo.variables.streams import fork_stream, join_stream from torch.library import opcheck - sample_inputs = [ - (0, torch.device("cuda:0"), 1, torch.device("cuda:1")), - (2, torch.device("cuda:2"), 3, torch.device("cuda:1")), - ] - for args in sample_inputs: - opcheck(fork_stream, args) - opcheck(join_stream, args) + original_stream = torch.accelerator.current_stream() + try: + s0 = torch.Stream() + s1 = torch.Stream() + store_user_object_weakrefs(s0, s1) + + sample_inputs = [ + (0, 1), + (1, 0), + ] + for args in sample_inputs: + opcheck(fork_stream, args) + opcheck(join_stream, args) + finally: + torch.accelerator.set_stream(original_stream) + reset_user_object_tracking() if __name__ == "__main__": diff --git a/torch/_dynamo/variables/streams.py b/torch/_dynamo/variables/streams.py index fb5dd775bd636..bb9552186da6d 100644 --- a/torch/_dynamo/variables/streams.py +++ b/torch/_dynamo/variables/streams.py @@ -10,6 +10,7 @@ from .. import graph_break_hints from ..bytecode_transformation import create_call_function from ..exc import TYPE_CHECKING, unimplemented_v2 +from ..graph_bytecode_inputs import get_external_object_by_index from .base import VariableTracker from .constant import ConstantVariable from .ctx_manager import FxTracebackAnnotateVariable @@ -29,40 +30,37 @@ @custom_op("streams::fork", mutates_args=()) def fork_stream( - from_index: int, - from_device: torch.device, + from_index: int, # kept to make stream transitions clearer to_index: int, - to_device: torch.device, ) -> None: - pass + stream = get_external_object_by_index(to_index) + assert isinstance(stream, torch.Stream), ( + f"fork_stream expects a stream object at index {to_index}" + ) + torch.accelerator.set_stream(stream) @fork_stream.register_fake def _( - from_index: int, - from_device: torch.device, + from_index: int, # kept to make stream transitions clearer to_index: int, - to_device: torch.device, ) -> None: pass @custom_op("streams::join", mutates_args=()) -def join_stream( - from_index: int, - from_device: torch.device, - to_index: int, - to_device: torch.device, -) -> None: - pass +def join_stream(from_index: int, to_index: int) -> None: + stream = get_external_object_by_index(to_index) + assert isinstance(stream, torch.Stream), ( + f"join_stream expects a stream object at index {to_index}" + ) + torch.accelerator.set_stream(stream) @join_stream.register_fake def _( from_index: int, - from_device: torch.device, to_index: int, - to_device: torch.device, ) -> None: pass From 46b3f913b351ccf3696932afa7f31c1b1b8bfee7 Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Wed, 5 Nov 2025 17:09:13 -0800 Subject: [PATCH 115/130] [user-streams] Add record/wait ops (#167151) Pull Request resolved: https://github.com/pytorch/pytorch/pull/167151 Approved by: https://github.com/Lucaskabela ghstack dependencies: #167141 --- test/dynamo/test_streams.py | 26 +++++++++++++- torch/_dynamo/variables/streams.py | 58 ++++++++++++++++++++++++------ 2 files changed, 73 insertions(+), 11 deletions(-) diff --git a/test/dynamo/test_streams.py b/test/dynamo/test_streams.py index 6b7ad5ce0ce96..c21ab934e5b45 100644 --- a/test/dynamo/test_streams.py +++ b/test/dynamo/test_streams.py @@ -441,7 +441,7 @@ def forward(self, tangents_1: "f32[2, 2]", tangents_2: "f32[2, 2]"): ) @requires_cuda - def test_run_opcheck(self): + def test_run_opcheck_fork_join(self): from torch._dynamo.variables.streams import fork_stream, join_stream from torch.library import opcheck @@ -462,6 +462,30 @@ def test_run_opcheck(self): torch.accelerator.set_stream(original_stream) reset_user_object_tracking() + @requires_cuda + def test_run_opcheck_wait_record(self): + from torch._dynamo.variables.streams import record_event, wait_event + from torch.library import opcheck + + original_stream = torch.accelerator.current_stream() + try: + s0 = torch.Stream() + s1 = torch.Stream() + e0 = torch.Event() + e1 = torch.Event() + store_user_object_weakrefs(s0, s1, e0, e1) + + sample_inputs = [ + (2, 0), + (3, 1), + ] + for args in sample_inputs: + opcheck(wait_event, args) + opcheck(record_event, args) + finally: + torch.accelerator.set_stream(original_stream) + reset_user_object_tracking() + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_dynamo/variables/streams.py b/torch/_dynamo/variables/streams.py index bb9552186da6d..98084cce28b27 100644 --- a/torch/_dynamo/variables/streams.py +++ b/torch/_dynamo/variables/streams.py @@ -28,16 +28,28 @@ Tensor = torch.Tensor +def _get_stream_by_index(index: int) -> torch.Stream: + stream = get_external_object_by_index(index) + assert isinstance(stream, torch.Stream), ( + f"Fork/join stream expected a stream object at index {index}" + ) + return stream + + +def _get_event_by_index(index: int) -> torch.Event: + event = get_external_object_by_index(index) + assert isinstance(event, torch.Event), ( + f"Record/wait event expected an event object at index {index}" + ) + return event + + @custom_op("streams::fork", mutates_args=()) def fork_stream( from_index: int, # kept to make stream transitions clearer to_index: int, ) -> None: - stream = get_external_object_by_index(to_index) - assert isinstance(stream, torch.Stream), ( - f"fork_stream expects a stream object at index {to_index}" - ) - torch.accelerator.set_stream(stream) + torch.accelerator.set_stream(_get_stream_by_index(to_index)) @fork_stream.register_fake @@ -50,11 +62,7 @@ def _( @custom_op("streams::join", mutates_args=()) def join_stream(from_index: int, to_index: int) -> None: - stream = get_external_object_by_index(to_index) - assert isinstance(stream, torch.Stream), ( - f"join_stream expects a stream object at index {to_index}" - ) - torch.accelerator.set_stream(stream) + torch.accelerator.set_stream(_get_stream_by_index(to_index)) @join_stream.register_fake @@ -65,6 +73,36 @@ def _( pass +@custom_op("streams::record_event", mutates_args=()) +def record_event(event_index: int, stream_index: int) -> None: + event = _get_event_by_index(event_index) + stream = _get_stream_by_index(stream_index) + stream.record_event(event) + + +@record_event.register_fake +def _( + event_index: int, + stream_index: int, +) -> None: + pass + + +@custom_op("streams::wait_event", mutates_args=()) +def wait_event(event_index: int, stream_index: int) -> None: + event = _get_event_by_index(event_index) + stream = _get_stream_by_index(stream_index) + stream.wait_event(event) + + +@wait_event.register_fake +def _( + event_index: int, + stream_index: int, +) -> None: + pass + + class SymbolicStreamState: """Track the currently entered stream if any""" From 7b423c2d217452d7f65788dc3a9cb786f0b45769 Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Wed, 5 Nov 2025 17:09:13 -0800 Subject: [PATCH 116/130] [user-streams] Mark stream ops as side effectful (#167152) Pull Request resolved: https://github.com/pytorch/pytorch/pull/167152 Approved by: https://github.com/Lucaskabela ghstack dependencies: #167141, #167151 --- test/dynamo/test_streams.py | 16 ++++++++++++++++ torch/_dynamo/variables/streams.py | 14 +++++++++++++- 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/test/dynamo/test_streams.py b/test/dynamo/test_streams.py index c21ab934e5b45..1b81597977d77 100644 --- a/test/dynamo/test_streams.py +++ b/test/dynamo/test_streams.py @@ -486,6 +486,22 @@ def test_run_opcheck_wait_record(self): torch.accelerator.set_stream(original_stream) reset_user_object_tracking() + def test_is_marked_side_effectful(self): + self.assertIn( + torch.ops.streams.fork.default, torch.fx.node._side_effectful_functions + ) + self.assertIn( + torch.ops.streams.join.default, torch.fx.node._side_effectful_functions + ) + self.assertIn( + torch.ops.streams.wait_event.default, + torch.fx.node._side_effectful_functions, + ) + self.assertIn( + torch.ops.streams.record_event.default, + torch.fx.node._side_effectful_functions, + ) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_dynamo/variables/streams.py b/torch/_dynamo/variables/streams.py index 98084cce28b27..65b4add4232f6 100644 --- a/torch/_dynamo/variables/streams.py +++ b/torch/_dynamo/variables/streams.py @@ -5,7 +5,7 @@ import torch from torch._dynamo.variables.dicts import ConstDictVariable from torch._dynamo.variables.lists import TupleVariable -from torch.fx import Proxy +from torch.fx import has_side_effect, Proxy from .. import graph_break_hints from ..bytecode_transformation import create_call_function @@ -60,6 +60,9 @@ def _( pass +has_side_effect(torch.ops.streams.fork.default) + + @custom_op("streams::join", mutates_args=()) def join_stream(from_index: int, to_index: int) -> None: torch.accelerator.set_stream(_get_stream_by_index(to_index)) @@ -73,6 +76,9 @@ def _( pass +has_side_effect(torch.ops.streams.join.default) + + @custom_op("streams::record_event", mutates_args=()) def record_event(event_index: int, stream_index: int) -> None: event = _get_event_by_index(event_index) @@ -88,6 +94,9 @@ def _( pass +has_side_effect(torch.ops.streams.record_event.default) + + @custom_op("streams::wait_event", mutates_args=()) def wait_event(event_index: int, stream_index: int) -> None: event = _get_event_by_index(event_index) @@ -103,6 +112,9 @@ def _( pass +has_side_effect(torch.ops.streams.wait_event.default) + + class SymbolicStreamState: """Track the currently entered stream if any""" From 8b2365094dbb531f9122b05fdf89553f6ccee03b Mon Sep 17 00:00:00 2001 From: Yanan Cao Date: Thu, 6 Nov 2025 05:59:05 +0000 Subject: [PATCH 117/130] Expose torch.compiler.config.force_disable_caches as a public API (#166699) Exposing this flag as some upstream frameworks (like vLLM) could benefit from knowing whether torch.compile caches are enabled or not to adjust their own caching behavior. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166699 Approved by: https://github.com/oulgen, https://github.com/mlazos --- torch/compiler/config.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torch/compiler/config.py b/torch/compiler/config.py index e7578a57f2c0b..e507ddc18052e 100644 --- a/torch/compiler/config.py +++ b/torch/compiler/config.py @@ -35,6 +35,7 @@ "enable_cpp_symbolic_shape_guards", "wrap_top_frame", "reorderable_logging_functions", + "force_disable_caches", ] From 09d8953fb47de9a9209e409f6e72c7c8fa0ac0aa Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Wed, 5 Nov 2025 11:19:35 -0800 Subject: [PATCH 118/130] Update `tensorpipe` submodule (#167108) To pick a single change https://github.com/pytorch/tensorpipe/commit/2b4cd91092d335a697416b2a3cb398283246849d that should fix compilation errors with clang-21 Pull Request resolved: https://github.com/pytorch/pytorch/pull/167108 Approved by: https://github.com/Skylion007 --- third_party/tensorpipe | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/tensorpipe b/third_party/tensorpipe index af0118d13e52f..2b4cd91092d33 160000 --- a/third_party/tensorpipe +++ b/third_party/tensorpipe @@ -1 +1 @@ -Subproject commit af0118d13e52f5a08841464a768e01a0bf3e3075 +Subproject commit 2b4cd91092d335a697416b2a3cb398283246849d From 9eebda944df1bac3c1668e0bf041b85473c80aaa Mon Sep 17 00:00:00 2001 From: Laith Sakka Date: Wed, 5 Nov 2025 12:04:49 -0800 Subject: [PATCH 119/130] make narrow_tensor_symint DDE-free (#166379) https://github.com/pytorch/pytorch/issues/158081 Pull Request resolved: https://github.com/pytorch/pytorch/pull/166379 Approved by: https://github.com/Lucaskabela ghstack dependencies: #166361 --- aten/src/ATen/native/TensorShape.cpp | 4 ++-- test/functorch/test_aotdispatch.py | 2 +- test/test_dynamic_shapes.py | 13 +++++++++++++ test/test_proxy_tensor.py | 1 - 4 files changed, 16 insertions(+), 4 deletions(-) diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index daa8a86da253b..0079a530b3d0e 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -1784,8 +1784,8 @@ Tensor narrow_tensor_symint( start.dim() == 0 && isIntegralType(start.scalar_type(), /*includeBool=*/false), "start must be an 0-dim integral Tensor."); - int64_t st = start.item(); - return at::narrow_symint(self, dim, c10::SymInt(st), std::move(length)); + c10::SymInt st = start.item().toSymInt(); + return at::narrow_symint(self, dim, std::move(st), std::move(length)); } std:: diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index b0dd1ff8fa75d..6cae42d8929da 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -8126,7 +8126,7 @@ def fn(x): xfail("corrcoef"), xfail("quantile"), xfail("nanquantile"), - xfail("narrow"), + skip("narrow"), xfail("istft"), xfail("linalg.eig"), skip("as_strided_scatter"), diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index b63e0427c26c3..d3f9e415ff944 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -4452,6 +4452,19 @@ def test_narrow_unbacked_start_cpp_wrapper(self): """Test narrow with unbacked start with cpp_wrapper""" self.test_narrow_unbacked_start() + @torch._dynamo.config.patch(capture_scalar_outputs=True) + def test_narrow_with_tensor_start(self): + @torch.compile(backend="inductor", fullgraph=True) + def f(x, start, end): + return torch.narrow(x, 0, start, end) + + x = torch.tensor( + [False], device="cuda:0" if torch.cuda.is_available() else "cpu" + ) + start = torch.tensor(0) + res = f(x, start, 0) + self.assertEqual(res.shape, torch.Size([0])) + instantiate_parametrized_tests(TestUnbacked) diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index b76895a0a91f3..0487995a2d1c5 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -1987,7 +1987,6 @@ def f(t): } only_fake_tensor_failures = { - xfail('narrow'), xfail('tensor_split'), } From ed4aa449b60f0b595e376575362efe739eae00a1 Mon Sep 17 00:00:00 2001 From: tianrengao Date: Thu, 6 Nov 2025 06:59:06 +0000 Subject: [PATCH 120/130] CustomOp Inline Fusion (#165952) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add Inline Fusion Support for Custom Op Autotuning -------------------------------------------------- This PR extends PyTorch Inductor's custom op autotuning with inline fusion capabilities, enabling the winning decomposition to be inlined directly into the computation graph for fusion with surrounding operations. ### Usage ```python def decompose_k_implementation( a: torch.Tensor, b: torch.Tensor, k_splits: int = 4 ) -> torch.Tensor: """Matrix multiply with k-way decomposition.""" ... @torch.library.custom_op("my_lib::matmul_relu", mutates_args={}) def custom_matmul_relu_dk( a: torch.Tensor, b: torch.Tensor, k_splits: int ) -> torch.Tensor: return torch.relu(decompose_k_implementation(a, b, k_splits)) register_custom_op_autotuning( custom_op=custom_matmul_relu_dk, configs=[ CustomOpConfig(k_splits=2), CustomOpConfig(k_splits=4), CustomOpConfig(k_splits=8), CustomOpConfig(k_splits=32), CustomOpConfig(k_splits=64), ], name="decompose_k_autotuned", input_gen_fns={ "a": lambda fake: torch.randn_like(fake, device='cuda'), "b": lambda fake: torch.randn_like(fake, device='cuda'), } ) ``` ### How It Works Enable optimizations from Inductor by inlining the best decomposition, allowing fusion with surrounding elementwise operations and other graph-level optimizations. This provide potentially better performance and memory efficiency. During customop autotuning phase, we still benchmarks all CustomOpConfigs to find the fastest implementation. Then during inline fusion, inductor inline the decompositions into the main graph, converting the winning choice to individual ComputedBuffer IR nodes (fusable). At the end, Inductor automatically fuses inlined operations with surrounding elementwise ops (e.g., bias add, ReLU, scaling). Note that the winning choice must be a SubgraphChoiceCaller (decomposition-based) rather than an ExternKernelChoice for inlining to work. If the ExternKernelChoice is returned, no inline happens. Performance Results Benchmarked on matmul+relu workload with decompose-k fusion (H100 GPU, 15 test shapes): Screenshot 2025-11-04 at 12 43 11 AM Metric | Result -- | -- Average Speedup vs ATen | 1.28x Max Speedup vs ATen | 1.41x
The performance comparison are detailed in the below plots. We spot that on most use cases, the inline fusion gains better performance compared to aten baseline and the current torch.compile. image **Test**: `test_decompose_k_with_fusion` demonstrates decompose-k with inline fusion enabled. -------------- ### Integration to mm.py decomposeK with a flag enable_inline_subgraph_fusion=True in config (deprecated to avoid breaking async compilation. removed from the PR already) FP32: Screenshot 2025-11-04 at 12 05 08 AM FP16: Screenshot 2025-11-04 at 12 13 49 AM The TCF column represents torch compile fusion, which is close to custom_op decomposek. The difference might due to different candidate k values. #### Usage: Note: this only happens when we don't benchmark_epilogue_fusion, i.e., not using multi_template_buffer. ```python # Define the matmul+relu function def matmul_relu(x, y): return torch.nn.functional.relu(torch.matmul(x, y)) # Compile with inline subgraph fusion enabled @torch.compile def compiled_matmul_relu(x, y): return matmul_relu(x, y) # Reset dynamo to ensure clean compilation torch._dynamo.reset() with config.patch( { "max_autotune": True, # CRITICAL: These two flags enable inline subgraph fusion "benchmark_epilogue_fusion": False, # Must be False for inline fusion! "enable_inline_subgraph_fusion": True, # Enable inline fusion } ): # Compile and run result = compiled_matmul_relu(a, b) torch.cuda.synchronize() ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/165952 Approved by: https://github.com/PaulZhang12, https://github.com/eellison --- test/inductor/test_custom_op_autotune.py | 178 +++++++---------------- torch/_inductor/codegen/subgraph.py | 25 +++- torch/_inductor/kernel/custom_op.py | 43 ++++-- torch/_inductor/lowering.py | 45 ++++-- torch/_inductor/select_algorithm.py | 11 ++ 5 files changed, 151 insertions(+), 151 deletions(-) diff --git a/test/inductor/test_custom_op_autotune.py b/test/inductor/test_custom_op_autotune.py index adc46a0f390a4..c148c69468902 100644 --- a/test/inductor/test_custom_op_autotune.py +++ b/test/inductor/test_custom_op_autotune.py @@ -216,115 +216,6 @@ def _(input_tensor: torch.Tensor, weight: torch.Tensor, eps: float = 1e-8): test_rmsnorm_op, (input_tensor, weight), expected, f"RMSNorm_{i}" ) - @skipIfXpu - def test_mlp_custom_op_autotune(self): - """Test MLP autotuning with method parameter controlling different decomposition variants. - - Validates parametric tuning where the same decomposition function uses different - algorithmic approaches based on a method parameter (standard matmul, batched mm, fused weights). - """ - test_op_name = f"test_lib::mlp_{id(self)}" - - def mlp_variants( - input_tensor: torch.Tensor, - gate_weight: torch.Tensor, - up_weight: torch.Tensor, - down_weight: torch.Tensor, - method: int = 0, - ) -> torch.Tensor: - """MLP implementation with different computational approaches controlled by method parameter.""" - - if method == 0: - gate_proj = torch.matmul(input_tensor, gate_weight) - up_proj = torch.matmul(input_tensor, up_weight) - gated = torch.relu(gate_proj) * up_proj - return torch.matmul(gated, down_weight) - - elif method == 1: - batch_shape = input_tensor.shape[:-1] - hidden_dim = input_tensor.shape[-1] - output_dim = down_weight.shape[-1] - - input_2d = input_tensor.view(-1, hidden_dim) - - gate_proj = torch.mm(input_2d, gate_weight) - up_proj = torch.mm(input_2d, up_weight) - - gated = torch.relu(gate_proj) * up_proj - output_2d = torch.mm(gated, down_weight) - - return output_2d.view(*batch_shape, output_dim) - - @torch.library.custom_op(test_op_name, mutates_args=()) - def test_mlp_op( - input_tensor: torch.Tensor, - gate_weight: torch.Tensor, - up_weight: torch.Tensor, - down_weight: torch.Tensor, - method: int = 0, - ) -> torch.Tensor: - return mlp_variants( - input_tensor, gate_weight, up_weight, down_weight, method=method - ) - - @test_mlp_op.register_fake - def _( - input_tensor: torch.Tensor, - gate_weight: torch.Tensor, - up_weight: torch.Tensor, - down_weight: torch.Tensor, - method: int = 0, - ): - return torch.empty( - input_tensor.shape[:-1] + (down_weight.shape[-1],), - device=input_tensor.device, - dtype=input_tensor.dtype, - ) - - # Use explicit config with method parameter as tuning knob - register_custom_op_autotuning( - test_mlp_op, - configs=[ - CustomOpConfig(method=0), - CustomOpConfig(method=1), - ], - name="test_mlp_autotuned", - input_gen_fns={ - "input_tensor": lambda fake_tensor: torch.randn_like( - fake_tensor, device=self.device - ) - * 0.1, - "gate_weight": lambda fake_tensor: torch.randn_like( - fake_tensor, device=self.device - ) - * 0.05, - "up_weight": lambda fake_tensor: torch.randn_like( - fake_tensor, device=self.device - ) - * 0.05, - "down_weight": lambda fake_tensor: torch.randn_like( - fake_tensor, device=self.device - ) - * 0.05, - }, - ) - - # Create test inputs - input_tensor, gate_weight, up_weight, down_weight = self._create_mlp_inputs() - - # Test that all method variants produce numerically equivalent results - expected = mlp_variants( - input_tensor, gate_weight, up_weight, down_weight, method=0 - ) - - # Test autotuning - self._run_autotune_test( - test_mlp_op, - (input_tensor, gate_weight, up_weight, down_weight), - expected, - "MLP", - ) - def _create_decompose_k_inputs(self, m=256, k=65536, n=1024): """Create test inputs for decompose_k matrix multiplication - divisible by all k_splits values.""" # Ensure k is divisible by all k_splits values: [2, 32, 64, 128, 256] @@ -335,12 +226,12 @@ def _create_decompose_k_inputs(self, m=256, k=65536, n=1024): @skipIfXpu def test_decompose_k_custom_op_autotune(self): - """Test decompose_k autotuning with parametric tuning for k_splits values. + """Test decompose_k autotuning with epilogue fusion (matmul + bias + relu + scale). - Validates numerical parameter sweep where k_splits controls how the K dimension - is decomposed for matrix multiplication (k_splits in [32, 64, 128, 256]). + Validates that the custom op encapsulates the entire fused operation with parametric + tuning for k_splits values controlling how the K dimension is decomposed. """ - test_op_name = f"test_lib::decompose_k_{id(self)}" + test_op_name = f"test_lib::matmul_relu_epilogue_{id(self)}" def decompose_k_implementation( a: torch.Tensor, b: torch.Tensor, k_splits: int = 4 @@ -363,19 +254,23 @@ def decompose_k_implementation( return torch.sum(result, dim=0) # [m, n] @torch.library.custom_op(test_op_name, mutates_args=()) - def test_decompose_k_op( - a: torch.Tensor, b: torch.Tensor, k_splits: int = 4 + def matmul_relu_epilogue_op( + a: torch.Tensor, b: torch.Tensor, bias: torch.Tensor, k_splits: int = 4 ) -> torch.Tensor: - """Matrix multiply with k-way decomposition - custom op using the decomposition.""" - return decompose_k_implementation(a, b, k_splits) - - @test_decompose_k_op.register_fake - def _(a: torch.Tensor, b: torch.Tensor, k_splits: int = 4): + """Matmul with decompose_k + bias + relu + scale (complete epilogue fusion).""" + matmul_result = decompose_k_implementation(a, b, k_splits) + biased = matmul_result + bias + activated = torch.relu(biased) + scaled = activated * 2.0 + return scaled + + @matmul_relu_epilogue_op.register_fake + def _(a: torch.Tensor, b: torch.Tensor, bias: torch.Tensor, k_splits: int = 4): return torch.empty(a.shape[0], b.shape[1], device=a.device, dtype=a.dtype) - # Register autotuning with different k_splits values using decomposition function + # Register autotuning with different k_splits values register_custom_op_autotuning( - test_decompose_k_op, + matmul_relu_epilogue_op, configs=[ CustomOpConfig(k_splits=2), CustomOpConfig(k_splits=4), @@ -385,7 +280,7 @@ def _(a: torch.Tensor, b: torch.Tensor, k_splits: int = 4): CustomOpConfig(k_splits=64), CustomOpConfig(k_splits=128), ], - name="test_decompose_k_autotuned", + name="matmul_relu_epilogue_autotuned", input_gen_fns={ "a": lambda fake_tensor: torch.randn_like( fake_tensor, device=self.device @@ -395,12 +290,45 @@ def _(a: torch.Tensor, b: torch.Tensor, k_splits: int = 4): fake_tensor, device=self.device ) * 0.1, + "bias": lambda fake_tensor: torch.randn_like( + fake_tensor, device=self.device + ) + * 0.1, }, ) + # Create test inputs a, b = self._create_decompose_k_inputs() - expected = a @ b - self._run_autotune_test(test_decompose_k_op, (a, b), expected, "DecomposeK") + bias = torch.randn(b.shape[1], device=self.device, dtype=self.dtype) * 0.1 + + # Compile the model using the custom op + @torch.compile + def test_model(a, b, bias): + return matmul_relu_epilogue_op(a, b, bias) + + torch._dynamo.reset() + + with config.patch( + max_autotune=True, + benchmark_fusion=True, + ): + compiled_result = test_model(a, b, bias) + + def reference_model(a, b, bias): + matmul_result = a @ b + biased = matmul_result + bias + activated = torch.relu(biased) + scaled = activated * 2.0 + return scaled + + expected = reference_model(a, b, bias) + + torch.testing.assert_close( + compiled_result, + expected, + rtol=2e-1, + atol=5e-1, + ) @skipIfXpu def test_multi_parameter_tuning(self): diff --git a/torch/_inductor/codegen/subgraph.py b/torch/_inductor/codegen/subgraph.py index 4cc3f0ef282a8..1c1f0f1c9cd2c 100644 --- a/torch/_inductor/codegen/subgraph.py +++ b/torch/_inductor/codegen/subgraph.py @@ -24,6 +24,22 @@ log = logging.getLogger(__name__) +def inline_subgraph_to_ir_nodes( + gm: torch.fx.GraphModule, inputs: list[Any], name: str +) -> Any: + """Inline a subgraph by converting its FX operations to individual IR nodes. + + This converts a subgraph to multiple ComputedBuffer nodes (fusable), + enabling epilogue fusion with subsequent operations. + + Returns: + TensorBox containing the final operation result as individual IR nodes + """ + from torch._inductor.lowering import process_subgraph_nodes + + return process_subgraph_nodes(gm, inputs) + + class SubgraphChoiceCaller(ir.ChoiceCaller): """ Represents a Subgraph Autotuning choice, and the subgraph can be any arbitrary @@ -261,7 +277,14 @@ def make_fx_graph( # decomp_kwargs contains all merged parameters: CustomOpConfig params + runtime kwargs from torch.fx.experimental.proxy_tensor import make_fx - return make_fx(functools.partial(decomp, **decomp_kwargs))(*args) + from ..decomposition import select_decomp_table + + decomposition_table = select_decomp_table() + + return make_fx( + functools.partial(decomp, **decomp_kwargs), + decomposition_table=decomposition_table, + )(*args) # Generate descriptive name for this variant variant_name = self._generate_variant_name(decomp, decomp_kwargs) diff --git a/torch/_inductor/kernel/custom_op.py b/torch/_inductor/kernel/custom_op.py index d35309c01d07c..23878f757cc5e 100644 --- a/torch/_inductor/kernel/custom_op.py +++ b/torch/_inductor/kernel/custom_op.py @@ -6,6 +6,7 @@ from typing import Any, Optional, Union import torch +from torch._inductor import config from torch._inductor.codegen.subgraph import SubgraphTemplate from torch._inductor.ir import Buffer, FixedLayout, ir_node_to_tensor, TensorBox from torch._inductor.lowering import lowerings, validate_ir @@ -158,7 +159,6 @@ def _adapt_user_input_gen_fns( Uses V.graph.sizevars.size_hints() to guess best for dynamic shapes. """ - from torch._inductor import config name_to_index = {name: i for i, name in enumerate(arg_names)} index_based_fns = {} @@ -238,6 +238,7 @@ def autotune_custom_op( This function generates multiple implementation choices for a custom operation and uses Inductor's autotuning system to select the best performing variant at runtime. + After selecting the best choice, applies inline fusion if the winning choice has a graph. Args: name: Unique identifier for the autotuning operation @@ -320,14 +321,34 @@ def autotune_custom_op( ) input_gen_fns = _adapt_user_input_gen_fns(inputs, arg_names, user_input_gen_fns) - return autotune_select_algorithm( + # Run autotuning and get both result and winning choice + selected_result, winning_choice = autotune_select_algorithm( name=name, choices=choices, input_nodes=list(inputs), layout=choices[0].layout, input_gen_fns=input_gen_fns, + return_choice=True, ) + # Apply inlining for fusion if winning_choice has graph; otherwise return result as-is(default fallback impl) + if winning_choice.gm is not None: + log.debug( + "Inlining winning choice: %s (name=%s)", + getattr(winning_choice, "name", type(winning_choice).__name__), + name, + ) + from torch._inductor.codegen.subgraph import inline_subgraph_to_ir_nodes + + return inline_subgraph_to_ir_nodes(winning_choice.gm, inputs, name) + + log.debug( + "Winning choice does not support inlining: %s (name=%s)", + getattr(winning_choice, "name", type(winning_choice).__name__), + name, + ) + return selected_result + def register_custom_op_autotuning( custom_op: torch._library.custom_ops.CustomOpDef, @@ -360,7 +381,7 @@ def my_attention(query, key, value, head_dim=32): "query": lambda fake: torch.randn_like(fake, device='cuda'), "key": lambda fake: torch.randn_like(fake, device='cuda'), "value": lambda fake: torch.randn_like(fake, device='cuda'), - } + }, ) """ from torch._library.custom_ops import CustomOpDef @@ -378,12 +399,12 @@ def my_attention(query, key, value, head_dim=32): raise TypeError(f"configs must be a list or tuple, got {type(configs)}") processed_configs = [] - for config in configs: - if isinstance(config, CustomOpConfig): - processed_configs.append(config) + for cfg in configs: + if isinstance(cfg, CustomOpConfig): + processed_configs.append(cfg) else: raise TypeError( - f"Each config must be a CustomOpConfig object, got {type(config)}" + f"Each config must be a CustomOpConfig object, got {type(cfg)}" ) if not processed_configs: @@ -402,14 +423,12 @@ def autotuning_lowering(*args: Any, **kwargs: Any) -> Any: decompositions = [] non_tensor_args = [] - for config in processed_configs: - decomp = config.get_decomposition(default_impl=default_impl) + for cfg in processed_configs: + decomp = cfg.get_decomposition(default_impl=default_impl) decompositions.append(decomp) # Merge config params with runtime kwargs (runtime takes precedence) - merged_kwargs = _merge_config_and_runtime_kwargs( - config.params, runtime_kwargs - ) + merged_kwargs = _merge_config_and_runtime_kwargs(cfg.params, runtime_kwargs) non_tensor_args.append(merged_kwargs) result = autotune_custom_op( diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index cc13f79909014..f6ad1028ca12d 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -7307,6 +7307,35 @@ def invoke_subgraph(subgraph_fn: ir.Subgraph, identifier: str, *operands): return list(map(TensorBox.create, result)) # type: ignore[call-overload] +def process_subgraph_nodes(graph_module: torch.fx.GraphModule, args: list[Any]): + """Process nodes from a FX graph by executing them through V.graph. + + This is a common pattern for executing a subgraph's nodes: + - Placeholder nodes are mapped to the provided args + - Output nodes return their result + - Other nodes are executed via V.graph.run_node + + """ + output = None + + for i, node in enumerate(graph_module.graph.nodes): + if node.op == "placeholder": + assert node not in V.graph.env + V.graph.env[node] = args[i] + continue + elif node.op == "output": + output_args, kwargs = V.graph.fetch_args_kwargs_from_env(node) + output = torch.fx.Interpreter.output(V.graph, node, output_args, kwargs) + else: + assert node not in V.graph.env + V.graph.env[node] = V.graph.run_node(node) + + if output is None: + raise RuntimeError("No output node found in graph") + + return output + + # Import the control_deps_op HOP for lowering from torch._inductor.fx_passes.control_dependencies import control_deps @@ -7334,21 +7363,11 @@ def control_deps_op_lowering(additional_deps, subgraph_fn, *args): arg_offset = 2 # first two args (additional_deps, subgraph) assert len(args) + arg_offset == len(original_args) - output = None - operation_len = len(V.graph.operations) assert len(subgraph_fn.graph_module.graph.find_nodes(op="placeholder")) == len(args) - for i, node in enumerate(subgraph_fn.graph_module.graph.nodes): - if node.op == "placeholder": - assert node not in V.graph.env - V.graph.env[node] = args[i] - continue - elif node.op == "output": - args, kwargs = V.graph.fetch_args_kwargs_from_env(node) - output = torch.fx.Interpreter.output(V.graph, node, args, kwargs) - else: - assert node not in V.graph.env - V.graph.env[node] = V.graph.run_node(node) + + # Process subgraph nodes using the shared helper + output = process_subgraph_nodes(subgraph_fn.graph_module, list(args)) assert output is not None and additional_deps diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index 41021b0fc8ed1..e1d36d54e844a 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -2145,6 +2145,8 @@ def __init__( # There is no src hash for ExternKernelChoice in the traditional sense # so we indicate this by returning None self.src_hash = None + # By default GraphModule is None for extern kernels if not set + self.gm = None def to_callable(self): return getattr(extern_kernels, self.name) @@ -2317,6 +2319,7 @@ def __init__( self.choice = choice self.kwargs = kwargs or {} self.has_out_variant = has_out_variant + self.gm = choice.gm def __str__(self) -> str: return f"ExternKernelCaller({self.choice.call_name()})" @@ -2700,6 +2703,7 @@ def __call__( precompilation_timeout_seconds: int = 60 * 60, return_multi_template=False, best_config_future=None, + return_choice=False, # TODO: return_choice is temporary and will be refactored soon ): from .codegen.cuda.cuda_kernel import CUDATemplateCaller @@ -2973,18 +2977,25 @@ def get_timings(hint_override: Optional[int] = None): "Autotuning returned empty timings, falling back to first `ExternKernelCaller`: %s", node, ) + if return_choice: + return node, choice return node node = choices[0].output_node() + choice = choices[0] log.debug( "Autotuning returned empty timings, falling back to first choice: %s", node, ) + if return_choice: + return node, choice return node # if we got any timings at all, pick the best of those choice = min(timings, key=timings.__getitem__) node = choice.output_node() log.debug("Autotuning selected choice: %s", node) + if return_choice: + return node, choice return node def make_precompile_fn( From a51208c656fb3e9a8b091a4d181f9a9cda783c04 Mon Sep 17 00:00:00 2001 From: Will Feng Date: Thu, 6 Nov 2025 08:02:53 +0000 Subject: [PATCH 121/130] Check cluster_dims attribute exists before access (#167187) Error in Helion CI's AMD job: https://github.com/pytorch/helion/actions/runs/19118581048/job/54633730633 ``` > (binary.metadata.num_ctas, *binary.metadata.cluster_dims) ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ if hasattr(binary, "metadata") else () ) ), "function": get_first_attr(binary, "function", "cu_function"), "runner": get_first_attr(binary, "run", "c_wrapper"), "math": math_lib, "torch": torch_lib, "triton": triton_lib, } E torch._inductor.exc.InductorError: AttributeError: 'KernelMetadata' object has no attribute 'cluster_dims' ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/167187 Approved by: https://github.com/oulgen --- torch/_inductor/runtime/triton_heuristics.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index cdecd50927024..2e0a0dba9092e 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -1851,6 +1851,8 @@ def make_launcher(self) -> LauncherType: else ( (binary.metadata.num_ctas, *binary.metadata.cluster_dims) if hasattr(binary, "metadata") + and hasattr(binary.metadata, "num_ctas") + and hasattr(binary.metadata, "cluster_dims") else () ) ), From c724f0097ddcf2a1dffb928ad18eafed6005595e Mon Sep 17 00:00:00 2001 From: Yuanyuan Chen Date: Thu, 6 Nov 2025 12:13:47 +0000 Subject: [PATCH 122/130] [2/N] Use `key in dict` for existence checks (#167174) This PR uses `key in dict` expressions for existence checks of dict elements in Python code. This operation is more efficient than `key in dict.keys()`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/167174 Approved by: https://github.com/mlazos --- test/cpp/api/init_baseline.py | 2 +- test/cpp/api/optim_baseline.py | 2 +- .../checkpoint/test_hf_safetensor_e2e.py | 4 ++-- .../fsdp/test_fsdp_mixed_precision.py | 2 +- test/distributed/test_c10d_common.py | 6 ++---- test/distributed/test_local_tensor.py | 20 +++++++++---------- test/dynamo/test_subclasses.py | 2 +- test/functorch/xfail_suggester.py | 2 +- test/inductor/test_compiled_optimizers.py | 2 +- test/profiler/test_profiler.py | 8 ++------ .../core/test_quantized_module.py | 4 ++-- .../quantization/core/test_workflow_module.py | 2 +- .../pt2e/test_x86inductor_quantizer.py | 2 +- test/test_fx.py | 2 +- test/test_testing.py | 2 +- torch/ao/ns/fx/pattern_utils.py | 2 +- .../ao/pruning/sparsifier/base_sparsifier.py | 2 +- torch/ao/quantization/_equalize.py | 2 +- torch/ao/quantization/fx/_equalize.py | 2 +- .../quantization/fx/_model_report/detector.py | 2 +- torch/ao/quantization/fx/convert.py | 2 +- torch/ao/quantization/fx/prepare.py | 4 ++-- .../quantization/fx/qconfig_mapping_utils.py | 4 ++-- torch/ao/quantization/fx/utils.py | 2 +- .../quantization/pt2e/port_metadata_pass.py | 2 +- torch/ao/quantization/pt2e/prepare.py | 2 +- torch/ao/quantization/qconfig_mapping.py | 2 +- torch/ao/quantization/quantize_jit.py | 4 ++-- .../quantizer/x86_inductor_quantizer.py | 4 +--- torch/fx/experimental/unify_refinements.py | 6 +++--- torch/fx/graph.py | 2 +- torch/fx/passes/runtime_assert.py | 6 ++---- torch/fx/passes/splitter_base.py | 4 ++-- torch/fx/passes/utils/source_matcher_utils.py | 4 ++-- torch/jit/_recursive.py | 2 +- torch/jit/_script.py | 2 +- torch/nn/modules/module.py | 2 +- .../_internal/exporter/_dynamic_shapes.py | 2 +- torch/profiler/_memory_profiler.py | 4 ++-- torch/profiler/_utils.py | 6 +++--- torch/utils/_config_module.py | 2 +- torch/utils/collect_env.py | 4 ++-- torch/utils/data/datapipes/iter/callable.py | 2 +- torch/utils/data/datapipes/iter/grouping.py | 2 +- torch/utils/tensorboard/summary.py | 2 +- torch/utils/tensorboard/writer.py | 2 +- torchgen/gen_backend_stubs.py | 3 +-- 47 files changed, 72 insertions(+), 83 deletions(-) diff --git a/test/cpp/api/init_baseline.py b/test/cpp/api/init_baseline.py index 47b202e86311d..4042657b4d5c3 100644 --- a/test/cpp/api/init_baseline.py +++ b/test/cpp/api/init_baseline.py @@ -64,7 +64,7 @@ def run(initializer): def main(): initializer_parameter_map = {} - for initializer in INITIALIZERS.keys(): + for initializer in INITIALIZERS: sys.stderr.write(f"Evaluating {initializer} ...\n") initializer_parameter_map[initializer] = run(initializer) diff --git a/test/cpp/api/optim_baseline.py b/test/cpp/api/optim_baseline.py index 7e278d4e42086..e1a3c91b7128f 100644 --- a/test/cpp/api/optim_baseline.py +++ b/test/cpp/api/optim_baseline.py @@ -130,7 +130,7 @@ def main(): options = parser.parse_args() optimizer_parameter_map = {} - for optimizer in OPTIMIZERS.keys(): + for optimizer in OPTIMIZERS: sys.stderr.write(f"Evaluating {optimizer} ...\n") optimizer_parameter_map[optimizer] = run( optimizer, options.iterations, options.sample_every diff --git a/test/distributed/checkpoint/test_hf_safetensor_e2e.py b/test/distributed/checkpoint/test_hf_safetensor_e2e.py index f0316fde9f2c5..1aaaf645c58df 100644 --- a/test/distributed/checkpoint/test_hf_safetensor_e2e.py +++ b/test/distributed/checkpoint/test_hf_safetensor_e2e.py @@ -208,7 +208,7 @@ def test_quantized_checkpoint_loading(self) -> None: # Create model.safetensors.index.json with weight mapping weight_map = {} - for key in quantized_checkpoint.keys(): + for key in quantized_checkpoint: weight_map[key] = "model.safetensors" index_data = { @@ -245,7 +245,7 @@ def test_quantized_checkpoint_loading(self) -> None: sorted(original_tensors.keys()), sorted(state_dict_to_load.keys()) ) - for tensor_name in original_tensors.keys(): + for tensor_name in original_tensors: original = original_tensors[tensor_name] loaded = state_dict_to_load[tensor_name] diff --git a/test/distributed/fsdp/test_fsdp_mixed_precision.py b/test/distributed/fsdp/test_fsdp_mixed_precision.py index dee38d0403467..b4532a86e3052 100644 --- a/test/distributed/fsdp/test_fsdp_mixed_precision.py +++ b/test/distributed/fsdp/test_fsdp_mixed_precision.py @@ -498,7 +498,7 @@ def _run_test_mixed_precision_e2e( for name, tensor in state_dict.items(): # Parameters and buffers are checkpointed in their # original dtypes, which may be different. - if name in named_buffers.keys(): + if name in named_buffers: self.assertEqual(tensor.dtype, _BUFFER_ORIG_DTYPE) else: self.assertEqual( diff --git a/test/distributed/test_c10d_common.py b/test/distributed/test_c10d_common.py index 985e2d5f151a2..2a1cb2b5580cb 100644 --- a/test/distributed/test_c10d_common.py +++ b/test/distributed/test_c10d_common.py @@ -1189,9 +1189,7 @@ def _test_sequence_num_incremented(self, process_group, ranks): self.assertEqual(len(set(rank_to_seq_num.values())), 2) self.assertEqual(rank_to_seq_num[0], rank_to_seq_num[2]) expected_same = { - rank_to_seq_num[i] - for i in rank_to_seq_num.keys() - if i not in [0, 2] + rank_to_seq_num[i] for i in rank_to_seq_num if i not in [0, 2] } self.assertEqual(len(expected_same), 1) self.assertEqual(rank_to_seq_num[0] + 1, rank_to_seq_num[1]) @@ -1558,7 +1556,7 @@ def test_debug_level(self): } invalid_debug_modes = ["foo", 0, 1, -1] - for mode in mapping.keys(): + for mode in mapping: os.environ["TORCH_DISTRIBUTED_DEBUG"] = str(mode) dist.set_debug_level_from_env() set_debug_mode = dist.get_debug_level() diff --git a/test/distributed/test_local_tensor.py b/test/distributed/test_local_tensor.py index c58ddf0f82ba7..fa081243c2816 100644 --- a/test/distributed/test_local_tensor.py +++ b/test/distributed/test_local_tensor.py @@ -128,14 +128,14 @@ def test_basic_arithmetic_operations(self): self.assertEqual(len(result_add._local_tensors), 2) # Verify the operation was applied to each local tensor - for rank in identical_local_tensors.keys(): + for rank in identical_local_tensors: expected = identical_local_tensors[rank] + identical_local_tensors[rank] self.assertEqual(result_add._local_tensors[rank], expected) # Test multiplication result_mul = lt1 * 2.0 self.assertIsInstance(result_mul, LocalTensor) - for rank in identical_local_tensors.keys(): + for rank in identical_local_tensors: expected = identical_local_tensors[rank] * 2.0 self.assertEqual(result_mul._local_tensors[rank], expected) @@ -163,7 +163,7 @@ def test_mixed_operations_with_regular_tensors(self): result = lt + regular_tensor self.assertIsInstance(result, LocalTensor) - for rank in identical_local_tensors.keys(): + for rank in identical_local_tensors: expected = identical_local_tensors[rank] + regular_tensor self.assertEqual(result._local_tensors[rank], expected) @@ -212,14 +212,14 @@ def test_collectives_within_local_tensor_mode(self): dist.all_reduce(lt_sum, group=fake_pg) expected_sum = torch.tensor([[6.0, 8.0], [10.0, 12.0]]) - for rank in test_tensors.keys(): + for rank in test_tensors: self.assertEqual(lt_sum._local_tensors[rank], expected_sum) # Test broadcast within mode lt_broadcast = LocalTensor({k: v.clone() for k, v in test_tensors.items()}) dist.broadcast(lt_broadcast, src=0, group=fake_pg) - for rank in test_tensors.keys(): + for rank in test_tensors: self.assertEqual(lt_broadcast._local_tensors[rank], test_tensors[0]) # Test that regular operations still work @@ -293,21 +293,21 @@ def test_collective_reduction_operations(self): lt_sum = LocalTensor({k: v.clone() for k, v in test_tensors.items()}) dist.all_reduce(lt_sum, op=dist.ReduceOp.SUM, group=fake_pg) expected_sum = torch.tensor([[6.0, 7.0], [6.0, 15.0]]) # Sum of all tensors - for rank in test_tensors.keys(): + for rank in test_tensors: self.assertEqual(lt_sum._local_tensors[rank], expected_sum) # Test MAX reduction lt_max = LocalTensor({k: v.clone() for k, v in test_tensors.items()}) dist.all_reduce(lt_max, op=dist.ReduceOp.MAX, group=fake_pg) expected_max = torch.tensor([[3.0, 4.0], [3.0, 6.0]]) # Max across all tensors - for rank in test_tensors.keys(): + for rank in test_tensors: self.assertEqual(lt_max._local_tensors[rank], expected_max) # Test MIN reduction lt_min = LocalTensor({k: v.clone() for k, v in test_tensors.items()}) dist.all_reduce(lt_min, op=dist.ReduceOp.MIN, group=fake_pg) expected_min = torch.tensor([[1.0, 1.0], [1.0, 4.0]]) # Min across all tensors - for rank in test_tensors.keys(): + for rank in test_tensors: self.assertEqual(lt_min._local_tensors[rank], expected_min) def test_all_reduce_collective(self): @@ -328,7 +328,7 @@ def test_all_reduce_collective(self): # Verify all ranks have the sum of all tensors (after adding 1 to each) expected_sum = torch.tensor([[114.0, 225.0, 336.0], [447.0, 558.0, 669.0]]) - for rank in different_tensors.keys(): + for rank in different_tensors: self.assertEqual(lt_sum._local_tensors[rank], expected_sum) def test_broadcast_collective(self): @@ -348,7 +348,7 @@ def test_broadcast_collective(self): # Verify all ranks have rank 1's original tensor expected_broadcast = different_tensors[1] - for rank in different_tensors.keys(): + for rank in different_tensors: self.assertEqual(lt_broadcast._local_tensors[rank], expected_broadcast) def test_all_gather_collective(self): diff --git a/test/dynamo/test_subclasses.py b/test/dynamo/test_subclasses.py index 39a0dc628baec..5d31fa28880a6 100644 --- a/test/dynamo/test_subclasses.py +++ b/test/dynamo/test_subclasses.py @@ -4036,7 +4036,7 @@ def backend(gm, args): @parametrize( "nt_view_name", - [k for k in VIEW_TEST_CASES.keys() if k != "subclass_dense_subclass_dense"], + [k for k in VIEW_TEST_CASES if k != "subclass_dense_subclass_dense"], ) def test_inputs_to_compiled_fn_are_views(self, nt_view_name): self._input_view_test(nt_view_name) diff --git a/test/functorch/xfail_suggester.py b/test/functorch/xfail_suggester.py index cab6b018d5782..8efd8dfe398f2 100644 --- a/test/functorch/xfail_suggester.py +++ b/test/functorch/xfail_suggester.py @@ -73,7 +73,7 @@ def parse_namespace(base): "sparse_": "sparse", "special_": "special", } - for heading in mappings.keys(): + for heading in mappings: if base.startswith(heading): return mappings[heading], base[len(heading) :] return None, base diff --git a/test/inductor/test_compiled_optimizers.py b/test/inductor/test_compiled_optimizers.py index df93e7e1e4d61..ebee5149476b8 100644 --- a/test/inductor/test_compiled_optimizers.py +++ b/test/inductor/test_compiled_optimizers.py @@ -320,7 +320,7 @@ def build_opt_kwarg_db(): continue if has_tensor_lr: - for scheduler_cls in LR_SCHEDULER_TO_KWARGS.keys(): + for scheduler_cls in LR_SCHEDULER_TO_KWARGS: name_w_scheduler = name + f"_{scheduler_cls.__name__.lower()}" compiled_opt_db.append( ( diff --git a/test/profiler/test_profiler.py b/test/profiler/test_profiler.py index 25fb60674e59e..fc128ba61907a 100644 --- a/test/profiler/test_profiler.py +++ b/test/profiler/test_profiler.py @@ -916,8 +916,7 @@ def judge(expected_event_count, prof): ) for key, count in expected_event_count.items(): self.assertTrue( - (key in actual_event_count.keys()) - and (count == actual_event_count[key]) + (key in actual_event_count) and (count == actual_event_count[key]) ) with _profile(use_kineto=kineto_available()) as prof: @@ -1406,10 +1405,7 @@ def test_profiler_fwd_bwd_link(self): s_ts_2 = flow_s_to_ts[2] f_ts_2 = flow_f_to_ts[2] self.assertTrue( - all( - ts in ts_to_name.keys() - for ts in [s_ts_1, f_ts_1, s_ts_2, f_ts_2] - ) + all(ts in ts_to_name for ts in [s_ts_1, f_ts_1, s_ts_2, f_ts_2]) ) self.assertTrue( ts_to_name[s_ts_1] == "aten::binary_cross_entropy_with_logits" diff --git a/test/quantization/core/test_quantized_module.py b/test/quantization/core/test_quantized_module.py index b2b2b402327ad..f2cdbfd2d6316 100644 --- a/test/quantization/core/test_quantized_module.py +++ b/test/quantization/core/test_quantized_module.py @@ -1840,7 +1840,7 @@ def test_cell_api(self, dtype): 'RNNTanh': torch.ops.quantized.quantized_rnn_tanh_cell_dynamic, 'RNNReLU': torch.ops.quantized.quantized_rnn_relu_cell_dynamic} - for rnn_type in cell_dict.keys(): + for rnn_type in cell_dict: if not (dtype == torch.float16 and torch.backends.quantized.engine in ("qnnpack", "onednn")): # fp16 dynamic quant is not supported for qnnpack or onednn kwargs = {'input_size': input_size, 'hidden_size': hidden_size, 'bias': bias, 'dtype': dtype} @@ -1903,7 +1903,7 @@ def test_rnn_cell(self): 'RNNTanh': nnqr.RNNCell, 'RNNReLU': nnqr.RNNCell} - for rnn_type in cell_dict.keys(): + for rnn_type in cell_dict: kwargs = {'input_size': input_size, 'hidden_size': hidden_size, 'bias': bias} if rnn_type == 'RNNReLU': kwargs['nonlinearity'] = "relu" diff --git a/test/quantization/core/test_workflow_module.py b/test/quantization/core/test_workflow_module.py index 9ea8d38828a63..93993fe33a49c 100644 --- a/test/quantization/core/test_workflow_module.py +++ b/test/quantization/core/test_workflow_module.py @@ -650,7 +650,7 @@ def test_record_observer(self): observer_dict = {} _get_observer_dict(model, observer_dict) - self.assertTrue('fc1.module.activation_post_process' in observer_dict.keys(), + self.assertTrue('fc1.module.activation_post_process' in observer_dict, 'observer is not recorded in the dict') self.assertEqual(len(observer_dict['fc1.module.activation_post_process'].get_tensor_value()), 2 * len(self.calib_data)) diff --git a/test/quantization/pt2e/test_x86inductor_quantizer.py b/test/quantization/pt2e/test_x86inductor_quantizer.py index dfd591cb9419c..5b9aa34158b5e 100644 --- a/test/quantization/pt2e/test_x86inductor_quantizer.py +++ b/test/quantization/pt2e/test_x86inductor_quantizer.py @@ -2016,7 +2016,7 @@ def test_qat_conv2d_unary(self): } with override_quantized_engine("x86"): - for unary_op in unary_map.keys(): + for unary_op in unary_map: m = TestHelperModules.Conv2dUnaryModule( unary_map[unary_op][0], with_bn=True ) diff --git a/test/test_fx.py b/test/test_fx.py index f728187fd85f5..3ad21e64c8ce2 100644 --- a/test/test_fx.py +++ b/test/test_fx.py @@ -4746,7 +4746,7 @@ def check_symbols_have_bc_designation(m, seen): check_symbols_have_bc_designation(torch.fx.passes, set()) non_back_compat_strs = [ - torch.typename(obj) for obj in non_back_compat_objects.keys() + torch.typename(obj) for obj in non_back_compat_objects ] # Only want objects in torch.fx non_back_compat_strs = [ diff --git a/test/test_testing.py b/test/test_testing.py index c660eb83b8042..09887be17c47a 100644 --- a/test/test_testing.py +++ b/test/test_testing.py @@ -510,7 +510,7 @@ def test_trivial_passing_test(self, device): # Test without setting env var should run everything. env = dict(os.environ) for k in ['CI', PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY, PYTORCH_TESTING_DEVICE_EXCEPT_FOR_KEY]: - if k in env.keys(): + if k in env: del env[k] _, stderr = TestCase.run_process_no_exception(test_filter_file_template, env=env) self.assertIn(f'Ran {test_bases_count} test', stderr.decode('ascii')) diff --git a/torch/ao/ns/fx/pattern_utils.py b/torch/ao/ns/fx/pattern_utils.py index c4d231e713b20..d10fdd39da908 100644 --- a/torch/ao/ns/fx/pattern_utils.py +++ b/torch/ao/ns/fx/pattern_utils.py @@ -72,7 +72,7 @@ def get_reversed_fusions() -> list[tuple[NSFusionType, int]]: all_quant_patterns = _get_pattern_to_quantize_handlers(get_native_backend_config()) default_base_op_idx = 0 - for quant_pattern in all_quant_patterns.keys(): + for quant_pattern in all_quant_patterns: # TODO: this is a temporary hack to flatten the patterns from quantization so # that it works with the ns matcher function, maybe we should use `_is_match` # in torch.ao.quantization.fx.match_utils to match the patterns diff --git a/torch/ao/pruning/sparsifier/base_sparsifier.py b/torch/ao/pruning/sparsifier/base_sparsifier.py index 14764c77cc604..59f6a46fe1350 100644 --- a/torch/ao/pruning/sparsifier/base_sparsifier.py +++ b/torch/ao/pruning/sparsifier/base_sparsifier.py @@ -196,7 +196,7 @@ def prepare(self, model, config): # check that whatever was put into local_args agrees with what was obtained # from tensor_fqn - for key in info_from_tensor_fqn.keys(): + for key in info_from_tensor_fqn: if key in local_args: if not ( info_from_tensor_fqn[key] == local_args[key] diff --git a/torch/ao/quantization/_equalize.py b/torch/ao/quantization/_equalize.py index a78dd307fc6d6..e4ff327f285aa 100644 --- a/torch/ao/quantization/_equalize.py +++ b/torch/ao/quantization/_equalize.py @@ -270,7 +270,7 @@ def converged(curr_modules, prev_modules, threshold=1e-4): summed_norms = torch.tensor(0.0) if None in prev_modules.values(): return False - for name in curr_modules.keys(): + for name in curr_modules: curr_weight = get_module_weight(curr_modules[name]) prev_weight = get_module_weight(prev_modules[name]) diff --git a/torch/ao/quantization/fx/_equalize.py b/torch/ao/quantization/fx/_equalize.py index b8809c1c60871..6c8c32b992ed4 100644 --- a/torch/ao/quantization/fx/_equalize.py +++ b/torch/ao/quantization/fx/_equalize.py @@ -350,7 +350,7 @@ def get_op_node_and_weight_eq_obs( # Find the op node that comes directly after the input equalization observer op_node = None - for user in input_eq_obs_node.users.keys(): + for user in input_eq_obs_node.users: if node_supports_equalization(user, modules): op_node = user break diff --git a/torch/ao/quantization/fx/_model_report/detector.py b/torch/ao/quantization/fx/_model_report/detector.py index 993a6c41f176f..0a48bbbaaee90 100644 --- a/torch/ao/quantization/fx/_model_report/detector.py +++ b/torch/ao/quantization/fx/_model_report/detector.py @@ -743,7 +743,7 @@ def generate_detector_report( # Populates the string based report with the information from module_dynamic_static_info # Compiles the complete report by appending relevant formatted strings - for module_fqn in module_dynamic_static_info.keys(): + for module_fqn in module_dynamic_static_info: # there is at least 1 module for suggestion modules_added = True module_info = module_dynamic_static_info[module_fqn] diff --git a/torch/ao/quantization/fx/convert.py b/torch/ao/quantization/fx/convert.py index 08ae102f69f41..06936e5327bce 100644 --- a/torch/ao/quantization/fx/convert.py +++ b/torch/ao/quantization/fx/convert.py @@ -683,7 +683,7 @@ def _maybe_get_observer_for_node( If the node is observed, return the observer instance. Otherwise, return None. """ - for maybe_obs_node in node.users.keys(): + for maybe_obs_node in node.users: if maybe_obs_node.op == "call_module": maybe_obs = modules[str(maybe_obs_node.target)] if _is_activation_post_process(maybe_obs): diff --git a/torch/ao/quantization/fx/prepare.py b/torch/ao/quantization/fx/prepare.py index 0c05e6499901d..8351dbedd07d7 100644 --- a/torch/ao/quantization/fx/prepare.py +++ b/torch/ao/quantization/fx/prepare.py @@ -950,7 +950,7 @@ def _maybe_insert_input_observer_for_arg_or_kwarg( # we should remove this # removing this means we insert one observer for each use, even if they # have the same dtype, we can have an extra pass that removes the extra observers - for maybe_obs_node in arg.users.keys(): + for maybe_obs_node in arg.users: if maybe_obs_node.op == "call_module": maybe_obs_mod = named_modules[maybe_obs_node.target] # type: ignore[index] if ( @@ -1440,7 +1440,7 @@ def _maybe_make_input_output_share_observers( setattr(named_modules[parent_name], name, obs_mod_to_use) # set the output observer node to use that module - for output_obs_node in node.users.keys(): + for output_obs_node in node.users: if not _is_activation_post_process_node(output_obs_node, named_modules): raise AssertionError( "output_obs_node must be an activation post process node" diff --git a/torch/ao/quantization/fx/qconfig_mapping_utils.py b/torch/ao/quantization/fx/qconfig_mapping_utils.py index 74f90505ea2af..951ca66703f47 100644 --- a/torch/ao/quantization/fx/qconfig_mapping_utils.py +++ b/torch/ao/quantization/fx/qconfig_mapping_utils.py @@ -206,7 +206,7 @@ def _check_is_valid_config_dict( `config_dict`: dictionary whose keys we want to check """ - for k in config_dict.keys(): + for k in config_dict: if k not in allowed_keys: raise ValueError( "Expected " @@ -250,7 +250,7 @@ def _compare_prepare_convert_qconfig_mappings( _MODULE_NAME_REGEX_DICT_KEY, ] for i in range(len(prepare_dicts)): - for name in prepare_dicts[i].keys(): + for name in prepare_dicts[i]: if name not in convert_dicts[i]: raise AssertionError( f"Missing key {dict_names[i]} {name} in convert QConfigMapping when it was present in prepare" diff --git a/torch/ao/quantization/fx/utils.py b/torch/ao/quantization/fx/utils.py index 3e2afaaa1d9f3..9f76f2a328df1 100644 --- a/torch/ao/quantization/fx/utils.py +++ b/torch/ao/quantization/fx/utils.py @@ -442,7 +442,7 @@ def maybe_get_next_module( target_functional_type: Functional type that we want to check """ - for user in node.users.keys(): + for user in node.users: if ( user.op == "call_module" and target_module_type is not None diff --git a/torch/ao/quantization/pt2e/port_metadata_pass.py b/torch/ao/quantization/pt2e/port_metadata_pass.py index aab4c435c872f..8e768592826e4 100644 --- a/torch/ao/quantization/pt2e/port_metadata_pass.py +++ b/torch/ao/quantization/pt2e/port_metadata_pass.py @@ -66,7 +66,7 @@ def _find_choose_qparams_node(node: torch.fx.Node) -> Optional[torch.fx.Node]: continue if n.op == "call_function" and n.target in _CHOOSE_QPARAMS_OPS: return n - for k in n.users.keys(): + for k in n.users: queue.append(k) return None diff --git a/torch/ao/quantization/pt2e/prepare.py b/torch/ao/quantization/pt2e/prepare.py index 9f7767101aba6..c15e7878eb2b7 100644 --- a/torch/ao/quantization/pt2e/prepare.py +++ b/torch/ao/quantization/pt2e/prepare.py @@ -391,7 +391,7 @@ def _maybe_insert_input_observer_for_arg_or_kwarg( # instead of inserting new observers we will have: # conv1 -> obs1 -> existing_obs -> conv2 # \ -> conv3 - for maybe_obs_node in arg.users.keys(): + for maybe_obs_node in arg.users: if not _is_activation_post_process_node(maybe_obs_node, named_modules): continue maybe_obs_mod = named_modules[maybe_obs_node.target] # type: ignore[index] diff --git a/torch/ao/quantization/qconfig_mapping.py b/torch/ao/quantization/qconfig_mapping.py index 10111d4ab8a2a..2bfce5d858cc4 100644 --- a/torch/ao/quantization/qconfig_mapping.py +++ b/torch/ao/quantization/qconfig_mapping.py @@ -187,7 +187,7 @@ def _get_default_qconfig_mapping_with_default_qconfig( else: qconfig_mapping = get_default_qconfig_mapping(backend) qconfig_mapping.set_global(default_qconfig) - for pattern in qconfig_mapping.object_type_qconfigs.keys(): + for pattern in qconfig_mapping.object_type_qconfigs: if pattern not in _FIXED_QPARAMS_OP_TO_OBSERVER: qconfig_mapping.set_object_type(pattern, default_qconfig) return qconfig_mapping diff --git a/torch/ao/quantization/quantize_jit.py b/torch/ao/quantization/quantize_jit.py index 79f8db1a792fc..ec4caab1edcd0 100644 --- a/torch/ao/quantization/quantize_jit.py +++ b/torch/ao/quantization/quantize_jit.py @@ -68,7 +68,7 @@ def fuse_conv_bn_jit(model, inplace=False): def _prepare_jit(model, qconfig_dict, inplace=False, quant_type=QuantType.STATIC): _check_is_script_module(model) _check_forward_method(model) - if not all(isinstance(x, str) for x in qconfig_dict.keys()): + if not all(isinstance(x, str) for x in qconfig_dict): raise ValueError("qconfig_dict should only contain names(str) as keys.") scripted_qconfig_dict = script_qconfig_dict(qconfig_dict) model = fuse_conv_bn_jit(model, inplace) @@ -90,7 +90,7 @@ def _prepare_ondevice_jit( quant_type=QuantType.STATIC, ): _check_is_script_module(model) - if not all(isinstance(x, str) for x in qconfig_dict.keys()): + if not all(isinstance(x, str) for x in qconfig_dict): raise ValueError("qconfig_dict should only contain names(str) as keys.") scripted_qconfig_dict = script_qconfig_dict(qconfig_dict) method_graph = model._c._get_method(method_name).graph diff --git a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py index b10163d4b1e50..816f48fd6267a 100644 --- a/torch/ao/quantization/quantizer/x86_inductor_quantizer.py +++ b/torch/ao/quantization/quantizer/x86_inductor_quantizer.py @@ -1361,9 +1361,7 @@ def is_all_inputs_connected_to_quantized_op(input_nodes): elif ( node.target is torch.ops.aten.flatten.using_ints and len(node.users) > 0 - and not any( - user.target in quantizable_ops for user in node.users.keys() - ) + and not any(user.target in quantizable_ops for user in node.users) ): # Recipe of flatten: check if any users of flatten node are quantizable ops or not return diff --git a/torch/fx/experimental/unify_refinements.py b/torch/fx/experimental/unify_refinements.py index bab662e0655a2..efafb146179a6 100644 --- a/torch/fx/experimental/unify_refinements.py +++ b/torch/fx/experimental/unify_refinements.py @@ -61,7 +61,7 @@ def substitute_solution_one_type(mapping, t): Apply the most general unifier to a type """ if isinstance(t, Var): - if t in mapping.keys(): + if t in mapping: return mapping[t] else: return t @@ -69,7 +69,7 @@ def substitute_solution_one_type(mapping, t): elif isinstance(t, TensorType): new_type = [] for typ in t.__args__: - if typ in mapping.keys(): + if typ in mapping: new_type.append(mapping[typ]) else: new_type.append(typ) @@ -102,7 +102,7 @@ def substitute_all_types(graph, mapping): flag = False for k in mapping: old_mapping_val = mapping[k] - if mapping[k] in mapping.keys(): + if mapping[k] in mapping: new_key = mapping[k] mapping[k] = mapping[new_key] if old_mapping_val != mapping[k]: diff --git a/torch/fx/graph.py b/torch/fx/graph.py index d924eac24d3c2..d8cfa42472b49 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -1145,7 +1145,7 @@ def find_nodes(self, *, op: str, target: Optional["Target"] = None): return [*self.table[(op, None)].keys()] # op is call_method, get_attr, call_module - return [node for node in self.table[(op, None)].keys() if node.target == target] + return [node for node in self.table[(op, None)] if node.target == target] @compatibility(is_backward_compatible=True) diff --git a/torch/fx/passes/runtime_assert.py b/torch/fx/passes/runtime_assert.py index 58aa801062824..1d3b0b33e7bce 100644 --- a/torch/fx/passes/runtime_assert.py +++ b/torch/fx/passes/runtime_assert.py @@ -373,11 +373,9 @@ def has_new_untracked_symbols(): shape_env, node.meta.get("unbacked_bindings", {}) ) - assert resolved_unbacked_bindings is not None - def has_new_unbacked_bindings(): - # pyrefly: ignore [missing-attribute] - for key in resolved_unbacked_bindings.keys(): + assert resolved_unbacked_bindings is not None + for key in resolved_unbacked_bindings: if key not in expr_to_proxy: return True return False diff --git a/torch/fx/passes/splitter_base.py b/torch/fx/passes/splitter_base.py index 6cf708a619069..8d90f9d55cfdb 100644 --- a/torch/fx/passes/splitter_base.py +++ b/torch/fx/passes/splitter_base.py @@ -204,7 +204,7 @@ def to_dict(self): Create dict dump on all events. """ ret: dict[str, list[str]] = {} - for name in self.node_events.keys(): + for name in self.node_events: ret[name] = [] for idx in self.node_events.get(name, []): event = self.events[idx] @@ -218,7 +218,7 @@ def print_all(self, writer=None): """ if not writer: writer = self.writer - for name in self.node_events.keys(): + for name in self.node_events: writer(f"Node: {name}:") self.print_node(name, recursive=False, tab=" ", writer=writer) diff --git a/torch/fx/passes/utils/source_matcher_utils.py b/torch/fx/passes/utils/source_matcher_utils.py index 043c65e6b77d2..82259b8a36ab7 100644 --- a/torch/fx/passes/utils/source_matcher_utils.py +++ b/torch/fx/passes/utils/source_matcher_utils.py @@ -113,7 +113,7 @@ def make_partition(nodes: list[Node], module_type: type) -> SourcePartition: # get_attr nodes won't be output nodes continue - for user in node.users.keys(): + for user in node.users: if user not in nodes: output_nodes.add(node) @@ -157,7 +157,7 @@ def check_subgraphs_connected( """ for node in reversed(subgraph1.nodes): - for user in node.users.keys(): + for user in node.users: if user in subgraph2.nodes: return True return False diff --git a/torch/jit/_recursive.py b/torch/jit/_recursive.py index 3a2b3ef8b6001..343871b1f94a2 100644 --- a/torch/jit/_recursive.py +++ b/torch/jit/_recursive.py @@ -574,7 +574,7 @@ def create_script_module_impl(nn_module, concrete_type, stubs_fn): def init_fn(script_module): # Initialize the ScriptModule: # 1. Copy the attributes/parameters/buffers from the original `nn_module` to the new ScriptModule. - for name in concrete_type.get_attributes().keys(): + for name in concrete_type.get_attributes(): orig_value = getattr(nn_module, name) orig_value = ( orig_value.value diff --git a/torch/jit/_script.py b/torch/jit/_script.py index a8bb3ba9bd8f5..46e6f47534108 100644 --- a/torch/jit/_script.py +++ b/torch/jit/_script.py @@ -856,7 +856,7 @@ def __setattr__(self, attr, value): self._c.setattr(attr, value) elif ( hasattr(self, "_concrete_type") - and attr in self._concrete_type.get_constants().keys() + and attr in self._concrete_type.get_constants() ): # TODO: we don't have _concrete_type set after load(), and in general we lose constant information. # We should encode constants as class type attributes (or something) so it persists across save/load. diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py index 10a240e3a9cf7..33bf35a1d852a 100644 --- a/torch/nn/modules/module.py +++ b/torch/nn/modules/module.py @@ -2521,7 +2521,7 @@ def _load_from_state_dict( unexpected_keys.append(extra_state_key) if strict: - for key in state_dict.keys(): + for key in state_dict: if key.startswith(prefix) and key != extra_state_key: input_name = key[len(prefix) :].split(".", 1) # Must be Module if it have attributes diff --git a/torch/onnx/_internal/exporter/_dynamic_shapes.py b/torch/onnx/_internal/exporter/_dynamic_shapes.py index e128ecf74e9e4..888db138736fb 100644 --- a/torch/onnx/_internal/exporter/_dynamic_shapes.py +++ b/torch/onnx/_internal/exporter/_dynamic_shapes.py @@ -67,7 +67,7 @@ def from_dynamic_axes_to_dynamic_shapes( # output names are not needed for dynamic_shapes continue if isinstance(axes, dict): - if any(not isinstance(k, int) for k in axes.keys()): + if any(not isinstance(k, int) for k in axes): raise ValueError( "The axis in dynamic_axes must be in the form of: dict[int, str] or list[int]." ) diff --git a/torch/profiler/_memory_profiler.py b/torch/profiler/_memory_profiler.py index 3f21ce81171d7..dfa83f7467cd6 100644 --- a/torch/profiler/_memory_profiler.py +++ b/torch/profiler/_memory_profiler.py @@ -711,7 +711,7 @@ def timeline(self) -> tuple[tuple[int, Action, KeyAndID, int], ...]: events: list[tuple[int, Action, TensorAndID]] = [ (-1, Action.PREEXISTING, (key, version)) - for key, version in snapshot.keys() + for key, version in snapshot if (key, True) not in allocation_times and version == 0 ] @@ -938,7 +938,7 @@ def _set_parameters_using_data_flow(self) -> None: parameter_keys = {key.id for key, _ in candidate_parameters} parameter_keys &= self._any_version_depends_on_gradient() - for key, _ in snapshot.keys(): + for key, _ in snapshot: if key.id in parameter_keys: self._categories.set_by_id(key, Category.PARAMETER) diff --git a/torch/profiler/_utils.py b/torch/profiler/_utils.py index 47df87ce1678d..2c575b06509e5 100644 --- a/torch/profiler/_utils.py +++ b/torch/profiler/_utils.py @@ -103,7 +103,7 @@ def __init__(self, prof: profile) -> None: self.metrics: dict[EventKey, EventMetrics] = {} self.compute_self_time() self.event_keys = sorted( - (e for e in self.metrics.keys()), key=lambda x: x.event.start_time_ns + self.metrics.keys(), key=lambda x: x.event.start_time_ns ) self.events = [e.event for e in self.event_keys] self.cuda_events: list[_KinetoEvent] = [] @@ -265,7 +265,7 @@ def compute_idle_time(self) -> None: idle_intervals.append(Interval(idle_start, data_point.start)) idle = False - event_list = [e.event for e in self.metrics.keys()] + event_list = [e.event for e in self.metrics] for event in event_list: self.metrics[EventKey(event)].idle_time_ns = EventKey( event @@ -316,7 +316,7 @@ def rank_events(self, length): # Filter out events that are not in the decrease interval event_list = [ event - for event in self.metrics.keys() + for event in self.metrics if event.intervals_overlap(decrease_interval) ] if event_list: diff --git a/torch/utils/_config_module.py b/torch/utils/_config_module.py index 12ba497efd79c..f302a10b8338e 100644 --- a/torch/utils/_config_module.py +++ b/torch/utils/_config_module.py @@ -692,7 +692,7 @@ def __enter__(self) -> None: raise AssertionError( "prior should be empty when entering ConfigPatch" ) - for key in self.changes.keys(): + for key in self.changes: # KeyError on invalid entry prior[key] = config.__getattr__(key) for k, v in self.changes.items(): diff --git a/torch/utils/collect_env.py b/torch/utils/collect_env.py index 3b8b62cfde6d4..a643314f3b9cd 100644 --- a/torch/utils/collect_env.py +++ b/torch/utils/collect_env.py @@ -803,14 +803,14 @@ def get_version_or_na(cfg, prefix): def pretty_str(envinfo): def replace_nones(dct, replacement="Could not collect"): - for key in dct.keys(): + for key in dct: if dct[key] is not None: continue dct[key] = replacement return dct def replace_bools(dct, true="Yes", false="No"): - for key in dct.keys(): + for key in dct: if dct[key] is True: dct[key] = true elif dct[key] is False: diff --git a/torch/utils/data/datapipes/iter/callable.py b/torch/utils/data/datapipes/iter/callable.py index 1ce1c9c07196c..2e3bb18c80bb7 100644 --- a/torch/utils/data/datapipes/iter/callable.py +++ b/torch/utils/data/datapipes/iter/callable.py @@ -149,7 +149,7 @@ def _collate_helper(conversion, item): tuple_names: list = [] tuple_values: list = [] - for name in conversion.keys(): + for name in conversion: if name not in columns_name: raise RuntimeError("Conversion keys mismatch") diff --git a/torch/utils/data/datapipes/iter/grouping.py b/torch/utils/data/datapipes/iter/grouping.py index 865feb9953e35..a289bdb5e0949 100644 --- a/torch/utils/data/datapipes/iter/grouping.py +++ b/torch/utils/data/datapipes/iter/grouping.py @@ -234,7 +234,7 @@ def _remove_biggest_key(self): biggest_key = None biggest_size = 0 result_to_yield = None - for findkey in self.buffer_elements.keys(): + for findkey in self.buffer_elements: if len(self.buffer_elements[findkey]) > biggest_size: biggest_size = len(self.buffer_elements[findkey]) biggest_key = findkey diff --git a/torch/utils/tensorboard/summary.py b/torch/utils/tensorboard/summary.py index f36382cb42e16..1b6a2bb9bb66f 100644 --- a/torch/utils/tensorboard/summary.py +++ b/torch/utils/tensorboard/summary.py @@ -334,7 +334,7 @@ def hparams(hparam_dict=None, metric_dict=None, hparam_domain_discrete=None): # pyrefly: ignore [missing-attribute] ssi = Summary(value=[Summary.Value(tag=SESSION_START_INFO_TAG, metadata=smd)]) - mts = [MetricInfo(name=MetricName(tag=k)) for k in metric_dict.keys()] + mts = [MetricInfo(name=MetricName(tag=k)) for k in metric_dict] exp = Experiment(hparam_infos=hps, metric_infos=mts) diff --git a/torch/utils/tensorboard/writer.py b/torch/utils/tensorboard/writer.py index 4fab33dc7ff09..0f533ae5b0f57 100644 --- a/torch/utils/tensorboard/writer.py +++ b/torch/utils/tensorboard/writer.py @@ -424,7 +424,7 @@ def add_scalars(self, main_tag, tag_scalar_dict, global_step=None, walltime=None fw_tag = fw_logdir + "/" + main_tag.replace("/", "_") + "_" + tag if self.all_writers is None: raise AssertionError("self.all_writers is None") - if fw_tag in self.all_writers.keys(): + if fw_tag in self.all_writers: fw = self.all_writers[fw_tag] else: fw = FileWriter( diff --git a/torchgen/gen_backend_stubs.py b/torchgen/gen_backend_stubs.py index 07097010f8f28..c9f1b660f02c5 100644 --- a/torchgen/gen_backend_stubs.py +++ b/torchgen/gen_backend_stubs.py @@ -287,8 +287,7 @@ def error_on_missing_kernels( expected_backend_native_funcs: list[NativeFunction] = [ f for f in native_functions - if f.func.name in expected_backend_op_names.keys() - and f.func.name not in full_codegen + if f.func.name in expected_backend_op_names and f.func.name not in full_codegen ] expected_backend_kernel_name_counts: dict[str, list[NativeFunction]] = defaultdict( list From 80ec2ab78e43a2f637bf5ceae753061c315eaaa5 Mon Sep 17 00:00:00 2001 From: Yuanyuan Chen Date: Thu, 6 Nov 2025 12:19:56 +0000 Subject: [PATCH 123/130] [8/N] Fix unused loop variables in tests (#166921) This PR continues to fix or remove unused loop variables in tests. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166921 Approved by: https://github.com/mlazos --- test/quantization/core/test_workflow_ops.py | 4 ++-- test/quantization/fx/test_quantize_fx.py | 2 +- test/test_datapipe.py | 2 +- test/test_jit_fuser_te.py | 15 ++++----------- torch/_export/serde/serialize.py | 2 +- 5 files changed, 9 insertions(+), 16 deletions(-) diff --git a/test/quantization/core/test_workflow_ops.py b/test/quantization/core/test_workflow_ops.py index f69852760e8a0..78e7799c864b1 100644 --- a/test/quantization/core/test_workflow_ops.py +++ b/test/quantization/core/test_workflow_ops.py @@ -368,8 +368,8 @@ def _test_forward_per_tensor_cachemask_impl(self, device): float_types = (torch.float32, torch.float16, torch.float64, torch.bfloat16) torch_types = (torch.qint8, torch.quint8) Xs = (torch.randn(4, 8, device=device), torch.randn(4, 16, device=device)[:, ::2]) - tensor_qparam = (True, False) - for float_type, torch_type, X, tensor_qparams in itertools.product(float_types, torch_types, Xs, tensor_qparam): + tensor_qparams = (True, False) + for float_type, torch_type, X, tensor_qparam in itertools.product(float_types, torch_types, Xs, tensor_qparams): # pick the scale + zp so that some values get clipped X = X.to(float_type) obs = torch.ao.quantization.MinMaxObserver(torch_type) diff --git a/test/quantization/fx/test_quantize_fx.py b/test/quantization/fx/test_quantize_fx.py index b33afc7a80363..9c0526fde6987 100644 --- a/test/quantization/fx/test_quantize_fx.py +++ b/test/quantization/fx/test_quantize_fx.py @@ -8807,7 +8807,7 @@ def forward(self, indices, offsets): # check it works in None and static qconfig for qconfig in [None, default_qconfig]: - qconfig_dict = {"": default_qconfig} + qconfig_dict = {"": qconfig} m = M().eval() m = prepare_fx(model, qconfig_dict, example_inputs=example_inputs) self.checkGraphModuleNodes(m, expected_node_occurrence={ diff --git a/test/test_datapipe.py b/test/test_datapipe.py index 5a535e7e00663..cab86e42734f1 100644 --- a/test/test_datapipe.py +++ b/test/test_datapipe.py @@ -1136,7 +1136,7 @@ def test_fork_iterdatapipe(self): ) break with warnings.catch_warnings(record=True) as wa: - for i, (n1, n2) in enumerate(zip(dp1, dp2)): + for n1, n2 in zip(dp1, dp2): output1.append(n1) output2.append(n2) self.assertEqual(len(wa), 1) diff --git a/test/test_jit_fuser_te.py b/test/test_jit_fuser_te.py index c3018be817d9b..8622d428cb4fe 100644 --- a/test/test_jit_fuser_te.py +++ b/test/test_jit_fuser_te.py @@ -1682,11 +1682,8 @@ def apply(fn): ] dtypes = ["int", "float", "bool"] values = {"int": [10, 3], "float": [12.34, 2.78], "bool": [True, False]} - devices = self.devices - for dtype_x, dtype_y, op, device in product( - dtypes, dtypes, binary_ops, devices - ): - code = ir_template.format(**locals()) + for dtype_x, dtype_y, op in product(dtypes, dtypes, binary_ops): + code = ir_template.format(dtype_x=dtype_x, dtype_y=dtype_y, op=op) # Interpret the graph try: @@ -1701,9 +1698,7 @@ def apply(fn): try: k = torch._C._te.TensorExprKernel(graph) except Exception as e: - raise RuntimeError( - " ".join(["Compilation failed:", device, str(code)]) - ) from e + raise RuntimeError(" ".join(["Compilation failed:", str(code)])) from e # Run the graph for x, y in product(values[dtype_x], values[dtype_y]): @@ -1713,9 +1708,7 @@ def apply(fn): self.assertEqual(ref, res) except Exception as e: raise RuntimeError( - " ".join( - ["Failed at runtime:", device, str(x), str(y), str(code)] - ) + " ".join(["Failed at runtime:", str(x), str(y), str(code)]) ) from e def test_matmul(self): diff --git a/torch/_export/serde/serialize.py b/torch/_export/serde/serialize.py index 9c4629f13337d..e328422ec5e66 100644 --- a/torch/_export/serde/serialize.py +++ b/torch/_export/serde/serialize.py @@ -617,7 +617,7 @@ def get_triton_kernel_and_cache_entry(node: torch.fx.Node): return actual_kernel, matching_entries[0][1] if is_autotuner: - for sig_key, cache_entry in matching_entries: + for _sig_key, cache_entry in matching_entries: entry_metadata = cache_entry.metadata # pyrefly: ignore [missing-attribute] for config in kernel.configs: From b2d72a4008fa13612adc34c246e8e24c2185300e Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Thu, 6 Nov 2025 13:26:04 +0000 Subject: [PATCH 124/130] Revert "Don't hardcode double argument for reduction base (#166951)" This reverts commit a74fe75c450277eb88a95c764e8b0a664a550a86. Reverted https://github.com/pytorch/pytorch/pull/166951 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/166951#issuecomment-3497253260)) --- aten/src/ATen/native/cpu/Reduce.h | 4 ++-- aten/src/ATen/native/cpu/ReduceOpsKernel.cpp | 22 +++++++++++++++++++- 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/aten/src/ATen/native/cpu/Reduce.h b/aten/src/ATen/native/cpu/Reduce.h index ab9051ca8d2a2..6c9efbb0f6e7f 100644 --- a/aten/src/ATen/native/cpu/Reduce.h +++ b/aten/src/ATen/native/cpu/Reduce.h @@ -247,8 +247,8 @@ void binary_kernel_reduce(TensorIteratorBase& iter, ops_t ops, init_t init) { }); } -template -void binary_kernel_reduce_vec(TensorIteratorBase& iter, func_t op, vec_func_t vop, ident_t ident = static_cast(0)) { +template +void binary_kernel_reduce_vec(TensorIteratorBase& iter, func_t op, vec_func_t vop, double ident = 0) { using traits = binary_function_traits; static_assert( all_same< diff --git a/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp b/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp index 053db7b4eda00..3bad49a32d98c 100644 --- a/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp +++ b/aten/src/ATen/native/cpu/ReduceOpsKernel.cpp @@ -339,13 +339,33 @@ void or_kernel_impl(TensorIterator& iter) { } } +template +struct MinValuesOps: public at::native::MinOps { + using arg_t = typename MinOps::arg_t; + static scalar_t project(arg_t arg) { + return arg.first; + } +}; + void min_values_kernel_impl(TensorIterator& iter) { + // This case is special because of Vectorized does not + // handle upper_bound(). + // See: https://github.com/pytorch/pytorch/issues/43254 + if (iter.dtype() == kLong || iter.dtype() == kUInt64) { + AT_DISPATCH_V2(iter.dtype(), "min_values_cpu", AT_WRAP([&iter] { + binary_kernel_reduce( + iter, + MinValuesOps{}, + std::pair(upper_bound(), -1)); + }), kLong, kUInt64); + return; + } AT_DISPATCH_V2(iter.dtype(), "min_values_cpu", AT_WRAP([&iter] { binary_kernel_reduce_vec( iter, [](scalar_t a, scalar_t b) -> scalar_t { return min_impl(a, b); }, [](Vectorized a, Vectorized b) { return minimum(a, b); }, - upper_bound()); + static_cast(upper_bound())); }), AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES), kBFloat16, kHalf, kBool); } From 2005b5f54842427839edb02a6782ea92a696560a Mon Sep 17 00:00:00 2001 From: IvanKobzarev Date: Wed, 5 Nov 2025 07:22:07 -0800 Subject: [PATCH 125/130] [inductor] Use runtime estimations in iterative reorder collectives pass (#167080) Split of https://github.com/pytorch/pytorch/pull/162469 to be under 2K reorder iterative part Pull Request resolved: https://github.com/pytorch/pytorch/pull/167080 Approved by: https://github.com/eellison --- test/distributed/test_inductor_collectives.py | 6 +- torch/_inductor/comms.py | 1123 ++++++++++++----- torch/_inductor/config.py | 25 +- torch/_inductor/config_comms.py | 47 + torch/_inductor/utils.py | 7 +- 5 files changed, 851 insertions(+), 357 deletions(-) diff --git a/test/distributed/test_inductor_collectives.py b/test/distributed/test_inductor_collectives.py index ac3103e09341d..daa9bf2e309ff 100644 --- a/test/distributed/test_inductor_collectives.py +++ b/test/distributed/test_inductor_collectives.py @@ -1985,6 +1985,7 @@ def _reorder_communication_preserving_peak_memory( "bucket_reduce_scatters_fx_bucket_size_determinator": lambda _: 2, "reorder_for_compute_comm_overlap": True, "reorder_for_compute_comm_overlap_passes": [ + _reorder_communication_preserving_peak_memory, sink_waits_iterative, _reorder_communication_preserving_peak_memory, ], @@ -2046,11 +2047,6 @@ def _reorder_communication_preserving_peak_memory( assert node_stats is not None self.assertTrue(isinstance(node_stats, dict)) self.assertEqual(len(node_stats), 4) - it = iter(node_stats.values()) - node_stat0 = next(it) - self.assertTrue(node_stat0.limiting_factor == "None") - node_stat1 = next(it) - self.assertTrue("collective ordering" in node_stat1.limiting_factor) @skipIfXpu # https://github.com/intel/torch-xpu-ops/issues/1581 @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") diff --git a/torch/_inductor/comms.py b/torch/_inductor/comms.py index 6c7c9a8bd7dab..a4a4cac8e3ec2 100644 --- a/torch/_inductor/comms.py +++ b/torch/_inductor/comms.py @@ -18,7 +18,7 @@ from torch.multiprocessing.reductions import StorageWeakRef from torch.utils._ordered_set import OrderedSet -from . import config, ir +from . import config, config_comms, ir from .dependencies import WeakDep @@ -155,12 +155,15 @@ class ReorderInfo: Debug info describing how an individual snode was reordered """ - initial_exposed: float = -1 - final_exposed: float = -1 limiting_factor: str = "None" moves: int = 0 grouped: int = 0 grouped_info: str = "" + comm_time: float = -1.0 + comp_time: float = -1.0 + initial_exposed: float = -1.0 + final_exposed: float = -1.0 + overlap_info: str = "None" @property def improvement(self): @@ -193,7 +196,7 @@ def contains_gemm_like(snode: BaseSchedulerNode) -> bool: return is_gemm_like(snode.node) -def _temp_group_visit_leaves(snode, fn): +def _temp_group_visit_leaves(snode: BaseSchedulerNode, fn): from torch._inductor.scheduler import GroupedSchedulerNode if isinstance(snode, GroupedSchedulerNode) and snode.temp_grouping: @@ -203,6 +206,126 @@ def _temp_group_visit_leaves(snode, fn): fn(snode) +def wait_exposed_communication_time( + snodes_to_wait: list[BaseSchedulerNode], runtimes: dict[BaseSchedulerNode, float] +) -> tuple[float, float, str]: + """ + Calculate exposed communication time for a wait operation by finding its corresponding + collective and accumulating overlapping compute time between them. + + The Wait node must be the last in snodes_to_wait. + Compute time between corresponding Collective and Wait is accumulated. + If there is another pair of Collective and Wait inside, + Only compute before first such Wait' is considered as overlapping. + + Multiple process groups are not modeled so far. + """ + wait_snode = snodes_to_wait[-1] + assert is_wait(wait_snode.node) + assert len(snodes_to_wait) > 1 + idx = len(snodes_to_wait) - 2 + comm_time = 0.0 + comp_time = 0.0 + overlap_info = "" + waits_found = [] + for i in range(idx, -1, -1): + c = snodes_to_wait[i] + if contains_wait(c): + waits_found.append(c) + if contains_collective(c): + if is_corresponding_collective_wait(c, wait_snode): + comm_time = runtimes[c] + overlap_info += f"->C[{c.get_name()}]" + break + + if not contains_async_collective(c): + # Sync Collective + comp_time = 0.0 + continue + else: + for w in waits_found: + if is_corresponding_collective_wait(c, w): + # Similar to Sync Collective + # If after our Collective exist another Collective-Wait, + # All compute after it will not be overlapping + comp_time = 0.0 + continue + + comp_time_before = comp_time + + def accumulate_time(_snode: BaseSchedulerNode) -> None: + nonlocal comp_time + comp_time += runtimes[_snode] + + _temp_group_visit_leaves(c, accumulate_time) + comp_time_after = comp_time + overlap_info += f"+{c.get_name()}[{comp_time_after - comp_time_before}]" + + return comm_time, comp_time, overlap_info + + +def coll_exposed_communication_time( + snodes: list[BaseSchedulerNode], + runtimes: dict[BaseSchedulerNode, float], +) -> tuple[float, float, str]: + """ + Calculate exposed communication time for a collective operation by finding its corresponding + wait and accumulating compute time that can overlap with communication. + + The Collective node must be the first in snodes. + Compute time between corresponding Collective and Wait is accumulated. + If there is another pair of Collective and Wait inside, + Only compute before first such Wait' is considered as overlapping. + + Multiple process groups are not modeled so far. + """ + collective_snode = snodes[0] + comm_time = runtimes[collective_snode] + comp_time = 0.0 + collective_outs: OrderedSet[str] = OrderedSet( + o.get_name() for o in collective_snode.get_outputs() + ) + overlap_info = "" + collectives_found: list[BaseSchedulerNode] = [] + for snode in snodes[1:]: + # We may have some ops without Wait, + # e.g. DTensor torch.ops._dtensor.shard_dim_alltoall + unmet_deps = OrderedSet( + d.name for d in snode.unmet_dependencies if not _is_fake_dep(d) + ) + + if unmet_deps & collective_outs: + overlap_info += f"->W[{snode.get_name()}]" + break + + if contains_collective(snode): + if not contains_async_collective(snode): + break + else: + collectives_found.append(snode) + continue + if contains_wait(snode): + has_wait_for_collectives_found = False + for coll in collectives_found: + if is_corresponding_collective_wait(collective_snode, snode): + has_wait_for_collectives_found = True + break + if has_wait_for_collectives_found: + # Any compute after not overlapping original Collective + break + + comp_time_before = comp_time + + def accumulate_time(_snode: BaseSchedulerNode) -> None: + nonlocal comp_time + comp_time += runtimes[_snode] + + _temp_group_visit_leaves(snode, accumulate_time) + comp_time_after = comp_time + overlap_info += f"+{snode.get_name()}[{comp_time_after - comp_time_before}]" + return comm_time, comp_time, overlap_info + + def _group_name(snode, with_bufs=False) -> str: ret = "" for n in snode.snodes: @@ -258,369 +381,361 @@ def _initialize_double_linked_list( return _prev, _next, _head -def _reorder_communication_preserving_peak_memory_internal( - snodes: list[BaseSchedulerNode], -) -> tuple[list[BaseSchedulerNode], dict[BaseSchedulerNode, ReorderInfo]]: +def is_corresponding_collective_wait( + collective_snode: BaseSchedulerNode, wait_snode: BaseSchedulerNode +) -> bool: """ - Internal testing helper that also returns debug info. - Returns: - - reordered snodes list - - dict {snode: ReorderInfo} + Check if a wait node corresponds to a given collective node by verifying if the wait + depends on outputs from the collective. """ - has_collectives = False - for snode in snodes: - if contains_collective(snode): - has_collectives = True - break - if not has_collectives: - return snodes, {} + collective_outs = OrderedSet(o.get_name() for o in collective_snode.get_outputs()) + unmet_deps = OrderedSet(d.name for d in wait_snode.unmet_dependencies) + return bool(unmet_deps & collective_outs) - from torch._inductor.scheduler import GroupedSchedulerNode - original_snodes_num = len(snodes) - # heuristic to avoid degenerating to quadratic time - graph_inputs: OrderedSet[str] = OrderedSet(V.graph.graph_inputs.keys()) - graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names()) - ( - peak_memory, - _curr_memory, - snodes_allocfree, - buf_to_snode_last_use, - name_to_freeable_input_buf, - ) = _initialize_memory_tracking(snodes, graph_inputs, graph_outputs) - runtimes: dict[BaseSchedulerNode, float] = { - snode: estimate_op_runtime(snode) for snode in snodes - } - # debug stats - stats: dict[BaseSchedulerNode, ReorderInfo] = {} +def _op_runtime_estimate_mult(snode): + # Apply multipliers for faster experimentation. + # TODO(ivankobzarev): Remove after confirmation that runtime estimations are correct. + if contains_collective(snode): + return config_comms.reorder_sink_runtime_estimations_comm_mult - def exposed_communication_time( - collective_snode: BaseSchedulerNode, remaining_snodes: list[BaseSchedulerNode] - ) -> float: - # assumes a linear schedule and computes the overlap of the collective with the remaining nodes - comm_time = estimate_op_runtime(collective_snode) - compute_time = 0.0 - for snode in remaining_snodes: - if contains_collective(snode): - continue - if contains_wait(snode): - # TODO - if the wait is for a collective that started before this collective or on another stream, - # we can ignore it. Otherwise, it's the end of the road for overlap opportunities - break + return config_comms.reorder_sink_runtime_estimations_non_comm_mult - def accumulate_time(_snode: BaseSchedulerNode) -> None: - nonlocal compute_time - compute_time += runtimes[_snode] - _temp_group_visit_leaves(snode, accumulate_time) - return max(0, comm_time - compute_time) +def is_async_collective(snode): + """ + Filtering out ops that contain Collective and Wait inside and considered as Collectives. + See contains_collective function. + If the op contains Wait inside - consider as Synchronous compute. + """ + if python_kernel_name := getattr(snode.node, "python_kernel_name", None): + if "torch.ops._dtensor.shard_dim_alltoall.default" in python_kernel_name: + return False - total_moves = 0 + return True - _prev, _next, _head = _initialize_double_linked_list(snodes) - def _group_nodes( - head: Optional[BaseSchedulerNode], tail: Optional[BaseSchedulerNode] - ) -> list[BaseSchedulerNode]: - ret = [] - n = head - while True: - if n is not None: - ret.append(n) - if n == tail: - break - n = _next[n] # type: ignore[index] - return ret +def contains_async_collective(snode): + return contains_collective(snode, is_async_collective) - def _perform_double_linked_list_swap(candidate, group_head, group_tail): - # swap (candidate, group_head...group_tail) - # Before: - # candidate_prev -0-> candidate -1-> group_head...group_tail -2-> group_tail_next - # After: - # candidate_prev -0-> group_head...group_tail -1-> candidate -2-> group_tail_next - # 0 - candidate_prev = _prev[candidate] - if candidate_prev: - _next[candidate_prev] = group_head - _prev[group_head] = candidate_prev - - # 2 - group_tail_next = _next[group_tail] - if group_tail_next: - _prev[group_tail_next] = candidate - _next[candidate] = group_tail_next - - # 1 - _prev[candidate] = group_tail - _next[group_tail] = candidate - nonlocal _head - if _head == candidate: - _head = group_head +def _group_nodes_from_linked_list( + head: Optional[BaseSchedulerNode], + tail: Optional[BaseSchedulerNode], + next_dict: dict[BaseSchedulerNode, Optional[BaseSchedulerNode]], +) -> list[BaseSchedulerNode]: + """ + Traverse doubly-linked list from head to tail and return nodes as a list. - def _calculate_potential_peak_memory( - candidate, group_ns, group_n_to_bufs_after_swap_dealloc_by_candidate - ): - # Caching calculations of memory for group nodes and candidate, - # to apply without recalculation after swap. - _post_alloc_update: dict[BaseSchedulerNode, int] = {} - potential_peak: int = 0 - if not group_n_to_bufs_after_swap_dealloc_by_candidate: - # Not accounting for buffers last use change - potential_peak = max( - group_peak_memory - candidate_delta_mem, - _curr_memory[group_tail][1] - - candidate_delta_mem - + candidate_allocfree.size_alloc, - ) - return potential_peak, _post_alloc_update + Args: + head: Starting node of the segment + tail: Ending node of the segment (inclusive) + next_dict: Dictionary mapping each node to its next node - # If candidate will be after group, the starting memory level of group nodes - # changes to the -(candidate.size_alloc - candidate.size_free) - mem_after_reorder_delta: int = -candidate_delta_mem - for gn in gns: - gn_post_alloc_mem = _curr_memory[gn][0] + mem_after_reorder_delta - _post_alloc_update[gn] = gn_post_alloc_mem - potential_peak = max(potential_peak, gn_post_alloc_mem) + Returns: + List of nodes from head to tail (inclusive) + """ + ret = [] + n = head + while True: + if n is not None: + ret.append(n) + if n == tail: + break + n = next_dict[n] # type: ignore[index] + return ret - bufs = group_n_to_bufs_after_swap_dealloc_by_candidate.get(gn, None) - if bufs is not None: - for buf in bufs: - # Candidate will deallocate those buffers - mem_after_reorder_delta += buf.mpi_buffer.size_free - candidate_mem_post_alloc = ( - _curr_memory[group_tail][1] - + mem_after_reorder_delta - + candidate_allocfree.size_alloc +def _perform_double_linked_list_swap( + candidate: BaseSchedulerNode, + group_head: BaseSchedulerNode, + group_tail: BaseSchedulerNode, + prev_dict: dict[BaseSchedulerNode, Optional[BaseSchedulerNode]], + next_dict: dict[BaseSchedulerNode, Optional[BaseSchedulerNode]], + head: BaseSchedulerNode, +) -> BaseSchedulerNode: + """ + Swap positions of candidate and group in doubly-linked list. + + Transforms: + candidate_prev -> candidate -> group_head...group_tail -> group_tail_next + Into: + candidate_prev -> group_head...group_tail -> candidate -> group_tail_next + + Args: + candidate: Node to swap with group + group_head: First node of group + group_tail: Last node of group + prev_dict: Dictionary mapping nodes to their previous nodes + next_dict: Dictionary mapping nodes to their next nodes + head: Current head of the linked list + + Returns: + New head of the linked list (may change if candidate was the head) + """ + # 0: Update candidate's previous node + candidate_prev = prev_dict[candidate] + if candidate_prev: + next_dict[candidate_prev] = group_head + prev_dict[group_head] = candidate_prev + + # 2: Update group_tail's next node + group_tail_next = next_dict[group_tail] + if group_tail_next: + prev_dict[group_tail_next] = candidate + next_dict[candidate] = group_tail_next + + # 1: Link group_tail to candidate + prev_dict[candidate] = group_tail + next_dict[group_tail] = candidate + + # Update head if candidate was the head + if head == candidate: + return group_head + return head + + +def _calculate_potential_peak_memory_reorder( + candidate: BaseSchedulerNode, + gns: list[BaseSchedulerNode], + group_tail: BaseSchedulerNode, + group_peak_memory: int, + candidate_delta_mem: int, + candidate_allocfree: SNodeMemory, + group_n_to_bufs_after_swap_dealloc_by_candidate: dict, + curr_memory: dict, +) -> tuple[int, dict[BaseSchedulerNode, int]]: + """ + Calculate potential peak memory after swapping candidate with group (reorder version). + + Computes new memory levels for all affected nodes and returns the potential + peak memory along with cached post-allocation memory values for each node. + + Args: + candidate: Node being moved + gns: Group nodes + group_tail: Last node of group + group_peak_memory: Current peak memory within the group + candidate_delta_mem: Net memory change from candidate (alloc - free) + candidate_allocfree: Candidate's allocation/free info + group_n_to_bufs_after_swap_dealloc_by_candidate: Buffers whose deallocation moves to candidate + curr_memory: Current memory state dict + + Returns: + Tuple of (potential_peak_memory, post_alloc_update_dict) + """ + # Caching calculations of memory for group nodes and candidate, + # to apply without recalculation after swap. + _post_alloc_update: dict[BaseSchedulerNode, int] = {} + potential_peak: int = 0 + if not group_n_to_bufs_after_swap_dealloc_by_candidate: + # Not accounting for buffers last use change + potential_peak = max( + group_peak_memory - candidate_delta_mem, + curr_memory[group_tail][1] + - candidate_delta_mem + + candidate_allocfree.size_alloc, ) - _post_alloc_update[candidate] = candidate_mem_post_alloc - potential_peak = max(potential_peak, candidate_mem_post_alloc) return potential_peak, _post_alloc_update - def _update_memory_tracking_after_swap( - candidate, - gns, - group_n_to_bufs_after_swap_dealloc_by_candidate, - _post_alloc_update, - ): - if not group_n_to_bufs_after_swap_dealloc_by_candidate: - for gn in gns: - cm = _curr_memory[gn] - _curr_memory[gn] = ( - cm[0] - candidate_delta_mem, - cm[1] - candidate_delta_mem, - ) - _candidate_post_alloc_mem = ( - _curr_memory[group_tail][1] + candidate_allocfree.size_alloc - ) - _candidate_post_free_mem = ( - _candidate_post_alloc_mem - candidate_allocfree.size_free - ) - _curr_memory[candidate] = ( - _candidate_post_alloc_mem, - _candidate_post_free_mem, - ) - return + # If candidate will be after group, the starting memory level of group nodes + # changes to the -(candidate.size_alloc - candidate.size_free) + mem_after_reorder_delta: int = -candidate_delta_mem + for gn in gns: + gn_post_alloc_mem = curr_memory[gn][0] + mem_after_reorder_delta + _post_alloc_update[gn] = gn_post_alloc_mem + potential_peak = max(potential_peak, gn_post_alloc_mem) - # Candidate becomes last use of some bufs - for bufs in group_n_to_bufs_after_swap_dealloc_by_candidate.values(): + bufs = group_n_to_bufs_after_swap_dealloc_by_candidate.get(gn) + if bufs is not None: for buf in bufs: - buf_to_snode_last_use[buf] = candidate - - size_free_to_move_to_candidate_sum: int = 0 - for n in gns: - _gn_post_alloc_mem: int = _post_alloc_update[n] - size_free_to_move_to_candidate: int = sum( - buf.mpi_buffer.size_free - for buf in group_n_to_bufs_after_swap_dealloc_by_candidate[n] - ) - size_free_to_move_to_candidate_sum += size_free_to_move_to_candidate - # group node does not deallocate this after swap - snodes_allocfree[n].size_free -= size_free_to_move_to_candidate - gn_post_free_mem: int = _gn_post_alloc_mem - snodes_allocfree[n].size_free - _curr_memory[n] = (_gn_post_alloc_mem, gn_post_free_mem) - _candidate_post_alloc_mem = _post_alloc_update[candidate] - snodes_allocfree[candidate].size_free += size_free_to_move_to_candidate_sum - candidate_post_free_mem = ( - _candidate_post_alloc_mem - snodes_allocfree[candidate].size_free - ) - _curr_memory[candidate] = ( - _candidate_post_alloc_mem, - candidate_post_free_mem, - ) + # Candidate will deallocate those buffers + mem_after_reorder_delta += buf.mpi_buffer.size_free - debug_num_collectives_to_reorder: Optional[int] = ( - config.reorder_iterative_debug_limit_to_reorder + candidate_mem_post_alloc = ( + curr_memory[group_tail][1] + + mem_after_reorder_delta + + candidate_allocfree.size_alloc ) + _post_alloc_update[candidate] = candidate_mem_post_alloc + potential_peak = max(potential_peak, candidate_mem_post_alloc) + return potential_peak, _post_alloc_update + + +def _update_memory_tracking_after_swap_reorder( + candidate: BaseSchedulerNode, + gns: list[BaseSchedulerNode], + group_tail: BaseSchedulerNode, + candidate_delta_mem: int, + candidate_allocfree: SNodeMemory, + group_n_to_bufs_after_swap_dealloc_by_candidate: dict, + post_alloc_update: dict[BaseSchedulerNode, int], + curr_memory: dict, + buf_to_snode_last_use: dict, + snodes_allocfree: dict, +) -> None: + """ + Update memory tracking structures after swap (reorder version). - num_processed_collectives: int = 0 - curr = _head - debug_iterative_memory_recompute = config.reorder_iterative_debug_memory_recompute - iterative_recompute_error = False - - while _next[curr] is not None: - if iterative_recompute_error: - break - # pyrefly: ignore [bad-argument-type] - if contains_collective(curr): - if debug_num_collectives_to_reorder is not None and ( - num_processed_collectives >= debug_num_collectives_to_reorder - ): - break - num_processed_collectives += 1 + Updates curr_memory, buf_to_snode_last_use, and snodes_allocfree dictionaries + to reflect the new memory state after swapping candidate with group. - info = stats[curr] = ReorderInfo() - info.initial_exposed = info.final_exposed = exposed_communication_time( - curr, _group_nodes(_next[curr], None) + Args: + candidate: Node that was moved + gns: Group nodes + group_tail: Last node of group + candidate_delta_mem: Net memory change from candidate (alloc - free) + candidate_allocfree: Candidate's allocation/free info + group_n_to_bufs_after_swap_dealloc_by_candidate: Buffers whose deallocation moves to candidate + post_alloc_update: Cached post-allocation memory values + curr_memory: Current memory state dict (mutated) + buf_to_snode_last_use: Buffer to last-use node mapping (mutated) + snodes_allocfree: Node allocation/free info dict (mutated) + """ + if not group_n_to_bufs_after_swap_dealloc_by_candidate: + for gn in gns: + cm = curr_memory[gn] + curr_memory[gn] = ( + cm[0] - candidate_delta_mem, + cm[1] - candidate_delta_mem, ) + _candidate_post_alloc_mem = ( + curr_memory[group_tail][1] + candidate_allocfree.size_alloc + ) + _candidate_post_free_mem = ( + _candidate_post_alloc_mem - candidate_allocfree.size_free + ) + curr_memory[candidate] = ( + _candidate_post_alloc_mem, + _candidate_post_free_mem, + ) + return - candidate = _prev[curr] - group_head = curr - group_tail = curr - group_peak_memory = _curr_memory[curr][0] # post_alloc memory - while candidate is not None: - if contains_collective(candidate): - info.limiting_factor = "collective ordering" - break - - gns: list[BaseSchedulerNode] = _group_nodes(group_head, group_tail) - group = GroupedSchedulerNode( - curr.scheduler, - gns, - temp_grouping=True, - ) - - # We can have multiple deps with the same name. - # As we ignore WeakDep(is_fake=True) => - # filter them out first to avoid overwriting of real dep. - data_deps = { - d.name: d for d in group.unmet_dependencies if not _is_fake_dep(d) - } - - candidate_outs = candidate.get_outputs() - data_dep = None - for o in candidate_outs: - if d := data_deps.get(o.get_name(), None): - data_dep = d - break + # Candidate becomes last use of some bufs + for bufs in group_n_to_bufs_after_swap_dealloc_by_candidate.values(): + for buf in bufs: + buf_to_snode_last_use[buf] = candidate + + size_free_to_move_to_candidate_sum: int = 0 + for n in gns: + _gn_post_alloc_mem: int = post_alloc_update[n] + size_free_to_move_to_candidate: int = sum( + buf.mpi_buffer.size_free + for buf in group_n_to_bufs_after_swap_dealloc_by_candidate[n] + ) + size_free_to_move_to_candidate_sum += size_free_to_move_to_candidate + # group node does not deallocate this after swap + snodes_allocfree[n].size_free -= size_free_to_move_to_candidate + gn_post_free_mem: int = _gn_post_alloc_mem - snodes_allocfree[n].size_free + curr_memory[n] = (_gn_post_alloc_mem, gn_post_free_mem) + _candidate_post_alloc_mem = post_alloc_update[candidate] + snodes_allocfree[candidate].size_free += size_free_to_move_to_candidate_sum + candidate_post_free_mem = ( + _candidate_post_alloc_mem - snodes_allocfree[candidate].size_free + ) + curr_memory[candidate] = ( + _candidate_post_alloc_mem, + candidate_post_free_mem, + ) - if data_dep is not None: - def is_groupable( - candidate: BaseSchedulerNode, - ) -> tuple[bool, Optional[str]]: - # preserve ordering - if contains_collective(candidate): - return False, "contains_collective" +def _find_buffers_with_changed_last_use( + candidate: BaseSchedulerNode, + gns: list[BaseSchedulerNode], + buf_to_snode_last_use: dict, +) -> dict[BaseSchedulerNode, list[Union[FreeableInputBuffer, Any]]]: + """ + Find buffers whose last use will change after swapping candidate with group. - if contains_gemm_like(candidate): - return False, "contains_gemm_like" - return True, None + When we swap [candidate [group]] to [[group] candidate], some buffers that + were last used by a group node will now be last used by candidate instead. + This affects memory deallocation timing. - is_groupable_result, grouping_reason = is_groupable(candidate) - if is_groupable_result: - group_head = candidate - group_peak_memory = max( - group_peak_memory, _curr_memory[candidate][0] - ) - info.grouped += 1 - info.grouped_info = _group_names(gns) - candidate = _prev[candidate] - continue - else: - msg = ( - f"data dependency {data_dep}(dep_names:{list(data_deps.keys())})" - f"\n candidate:{candidate.get_name()}(outs:{[candidate.get_buffer_names()]})" - f"dep on {_group_names(gns)}" - f"\n non_group_reason:{grouping_reason}" - ) - info.limiting_factor = msg - break + Args: + candidate: The node being moved + gns: Group nodes being swapped with candidate + buf_to_snode_last_use: Mapping of buffers to their current last-use nodes - candidate_allocfree: SNodeMemory = snodes_allocfree[candidate] - candidate_delta_mem: int = ( - candidate_allocfree.size_alloc - candidate_allocfree.size_free - ) - # candidate and one of group nodes are successors of the same buffer - # and last use of the buffer happen in group nodes. - # This last use deallocates it. - # If we swap [candidate [group]] to [[group] candidate], - # candidate becomes the last use - # and deallocated this buffer instead of group node. - # we need to update size_free accordingly to group_node and candidate, - # and recalculate post_alloc, post_free for them. - # - # Buf that changes its last use snode, - # after swap will be deallocated only by candidate, - # while before it was deallocated by group node. - group_n_to_bufs_after_swap_dealloc_by_candidate: dict[ - BaseSchedulerNode, list[Union[FreeableInputBuffer, Any]] - ] = defaultdict(list) - for ( - buf, - snode_last_use, - ) in buf_to_snode_last_use.items(): - succ_nodes = buf.mpi_buffer.succ_nodes - if candidate not in succ_nodes: - continue + Returns: + Dict mapping group nodes to buffers that will change their last-use node + """ + group_n_to_bufs_after_swap_dealloc_by_candidate: dict[ + BaseSchedulerNode, list[Union[FreeableInputBuffer, Any]] + ] = defaultdict(list) + for ( + buf, + snode_last_use, + ) in buf_to_snode_last_use.items(): + succ_nodes = buf.mpi_buffer.succ_nodes + if candidate not in succ_nodes: + continue - if not any(gn == snode_last_use for gn in gns): - continue + if not any(gn == snode_last_use for gn in gns): + continue - group_n_to_bufs_after_swap_dealloc_by_candidate[ - snode_last_use - ].append(buf) + group_n_to_bufs_after_swap_dealloc_by_candidate[snode_last_use].append(buf) - potential_peak, _post_alloc_update = _calculate_potential_peak_memory( - candidate, gns, group_n_to_bufs_after_swap_dealloc_by_candidate - ) + return group_n_to_bufs_after_swap_dealloc_by_candidate - if potential_peak > peak_memory: - info.limiting_factor = ( - f"peak memory new:{potential_peak} vs base:{peak_memory}" - ) - break - info.moves += 1 - total_moves += 1 - _perform_double_linked_list_swap(candidate, group_head, group_tail) +def _is_node_groupable_for_reorder( + candidate: BaseSchedulerNode, +) -> tuple[bool, Optional[str]]: + """ + Check if a candidate node can be grouped with collective during reordering. - info.final_exposed = exposed_communication_time( - curr, _group_nodes(_next[curr], None) - ) + This pass processes collectives left to right, so we avoid grouping with + already-processed collectives based on configuration. - _update_memory_tracking_after_swap( - candidate, - gns, - group_n_to_bufs_after_swap_dealloc_by_candidate, - _post_alloc_update, - ) + Args: + candidate: Node to check for groupability - if debug_iterative_memory_recompute: - # Compare iteratively recomputed memory data - # with full run of estimate_peak_memory + Returns: + Tuple of (is_groupable, reason_if_not_groupable) + """ + # This pass processes collectives left to right, + # Do not group with processed collectives. + # Leaving config for experimentation in 2D + if not config_comms.reorder_iterative_group_with_collectives: + if contains_async_collective(candidate): + return ( + False, + f"candidate contains_collective {candidate.get_name()}", + ) + if not config_comms.reorder_iterative_use_runtime_estimations: + if contains_gemm_like(candidate): + return False, "contains_gemm_like" + return True, None + + +def _format_and_log_reordering_stats( + stats: dict[BaseSchedulerNode, ReorderInfo], + head: BaseSchedulerNode, + next_dict: dict[BaseSchedulerNode, Optional[BaseSchedulerNode]], + original_snodes_num: int, + peak_memory: int, + name_to_freeable_input_buf: dict, + graph_outputs: OrderedSet[str], +) -> list[BaseSchedulerNode]: + """ + Format reordering statistics, log them, and return final node list. - from .comms_debug import _debug_iterative_memory_recompute + Computes improvement metrics, creates a formatted table (using tabulate if + available), validates the reordered node count, recalculates peak memory, + and logs all information. - iterative_recompute_error = _debug_iterative_memory_recompute( - candidate, - gns, - _group_names(gns), - _group_nodes(_head, None), - name_to_freeable_input_buf, - graph_outputs, - peak_memory, - _curr_memory, - snodes_allocfree, - "reorder_communication_preserving_peak_memory", - group_n_to_bufs_after_swap_dealloc_by_candidate, - ) - if iterative_recompute_error: - break - candidate = _prev[group_head] - curr = _next[curr] # type: ignore[assignment] + Args: + stats: Per-node reordering statistics + head: Head of the reordered linked list + next_dict: Linked list next pointers + original_snodes_num: Original number of nodes (for validation) + peak_memory: Initial peak memory before reordering + name_to_freeable_input_buf: Buffer memory tracking info + graph_outputs: Graph output names + Returns: + Final reordered list of scheduler nodes + """ node_stats = stats improvement = {snode: node_stats[snode].improvement for snode in node_stats} total_improvement = sum([improvement[snode] for snode in improvement]) @@ -632,28 +747,35 @@ def is_groupable( ) headers = [ "Collective node", - "initial exposed", - "final exposed", - "improvement", + "comm_time(us)", + "comp_time(us)", + "initial exposed(us)", + "final exposed(us)", + "improvement(us)", "limiting factor", "moves", "grouped", "grouped_info", + "overlap_info", ] rows = [ [ node_summary(snode), - node_info.initial_exposed, - node_info.final_exposed, - node_info.improvement, + node_info.comm_time / 1e3, + node_info.comp_time / 1e3, + node_info.initial_exposed / 1e3, + node_info.final_exposed / 1e3, + node_info.improvement / 1e3, node_info.limiting_factor, node_info.moves, node_info.grouped, node_info.grouped_info, + node_info.overlap_info, ] for snode, node_info in node_stats.items() ] if importlib.util.find_spec("tabulate"): + # pyrefly: ignore[import-error] from tabulate import tabulate reorder_log_str += tabulate( @@ -667,7 +789,7 @@ def is_groupable( reorder_log_str += str(headers) + "\n" reorder_log_str += "\n".join(map(str, rows)) - new_snodes = _group_nodes(_head, None) + new_snodes = _group_nodes_from_linked_list(head, None, next_dict) assert len(new_snodes) == original_snodes_num new_peak_memory, _, _, _ = estimate_peak_memory_allocfree( new_snodes, name_to_freeable_input_buf, graph_outputs @@ -685,6 +807,334 @@ def is_groupable( payload_fn=lambda: reorder_log_str, ) + return new_snodes + + +def _reorder_communication_preserving_peak_memory_internal( + snodes: list[BaseSchedulerNode], +) -> tuple[list[BaseSchedulerNode], dict[BaseSchedulerNode, ReorderInfo]]: + """ + Internal testing helper that also returns debug info. + Returns: + - reordered snodes list + - dict {snode: ReorderInfo} + """ + has_collectives = False + for snode in snodes: + if contains_collective(snode): + has_collectives = True + break + if not has_collectives: + return snodes, {} + + from torch._inductor.scheduler import GroupedSchedulerNode + + original_snodes_num = len(snodes) + # heuristic to avoid degenerating to quadratic time + graph_inputs: OrderedSet[str] = OrderedSet(V.graph.graph_inputs.keys()) + graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names()) + ( + peak_memory, + _curr_memory, + snodes_allocfree, + buf_to_snode_last_use, + name_to_freeable_input_buf, + ) = _initialize_memory_tracking(snodes, graph_inputs, graph_outputs) + + runtimes: dict[BaseSchedulerNode, float] = { + snode: estimate_op_runtime(snode) * _op_runtime_estimate_mult(snode) + for snode in snodes + } + # debug stats + stats: dict[BaseSchedulerNode, ReorderInfo] = {} + + total_moves = 0 + + _prev, _next, _head = _initialize_double_linked_list(snodes) + + debug_num_collectives_to_reorder: Optional[int] = ( + config_comms.reorder_iterative_debug_limit_to_reorder + ) + + num_processed_collectives: int = 0 + curr: Optional[BaseSchedulerNode] = _head + debug_iterative_memory_recompute = ( + config_comms.reorder_iterative_debug_memory_recompute + ) + iterative_recompute_error = False + + while curr is not None and _next[curr] is not None: + _next_curr = _next[curr] + if iterative_recompute_error: + break + # pyrefly: ignore [bad-argument-type] + if not contains_async_collective(curr): + curr = _next_curr + continue + + if debug_num_collectives_to_reorder is not None and ( + num_processed_collectives >= debug_num_collectives_to_reorder + ): + break + num_processed_collectives += 1 + + info = stats[curr] = ReorderInfo() + comm_time, comp_time, overlap_info = coll_exposed_communication_time( + _group_nodes_from_linked_list(curr, None, _next), runtimes + ) + info.comm_time = comm_time + info.comp_time = comp_time + info.initial_exposed = info.final_exposed = comm_time - comp_time + info.overlap_info = overlap_info + + candidate = _prev[curr] + group_head = curr + group_tail = curr + group_waits = {} + group_runtime = 0.0 + group_peak_memory = _curr_memory[curr][0] # post_alloc memory + + while candidate is not None: + if config_comms.reorder_iterative_use_runtime_estimations and ( + info.final_exposed + < -config_comms.reorder_iterative_extra_comm_comp_overlap + * info.comm_time + ): + info.limiting_factor = "unexposed by runtime estimations" + break + + if ( + not config_comms.reorder_iterative_unsafe_collectives_reorder + and contains_collective(candidate) + ): + info.limiting_factor = "collective ordering" + break + + gns: list[BaseSchedulerNode] = _group_nodes_from_linked_list( + group_head, group_tail, _next + ) + group = GroupedSchedulerNode( + curr.scheduler, + gns, + temp_grouping=True, + ) + + # We can have multiple deps with the same name. + # As we ignore WeakDep(is_fake=True) => + # filter them out first to avoid overwriting of real dep. + data_deps = { + d.name: d for d in group.unmet_dependencies if not _is_fake_dep(d) + } + + candidate_outs = candidate.get_outputs() + data_dep = None + for o in candidate_outs: + if d := data_deps.get(o.get_name(), None): + data_dep = d + break + + if data_dep is not None: + is_groupable_result, grouping_reason = _is_node_groupable_for_reorder( + candidate + ) + if is_groupable_result: + group_head = candidate + # pyrefly: ignore[unbound-name] + if config_comms.reorder_iterative_use_runtime_estimations: + if contains_wait(candidate): + comm_time, comp_time, _ = wait_exposed_communication_time( + _group_nodes_from_linked_list(_head, candidate, _next), + runtimes, + ) + group_waits[candidate] = comm_time, comp_time + if not contains_async_collective(candidate): + group_runtime += runtimes[candidate] + + group_peak_memory = max( + group_peak_memory, _curr_memory[candidate][0] + ) + info.grouped += 1 + info.grouped_info = _group_names(gns) + candidate = _prev[candidate] + continue + else: + msg = ( + f"data dependency {data_dep}(dep_names:{list(data_deps.keys())})" + f"\n candidate:{candidate.get_name()}(outs:{[candidate.get_buffer_names()]})" + f"dep on {_group_names(gns)}" + f"\n non_group_reason:{grouping_reason}" + ) + info.limiting_factor = msg + break + + # pyrefly: ignore[unbound-name] + if config_comms.reorder_iterative_use_runtime_estimations: + # Check if candidate has sync runtime + if not contains_async_collective(candidate): + c_runtime = runtimes[candidate] + + if c_runtime > 0 and len(group_waits) > 0: + # pyrefly: ignore[no-matching-overload] + exposed_before = max(0, info.comm_time - info.comp_time) + # pyrefly: ignore[no-matching-overload] + exposed_after = max( + 0, info.comm_time - info.comp_time - c_runtime + ) + exposed_delta = exposed_after - exposed_before + for gw_comm_time, gw_comp_time in group_waits.values(): + gw_exposed_before = max(0, gw_comm_time - gw_comp_time) + gw_exposed_after = max( + 0, gw_comm_time - gw_comp_time + c_runtime + ) + + exposed_delta += gw_exposed_after - gw_exposed_before + + if exposed_delta > 0: + info.limiting_factor = ( + f"candidate has compute {c_runtime}," + f" group contains waits, total_exposed_delta {exposed_delta}" + ) + break + else: + # Update all group_colls comm_time, comp_time + for gw, ( + gw_comm_time, + gw_comp_time, + ) in group_waits.items(): + group_waits[gw] = ( + gw_comm_time, + gw_comp_time - c_runtime, + ) + else: + # Candidate is async_collective + + # Unsafe collectives reordering + # Cj -> [...group_runtime..., Ci] -> Wj + # Checking that we are not increasing exposed time of Cj + if group_runtime > 0: + comm_time, comp_time, _ = coll_exposed_communication_time( + _group_nodes_from_linked_list(candidate, None, _next), + runtimes, + ) + # pyrefly: ignore[no-matching-overload] + exposed_before = max(0, comm_time - comp_time) + # pyrefly: ignore[no-matching-overload] + exposed_after = max(0, comm_time - comp_time + group_runtime) + exposed_delta = exposed_after - exposed_before + if exposed_delta > 0: + info.limiting_factor = ( + f"candidate {candidate.get_name()} is collective," + f" group_runtime:{group_runtime}," + f" exposed_delta:{exposed_delta} c_comm_time:{comm_time} c_comp_time:{comp_time}" + ) + break + + candidate_allocfree: SNodeMemory = snodes_allocfree[candidate] + candidate_delta_mem: int = ( + candidate_allocfree.size_alloc - candidate_allocfree.size_free + ) + # candidate and one of group nodes are successors of the same buffer + # and last use of the buffer happen in group nodes. + # This last use deallocates it. + # If we swap [candidate [group]] to [[group] candidate], + # candidate becomes the last use + # and deallocated this buffer instead of group node. + # we need to update size_free accordingly to group_node and candidate, + # and recalculate post_alloc, post_free for them. + # + # Buf that changes its last use snode, + # after swap will be deallocated only by candidate, + # while before it was deallocated by group node. + group_n_to_bufs_after_swap_dealloc_by_candidate = ( + _find_buffers_with_changed_last_use( + candidate, gns, buf_to_snode_last_use + ) + ) + + potential_peak, _post_alloc_update = ( + _calculate_potential_peak_memory_reorder( + candidate, + gns, + group_tail, + group_peak_memory, + candidate_delta_mem, + candidate_allocfree, + group_n_to_bufs_after_swap_dealloc_by_candidate, + _curr_memory, + ) + ) + + if ( + potential_peak - peak_memory + # pyrefly: ignore[unbound-name] + > peak_memory * config_comms.reorder_iterative_peak_memory_budget + ): + info.limiting_factor = ( + f"peak memory new:{potential_peak} vs base:{peak_memory}" + ) + break + info.moves += 1 + total_moves += 1 + + _head = _perform_double_linked_list_swap( + candidate, group_head, group_tail, _prev, _next, _head + ) + + comm_time, comp_time, overlap_info = coll_exposed_communication_time( + _group_nodes_from_linked_list(curr, None, _next), runtimes + ) + info.comm_time = comm_time + info.comp_time = comp_time + info.overlap_info = overlap_info + info.final_exposed = comm_time - comp_time + + _update_memory_tracking_after_swap_reorder( + candidate, + gns, + group_tail, + candidate_delta_mem, + candidate_allocfree, + group_n_to_bufs_after_swap_dealloc_by_candidate, + _post_alloc_update, + _curr_memory, + buf_to_snode_last_use, + snodes_allocfree, + ) + + if debug_iterative_memory_recompute: + # Compare iteratively recomputed memory data + # with full run of estimate_peak_memory + + from .comms_debug import _debug_iterative_memory_recompute + + iterative_recompute_error = _debug_iterative_memory_recompute( + candidate, + gns, + _group_names(gns), + _group_nodes_from_linked_list(_head, None, _next), + name_to_freeable_input_buf, + graph_outputs, + peak_memory, + _curr_memory, + snodes_allocfree, + "reorder_communication_preserving_peak_memory", + group_n_to_bufs_after_swap_dealloc_by_candidate, + ) + if iterative_recompute_error: + break + candidate = _prev[group_head] + curr = _next_curr + + new_snodes = _format_and_log_reordering_stats( + stats, + _head, + _next, + original_snodes_num, + peak_memory, + name_to_freeable_input_buf, + graph_outputs, + ) + return new_snodes, stats @@ -1012,9 +1462,11 @@ def _update_memory_tracking_after_swap( curr = snodes[-1] processed_waits = OrderedSet() # type: ignore[var-annotated] - debug_iterative_memory_recompute = config.reorder_iterative_debug_memory_recompute + debug_iterative_memory_recompute = ( + config_comms.reorder_iterative_debug_memory_recompute + ) debug_num_sink_waits_to_reorder: Optional[int] = ( - config.sink_waits_iterative_debug_limit_to_sink + config_comms.sink_waits_iterative_debug_limit_to_sink ) iterative_recompute_error = False @@ -1213,6 +1665,7 @@ def is_groupable(snode): ] log_str = "" if importlib.util.find_spec("tabulate"): + # pyrefly: ignore[import-error] from tabulate import tabulate log_str += tabulate( @@ -1224,7 +1677,7 @@ def is_groupable(snode): log_str += str(headers) + "\n" log_str += "\n".join(map(str, rows)) overlap_log.info(log_str) - new_snodes = _group_nodes(_head, None) + new_snodes = _group_nodes_from_linked_list(_head, None, _next) assert len(new_snodes) == original_snodes_num new_peak_memory, _, _, _ = estimate_peak_memory_allocfree( new_snodes, name_to_freeable_input_buf, graph_outputs @@ -1267,7 +1720,7 @@ def node_summary(snode): if isinstance(snode.node, (ir.ExternKernelOut, ir._CollectiveKernel)): outs_str = f"outs:{[o.get_name() for o in snode.get_outputs()]}" ins_str = f"ins:{[d.name for d in snode.unmet_dependencies]}" - detail = f" {snode.get_name()} ({snode.node.python_kernel_name})\n {outs_str}\n ({ins_str})" + detail = f" {snode.get_name()} ({snode.node.python_kernel_name})\n {outs_str}({ins_str})" layouts = [child.node.get_output_spec() for child in snode.get_nodes()] out_tensor_info = ",".join( [ diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index aaf7fbd2f7f54..2d9e180db54f5 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -379,6 +379,15 @@ def prologue_fusion_enabled() -> bool: # for built-in passes, use string name; for user-defined passes, pass in the function handle # WARNING: Inductor scheduler IR is at prototype stage and subject to change, # hence custom IR passes built on top of it might break in the future. +# +# See aten_distributed_optimizations, it is recommended way for distributed optimizations. +# +# Recommended configuration for reorder_for_compute_comm_overlap_passes: +# [ +# "reorder_communication_preserving_peak_memory", +# "sink_waits_iterative", +# "reorder_communication_preserving_peak_memory", +# ] reorder_for_compute_comm_overlap_passes: list[ Union[ str, @@ -387,11 +396,7 @@ def prologue_fusion_enabled() -> bool: list["torch._inductor.scheduler.BaseSchedulerNode"], ], ] -] = [ - "reorder_compute_for_overlap", - "sink_waits", - "raise_comms", -] +] = [] # Maximum number of positions to advance a given collective, unlimited by default reorder_prefetch_limit: Optional[int] = None @@ -407,16 +412,6 @@ def prologue_fusion_enabled() -> bool: # is zero, which turns off this optimization. size_threshold_for_succ_based_strategy: int = 0 -reorder_iterative_debug_memory_recompute: bool = False -reorder_iterative_debug_limit_to_reorder: Optional[int] = ( - None - if (env_str := os.getenv("PYTORCH_REORDER_COLLECTIVES_LIMIT")) is None - else int(env_str) -) -sink_waits_iterative_debug_limit_to_sink: Optional[int] = ( - # pyrefly: ignore [unbound-name] - None if (env_str := os.getenv("PYTORCH_SINK_WAITS_LIMIT")) is None else int(env_str) -) bucket_all_gathers_fx: Literal["none", "all", "only_fsdp"] = "none" # By default torch._inductor.fx_passes.bucketing.bucket_size_determinator is used diff --git a/torch/_inductor/config_comms.py b/torch/_inductor/config_comms.py index b5dbf424f35b4..51242c7f2cf5b 100644 --- a/torch/_inductor/config_comms.py +++ b/torch/_inductor/config_comms.py @@ -1,4 +1,6 @@ +import os import sys +from typing import Optional from torch.utils._config_module import install_config_module @@ -11,5 +13,50 @@ # decisions on different distributed ranks. runtime_estimations_align_across_all_distributed_ranks: bool = False +reorder_iterative_debug_memory_recompute: bool = False +reorder_iterative_debug_limit_to_reorder: Optional[int] = ( + None + # pyrefly: ignore[unbound-name] + if (env_str := os.getenv("PYTORCH_REORDER_COLLECTIVES_LIMIT")) is None + else int(env_str) +) +sink_waits_iterative_debug_limit_to_sink: Optional[int] = ( + # pyrefly: ignore[unbound-name] + None if (env_str := os.getenv("PYTORCH_SINK_WAITS_LIMIT")) is None else int(env_str) +) + + +# Should be used with config.runtime_estimations_mms_benchmark = True +reorder_iterative_use_runtime_estimations: bool = False +sink_iterative_use_runtime_estimations: bool = False + +# Broadcast runtime estimations doing real Collective operation between all ranks. +# If non-deterministic runtime estimations are used this must be used to make +# all ranks to do identical decisions and prevent global Collectives reordering, +# (that will result un NCCL hangs) +reorder_for_compute_comm_overlap_broadcast_runtime_estimations: bool = False + +# Block of Ratios to workaround imperfection of current runtime estimations +# for collectives and compute for different scenarios. +# Multiplier of collectives estimated durations +reorder_sink_runtime_estimations_comm_mult: float = 2.0 +# Multiplier of compute estimated durations +reorder_sink_runtime_estimations_non_comm_mult: float = 1.0 +# The reordering will stop to reorder +# when overlap_comp >= (1 + extra_overlap_ratio) * comm_time +# Allows to configure more aggressive overlap +reorder_iterative_extra_comm_comp_overlap: float = 0.5 + +# Allow reorder iterative pass to increase peak memory +# up to peak_memory_before_pass * (1 + budget) +reorder_iterative_peak_memory_budget: float = 0.2 + +# Experimental unsafe configuration that allows changing relative collectives order. +# Must be used with runtime_estimations_align_across_all_distributed_ranks = True +reorder_iterative_unsafe_collectives_reorder: bool = True + +# Allow group and move other collectives during reordering +reorder_iterative_group_with_collectives: bool = False + # adds patch, save_config, etc install_config_module(sys.modules[__name__]) diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 3f8652882af79..9579dbb3536e3 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -2813,13 +2813,16 @@ def is_wait(node: Optional[Union[IRNode, Operation]]) -> bool: return type(node) is ir._WaitKernel -def contains_collective(snode: BaseSchedulerNode) -> bool: +def contains_collective( + snode: BaseSchedulerNode, + filter_fn: Optional[Callable[[BaseSchedulerNode], bool]] = None, +) -> bool: from torch._inductor.scheduler import GroupedSchedulerNode if isinstance(snode, GroupedSchedulerNode): return any(contains_collective(x) for x in snode.snodes) - return is_collective(snode.node) + return is_collective(snode.node) and (filter_fn is None or filter_fn(snode)) def contains_wait(snode: BaseSchedulerNode) -> bool: From da2eb31b824820666445e3e232007f26eb825e28 Mon Sep 17 00:00:00 2001 From: Jessica Vandebon Date: Thu, 6 Nov 2025 15:43:45 +0000 Subject: [PATCH 126/130] [MTIA][PyTorch] Add mtia as native device for PyTorch tests (#167089) Summary: Add MTIA as a native device type in PyTorch. Test Plan: CI Reviewed By: PatriceVignola Differential Revision: D80111801 Pull Request resolved: https://github.com/pytorch/pytorch/pull/167089 Approved by: https://github.com/andyanwang, https://github.com/nautsimon, https://github.com/albanD --- torch/testing/_internal/common_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py index 0c26738c2f52f..00572f9691380 100644 --- a/torch/testing/_internal/common_utils.py +++ b/torch/testing/_internal/common_utils.py @@ -333,7 +333,7 @@ def maybe_load_json(filename): if os.getenv("DISABLED_TESTS_FILE", ""): disabled_tests_dict = maybe_load_json(os.getenv("DISABLED_TESTS_FILE", "")) -NATIVE_DEVICES = ('cpu', 'cuda', 'xpu', 'meta', 'mps', torch._C._get_privateuse1_backend_name()) +NATIVE_DEVICES = ('cpu', 'cuda', 'xpu', 'meta', 'mps', 'mtia', torch._C._get_privateuse1_backend_name()) # used for managing devices testing for torch profiler UTs # for now cpu, cuda and xpu are added for testing torch profiler UTs From 7b055a0103008b84292dba154448547af424c739 Mon Sep 17 00:00:00 2001 From: Lakshay Garg Date: Thu, 6 Nov 2025 16:10:16 +0000 Subject: [PATCH 127/130] Add per_process_memory_fraction to PYTORCH_CUDA_ALLOC_CONF (#161035) torch.cuda.memory.set_per_process_memory_fraction allows setting an upper bound on how much device memory is allocated. This PR exposes this setting to an environment variable. For example, PYTORCH_CUDA_ALLOC_CONF="per_process_memory_fraction:0.5" will limit the device memory to half of the available memory. Pull Request resolved: https://github.com/pytorch/pytorch/pull/161035 Approved by: https://github.com/ngimel, https://github.com/eqy --- c10/cuda/CUDAAllocatorConfig.cpp | 15 ++++++ c10/cuda/CUDAAllocatorConfig.h | 11 ++++- c10/cuda/CUDACachingAllocator.cpp | 67 +++++++++++++-------------- c10/cuda/CUDACachingAllocator.h | 1 + c10/cuda/CUDAMallocAsyncAllocator.cpp | 1 - docs/source/notes/cuda.rst | 4 ++ test/test_cuda.py | 46 ++++++++++++++++++ 7 files changed, 108 insertions(+), 37 deletions(-) diff --git a/c10/cuda/CUDAAllocatorConfig.cpp b/c10/cuda/CUDAAllocatorConfig.cpp index 3046259b48a3e..5414d838cd8c4 100644 --- a/c10/cuda/CUDAAllocatorConfig.cpp +++ b/c10/cuda/CUDAAllocatorConfig.cpp @@ -106,6 +106,9 @@ void CUDAAllocatorConfig::parseArgs(const std::string& env) { } else if (key == "graph_capture_record_stream_reuse") { i = parseGraphCaptureRecordStreamReuse(tokenizer, i); used_native_specific_option = true; + } else if (key == "per_process_memory_fraction") { + i = parsePerProcessMemoryFraction(tokenizer, i); + used_native_specific_option = true; } else { const auto& keys = c10::CachingAllocator::AcceleratorAllocatorConfig::getKeys(); @@ -146,6 +149,18 @@ size_t CUDAAllocatorConfig::parseGraphCaptureRecordStreamReuse( return i; } +double CUDAAllocatorConfig::parsePerProcessMemoryFraction( + const c10::CachingAllocator::ConfigTokenizer& tokenizer, + size_t i) { + tokenizer.checkToken(++i, ":"); + double val_env = tokenizer.toDouble(++i); + TORCH_CHECK_VALUE( + val_env >= 0.0 && val_env <= 1.0, + "per_process_memory_fraction is invalid, set it in [0.0, 1.0]"); + m_per_process_memory_fraction = val_env; + return i; +} + size_t CUDAAllocatorConfig::parsePinnedNumRegisterThreads( const c10::CachingAllocator::ConfigTokenizer& tokenizer, size_t i) { diff --git a/c10/cuda/CUDAAllocatorConfig.h b/c10/cuda/CUDAAllocatorConfig.h index d61f69467a2dc..4e6097a406bc2 100644 --- a/c10/cuda/CUDAAllocatorConfig.h +++ b/c10/cuda/CUDAAllocatorConfig.h @@ -61,6 +61,10 @@ class C10_CUDA_API CUDAAllocatorConfig { return instance().m_graph_capture_record_stream_reuse; } + static double per_process_memory_fraction() { + return instance().m_per_process_memory_fraction; + } + /** Pinned memory allocator settings */ static bool pinned_use_cuda_host_register() { return instance().m_pinned_use_cuda_host_register; @@ -152,7 +156,8 @@ class C10_CUDA_API CUDAAllocatorConfig { "pinned_use_hip_host_register", "graph_capture_record_stream_reuse", "pinned_reserve_segment_size_mb", - "pinned_num_register_threads"}; + "pinned_num_register_threads", + "per_process_memory_fraction"}; return keys; } @@ -177,6 +182,9 @@ class C10_CUDA_API CUDAAllocatorConfig { size_t parseGraphCaptureRecordStreamReuse( const c10::CachingAllocator::ConfigTokenizer& tokenizer, size_t i); + double parsePerProcessMemoryFraction( + const c10::CachingAllocator::ConfigTokenizer& tokenizer, + size_t i); std::atomic m_pinned_num_register_threads{1}; std::atomic m_pinned_reserve_segment_size_mb{0}; @@ -189,6 +197,7 @@ class C10_CUDA_API CUDAAllocatorConfig { std::atomic m_release_lock_on_cudamalloc{false}; std::atomic m_pinned_use_cuda_host_register{false}; std::atomic m_graph_capture_record_stream_reuse{false}; + std::atomic m_per_process_memory_fraction{1.0}; }; // Keep this for backwards compatibility diff --git a/c10/cuda/CUDACachingAllocator.cpp b/c10/cuda/CUDACachingAllocator.cpp index 091e580f95819..d66c3a16c0004 100644 --- a/c10/cuda/CUDACachingAllocator.cpp +++ b/c10/cuda/CUDACachingAllocator.cpp @@ -1100,7 +1100,7 @@ class RingBuffer { } // anonymous namespace } // namespace Native -static std::string reportProcessMemoryInfo(c10::DeviceIndex device) { +static std::string reportProcessMemoryInfo(const cudaDeviceProp& prop) { #ifdef PYTORCH_C10_DRIVER_API_SUPPORTED void* nvml_handle = DriverAPI::get_nvml_handle(); if (!nvml_handle) { @@ -1111,9 +1111,6 @@ static std::string reportProcessMemoryInfo(c10::DeviceIndex device) { return true; }(); - cudaDeviceProp prop{}; - C10_CUDA_CHECK(cudaGetDeviceProperties(&prop, device)); - // NOLINTNEXTLINE(*-c-arrays) char pci_id[80]; snprintf( @@ -1215,14 +1212,16 @@ class DeviceCachingAllocator { // record used memory. size_t total_allocated_memory = 0; - size_t allowed_memory_maximum = 0; + cudaDeviceProp device_prop; + + // maximum amount of memory that device is allowed to + // allocate. This is set iff memory fraction is less than 1 + std::optional allowed_memory_maximum{std::nullopt}; // all live expandable segments std::vector expandable_segments_; std::vector devices_with_peer_access_; - bool set_fraction = false; - bool record_history = false; std::atomic context_recorder_; @@ -1264,6 +1263,9 @@ class DeviceCachingAllocator { : device_id(id), large_blocks(/*small=*/false), small_blocks(/*small=*/true) { + C10_CUDA_CHECK(cudaGetDeviceProperties(&device_prop, id)); + + setMemoryFraction(CUDAAllocatorConfig::per_process_memory_fraction()); stats.max_split_size = static_cast(AcceleratorAllocatorConfig::max_split_size()); context_recorder_.store(nullptr); @@ -1399,7 +1401,7 @@ class DeviceCachingAllocator { if (!block_found) { // Do garbage collection if the flag is set. if (C10_UNLIKELY( - set_fraction && + allowed_memory_maximum.has_value() && AcceleratorAllocatorConfig::garbage_collection_threshold() > 0.0)) { garbage_collect_cached_blocks(context); @@ -1456,11 +1458,12 @@ class DeviceCachingAllocator { C10_CUDA_CHECK(cudaMemGetInfo(&device_free, &device_total)); std::string allowed_info; - if (set_fraction) { - allowed_info = format_size(allowed_memory_maximum) + " allowed; "; + if (allowed_memory_maximum.has_value()) { + allowed_info = + format_size(allowed_memory_maximum.value()) + " allowed; "; } - std::string proc_info = reportProcessMemoryInfo(device_id); + std::string proc_info = reportProcessMemoryInfo(device_prop); record_trace( TraceEntry::OOM, @@ -1518,7 +1521,7 @@ class DeviceCachingAllocator { for (const auto& obs : observers_local) { obs(device_id, alloc_size, - set_fraction ? allowed_memory_maximum : device_total, + allowed_memory_maximum.value_or(device_total), device_free); } @@ -2015,25 +2018,26 @@ class DeviceCachingAllocator { /** get memory fraction limiting maximum allocated memory **/ double getMemoryFraction() { - if (!set_fraction) { + if (!allowed_memory_maximum.has_value()) { return 1.0; } - size_t device_free = 0; - size_t device_total = 0; - C10_CUDA_CHECK(cudaMemGetInfo(&device_free, &device_total)); - return static_cast(allowed_memory_maximum) / - static_cast(device_total); + return static_cast(allowed_memory_maximum.value()) / + static_cast(device_prop.totalGlobalMem); } /** set memory fraction to limit maximum allocated memory **/ void setMemoryFraction(double fraction) { - size_t device_free = 0; - size_t device_total = 0; - C10_CUDA_CHECK(cudaMemGetInfo(&device_free, &device_total)); - allowed_memory_maximum = - static_cast(fraction * static_cast(device_total)); - set_fraction = true; + TORCH_CHECK( + 0 <= fraction && fraction <= 1, + "invalid fraction:", + fraction, + ". Please set within [0, 1]."); + allowed_memory_maximum = std::nullopt; + if (fraction < 1.0) { + allowed_memory_maximum = static_cast( + fraction * static_cast(device_prop.totalGlobalMem)); + } } /** get expandable segment size for all the streams on device **/ @@ -3010,7 +3014,7 @@ class DeviceCachingAllocator { BlockPool& pool = *p.pool; if (C10_UNLIKELY( - set_fraction && + allowed_memory_maximum.has_value() && AcceleratorAllocatorConfig::garbage_collection_threshold() > 0.0)) { // Track block reuse interval only when garbage collection is enabled. ++pool.get_free_blocks_call_count; @@ -3083,7 +3087,7 @@ class DeviceCachingAllocator { size_t gc_threshold = static_cast( AcceleratorAllocatorConfig::garbage_collection_threshold() * - static_cast(allowed_memory_maximum)); + static_cast(allowed_memory_maximum.value())); // No need to trigger GC yet if (total_allocated_memory <= gc_threshold) { return; @@ -3161,8 +3165,8 @@ class DeviceCachingAllocator { bool active_pool = p.pool->owner_PrivatePool && p.pool->owner_PrivatePool->allocator(); - if (set_fraction && - total_allocated_memory + size > allowed_memory_maximum) { + if (allowed_memory_maximum.has_value() && + total_allocated_memory + size > allowed_memory_maximum.value()) { p.err = cudaErrorMemoryAllocation; return false; // Temporarily disable checkpointing & cudagraphs internally @@ -3859,7 +3863,6 @@ class NativeCachingAllocator : public CUDAAllocator { "Allocator not initialized for device ", device, ": did you call init?"); - C10_CUDA_CHECK(c10::cuda::SetDevice(device)); return device_allocator[device]->getMemoryFraction(); } @@ -3869,12 +3872,6 @@ class NativeCachingAllocator : public CUDAAllocator { "Allocator not initialized for device ", device, ": did you call init?"); - TORCH_CHECK( - 0 <= fraction && fraction <= 1, - "invalid fraction:", - fraction, - ". Please set within [0, 1]."); - C10_CUDA_CHECK(c10::cuda::SetDevice(device)); device_allocator[device]->setMemoryFraction(fraction); } diff --git a/c10/cuda/CUDACachingAllocator.h b/c10/cuda/CUDACachingAllocator.h index fbe5dab18e0ae..8fee00dd621dc 100644 --- a/c10/cuda/CUDACachingAllocator.h +++ b/c10/cuda/CUDACachingAllocator.h @@ -2,6 +2,7 @@ #include #include +#include #include #include #include diff --git a/c10/cuda/CUDAMallocAsyncAllocator.cpp b/c10/cuda/CUDAMallocAsyncAllocator.cpp index 93bce51f1b9d0..674eb00035c50 100644 --- a/c10/cuda/CUDAMallocAsyncAllocator.cpp +++ b/c10/cuda/CUDAMallocAsyncAllocator.cpp @@ -427,7 +427,6 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator { // on the current device each later call sees. void init(int dev_count) override { static bool called = [](int dev_count) { - ; // Are there external guarantees init will be called before // any of the allocator's other functions? // std::lock_guard lk(general_mutex); diff --git a/docs/source/notes/cuda.rst b/docs/source/notes/cuda.rst index caabeb399c722..2c1a2e8cbb6be 100644 --- a/docs/source/notes/cuda.rst +++ b/docs/source/notes/cuda.rst @@ -619,6 +619,10 @@ Available options: and reallocate buffers across multiple streams, especially when the capture DAG frequently reaches joined frontiers. +* ``per_process_memory_fraction`` option limits the amount of memory that can be allocated + on all the CUDA devices to a specified fraction of the available memory. This is a value + between 0 and 1. Attempting to allocate more memory will raise an out of memory error. + .. note:: Some stats reported by the diff --git a/test/test_cuda.py b/test/test_cuda.py index 329261fba7d3a..dfbcdc1b40401 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -4626,6 +4626,52 @@ def check_output(script: str) -> str: rc = check_output(test_script) self.assertEqual(rc, "cudaMallocAsync") + def test_allocator_memory_fraction_setting(self): + def make_env(fraction): + env = os.environ.copy() + var = "PYTORCH_CUDA_ALLOC_CONF" + key = "per_process_memory_fraction" + value = [ + x + for x in env.get(var, "").split(",") + if len(x) > 0 and not x.startswith(f"{key}:") + ] + value.append(f"{key}:{fraction}") + env[var] = ",".join(value) + return env + + def run_test(value): + test_script = """\ +import os +import torch +device = torch._C._cuda_getDevice() +value = torch.cuda.memory.get_per_process_memory_fraction(device) +print(value, end="") + """ + return subprocess.run( + [sys.executable, "-c", test_script], + env=make_env(value), + text=True, + check=True, + capture_output=True, + ) + + self.assertEqual(run_test(0.0).stdout, "0.0") + self.assertEqual(run_test(0.5).stdout, "0.5") + self.assertEqual(run_test(1.0).stdout, "1.0") + + with self.assertRaises(subprocess.CalledProcessError) as e: + run_test(-0.1) + assert "per_process_memory_fraction is invalid" in e.exception.stderr, ( + e.exception.stderr + ) + + with self.assertRaises(subprocess.CalledProcessError) as e: + run_test(1.1) + assert "per_process_memory_fraction is invalid" in e.exception.stderr, ( + e.exception.stderr + ) + def test_cachingAllocator_raw_alloc(self): # Test that raw_alloc respects the setting that # activates/deactivates the caching allocator From cc477f600968a89a0e080ccaa6052277543bc84b Mon Sep 17 00:00:00 2001 From: IvanKobzarev Date: Wed, 5 Nov 2025 07:22:13 -0800 Subject: [PATCH 128/130] [inductor] Use runtime estimations in iterative sink waits pass (#167081) Split of https://github.com/pytorch/pytorch/pull/162469 to be under 2K reorder iterative part Pull Request resolved: https://github.com/pytorch/pytorch/pull/167081 Approved by: https://github.com/eellison ghstack dependencies: #167080 --- torch/_inductor/comms.py | 932 ++++++++++++++++++++++---------- torch/_inductor/config_comms.py | 9 + 2 files changed, 644 insertions(+), 297 deletions(-) diff --git a/torch/_inductor/comms.py b/torch/_inductor/comms.py index a4a4cac8e3ec2..29efcb4a44493 100644 --- a/torch/_inductor/comms.py +++ b/torch/_inductor/comms.py @@ -24,7 +24,6 @@ if TYPE_CHECKING: from .ir import IRNode, Operation - from .scheduler import SchedulerBuffer from .memory import ( estimate_peak_memory, @@ -1325,341 +1324,289 @@ class SinkWaitInfo: moves: int = 0 moves_info: str = "" limiting_factor: str = "None" + comm_time: float = -1.0 + comp_time: float = -1.0 + initial_exposed: float = -1.0 + final_exposed: float = -1.0 + overlap_info: str = "None" + @property + def improvement(self): + return self.initial_exposed - self.final_exposed -def _sink_waits_iterative_internal( - snodes: list[BaseSchedulerNode], -) -> tuple[list[BaseSchedulerNode], dict[BaseSchedulerNode, SinkWaitInfo]]: - from torch._inductor.scheduler import GroupedSchedulerNode - original_snodes_num = len(snodes) - if original_snodes_num == 0: - return snodes, {} - graph_inputs: OrderedSet[str] = OrderedSet(V.graph.graph_inputs.keys()) - graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names()) - ( - peak_memory, - _curr_memory, - snodes_allocfree, - buf_to_snode_last_use, - name_to_freeable_input_buf, - ) = _initialize_memory_tracking(snodes, graph_inputs, graph_outputs) +def _is_node_groupable_for_sink_waits( + candidate: BaseSchedulerNode, +) -> tuple[bool, Optional[str]]: + """ + Check if a candidate node can be grouped during sink_waits pass. - _prev, _next, _head = _initialize_double_linked_list(snodes) + Sink Waits traverses waits right to left, so we don't group with + processed waits on the right or with async collectives. - stats: dict[BaseSchedulerNode, SinkWaitInfo] = {} - - def _group_nodes( - head: Optional[BaseSchedulerNode], tail: Optional[BaseSchedulerNode] - ) -> list[BaseSchedulerNode]: - ret = [] - n = head - while True: - if n is not None: - ret.append(n) - if n == tail: - break - n = _next[n] # type: ignore[index] - return ret + Args: + candidate: Node to check for groupability - def _calculate_potential_peak_memory( - candidate, group_ns, group_n_to_bufs_after_swap_dealloc_instead_of_candidate - ): - pre_group_mem = ( - _curr_memory[group_head][0] - snodes_allocfree[group_head].size_alloc + Returns: + Tuple of (is_groupable, reason_if_not_groupable) + """ + # Sink Waits traverse Waits right to left, + # => we do not group with processed Waits on the right. + if contains_wait(candidate): + return False, f"candidate contains wait {candidate.get_name()}" + if contains_async_collective(candidate): + return ( + False, + f"candidate contains_async_collective {candidate.get_name()}", ) - # Stash memory tracing updates to not recompute them after swap - _post_alloc_update: dict[BaseSchedulerNode, int] = {} - _size_free_delta_update: dict[BaseSchedulerNode, int] = {} - - potential_peak = 0 - if not group_n_to_bufs_after_swap_dealloc_instead_of_candidate: - # Not accounting for buffers liveliness change - potential_peak = max( - group_peak_memory + candidate_delta_mem, - pre_group_mem + candidate_allocfree.size_alloc, + + # pyrefly: ignore[unbound-name] + if not config_comms.sink_iterative_use_runtime_estimations: + # Heuristics pre-use_runtime_estimations: + # TODO(ivankobzarev): Remove them after confirming, + # that using runtime estimations always give better results. + # We do not want to group with collectives to not reorder them forward. + if contains_collective(candidate): + return ( + False, + f"candidate contains collective {candidate.get_name()}", + ) + if contains_gemm_like(candidate): + return ( + False, + f"candidate contains gemm_like {candidate.get_name()}", ) - return potential_peak, _post_alloc_update, _size_free_delta_update + return True, None + + +def _update_memory_tracking_after_swap_sink_waits( + candidate: BaseSchedulerNode, + gns: list[BaseSchedulerNode], + candidate_delta_mem: int, + candidate_allocfree: SNodeMemory, + group_n_to_bufs_after_swap_dealloc_instead_of_candidate: dict, + post_alloc_update: dict[BaseSchedulerNode, int], + size_free_delta_update: dict[BaseSchedulerNode, int], + curr_memory: dict, + snodes_allocfree: dict, +) -> None: + """ + Update memory tracking structures after swap (sink_waits version). + Updates curr_memory and snodes_allocfree dictionaries to reflect the new + memory state after swapping candidate with group. + + Args: + candidate: Node that was moved + gns: Group nodes + candidate_delta_mem: Net memory change from candidate (alloc - free) + candidate_allocfree: Candidate's allocation/free info + group_n_to_bufs_after_swap_dealloc_instead_of_candidate: Buffers whose deallocation moves from candidate to group + post_alloc_update: Cached post-allocation memory values + size_free_delta_update: Cached size-free delta values + curr_memory: Current memory state dict (mutated) + snodes_allocfree: Node allocation/free info dict (mutated) + """ + group_head = gns[0] + pre_group_mem = curr_memory[group_head][0] - snodes_allocfree[group_head].size_alloc + if not group_n_to_bufs_after_swap_dealloc_instead_of_candidate: candidate_post_alloc = pre_group_mem + candidate_allocfree.size_alloc - _post_alloc_update[candidate] = candidate_post_alloc - potential_peak = candidate_post_alloc - candidate_size_free_to_move = sum( - buf.mpi_buffer.size_free # type: ignore[attr-defined] - for buf in itertools.chain.from_iterable( - group_n_to_bufs_after_swap_dealloc_instead_of_candidate.values() - ) + curr_memory[candidate] = ( + candidate_post_alloc, + candidate_post_alloc - candidate_allocfree.size_free, ) - _size_free_delta_update[candidate] = -candidate_size_free_to_move - delta_mem = candidate_delta_mem + candidate_size_free_to_move for gn in gns: - gn_post_alloc = _curr_memory[gn][0] + delta_mem - _post_alloc_update[gn] = gn_post_alloc - potential_peak = max(potential_peak, gn_post_alloc) - gn_size_free_to_add = 0 - if gn in group_n_to_bufs_after_swap_dealloc_instead_of_candidate: - bufs = group_n_to_bufs_after_swap_dealloc_instead_of_candidate[gn] - for buf in bufs: - gn_size_free_to_add += buf.mpi_buffer.size_free - _size_free_delta_update[gn] = gn_size_free_to_add - delta_mem -= gn_size_free_to_add - return potential_peak, _post_alloc_update, _size_free_delta_update + cm = curr_memory[gn] + curr_memory[gn] = ( + cm[0] + candidate_delta_mem, + cm[1] + candidate_delta_mem, + ) + return - def _perform_double_linked_list_swap(candidate, group_head, group_tail): - # group_head_prev -0-> candidate -1-> group_head...group_tail -2-> candidate_next - # 0: - group_head_prev = _prev[group_head] - if group_head_prev: - _next[group_head_prev] = candidate - _prev[candidate] = group_head_prev - - # 2: - candidate_next = _next[candidate] - if candidate_next: - _prev[candidate_next] = group_tail - _next[group_tail] = candidate_next - - # 1: - _prev[group_head] = candidate - _next[candidate] = group_head - nonlocal _head - if group_head == _head: - _head = candidate - - def _update_memory_tracking_after_swap( - candidate, - gns, - group_n_to_bufs_after_swap_dealloc_instead_of_candidate, - _post_alloc_update, - _size_free_delta_update, - ): - group_head = gns[0] - pre_group_mem = ( - _curr_memory[group_head][0] - snodes_allocfree[group_head].size_alloc + for n in [candidate, *gns]: + post_alloc = post_alloc_update[n] + snodes_allocfree[n].size_free += size_free_delta_update.get(n, 0) + curr_memory[n] = ( + post_alloc, + post_alloc - snodes_allocfree[n].size_free, ) - if not group_n_to_bufs_after_swap_dealloc_instead_of_candidate: - candidate_post_alloc = pre_group_mem + candidate_allocfree.size_alloc - _curr_memory[candidate] = ( - candidate_post_alloc, - candidate_post_alloc - candidate_allocfree.size_free, - ) - for gn in gns: - cm = _curr_memory[gn] - _curr_memory[gn] = ( - cm[0] + candidate_delta_mem, - cm[1] + candidate_delta_mem, - ) - return - - for n in [candidate, *gns]: - post_alloc = _post_alloc_update[n] - snodes_allocfree[n].size_free += _size_free_delta_update[n] - _curr_memory[n] = ( - post_alloc, - post_alloc - snodes_allocfree[n].size_free, - ) - curr = snodes[-1] - processed_waits = OrderedSet() # type: ignore[var-annotated] - debug_iterative_memory_recompute = ( - config_comms.reorder_iterative_debug_memory_recompute - ) - debug_num_sink_waits_to_reorder: Optional[int] = ( - config_comms.sink_waits_iterative_debug_limit_to_sink - ) +def _calculate_potential_peak_memory_sink_waits( + candidate: BaseSchedulerNode, + gns: list[BaseSchedulerNode], + group_head: BaseSchedulerNode, + group_peak_memory: int, + candidate_delta_mem: int, + candidate_allocfree: SNodeMemory, + group_n_to_bufs_after_swap_dealloc_instead_of_candidate: dict, + curr_memory: dict, + snodes_allocfree: dict, +) -> tuple[int, dict[BaseSchedulerNode, int], dict[BaseSchedulerNode, int]]: + """ + Calculate potential peak memory after swapping candidate with group (sink_waits version). - iterative_recompute_error = False + Computes new memory levels for all affected nodes and returns the potential + peak memory along with cached post-allocation and size-free delta values. - while _prev[curr] is not None: - if iterative_recompute_error: - break - if ( - debug_num_sink_waits_to_reorder is not None - and len(processed_waits) >= debug_num_sink_waits_to_reorder - ): - break + Args: + candidate: Node being moved + gns: Group nodes + group_head: First node of group + group_peak_memory: Current peak memory within the group + candidate_delta_mem: Net memory change from candidate (alloc - free) + candidate_allocfree: Candidate's allocation/free info + group_n_to_bufs_after_swap_dealloc_instead_of_candidate: Buffers whose deallocation moves from candidate to group + curr_memory: Current memory state dict + snodes_allocfree: Allocation/free info for all nodes - # pyrefly: ignore [bad-argument-type] - if contains_wait(curr) and curr not in processed_waits: - processed_waits.add(curr) - info = stats[curr] = SinkWaitInfo() - candidate = _next[curr] - wait_snode = curr - group_head = curr - group_tail = curr - group_peak_memory = _curr_memory[curr][0] - while candidate is not None: - if iterative_recompute_error: - break - gns: list[BaseSchedulerNode] = _group_nodes(group_head, group_tail) - group = GroupedSchedulerNode( - wait_snode.scheduler, - gns, - temp_grouping=True, - ) + Returns: + Tuple of (potential_peak_memory, post_alloc_update_dict, size_free_delta_update_dict) + """ + pre_group_mem = curr_memory[group_head][0] - snodes_allocfree[group_head].size_alloc + # Stash memory tracing updates to not recompute them after swap + _post_alloc_update: dict[BaseSchedulerNode, int] = {} + _size_free_delta_update: dict[BaseSchedulerNode, int] = {} - # We can have multiple deps with the same name. - # As we ignore WeakDep(is_fake=True) => - # filter them out first to avoid overwriting of real dep. - data_deps = { - d.name: d - for d in candidate.unmet_dependencies - if not _is_fake_dep(d) - } - - group_outs = group.get_outputs() - data_dep = None - for o in group_outs: - if d := data_deps.get(o.get_name(), None): - data_dep = d - break - # 1. If we have data_dep - we can not swap => trying to group - # 2. If swap candidate and current node both contain collectives => trying to group - if data_dep is not None or ( - both_contain_comms := ( - contains_collective(group) and contains_collective(candidate) - ) - ): + potential_peak = 0 + if not group_n_to_bufs_after_swap_dealloc_instead_of_candidate: + # Not accounting for buffers liveliness change + potential_peak = max( + group_peak_memory + candidate_delta_mem, + pre_group_mem + candidate_allocfree.size_alloc, + ) + return potential_peak, _post_alloc_update, _size_free_delta_update - def is_groupable(snode): - # We do not want to group with collectives to not reorder them forward. - if contains_collective(snode): - return ( - False, - f"candidate contains collective {snode.get_name()}", - ) - if contains_gemm_like(snode): - return ( - False, - f"candidate contains gemm_like {snode.get_name()}", - ) - return True, None + candidate_post_alloc = pre_group_mem + candidate_allocfree.size_alloc + _post_alloc_update[candidate] = candidate_post_alloc + potential_peak = candidate_post_alloc + candidate_size_free_to_move = sum( + buf.mpi_buffer.size_free # type: ignore[attr-defined] + for buf in itertools.chain.from_iterable( + group_n_to_bufs_after_swap_dealloc_instead_of_candidate.values() + ) + ) + _size_free_delta_update[candidate] = -candidate_size_free_to_move + delta_mem = candidate_delta_mem + candidate_size_free_to_move + for gn in gns: + gn_post_alloc = curr_memory[gn][0] + delta_mem + _post_alloc_update[gn] = gn_post_alloc + potential_peak = max(potential_peak, gn_post_alloc) + gn_size_free_to_add = 0 + if gn in group_n_to_bufs_after_swap_dealloc_instead_of_candidate: + bufs = group_n_to_bufs_after_swap_dealloc_instead_of_candidate[gn] + for buf in bufs: + gn_size_free_to_add += buf.mpi_buffer.size_free + _size_free_delta_update[gn] = gn_size_free_to_add + delta_mem -= gn_size_free_to_add + return potential_peak, _post_alloc_update, _size_free_delta_update - is_grp, grp_reason = is_groupable(candidate) - if is_grp: - group_tail = candidate - group_peak_memory = max( - group_peak_memory, _curr_memory[candidate][0] - ) - info.grouped += 1 - info.grouped_info = _group_names(gns) - candidate = _next[candidate] - continue - elif (data_dep is None) and both_contain_comms: - info.limiting_factor = ( - f"collective ordering {_group_names(gns)}" - f" with candidate:{candidate.get_name()}" - ) - break - else: - info.limiting_factor = ( - f"data dependency {data_dep}(dep_names:{list(data_deps.keys())})" - f"\n candidate:{candidate.get_name()}(os:{[candidate.get_buffer_names()]})" - f"dep on {gns}" - f"\n outs:{[o.get_name() for o in group_outs]}" - f"\n non_group_reason:{grp_reason}" - ) - break - candidate_allocfree: SNodeMemory = snodes_allocfree[candidate] - candidate_delta_mem = ( - candidate_allocfree.size_alloc - candidate_allocfree.size_free - ) - # [group] candidate -> candidate [group] - # Check for buffers with successors in group and candidate last successor - # - # Buf that changes its last use snode, - # It was deallocated by candidate, - # but after swap it will be deallocated by group node. - group_n_to_bufs_after_swap_dealloc_instead_of_candidate: dict[ - BaseSchedulerNode, list[Union[FreeableInputBuffer, SchedulerBuffer]] - ] = defaultdict(list) - for ( - buf, - snode_last_use, - ) in buf_to_snode_last_use.items(): - succ_nodes = buf.mpi_buffer.succ_nodes - if snode_last_use != candidate: # noqa: E711 - continue - # candidate is last use of buf - last_succ_gn = None - for gn in gns: - if gn in succ_nodes: - last_succ_gn = gn - if last_succ_gn is None: - continue +def _perform_double_linked_list_swap_sink_waits( + candidate: BaseSchedulerNode, + group_head: BaseSchedulerNode, + group_tail: BaseSchedulerNode, + prev_dict: dict[BaseSchedulerNode, Optional[BaseSchedulerNode]], + next_dict: dict[BaseSchedulerNode, Optional[BaseSchedulerNode]], + head: BaseSchedulerNode, +) -> BaseSchedulerNode: + """ + Swap positions of candidate and group in doubly-linked list (sink_waits version). - # gn has successors of buf that after potential swap will become - # last use of buf and start deallocating buf instead of candidate - group_n_to_bufs_after_swap_dealloc_instead_of_candidate[ - last_succ_gn - ].append(buf) - - potential_peak, _post_alloc_update, _size_free_delta_update = ( - _calculate_potential_peak_memory( - candidate, - gns, - group_n_to_bufs_after_swap_dealloc_instead_of_candidate, - ) - ) - if potential_peak > peak_memory: - info.limiting_factor = ( - f"peak memory new:{potential_peak} vs base:{peak_memory}" - ) - break + Transforms (moves candidate to the left): + group_head_prev -> group_head...group_tail -> candidate -> candidate_next + Into: + group_head_prev -> candidate -> group_head...group_tail -> candidate_next - info.moves += 1 - info.moves_info += f"+{candidate.get_name()}" + Args: + candidate: Node to swap with group + group_head: First node of group + group_tail: Last node of group + prev_dict: Dictionary mapping nodes to their previous nodes + next_dict: Dictionary mapping nodes to their next nodes + head: Current head of the linked list - _perform_double_linked_list_swap(candidate, group_head, group_tail) + Returns: + New head of the linked list (may change if group_head was the head) + """ + # 0: Update group_head's previous node + group_head_prev = prev_dict[group_head] + if group_head_prev: + next_dict[group_head_prev] = candidate + prev_dict[candidate] = group_head_prev + + # 2: Update candidate's next node + candidate_next = next_dict[candidate] + if candidate_next: + prev_dict[candidate_next] = group_tail + next_dict[group_tail] = candidate_next + + # 1: Link candidate to group_head + prev_dict[group_head] = candidate + next_dict[candidate] = group_head + + # Update head if group_head was the head + if group_head == head: + return candidate + return head - _update_memory_tracking_after_swap( - candidate, - gns, - group_n_to_bufs_after_swap_dealloc_instead_of_candidate, - _post_alloc_update, - _size_free_delta_update, - ) - if debug_iterative_memory_recompute: - from .comms_debug import _debug_iterative_memory_recompute - - iterative_recompute_error = _debug_iterative_memory_recompute( - candidate, - gns, - _group_names(gns), - _group_nodes(_head, None), - name_to_freeable_input_buf, - graph_outputs, - peak_memory, - _curr_memory, - snodes_allocfree, - "sink_waits_iterative", - group_n_to_bufs_after_swap_dealloc_instead_of_candidate, - ) - if iterative_recompute_error: - break +def _format_and_log_sink_waits_stats( + stats: dict[BaseSchedulerNode, SinkWaitInfo], + head: BaseSchedulerNode, + next_dict: dict[BaseSchedulerNode, Optional[BaseSchedulerNode]], + original_snodes_num: int, + peak_memory: int, + name_to_freeable_input_buf: dict, + graph_outputs: OrderedSet[str], +) -> list[BaseSchedulerNode]: + """ + Format sink_waits statistics, log them, and return final node list. + + Computes improvement metrics, creates a formatted table (using tabulate if + available), validates the reordered node count, recalculates peak memory, + and logs all information. - candidate = _next[group_tail] - curr = _prev[curr] # type: ignore[assignment] + Args: + stats: Per-node sink_waits statistics + head: Head of the reordered linked list + next_dict: Linked list next pointers + original_snodes_num: Original number of nodes (for validation) + peak_memory: Initial peak memory before reordering + name_to_freeable_input_buf: Buffer memory tracking info + graph_outputs: Graph output names + Returns: + Final reordered list of scheduler nodes + """ headers = [ "Wait node", + "comm_time(us)", + "comp_time(us)", + "initial exposed(us)", + "final exposed(us)", + "improvement(us)", + "limiting factor", "grouped", "grouped_info", "moves", "moves_info", - "limiting factor", + "overlap_info", ] rows = [ [ node_summary(snode), + info.comm_time / 1e3, + info.comp_time / 1e3, + info.initial_exposed / 1e3, + info.final_exposed / 1e3, + info.improvement / 1e3, + info.limiting_factor, info.grouped, info.grouped_info, info.moves, info.moves_info, - info.limiting_factor, + info.overlap_info, ] for snode, info in stats.items() ] @@ -1677,7 +1624,7 @@ def is_groupable(snode): log_str += str(headers) + "\n" log_str += "\n".join(map(str, rows)) overlap_log.info(log_str) - new_snodes = _group_nodes_from_linked_list(_head, None, _next) + new_snodes = _group_nodes_from_linked_list(head, None, next_dict) assert len(new_snodes) == original_snodes_num new_peak_memory, _, _, _ = estimate_peak_memory_allocfree( new_snodes, name_to_freeable_input_buf, graph_outputs @@ -1692,18 +1639,409 @@ def is_groupable(snode): }, payload_fn=lambda: log_str, ) - return new_snodes, stats + return new_snodes + + +def _find_buffers_with_changed_last_use_sink_waits( + candidate: BaseSchedulerNode, + gns: list[BaseSchedulerNode], + buf_to_snode_last_use: dict, +) -> dict[BaseSchedulerNode, list[Union[FreeableInputBuffer, Any]]]: + """ + Find buffers whose last use will change after swapping in sink_waits pass. + When we swap [group] candidate to candidate [group], some buffers that + were last used by candidate will now be last used by a group node instead. + This is the opposite direction from the reorder version. -def sink_waits_iterative( + Args: + candidate: The node being moved (currently last use) + gns: Group nodes being swapped with candidate + buf_to_snode_last_use: Mapping of buffers to their current last-use nodes + + Returns: + Dict mapping group nodes to buffers that will change their last-use node + """ + group_n_to_bufs_after_swap_dealloc_instead_of_candidate: dict[ + BaseSchedulerNode, list[Union[FreeableInputBuffer, Any]] + ] = defaultdict(list) + for ( + buf, + snode_last_use, + ) in buf_to_snode_last_use.items(): + succ_nodes = buf.mpi_buffer.succ_nodes + if snode_last_use != candidate: # noqa: E711 + continue + # candidate is last use of buf + last_succ_gn = None + for gn in gns: + if gn in succ_nodes: + last_succ_gn = gn + if last_succ_gn is None: + continue + + # gn has successors of buf that after potential swap will become + # last use of buf and start deallocating buf instead of candidate + group_n_to_bufs_after_swap_dealloc_instead_of_candidate[last_succ_gn].append( + buf + ) + + return group_n_to_bufs_after_swap_dealloc_instead_of_candidate + + +def _sink_waits_iterative_internal( snodes: list[BaseSchedulerNode], -) -> list[BaseSchedulerNode]: +) -> tuple[list[BaseSchedulerNode], dict[BaseSchedulerNode, SinkWaitInfo]]: + from torch._inductor.scheduler import GroupedSchedulerNode + + original_snodes_num = len(snodes) + if original_snodes_num == 0: + return snodes, {} + graph_inputs: OrderedSet[str] = OrderedSet(V.graph.graph_inputs.keys()) + graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names()) + ( + peak_memory, + _curr_memory, + snodes_allocfree, + buf_to_snode_last_use, + name_to_freeable_input_buf, + ) = _initialize_memory_tracking(snodes, graph_inputs, graph_outputs) + + _prev, _next, _head = _initialize_double_linked_list(snodes) + + stats: dict[BaseSchedulerNode, SinkWaitInfo] = {} + + runtimes: dict[BaseSchedulerNode, float] = { + snode: estimate_op_runtime(snode) * _op_runtime_estimate_mult(snode) + for snode in snodes + } + + curr: Optional[BaseSchedulerNode] = snodes[-1] + + processed_waits = OrderedSet() # type: ignore[var-annotated] + debug_iterative_memory_recompute = ( + config_comms.reorder_iterative_debug_memory_recompute + ) + debug_num_sink_waits_to_reorder: Optional[int] = ( + config_comms.sink_waits_iterative_debug_limit_to_sink + ) + + iterative_recompute_error = False + while curr is not None and _prev[curr] is not None: + _prev_curr = _prev[curr] + if iterative_recompute_error: + break + if ( + debug_num_sink_waits_to_reorder is not None + and len(processed_waits) >= debug_num_sink_waits_to_reorder + ): + break + + # pyrefly: ignore [bad-argument-type] + if not (contains_wait(curr) and curr not in processed_waits): + curr = _prev_curr + continue + + processed_waits.add(curr) + info = stats[curr] = SinkWaitInfo() + comm_time, comp_time, overlap_info = wait_exposed_communication_time( + _group_nodes_from_linked_list(_head, curr, _next), runtimes + ) + info.initial_exposed = info.final_exposed = comm_time - comp_time + info.comm_time = comm_time + info.comp_time = comp_time + info.overlap_info = overlap_info + + candidate = _next[curr] + wait_snode = curr + group_head = curr + group_tail = curr + group_colls = {} + group_runtime = 0.0 + group_peak_memory = _curr_memory[curr][0] + + while candidate is not None: + if config_comms.sink_iterative_use_runtime_estimations and ( + info.final_exposed + < -config_comms.sink_iterative_extra_comm_comp_overlap * info.comm_time + ): + info.limiting_factor = "unexposed by runtime estimations" + break + + gns: list[BaseSchedulerNode] = _group_nodes_from_linked_list( + group_head, group_tail, _next + ) + group = GroupedSchedulerNode( + wait_snode.scheduler, + gns, + temp_grouping=True, + ) + + # We can have multiple deps with the same name. + # As we ignore WeakDep(is_fake=True) => + # filter them out first to avoid overwriting of real dep. + data_deps = { + d.name: d for d in candidate.unmet_dependencies if not _is_fake_dep(d) + } + + group_outs = group.get_outputs() + data_dep = None + for o in group_outs: + if d := data_deps.get(o.get_name(), None): + data_dep = d + break + # Conservative sink wait, limiting by space before next collective. + # The global strategy is that bucketing should create space. + # For 2D we can experiment with allowing to sink Wait beyond non current group collective. + # pyrefly: ignore[unbound-name] + if not config_comms.sink_waits_iterative_swap_with_collectives: + if contains_async_collective(candidate): + info.limiting_factor = ( + f"candidate contains_async_collective {candidate.get_name()}" + ) + break + + # 1. If we have data_dep - we can not swap => trying to group + # 2. If swap candidate and current node both contain collectives => trying to group + if data_dep is not None or ( + both_contain_comms := ( + contains_collective(group) and contains_collective(candidate) + ) + ): + _is_groupable, groupable_reason = _is_node_groupable_for_sink_waits( + candidate + ) + if _is_groupable: + group_tail = candidate + if ( + # pyrefly: ignore[unbound-name] + config_comms.sink_iterative_use_runtime_estimations + and contains_collective(candidate) + ): + comm_time, comp_time, _ = coll_exposed_communication_time( + _group_nodes_from_linked_list(candidate, None, _next), + runtimes, + ) + group_colls[candidate] = (comm_time, comp_time) + if not contains_async_collective(candidate): + group_runtime += runtimes[candidate] + + group_peak_memory = max( + group_peak_memory, _curr_memory[candidate][0] + ) + info.grouped += 1 + info.grouped_info = _group_names(gns) + candidate = _next[candidate] + continue + elif data_dep is None: + if ( + # pyrefly: ignore[unbound-name] + not config_comms.sink_waits_iterative_unsafe_collectives_reorder + and both_contain_comms + ): + info.limiting_factor = ( + f"collective ordering {_group_names(gns)}" + f"\n with candidate:{candidate.get_name()}" + ) + break + else: + info.limiting_factor = ( + f"data dependency {data_dep}(dep_names:{list(data_deps.keys())})" + f"\n candidate:{candidate.get_name()}(os:{[candidate.get_buffer_names()]})" + f"\n dep on {_group_names(gns)}" + f"\n outs:{[o.get_name() for o in group_outs]}" + f"\n non_group_reason:{groupable_reason}" + ) + break + + # pyrefly: ignore[unbound-name] + if config_comms.sink_iterative_use_runtime_estimations: + if is_wait(candidate.node): + # Corresponding collective is before the group, + # Swap can increase exposed time of corresponding collective + comm_time, comp_time, _ = wait_exposed_communication_time( + _group_nodes_from_linked_list(_head, candidate, _next), runtimes + ) + # pyrefly: ignore[no-matching-overload] + exposed_before = max(0, comm_time - comp_time) + # pyrefly: ignore[no-matching-overload] + exposed_after = max(0, comm_time - comp_time + group_runtime) + # We do not know how much we can sink more after this swap, + # Just comparing advantage at the moment for now. + if exposed_after > exposed_before: + info.limiting_factor = ( + "candidate is wait," + f" exposed_before:{exposed_before} vs exposed_after:{exposed_after}" + ) + break + + # Check if candidate has sync runtime + if not contains_async_collective(candidate): + # If candidate has sync runtime, + # Waits of gorup_colls are on the right from group. + # Swap can increase their exposed time. + c_runtime = runtimes[candidate] + + if c_runtime > 0 and len(group_colls) > 0: + # Advantage for current Wait to do the Swap + # pyrefly: ignore[no-matching-overload] + exposed_delta = max( + 0, + info.comm_time - info.comp_time, + ) + # pyrefly: ignore[no-matching-overload] + -max(0, info.comm_time - info.comp_time - c_runtime) + for gc, (gc_comm_time, gc_comp_time) in group_colls.items(): + exposed_delta += max(0, gc_comm_time - gc_comp_time) - max( + 0, gc_comm_time - gc_comp_time + c_runtime + ) + if exposed_delta > 0: + info.limiting_factor = ( + f"candidate has compute {c_runtime}, group contains collectives," + f" total_exposed_delta {exposed_delta}" + ) + break + else: + # Update all group_colls comm_time, comp_time + for gc, ( + gc_comm_time, + gc_comp_time, + ) in group_colls.items(): + group_colls[gc] = ( + gc_comm_time, + gc_comp_time - c_runtime, + ) + + candidate_allocfree: SNodeMemory = snodes_allocfree[candidate] + candidate_delta_mem = ( + candidate_allocfree.size_alloc - candidate_allocfree.size_free + ) + # [group] candidate -> candidate [group] + # Check for buffers with successors in group and candidate last successor + # + # Buf that changes its last use snode, + # It was deallocated by candidate, + # but after swap it will be deallocated by group node. + group_n_to_bufs_after_swap_dealloc_instead_of_candidate = ( + _find_buffers_with_changed_last_use_sink_waits( + candidate, gns, buf_to_snode_last_use + ) + ) + + potential_peak, _post_alloc_update, _size_free_delta_update = ( + _calculate_potential_peak_memory_sink_waits( + candidate, + gns, + group_head, + group_peak_memory, + candidate_delta_mem, + candidate_allocfree, + group_n_to_bufs_after_swap_dealloc_instead_of_candidate, + _curr_memory, + snodes_allocfree, + ) + ) + if ( + potential_peak - peak_memory + # pyrefly: ignore[unbound-name] + > peak_memory * config_comms.sink_iterative_peak_memory_budget + ): + info.limiting_factor = ( + f"peak memory new:{potential_peak} vs base:{peak_memory}" + ) + break + + info.moves += 1 + info.moves_info += f"+{candidate.get_name()}" + + _head = _perform_double_linked_list_swap_sink_waits( + candidate, group_head, group_tail, _prev, _next, _head + ) + + comm_time, comp_time, overlap_info = wait_exposed_communication_time( + _group_nodes_from_linked_list(_head, curr, _next), runtimes + ) + info.comm_time = comm_time + info.comp_time = comp_time + info.final_exposed = comm_time - comp_time + info.overlap_info = overlap_info + + _update_memory_tracking_after_swap_sink_waits( + candidate, + gns, + candidate_delta_mem, + candidate_allocfree, + group_n_to_bufs_after_swap_dealloc_instead_of_candidate, + _post_alloc_update, + _size_free_delta_update, + _curr_memory, + snodes_allocfree, + ) + + if debug_iterative_memory_recompute: + from .comms_debug import _debug_iterative_memory_recompute + + iterative_recompute_error = _debug_iterative_memory_recompute( + candidate, + gns, + _group_names(gns), + _group_nodes_from_linked_list(_head, None, _next), + name_to_freeable_input_buf, + graph_outputs, + peak_memory, + _curr_memory, + snodes_allocfree, + "sink_waits_iterative", + group_n_to_bufs_after_swap_dealloc_instead_of_candidate, + ) + if iterative_recompute_error: + break + + candidate = _next[group_tail] + curr = _prev_curr + + new_snodes = _format_and_log_sink_waits_stats( + stats, + _head, + _next, + original_snodes_num, + peak_memory, + name_to_freeable_input_buf, + graph_outputs, + ) + + return new_snodes, stats + + +def sink_waits_iterative(snodes: list[BaseSchedulerNode]) -> list[BaseSchedulerNode]: + """ + Similarly to reorder_communication_preserving_peak_memory this pass will try to iteratively + push Wait nodes later, recomputing estimated peak memory before each swap, + and preventing peak memory regressions. + + Pass will be applied to every Wait node. If there are immediate dependencies with next node, + pass will try to group them together and on the next step to swap the group with next candidate. + + If _inductor.config_comms.sink_iterative_use_runtime_estimations is set True, + pass will stop reordering of Wait once corresponding Collective is unexposed, + based on runtime estimations. + + inductor.config_comms.sink_iterative_peak_memory_budget allows to tune how much pass + can regress initial peak memory. + E.g.: + sink_iterative_peak_memory_budget == 0.0 - No regression of initial peak memory is allowed + sink_iterative_peak_memory_budget == 0.2 - Pass can improve comm-compute overlap, sacrificing + 20% of initial peak memory value. + + inductor.config_comms.sink_iterative_extra_comm_comp_overlap config allows to more aggressively + sink waits, stopping only when overlap_compute >= (1 + extra_comm_comp_overlap) * comm_time + """ return _sink_waits_iterative_internal(snodes)[0] def estimate_op_runtime(snode: BaseSchedulerNode) -> float: """ - Returns estimated op runtime in nanoseconds (ns) + Returns estimated op runtime in milliseconds (ms) """ if config.estimate_op_runtime == "default": runtime = snode.get_estimated_runtime() diff --git a/torch/_inductor/config_comms.py b/torch/_inductor/config_comms.py index 51242c7f2cf5b..31f38b867dd5e 100644 --- a/torch/_inductor/config_comms.py +++ b/torch/_inductor/config_comms.py @@ -46,17 +46,26 @@ # when overlap_comp >= (1 + extra_overlap_ratio) * comm_time # Allows to configure more aggressive overlap reorder_iterative_extra_comm_comp_overlap: float = 0.5 +# The sink waits reordering will stop to reorder +# when overlap_comp >= (1 + extra_overlap_ratio) * comm_time +# Allows to configure more aggressive sink waits +sink_iterative_extra_comm_comp_overlap: float = 0.5 # Allow reorder iterative pass to increase peak memory # up to peak_memory_before_pass * (1 + budget) reorder_iterative_peak_memory_budget: float = 0.2 +# Allow sink waits iterative pass to increase peak memory +# up to peak_memory_before_pass * (1 + budget) +sink_iterative_peak_memory_budget: float = 0.2 # Experimental unsafe configuration that allows changing relative collectives order. # Must be used with runtime_estimations_align_across_all_distributed_ranks = True reorder_iterative_unsafe_collectives_reorder: bool = True +sink_waits_iterative_unsafe_collectives_reorder: bool = True # Allow group and move other collectives during reordering reorder_iterative_group_with_collectives: bool = False +sink_waits_iterative_swap_with_collectives: bool = False # adds patch, save_config, etc install_config_module(sys.modules[__name__]) From 3fdc5dbf1d1742ed49aeebc190db28835fd6ddbf Mon Sep 17 00:00:00 2001 From: Nikita Shulga Date: Thu, 6 Nov 2025 07:24:19 -0800 Subject: [PATCH 129/130] Make CUDA preload logic more straightforward (#167046) I.e. remove distinction between two cases, and always preload full set of libraries For some reason, when one uses `virtualenv` instead of `venv`, preloading `cudart` works, but it fails to find cudnn or cublasLT later on Fix it, by getting read of partial preload logic for one of the cases and always preload full set of libraries Test plan on stock Ubuntu: ``` pip install virtualenv virtualenv --symlinks -p python3.11 --prompt virtv venv-virt source venv-virt/bin/activate pip install torch python -c 'import torch' ``` Fixes https://github.com/pytorch/pytorch/issues/165812 Pull Request resolved: https://github.com/pytorch/pytorch/pull/167046 Approved by: https://github.com/atalman --- torch/__init__.py | 73 +++++++++++++++++++++++++---------------------- 1 file changed, 39 insertions(+), 34 deletions(-) diff --git a/torch/__init__.py b/torch/__init__.py index 05a34bdd93200..b64961a9c56f6 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -303,8 +303,8 @@ def _get_cuda_dep_paths(path: str, lib_folder: str, lib_name: str) -> list[str]: return nvidia_lib_paths + lib_paths -def _preload_cuda_deps(lib_folder: str, lib_name: str, required: bool = True) -> None: # type: ignore[valid-type] - """Preloads cuda deps if they could not be found otherwise.""" +def _preload_cuda_lib(lib_folder: str, lib_name: str, required: bool = True) -> None: # type: ignore[valid-type] + """Preloads cuda library if it could not be found otherwise.""" # Should only be called on Linux if default path resolution have failed assert platform.system() == "Linux", "Should only be called on Linux" @@ -320,6 +320,39 @@ def _preload_cuda_deps(lib_folder: str, lib_name: str, required: bool = True) -> ctypes.CDLL(lib_path) +def _preload_cuda_deps(err: _Optional[OSError] = None) -> None: + cuda_libs: dict[str, str] = { + "cublas": "libcublas.so.*[0-9]", + "cudnn": "libcudnn.so.*[0-9]", + "cuda_nvrtc": "libnvrtc.so.*[0-9]", + "cuda_runtime": "libcudart.so.*[0-9]", + "cuda_cupti": "libcupti.so.*[0-9]", + "cufft": "libcufft.so.*[0-9]", + "curand": "libcurand.so.*[0-9]", + "nvjitlink": "libnvJitLink.so.*[0-9]", + "cusparse": "libcusparse.so.*[0-9]", + "cusparselt": "libcusparseLt.so.*[0-9]", + "cusolver": "libcusolver.so.*[0-9]", + "nccl": "libnccl.so.*[0-9]", + "nvshmem": "libnvshmem_host.so.*[0-9]", + "cufile": "libcufile.so.*[0-9]", + } + + # If error is passed, re-raise it if it's not about one of the abovementioned + # libraries + if err is not None and [ + lib for lib in cuda_libs.values() if lib.split(".", 1)[0] in err.args[0] + ]: + raise err + + # Otherwise, try to preload dependencies from site-packages + for lib_folder, lib_name in cuda_libs.items(): + _preload_cuda_lib(lib_folder, lib_name) + + # libnvToolsExt is Optional Dependency + _preload_cuda_lib("nvtx", "libnvToolsExt.so.*[0-9]", required=False) + + # See Note [Global dependencies] def _load_global_deps() -> None: if platform.system() == "Windows": @@ -346,43 +379,15 @@ def _load_global_deps() -> None: # libtorch_global_deps.so always depends in cudart, check if its installed and loaded if "libcudart.so" not in _maps: return - # If all above-mentioned conditions are met, preload nvrtc and nvjitlink - _preload_cuda_deps("cuda_nvrtc", "libnvrtc.so.*[0-9]") - _preload_cuda_deps("cuda_nvrtc", "libnvrtc-builtins.so.*[0-9]") - _preload_cuda_deps("nvjitlink", "libnvJitLink.so.*[0-9]") + # If all above-mentioned conditions are met, preload CUDA dependencies + _preload_cuda_deps() except Exception: pass except OSError as err: - # Can only happen for wheel with cuda libs as PYPI deps + # Can happen for wheel with cuda libs as PYPI deps # As PyTorch is not purelib, but nvidia-*-cu12 is - cuda_libs: dict[str, str] = { - "cublas": "libcublas.so.*[0-9]", - "cudnn": "libcudnn.so.*[0-9]", - "cuda_nvrtc": "libnvrtc.so.*[0-9]", - "cuda_runtime": "libcudart.so.*[0-9]", - "cuda_cupti": "libcupti.so.*[0-9]", - "cufft": "libcufft.so.*[0-9]", - "curand": "libcurand.so.*[0-9]", - "nvjitlink": "libnvJitLink.so.*[0-9]", - "cusparse": "libcusparse.so.*[0-9]", - "cusparselt": "libcusparseLt.so.*[0-9]", - "cusolver": "libcusolver.so.*[0-9]", - "nccl": "libnccl.so.*[0-9]", - "nvshmem": "libnvshmem_host.so.*[0-9]", - "cufile": "libcufile.so.*[0-9]", - } - - is_cuda_lib_err = [ - lib for lib in cuda_libs.values() if lib.split(".")[0] in err.args[0] - ] - if not is_cuda_lib_err: - raise err - for lib_folder, lib_name in cuda_libs.items(): - _preload_cuda_deps(lib_folder, lib_name) - - # libnvToolsExt is Optional Dependency - _preload_cuda_deps("nvtx", "libnvToolsExt.so.*[0-9]", required=False) + _preload_cuda_deps(err) ctypes.CDLL(global_deps_lib_path, mode=ctypes.RTLD_GLOBAL) From bfc0ba4af97a1169d6fee5692fae34051b750a12 Mon Sep 17 00:00:00 2001 From: Nikita Vedeneev Date: Thu, 6 Nov 2025 16:50:12 +0000 Subject: [PATCH 130/130] `nn.Linear`: nD contiguous input + bias -- dispatch to addmm also when weight is sparse (#166071) As per title. It seems safe to be able to generalize to arbitrary contiguous inputs since `at::matmul` is likely to do the flattening to avoid `baddmm`. Additionally, we guard for bias to be 1D and contiguous which is guaranteed to be fused with no copies. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166071 Approved by: https://github.com/ngimel --- aten/src/ATen/native/Linear.cpp | 61 ++++++++++++++++++++--------- test/profiler/test_profiler_tree.py | 6 +-- 2 files changed, 45 insertions(+), 22 deletions(-) diff --git a/aten/src/ATen/native/Linear.cpp b/aten/src/ATen/native/Linear.cpp index 1da245972f0cb..fbabba84dbb2d 100644 --- a/aten/src/ATen/native/Linear.cpp +++ b/aten/src/ATen/native/Linear.cpp @@ -50,18 +50,35 @@ static inline bool parseLinearFlatten3d() { // `_flatten_nd_linear` flattens all but the last dimension of the input tensor // before passing it to linear operation static inline Tensor _flatten_nd_linear(const Tensor& input, const Tensor& weight, const Tensor& bias) { - const auto input_sizes = input.sym_sizes(); - // can't use -1 in reshape because it errors when a dimension is 0 - c10::SymInt flattened_dim = 1; - for (int64_t i = 0, ndim = input_sizes.size(); i < ndim - 1; ++i) { - flattened_dim = flattened_dim * input_sizes[i]; + const auto input_sizes = input.sym_sizes(); + + const auto result_flattened = [&]() -> Tensor { + const auto input_ncols = input_sizes.back(); + const auto input_flattened_nrows = [&]() -> c10::SymInt { + // can't use -1 in reshape because it errors when a dimension is 0 + auto flattened_nrows = c10::SymInt{1}; + for (const auto& size : input_sizes.slice(0, input_sizes.size() - 1)) { + flattened_nrows *= size; + } + return flattened_nrows; + }(); + + const auto input_flattened = input.view_symint({input_flattened_nrows, input_ncols}); + if (weight.layout() == c10::kStrided) { + return at::addmm(bias, input_flattened, weight.t()); + } else { + // weight is sparse, and addmm for sparse expects matmul lhs to be sparse, + // so we transpose the problem. + // NOTE: at::matmul handles (dense @ sparse) similarly. + const auto bias_t = (bias.dim() >= 2) ? bias.mT() : bias.unsqueeze(-1); + return at::addmm(bias_t, weight, input_flattened.t()).t(); } - auto inp_reshape = input.reshape_symint({flattened_dim, input_sizes.at(input_sizes.size() -1)}); - const auto result = at::addmm(bias, inp_reshape, weight.t()); - auto new_size = input_sizes.slice(0, input_sizes.size() - 1); - c10::SymDimVector sizes_vec(new_size.begin(), new_size.end()); - sizes_vec.push_back(result.sym_size(1)); - return result.view_symint(sizes_vec); + }(); + + // Unflatten flattened row dims + auto result_sizes = c10::SymDimVector{input_sizes.begin(), input_sizes.end()}; + result_sizes.back() = result_flattened.sym_size(1); + return result_flattened.view_symint(result_sizes); } @@ -90,15 +107,23 @@ Tensor linear(const Tensor& input, const Tensor& weight, const std::optionaldefined() && !input.is_xla()) { - // Also hit the fused path for contiguous 3D input, if not using xla + + const auto is_bias_likely_fusable = ( + bias->defined() && + // cuBLASLt: will fuse in the epilogue without copies + // when input/weight/bias are all strided. + // When weight is not strided, bias will not be fused, + // but we can still dispatch here to avoid at::matmul + // path which will probably use a very similar + // flattening optimization. + ((bias->dim() == 1 || bias->squeeze().dim() == 1) && bias->is_contiguous_or_false()) + ); + if (is_bias_likely_fusable && !input.is_xla()) { + // Also hit the fused path for contiguous nD input, if not using xla // backend. Reshaping/flattening has some performance implications on xla. - bool is_contiguous = input.is_contiguous_or_false(); - if (is_contiguous && input_dim == 3) { - return _flatten_nd_linear(input, weight, *bias); - } else if (is_contiguous && input.layout() == c10::kStrided && weight.layout() == c10::kStrided && bias->dim() == 1) { + if (input.is_contiguous_or_false()) { return _flatten_nd_linear(input, weight, *bias); - } else if (parseLinearFlatten3d() && input_dim == 3) { + } else if (parseLinearFlatten3d()) { // If user forces flattening via env var const Tensor input_cont = input.contiguous(); return _flatten_nd_linear(input_cont, weight, *bias); diff --git a/test/profiler/test_profiler_tree.py b/test/profiler/test_profiler_tree.py index c6316fe3cd7e3..e8d28d7eff032 100644 --- a/test/profiler/test_profiler_tree.py +++ b/test/profiler/test_profiler_tree.py @@ -624,8 +624,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: torch/nn/modules/module.py(...): __getattr__ aten::linear - aten::reshape - aten::view + aten::view aten::t aten::transpose aten::as_strided @@ -671,8 +670,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: torch/nn/modules/module.py(...): __getattr__ aten::linear - aten::reshape - aten::view + aten::view aten::t aten::transpose aten::as_strided