@@ -1055,8 +1055,6 @@ defmodule EMLX.NxTest do
10551055 )
10561056 end
10571057
1058- # does not support window_dilations yet
1059- @ tag :skip
10601058 test "works with non-default options" do
10611059 t = Nx . tensor ( [ [ [ 4 , 2 , 1 , 3 ] , [ 4 , 2 , 1 , 7 ] ] , [ [ 1 , 2 , 5 , 7 ] , [ 1 , 8 , 9 , 2 ] ] ] )
10621060 opts = [ strides: [ 2 , 1 , 1 ] , padding: :valid , window_dilations: [ 1 , 2 , 2 ] ]
@@ -1117,8 +1115,6 @@ defmodule EMLX.NxTest do
11171115 )
11181116 end
11191117
1120- # window_dilations are not supported yet
1121- @ tag :skip
11221118 test "works with non-default options" do
11231119 t = Nx . tensor ( [ [ [ 4 , 2 , 1 , 3 ] , [ 4 , 2 , 1 , 7 ] ] , [ [ 1 , 2 , 5 , 7 ] , [ 1 , 8 , 9 , 2 ] ] ] )
11241120 opts = [ strides: [ 2 , 1 , 1 ] , padding: :valid , window_dilations: [ 1 , 2 , 2 ] ]
@@ -1244,8 +1240,6 @@ defmodule EMLX.NxTest do
12441240 )
12451241 end
12461242
1247- # window dilations are not supported yet
1248- @ tag :skip
12491243 test "supports window dilations" do
12501244 result = Nx . window_sum ( Nx . iota ( { 4 , 4 } ) , { 2 , 2 } , window_dilations: [ 2 , 1 ] )
12511245
@@ -1298,8 +1292,6 @@ defmodule EMLX.NxTest do
12981292 )
12991293 end
13001294
1301- # window dilations are not supported yet
1302- @ tag :skip
13031295 test "supports window dilations" do
13041296 result = Nx . window_product ( Nx . iota ( { 4 , 4 } ) , { 2 , 2 } , window_dilations: [ 2 , 1 ] )
13051297
@@ -1576,4 +1568,210 @@ defmodule EMLX.NxTest do
15761568 )
15771569 end
15781570 end
1571+
1572+ describe "window_scatter_max" do
1573+ test "window_scatter_max with strides [2, 3]" do
1574+ t =
1575+ Nx . tensor ( [
1576+ [ 7 , 2 , 5 , 3 , 10 , 2 ] ,
1577+ [ 3 , 8 , 9 , 3 , 4 , 2 ] ,
1578+ [ 1 , 5 , 7 , 5 , 6 , 1 ] ,
1579+ [ 0 , 6 , 2 , 7 , 2 , 8 ]
1580+ ] )
1581+
1582+ opts = [ strides: [ 2 , 3 ] , padding: :valid ]
1583+ result = Nx . window_scatter_max ( t , Nx . tensor ( [ [ 2 , 6 ] , [ 3 , 1 ] ] ) , 0 , { 2 , 3 } , opts )
1584+
1585+ assert_all_close (
1586+ result ,
1587+ Nx . tensor ( [
1588+ [ 0 , 0 , 0 , 0 , 6 , 0 ] ,
1589+ [ 0 , 0 , 2 , 0 , 0 , 0 ] ,
1590+ [ 0 , 0 , 3 , 0 , 0 , 0 ] ,
1591+ [ 0 , 0 , 0 , 0 , 0 , 1 ]
1592+ ] )
1593+ )
1594+ end
1595+
1596+ test "window_scatter_max with strides [2, 2]" do
1597+ t =
1598+ Nx . tensor ( [
1599+ [ 7 , 2 , 5 , 3 , 8 ] ,
1600+ [ 3 , 8 , 9 , 3 , 4 ] ,
1601+ [ 1 , 5 , 7 , 5 , 6 ] ,
1602+ [ 0 , 6 , 2 , 10 , 2 ]
1603+ ] )
1604+
1605+ opts = [ strides: [ 2 , 2 ] , padding: :valid ]
1606+ result = Nx . window_scatter_max ( t , Nx . tensor ( [ [ 2 , 6 ] , [ 3 , 1 ] ] ) , 0 , { 2 , 3 } , opts )
1607+
1608+ assert_all_close (
1609+ result ,
1610+ Nx . tensor ( [
1611+ [ 0 , 0 , 0 , 0 , 0 ] ,
1612+ [ 0 , 0 , 8 , 0 , 0 ] ,
1613+ [ 0 , 0 , 3 , 0 , 0 ] ,
1614+ [ 0 , 0 , 0 , 1 , 0 ]
1615+ ] )
1616+ )
1617+ end
1618+
1619+ test "window_scatter_max with vectorized input" do
1620+ t =
1621+ Nx . tensor ( [
1622+ [
1623+ [ 7 , 2 , 5 , 3 ] ,
1624+ [ 3 , 8 , 9 , 3 ]
1625+ ] ,
1626+ [
1627+ [ 1 , 5 , 7 , 5 ] ,
1628+ [ 0 , 6 , 2 , 8 ]
1629+ ]
1630+ ] )
1631+ |> Nx . vectorize ( :x )
1632+
1633+ opts = [ strides: [ 1 , 2 ] , padding: :valid ]
1634+
1635+ source =
1636+ Nx . tensor ( [
1637+ [ [ 2 , 6 ] ] ,
1638+ [ [ 3 , 1 ] ]
1639+ ] )
1640+ |> Nx . vectorize ( :y )
1641+
1642+ result = Nx . window_scatter_max ( t , source , 0 , { 2 , 2 } , opts )
1643+
1644+ Nx.Testing . assert_equal (
1645+ result ,
1646+ Nx . tensor ( [
1647+ [
1648+ [
1649+ [ 0 , 0 , 0 , 0 ] ,
1650+ [ 0 , 2 , 6 , 0 ]
1651+ ] ,
1652+ [
1653+ [ 0 , 0 , 0 , 0 ] ,
1654+ [ 0 , 3 , 1 , 0 ]
1655+ ]
1656+ ] ,
1657+ [
1658+ [
1659+ [ 0 , 0 , 0 , 0 ] ,
1660+ [ 0 , 2 , 0 , 6 ]
1661+ ] ,
1662+ [
1663+ [ 0 , 0 , 0 , 0 ] ,
1664+ [ 0 , 3 , 0 , 1 ]
1665+ ]
1666+ ]
1667+ ] )
1668+ |> Nx . vectorize ( [ :x , :y ] )
1669+ )
1670+ end
1671+ end
1672+
1673+ describe "window_scatter_min" do
1674+ test "window_scatter_min with strides [2, 3]" do
1675+ t =
1676+ Nx . tensor ( [
1677+ [ 7 , 2 , 5 , 3 , 10 , 2 ] ,
1678+ [ 3 , 8 , 9 , 3 , 4 , 2 ] ,
1679+ [ 1 , 5 , 7 , 5 , 6 , 1 ] ,
1680+ [ 0 , 6 , 2 , 7 , 2 , 8 ]
1681+ ] )
1682+
1683+ opts = [ strides: [ 2 , 3 ] , padding: :valid ]
1684+ result = Nx . window_scatter_min ( t , Nx . tensor ( [ [ 2 , 6 ] , [ 3 , 1 ] ] ) , 0 , { 2 , 3 } , opts )
1685+
1686+ assert_all_close (
1687+ result ,
1688+ Nx . tensor ( [
1689+ [ 0 , 2 , 0 , 0 , 0 , 0 ] ,
1690+ [ 0 , 0 , 0 , 0 , 0 , 6 ] ,
1691+ [ 0 , 0 , 0 , 0 , 0 , 1 ] ,
1692+ [ 3 , 0 , 0 , 0 , 0 , 0 ]
1693+ ] )
1694+ )
1695+ end
1696+
1697+ test "window_scatter_min with strides [2, 2]" do
1698+ t =
1699+ Nx . tensor ( [
1700+ [ 7 , 2 , 5 , 3 , 8 ] ,
1701+ [ 3 , 8 , 9 , 3 , 4 ] ,
1702+ [ 1 , 5 , 7 , 5 , 6 ] ,
1703+ [ 0 , 6 , 2 , 10 , 2 ]
1704+ ] )
1705+
1706+ opts = [ strides: [ 2 , 2 ] , padding: :valid ]
1707+ result = Nx . window_scatter_min ( t , Nx . tensor ( [ [ 2 , 6 ] , [ 3 , 1 ] ] ) , 0 , { 2 , 3 } , opts )
1708+
1709+ assert_all_close (
1710+ result ,
1711+ Nx . tensor ( [
1712+ [ 0 , 2 , 0 , 0 , 0 ] ,
1713+ [ 0 , 0 , 0 , 6 , 0 ] ,
1714+ [ 0 , 0 , 0 , 0 , 0 ] ,
1715+ [ 3 , 0 , 0 , 0 , 1 ]
1716+ ] )
1717+ )
1718+ end
1719+
1720+ test "window_scatter_min with vectorized input" do
1721+ t =
1722+ Nx . tensor ( [
1723+ [
1724+ [ 7 , 2 , 5 , 3 ] ,
1725+ [ 3 , 8 , 9 , 3 ]
1726+ ] ,
1727+ [
1728+ [ 1 , 5 , 7 , 5 ] ,
1729+ [ 0 , 6 , 2 , 8 ]
1730+ ]
1731+ ] )
1732+ |> Nx . vectorize ( :x )
1733+
1734+ opts = [ strides: [ 1 , 2 ] , padding: :valid ]
1735+
1736+ source =
1737+ Nx . tensor ( [
1738+ [
1739+ [ 2 , 6 ]
1740+ ] ,
1741+ [
1742+ [ 3 , 1 ]
1743+ ]
1744+ ] )
1745+ |> Nx . vectorize ( :y )
1746+
1747+ result = Nx . window_scatter_min ( t , source , 0 , { 2 , 2 } , opts )
1748+
1749+ assert_all_close (
1750+ result ,
1751+ Nx . tensor ( [
1752+ [
1753+ [
1754+ [ 0 , 2 , 0 , 0 ] ,
1755+ [ 0 , 0 , 0 , 6 ]
1756+ ] ,
1757+ [
1758+ [ 0 , 3 , 0 , 0 ] ,
1759+ [ 0 , 0 , 0 , 1 ]
1760+ ]
1761+ ] ,
1762+ [
1763+ [
1764+ [ 0 , 0 , 0 , 0 ] ,
1765+ [ 2 , 0 , 6 , 0 ]
1766+ ] ,
1767+ [
1768+ [ 0 , 0 , 0 , 0 ] ,
1769+ [ 3 , 0 , 1 , 0 ]
1770+ ]
1771+ ]
1772+ ] )
1773+ |> Nx . vectorize ( [ :x , :y ] )
1774+ )
1775+ end
1776+ end
15791777end
0 commit comments