@@ -535,49 +535,17 @@ def view_copy_dtype(
535535 return self .to (dtype ).clone ()
536536
537537
538- def get_like_layout (
539- tensor : torch .Tensor ,
540- memory_format : Optional [torch .memory_format ] = None ,
541- ) -> torch .memory_format :
542- # TODO: _to_copy tensor to stride permutation
543- if memory_format is torch .preserve_format or memory_format is None :
544- return utils .suggest_memory_format (tensor )
545- else :
546- return memory_format
547-
548-
549- @register_decomposition (aten .rand_like )
550- def rand_like (
538+ def _get_shape_permutation_like (
551539 self : torch .Tensor ,
552- * ,
553- dtype : Optional [torch .dtype ] = None ,
554- device : Optional [torch .device ] = None ,
555- memory_format : Optional [torch .memory_format ] = None ,
556- ** kwargs : Any ,
557- ) -> torch .Tensor :
558- return torch .rand (
559- [* self .size ()],
560- dtype = dtype or self .dtype ,
561- device = device or self .device ,
562- ** kwargs ,
563- ).to (memory_format = get_like_layout (self , memory_format ))
540+ ) -> tuple [utils .ShapeType , utils .StrideType ]:
541+ physical_layout = utils .compute_elementwise_output_logical_to_physical_perm (self )
542+ shape = [self .shape [l ] for l in physical_layout ]
564543
544+ permutation = [0 ] * len (shape )
545+ for p , l in enumerate (physical_layout ):
546+ permutation [l ] = p
565547
566- @register_decomposition (aten .randn_like )
567- def randn_like (
568- self : torch .Tensor ,
569- * ,
570- dtype : Optional [torch .dtype ] = None ,
571- device : Optional [torch .device ] = None ,
572- memory_format : Optional [torch .memory_format ] = None ,
573- ** kwargs : Any ,
574- ) -> torch .Tensor :
575- return torch .randn (
576- [* self .size ()],
577- dtype = dtype or self .dtype ,
578- device = device or self .device ,
579- ** kwargs ,
580- ).to (memory_format = get_like_layout (self , memory_format ))
548+ return (shape , permutation )
581549
582550
583551@register_decomposition (aten .full_like )
@@ -592,55 +560,91 @@ def full_like(
592560 requires_grad : bool = False ,
593561 memory_format : torch .memory_format = torch .preserve_format ,
594562) -> torch .Tensor :
595- return torch .full (
596- [* self .size ()],
597- fill_value ,
598- dtype = dtype or self .dtype ,
599- layout = layout or self .layout ,
600- device = device or self .device ,
601- requires_grad = requires_grad ,
602- ).to (memory_format = get_like_layout (self , memory_format ))
563+ dtype = self .dtype if dtype is None else dtype
564+ layout = self .layout if layout is None else layout
565+ device = self .device if device is None else device
566+
567+ if memory_format != torch .preserve_format :
568+ result = torch .full (
569+ self .shape ,
570+ fill_value ,
571+ dtype = dtype ,
572+ layout = layout ,
573+ device = device ,
574+ pin_memory = pin_memory ,
575+ requires_grad = requires_grad ,
576+ )
577+ return result .to (memory_format = memory_format )
578+
579+ else :
580+ assert layout == torch .strided
581+ shape , permutation = _get_shape_permutation_like (self )
582+ result = torch .full (
583+ shape ,
584+ fill_value ,
585+ dtype = dtype ,
586+ layout = layout ,
587+ device = device ,
588+ pin_memory = pin_memory ,
589+ requires_grad = requires_grad ,
590+ )
591+ if permutation == list (range (len (permutation ))):
592+ return result
593+ return result .permute (permutation ).clone ()
603594
604595
605- @ register_decomposition ( aten . randint_like . default )
606- def randint_like (
596+ def _rand_like (
597+ rand_fn : Callable [..., torch . Tensor ],
607598 self : torch .Tensor ,
608- high : int ,
609599 * ,
610600 dtype : Optional [torch .dtype ] = None ,
611601 device : Optional [torch .device ] = None ,
612- memory_format : Optional [ torch .memory_format ] = None ,
602+ memory_format : torch .memory_format = torch . preserve_format ,
613603 ** kwargs : Any ,
614604) -> torch .Tensor :
615- return aten .randint .low (
616- 0 ,
617- high ,
618- [* self .size ()],
619- dtype = dtype or self .dtype ,
620- device = device or self .device ,
605+ dtype = self .dtype if dtype is None else dtype
606+ device = self .device if device is None else device
607+
608+ if memory_format != torch .preserve_format :
609+ return rand_fn (
610+ self .shape ,
611+ dtype = dtype ,
612+ device = device ,
613+ ** kwargs ,
614+ ).to (memory_format = memory_format )
615+
616+ shape , permutation = _get_shape_permutation_like (self )
617+ result = rand_fn (
618+ shape ,
619+ dtype = dtype ,
620+ device = device ,
621621 ** kwargs ,
622- ).to (memory_format = get_like_layout (self , memory_format ))
622+ )
623+ if permutation == list (range (len (permutation ))):
624+ return result
625+ return result .permute (permutation ).clone ()
626+
627+
628+ @register_decomposition (aten .rand_like )
629+ def rand_like (self : torch .Tensor , ** kwargs : Any ) -> torch .Tensor :
630+ return _rand_like (torch .rand , self , ** kwargs )
631+
632+
633+ @register_decomposition (aten .randn_like )
634+ def randn_like (self : torch .Tensor , ** kwargs : Any ) -> torch .Tensor :
635+ return _rand_like (torch .randn , self , ** kwargs )
636+
637+
638+ @register_decomposition (aten .randint_like .default )
639+ def randint_like (self : torch .Tensor , high : int , ** kwargs : Any ) -> torch .Tensor :
640+ return _rand_like (functools .partial (aten .randint .low , 0 , high ), self , ** kwargs )
623641
624642
625643@register_decomposition (aten .randint_like .low_dtype )
626644def randint_like_low (
627- self : torch .Tensor ,
628- low : int ,
629- high : int ,
630- * ,
631- dtype : Optional [torch .dtype ] = None ,
632- device : Optional [torch .device ] = None ,
633- memory_format : Optional [torch .memory_format ] = None ,
634- ** kwargs : Any ,
645+ self : torch .Tensor , low : int , high : int , ** kwargs : Any
635646) -> torch .Tensor :
636- return aten .randint .low (
637- low ,
638- high ,
639- [* self .size ()],
640- dtype = dtype or self .dtype ,
641- device = device or self .device ,
642- ** kwargs ,
643- ).to (memory_format = get_like_layout (self , memory_format ))
647+ return _rand_like (functools .partial (aten .randint .low , low , high ), self , ** kwargs )
644648
645649
646650@register_decomposition (aten .randint .default )
0 commit comments