@@ -506,7 +506,7 @@ defmodule Torchx.Backend do
506506
507507 result =
508508 if axes == [ ] do
509- aggregate_whole_tensor ( t , keep_axes , & Torchx . product / 1 )
509+ aggregate_whole_tensor ( t , & Torchx . product / 1 )
510510 else
511511 aggregate_over_axes ( t , axes , keep_axes , & Torchx . product / 3 )
512512 end
@@ -523,7 +523,7 @@ defmodule Torchx.Backend do
523523
524524 result =
525525 if axes == [ ] do
526- aggregate_whole_tensor ( t , keep_axes , & Torchx . any / 1 )
526+ aggregate_whole_tensor ( t , & Torchx . any / 1 )
527527 else
528528 aggregate_over_axes ( t , axes , keep_axes , & Torchx . any / 3 )
529529 end
@@ -538,7 +538,7 @@ defmodule Torchx.Backend do
538538
539539 result =
540540 if axes == [ ] do
541- aggregate_whole_tensor ( t , keep_axes , & Torchx . all / 1 )
541+ aggregate_whole_tensor ( t , & Torchx . all / 1 )
542542 else
543543 aggregate_over_axes ( t , axes , keep_axes , & Torchx . all / 3 )
544544 end
@@ -563,18 +563,10 @@ defmodule Torchx.Backend do
563563 |> to_nx ( out )
564564 end
565565
566- defp aggregate_whole_tensor ( t , keep_axes , fun ) when is_function ( fun , 1 ) do
567- result =
568- t
569- |> from_nx ( )
570- |> then ( fun )
571-
572- if keep_axes do
573- shape = t . shape |> Tuple . delete_at ( - 1 ) |> Tuple . append ( 1 )
574- Torchx . reshape ( result , shape )
575- else
576- result
577- end
566+ defp aggregate_whole_tensor ( t , fun ) when is_function ( fun , 1 ) do
567+ t
568+ |> from_nx ( )
569+ |> then ( fun )
578570 end
579571
580572 defp aggregate_over_axes ( t , axes , keep_axes , fun ) when is_function ( fun , 3 ) do
0 commit comments