|
69 | 69 | log, |
70 | 70 | log1mexp, |
71 | 71 | log1p, |
| 72 | + log1pexp, |
72 | 73 | lt, |
73 | 74 | maximum, |
74 | 75 | minimum, |
@@ -1968,27 +1969,53 @@ def test_exp_softplus(self, exp_op): |
1968 | 1969 | decimal=6, |
1969 | 1970 | ) |
1970 | 1971 |
|
1971 | | - def test_softplus_log(self): |
1972 | | - # softplus(log(x)) -> log1p(x) |
| 1972 | + def test_log1pexp_log(self): |
| 1973 | + # log1pexp(log(x)) -> log1p(x) |
1973 | 1974 | data_valid = np.random.random((4, 3)).astype("float32") * 2 |
1974 | 1975 | data_valid[0, 0] = 0 # edge case |
1975 | 1976 | data_invalid = data_valid - 2 |
1976 | 1977 |
|
1977 | 1978 | x = fmatrix() |
1978 | | - f = function([x], softplus(log(x)), mode=self.mode) |
1979 | | - graph = f.maker.fgraph.toposort() |
1980 | | - ops_graph = [ |
1981 | | - node |
1982 | | - for node in graph |
1983 | | - if isinstance(node.op, Elemwise) |
1984 | | - and isinstance(node.op.scalar_op, ps.Log | ps.Exp | ps.Softplus) |
1985 | | - ] |
1986 | | - assert len(ops_graph) == 0 |
| 1979 | + f = function([x], log1pexp(log(x)), mode=self.mode.excluding("inplace")) |
| 1980 | + assert equal_computations( |
| 1981 | + f.maker.fgraph.outputs, |
| 1982 | + [ |
| 1983 | + pt.switch( |
| 1984 | + x >= np.array([[0]], dtype=np.int8), |
| 1985 | + pt.log1p(x), |
| 1986 | + np.array([[np.nan]], dtype=np.float32), |
| 1987 | + ) |
| 1988 | + ], |
| 1989 | + ) |
1987 | 1990 |
|
1988 | 1991 | expected = np.log1p(data_valid) |
1989 | 1992 | np.testing.assert_almost_equal(f(data_valid), expected) |
1990 | 1993 | assert np.all(np.isnan(f(data_invalid))) |
1991 | 1994 |
|
| 1995 | + def test_log1mexp_log(self): |
| 1996 | + # log1mexp(log(x)) -> log1p(-x) |
| 1997 | + data_valid = np.random.random((4, 3)).astype("float32") |
| 1998 | + data_valid[0, 0] = 0 # edge case |
| 1999 | + data_valid[0, 1] = 1 # another edge case |
| 2000 | + data_invalid = np.concatenate([data_valid + 1.1, data_valid - 1.1]) |
| 2001 | + |
| 2002 | + x = fmatrix() |
| 2003 | + f = function([x], log1mexp(log(x)), mode=self.mode.excluding("inplace")) |
| 2004 | + assert equal_computations( |
| 2005 | + f.maker.fgraph.outputs, |
| 2006 | + [ |
| 2007 | + pt.switch( |
| 2008 | + x >= np.array([[0]], dtype=np.int8), |
| 2009 | + pt.log1p(-x), |
| 2010 | + np.array([[np.nan]], dtype=np.float32), |
| 2011 | + ) |
| 2012 | + ], |
| 2013 | + ) |
| 2014 | + |
| 2015 | + expected = np.log1p(-data_valid) |
| 2016 | + np.testing.assert_almost_equal(f(data_valid), expected) |
| 2017 | + assert np.all(np.isnan(f(data_invalid))) |
| 2018 | + |
1992 | 2019 | @pytest.mark.parametrize( |
1993 | 2020 | ["nested_expression", "expected_switches"], |
1994 | 2021 | [ |
|
0 commit comments