Skip to content

Commit d5c9823

Browse files
apakbinmasnesral
andauthored
Fix rand_like decomposition to preserve strides (#2456)
Cherry-pick of: pytorch#159294 Co-authored-by: Sam Larsen <slarsen@meta.com>
1 parent 377ae7f commit d5c9823

File tree

3 files changed

+99
-78
lines changed

3 files changed

+99
-78
lines changed

test/inductor/test_torchinductor.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8918,7 +8918,7 @@ def forward(self, v1: torch.Tensor):
89188918
model = Model()
89198919
x = torch.rand(10, 3, 0)
89208920

8921-
self.common(model, (x,))
8921+
self.common(model, (x,), exact_stride=True)
89228922

89238923
def test_randint(self):
89248924
@torch.compile(fullgraph=True)
@@ -8973,9 +8973,21 @@ def bin(index, max_size):
89738973
@config.patch(fallback_random=True)
89748974
def test_like_rands(self):
89758975
def fn(x):
8976-
return torch.rand_like(x), torch.randn_like(x)
8976+
return torch.rand_like(x), torch.randn_like(x), torch.randint_like(x, 1, 11)
89778977

8978-
self.common(fn, [torch.zeros([20, 20])])
8978+
self.common(fn, [torch.zeros([20, 20])], exact_stride=True)
8979+
8980+
@config.patch(fallback_random=True)
8981+
@xfail_if_mps # 100% are not close
8982+
def test_like_rands_sliced(self):
8983+
def fn(x):
8984+
return (
8985+
torch.randn_like(x),
8986+
torch.randn_like(x),
8987+
torch.randint_like(x, 1, 11),
8988+
)
8989+
8990+
self.common(fn, (torch.zeros([3, 4])[:, ::2].permute(1, 0),), exact_stride=True)
89798991

89808992
@config.patch(check_stack_no_cycles_TESTING_ONLY=True)
89818993
def test_check_stack_no_cycles(self):
@@ -9008,6 +9020,8 @@ def fn(x):
90089020
a0 = fn(x).clone()
90099021
a1 = fn(x).clone()
90109022
self.assertFalse(torch.allclose(a0, a1))
9023+
self.assertEqual(a0.shape, a1.shape)
9024+
self.assertEqual(a0.stride(), a1.stride())
90119025

90129026
@requires_gpu()
90139027
@skip_if_triton_cpu("Flaky on Triton CPU")
@@ -9025,6 +9039,8 @@ def fn(x, device):
90259039
a1 = test_like_rands_on_different_device(GPU_TYPE, "cpu")
90269040
self.assertTrue(a0.device.type == GPU_TYPE)
90279041
self.assertTrue(a1.device.type == "cpu")
9042+
self.assertEqual(a0.shape, a1.shape)
9043+
self.assertEqual(a0.stride(), a1.stride())
90289044

90299045
def test_max_pool2d_with_indices_backward(self):
90309046
def fn(a, b, c):

test/inductor/test_torchinductor_codegen_dynamic_shapes.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ def run(*ex, **kwargs):
173173
"test_bucketize_int_dynamic_shapes": TestFailure(("cpu",)),
174174
"test_searchsorted_dynamic_shapes": TestFailure(("cpu",)),
175175
"test_like_rands_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
176+
"test_like_rands_sliced_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
176177
"test_linspace2_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
177178
"test_linspace3_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
178179
"test_linspace4_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),

torch/_inductor/decomposition.py

Lines changed: 79 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -535,49 +535,17 @@ def view_copy_dtype(
535535
return self.to(dtype).clone()
536536

537537

538-
def get_like_layout(
539-
tensor: torch.Tensor,
540-
memory_format: Optional[torch.memory_format] = None,
541-
) -> torch.memory_format:
542-
# TODO: _to_copy tensor to stride permutation
543-
if memory_format is torch.preserve_format or memory_format is None:
544-
return utils.suggest_memory_format(tensor)
545-
else:
546-
return memory_format
547-
548-
549-
@register_decomposition(aten.rand_like)
550-
def rand_like(
538+
def _get_shape_permutation_like(
551539
self: torch.Tensor,
552-
*,
553-
dtype: Optional[torch.dtype] = None,
554-
device: Optional[torch.device] = None,
555-
memory_format: Optional[torch.memory_format] = None,
556-
**kwargs: Any,
557-
) -> torch.Tensor:
558-
return torch.rand(
559-
[*self.size()],
560-
dtype=dtype or self.dtype,
561-
device=device or self.device,
562-
**kwargs,
563-
).to(memory_format=get_like_layout(self, memory_format))
540+
) -> tuple[utils.ShapeType, utils.StrideType]:
541+
physical_layout = utils.compute_elementwise_output_logical_to_physical_perm(self)
542+
shape = [self.shape[l] for l in physical_layout]
564543

544+
permutation = [0] * len(shape)
545+
for p, l in enumerate(physical_layout):
546+
permutation[l] = p
565547

566-
@register_decomposition(aten.randn_like)
567-
def randn_like(
568-
self: torch.Tensor,
569-
*,
570-
dtype: Optional[torch.dtype] = None,
571-
device: Optional[torch.device] = None,
572-
memory_format: Optional[torch.memory_format] = None,
573-
**kwargs: Any,
574-
) -> torch.Tensor:
575-
return torch.randn(
576-
[*self.size()],
577-
dtype=dtype or self.dtype,
578-
device=device or self.device,
579-
**kwargs,
580-
).to(memory_format=get_like_layout(self, memory_format))
548+
return (shape, permutation)
581549

582550

583551
@register_decomposition(aten.full_like)
@@ -592,55 +560,91 @@ def full_like(
592560
requires_grad: bool = False,
593561
memory_format: torch.memory_format = torch.preserve_format,
594562
) -> torch.Tensor:
595-
return torch.full(
596-
[*self.size()],
597-
fill_value,
598-
dtype=dtype or self.dtype,
599-
layout=layout or self.layout,
600-
device=device or self.device,
601-
requires_grad=requires_grad,
602-
).to(memory_format=get_like_layout(self, memory_format))
563+
dtype = self.dtype if dtype is None else dtype
564+
layout = self.layout if layout is None else layout
565+
device = self.device if device is None else device
566+
567+
if memory_format != torch.preserve_format:
568+
result = torch.full(
569+
self.shape,
570+
fill_value,
571+
dtype=dtype,
572+
layout=layout,
573+
device=device,
574+
pin_memory=pin_memory,
575+
requires_grad=requires_grad,
576+
)
577+
return result.to(memory_format=memory_format)
578+
579+
else:
580+
assert layout == torch.strided
581+
shape, permutation = _get_shape_permutation_like(self)
582+
result = torch.full(
583+
shape,
584+
fill_value,
585+
dtype=dtype,
586+
layout=layout,
587+
device=device,
588+
pin_memory=pin_memory,
589+
requires_grad=requires_grad,
590+
)
591+
if permutation == list(range(len(permutation))):
592+
return result
593+
return result.permute(permutation).clone()
603594

604595

605-
@register_decomposition(aten.randint_like.default)
606-
def randint_like(
596+
def _rand_like(
597+
rand_fn: Callable[..., torch.Tensor],
607598
self: torch.Tensor,
608-
high: int,
609599
*,
610600
dtype: Optional[torch.dtype] = None,
611601
device: Optional[torch.device] = None,
612-
memory_format: Optional[torch.memory_format] = None,
602+
memory_format: torch.memory_format = torch.preserve_format,
613603
**kwargs: Any,
614604
) -> torch.Tensor:
615-
return aten.randint.low(
616-
0,
617-
high,
618-
[*self.size()],
619-
dtype=dtype or self.dtype,
620-
device=device or self.device,
605+
dtype = self.dtype if dtype is None else dtype
606+
device = self.device if device is None else device
607+
608+
if memory_format != torch.preserve_format:
609+
return rand_fn(
610+
self.shape,
611+
dtype=dtype,
612+
device=device,
613+
**kwargs,
614+
).to(memory_format=memory_format)
615+
616+
shape, permutation = _get_shape_permutation_like(self)
617+
result = rand_fn(
618+
shape,
619+
dtype=dtype,
620+
device=device,
621621
**kwargs,
622-
).to(memory_format=get_like_layout(self, memory_format))
622+
)
623+
if permutation == list(range(len(permutation))):
624+
return result
625+
return result.permute(permutation).clone()
626+
627+
628+
@register_decomposition(aten.rand_like)
629+
def rand_like(self: torch.Tensor, **kwargs: Any) -> torch.Tensor:
630+
return _rand_like(torch.rand, self, **kwargs)
631+
632+
633+
@register_decomposition(aten.randn_like)
634+
def randn_like(self: torch.Tensor, **kwargs: Any) -> torch.Tensor:
635+
return _rand_like(torch.randn, self, **kwargs)
636+
637+
638+
@register_decomposition(aten.randint_like.default)
639+
def randint_like(self: torch.Tensor, high: int, **kwargs: Any) -> torch.Tensor:
640+
return _rand_like(functools.partial(aten.randint.low, 0, high), self, **kwargs)
623641

624642

625643
@register_decomposition(aten.randint_like.low_dtype)
626644
def randint_like_low(
627-
self: torch.Tensor,
628-
low: int,
629-
high: int,
630-
*,
631-
dtype: Optional[torch.dtype] = None,
632-
device: Optional[torch.device] = None,
633-
memory_format: Optional[torch.memory_format] = None,
634-
**kwargs: Any,
645+
self: torch.Tensor, low: int, high: int, **kwargs: Any
635646
) -> torch.Tensor:
636-
return aten.randint.low(
637-
low,
638-
high,
639-
[*self.size()],
640-
dtype=dtype or self.dtype,
641-
device=device or self.device,
642-
**kwargs,
643-
).to(memory_format=get_like_layout(self, memory_format))
647+
return _rand_like(functools.partial(aten.randint.low, low, high), self, **kwargs)
644648

645649

646650
@register_decomposition(aten.randint.default)

0 commit comments

Comments
 (0)