Skip to content

Commit 1ab240a

Browse files
author
Alex McKenna
committed
Expose pure variant of isWorkFree
When changing the RewriteState to use MVar, the isWorkFree function causes something of a problem in it's current form, as it forms a recursive group with isWorkFreeBinder. It cannot simply take a lens to an MVar, as this is not the cache itself, but the means to obtain the cache. To make things simpler, the work free functions are rewritten to be pure, and the old lensy monadic interface is exposed as another function on top.
1 parent 64d0ef2 commit 1ab240a

File tree

3 files changed

+113
-49
lines changed

3 files changed

+113
-49
lines changed

clash-lib/src/Clash/Core/PartialEval/Monad.hs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ import Clash.Core.Util (mkUniqSystemId, mkUniqSystemTyVar)
8484
import Clash.Core.Var (Id, TyVar, Var)
8585
import Clash.Core.VarEnv
8686
import Clash.Driver.Types (Binding(..))
87-
import Clash.Rewrite.WorkFree (isWorkFree)
87+
import Clash.Rewrite.WorkFree (isWorkFreePure)
8888

8989
{-
9090
NOTE [RWS monad]
@@ -311,7 +311,11 @@ workFreeValue :: Value -> Eval Bool
311311
workFreeValue = \case
312312
VNeutral _ -> pure False
313313
VThunk x _ -> do
314-
bindings <- fmap (fmap asTerm) . genvBindings <$> getGlobalEnv
315-
isWorkFree workFreeCache bindings x
314+
env <- getGlobalEnv
315+
let bindings = fmap (fmap asTerm) (genvBindings env)
316+
let (cache, wf) = isWorkFreePure (genvWorkCache env) bindings x
317+
318+
modifyGlobalEnv (\genv -> genv { genvWorkCache = cache })
319+
pure wf
316320

317321
_ -> pure True

clash-lib/src/Clash/Core/PartialEval/NormalForm.hs

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,9 @@ module Clash.Core.PartialEval.NormalForm
2727
, Normal(..)
2828
, LocalEnv(..)
2929
, GlobalEnv(..)
30-
, workFreeCache
3130
) where
3231

3332
import Control.Concurrent.Supply (Supply)
34-
import Control.Lens (Lens', lens)
3533
import Data.IntMap.Strict (IntMap)
3634
import Data.Map.Strict (Map)
3735

@@ -192,6 +190,3 @@ data GlobalEnv = GlobalEnv
192190
-- ^ Cache for the results of isWorkFree. This is required to use
193191
-- Clash.Rewrite.WorkFree.isWorkFree.
194192
}
195-
196-
workFreeCache :: Lens' GlobalEnv (VarEnv Bool)
197-
workFreeCache = lens genvWorkCache (\env x -> env { genvWorkCache = x })

clash-lib/src/Clash/Rewrite/WorkFree.hs

Lines changed: 106 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,23 @@ evaluation to check whether it is possible to perform changes without
88
duplicating work in the result, e.g. inlining.
99
-}
1010

11+
{-# LANGUAGE FlexibleContexts #-}
1112
{-# LANGUAGE OverloadedStrings #-}
1213
{-# LANGUAGE RankNTypes #-}
1314
{-# LANGUAGE TemplateHaskellQuotes #-}
1415

1516
module Clash.Rewrite.WorkFree
1617
( isWorkFree
18+
, isWorkFreePure
1719
, isWorkFreeClockOrResetOrEnable
1820
, isWorkFreeIsh
1921
, isConstant
2022
, isConstantNotClockReset
2123
) where
2224

23-
import Control.Lens (Lens')
24-
import Control.Monad.Extra (allM, andM, eitherM)
25+
import Control.Lens as Lens (Lens', use, (.=))
2526
import Control.Monad.State.Class (MonadState)
27+
import Control.Monad.Trans.Control (MonadBaseControl)
2628
import qualified Data.Text.Extra as Text
2729
import GHC.Stack (HasCallStack)
2830

@@ -35,41 +37,63 @@ import Clash.Core.TyCon (TyConMap)
3537
import Clash.Core.Type (isPolyFunTy)
3638
import Clash.Core.Util
3739
import Clash.Core.Var (Id, isLocalId)
38-
import Clash.Core.VarEnv (VarEnv, lookupVarEnv)
40+
import Clash.Core.VarEnv (VarEnv, extendVarEnv, lookupVarEnv, unionVarEnv)
3941
import Clash.Driver.Types (BindingMap, Binding(..))
4042
import Clash.Normalize.Primitives (removedArg)
41-
import Clash.Util (makeCachedU)
43+
44+
-- TODO I think isWorkFree only needs to exist within the rewriting monad, and
45+
-- this extra polymorphism is probably unnecessary. Needs checking. -- Alex
46+
47+
{-# INLINABLE isWorkFree #-}
48+
isWorkFree
49+
:: (HasCallStack, MonadState s m, MonadBaseControl IO m)
50+
=> Lens' s (VarEnv Bool)
51+
-> BindingMap
52+
-> Term
53+
-> m Bool
54+
isWorkFree cacheL bndrs bndr = do
55+
cache <- Lens.use cacheL
56+
let (cache', wf) = isWorkFreePure cache bndrs bndr
57+
58+
cacheL .= cache'
59+
pure wf
60+
4261

4362
-- | Determines whether a global binder is work free. Errors if binder does
4463
-- not exist.
4564
isWorkFreeBinder
46-
:: (HasCallStack, MonadState s m)
47-
=> Lens' s (VarEnv Bool)
65+
:: HasCallStack
66+
=> VarEnv Bool
4867
-> BindingMap
4968
-> Id
50-
-> m Bool
69+
-> (VarEnv Bool, Bool)
5170
isWorkFreeBinder cache bndrs bndr =
52-
makeCachedU bndr cache $
53-
case lookupVarEnv bndr bndrs of
54-
Nothing -> error ("isWorkFreeBinder: couldn't find binder: " ++ showPpr bndr)
55-
Just (bindingTerm -> t) ->
56-
if bndr `globalIdOccursIn` t
57-
then pure False
58-
else isWorkFree cache bndrs t
71+
case lookupVarEnv bndr cache of
72+
Just value ->
73+
(cache, value)
5974

60-
{-# INLINABLE isWorkFree #-}
75+
Nothing ->
76+
case lookupVarEnv bndr bndrs of
77+
Nothing ->
78+
error ("isWorkFreeBinder: couldn't find binder: " ++ showPpr bndr)
79+
80+
Just (bindingTerm -> t) ->
81+
if bndr `globalIdOccursIn` t
82+
then (extendVarEnv bndr False cache, False)
83+
else isWorkFreePure cache bndrs t
84+
85+
{-# INLINABLE isWorkFreePure #-}
6186
-- | Determine whether a term does any work, i.e. adds to the size of the
6287
-- circuit. This function requires a cache (specified as a lens) to store the
6388
-- result for querying work info of global binders.
6489
--
65-
isWorkFree
66-
:: forall s m
67-
. (HasCallStack, MonadState s m)
68-
=> Lens' s (VarEnv Bool)
90+
isWorkFreePure
91+
:: HasCallStack
92+
=> VarEnv Bool
6993
-> BindingMap
7094
-> Term
71-
-> m Bool
72-
isWorkFree cache bndrs = go True
95+
-> (VarEnv Bool, Bool)
96+
isWorkFreePure cache bndrs = go True
7397
where
7498
-- If we are in the outermost level of a term (i.e. not checking a subterm)
7599
-- then a term is work free if it simply refers to a local variable. This
@@ -79,7 +103,7 @@ isWorkFree cache bndrs = go True
79103
--
80104
-- as being work free, as the term bound to f may introduce work.
81105
--
82-
go :: HasCallStack => Bool -> Term -> m Bool
106+
go :: HasCallStack => Bool -> Term -> (VarEnv Bool, Bool)
83107
go isOutermost (collectArgs -> (fun, args)) =
84108
case fun of
85109
Var i
@@ -91,38 +115,79 @@ isWorkFree cache bndrs = go True
91115
-- would need to be changed to know the FVs of global binders first.
92116
--
93117
| isPolyFunTy (coreTypeOf i) ->
94-
pure (isLocalId i && isOutermost && null args)
118+
(cache, isLocalId i && isOutermost && null args)
95119
| isLocalId i ->
96-
pure True
120+
(cache, True)
97121
| otherwise ->
98-
andM [isWorkFreeBinder cache bndrs i, allM goArg args]
122+
let (cache', wf) = isWorkFreeBinder cache bndrs i
123+
(caches, wfs) = unzip (fmap goArg args)
124+
in (foldr unionVarEnv cache' caches, and (wf : wfs))
125+
126+
Data _ ->
127+
let (caches, wfs) = unzip (fmap goArg args)
128+
in (foldr unionVarEnv mempty caches, and wfs)
129+
130+
Literal _ ->
131+
(cache, True)
99132

100-
Data _ -> allM goArg args
101-
Literal _ -> pure True
102133
Prim pr ->
103134
case primWorkInfo pr of
104135
-- We can ignore arguments because the primitive outputs a constant
105136
-- regardless of their values.
106-
WorkConstant -> pure True
107-
WorkNever -> allM goArg args
108-
WorkIdentity _ _ -> allM goArg args
109-
WorkVariable -> pure (all isConstantArg args)
110-
WorkAlways -> pure False
111-
112-
Lam _ e -> andM [go False e, allM goArg args]
113-
TyLam _ e -> andM [go False e, allM goArg args]
114-
Let (NonRec _ x) e -> andM [go False e, go False x, allM goArg args]
115-
Let (Rec bs) e -> andM [go False e, allM (go False . snd) bs, allM goArg args]
116-
Case s _ [(_, a)] -> andM [go False s, go False a, allM goArg args]
117-
Case e _ _ -> andM [go False e, allM goArg args]
118-
Cast e _ _ -> andM [go False e, allM goArg args]
137+
WorkConstant -> (cache, True)
138+
WorkNever ->
139+
let (caches, wfs) = unzip (fmap goArg args)
140+
in (foldr unionVarEnv mempty caches, and wfs)
141+
WorkIdentity _ _ ->
142+
let (caches, wfs) = unzip (fmap goArg args)
143+
in (foldr unionVarEnv mempty caches, and wfs)
144+
WorkVariable -> (cache, all isConstantArg args)
145+
WorkAlways -> (cache, False)
146+
147+
Lam _ e ->
148+
let (cache', wf) = go False e
149+
(caches, wfs) = unzip (fmap goArg args)
150+
in (foldr unionVarEnv cache' caches, and (wf : wfs))
151+
152+
TyLam _ e ->
153+
let (cache', wf) = go False e
154+
(caches, wfs) = unzip (fmap goArg args)
155+
in (foldr unionVarEnv cache' caches, and (wf : wfs))
156+
157+
Let (NonRec _ x) e ->
158+
let (cacheE, wfE) = go False e
159+
(cacheX, wfX) = go False x
160+
(caches, wfs) = unzip (fmap goArg args)
161+
in (foldr unionVarEnv cacheE (cacheX : caches), and (wfE : wfX : wfs))
162+
163+
Let (Rec bs) e ->
164+
let (cacheE, wfE) = go False e
165+
(cacheBs, wfBs) = unzip (fmap (go False . snd) bs)
166+
(caches, wfs) = unzip (fmap goArg args)
167+
in (foldr unionVarEnv cacheE (cacheBs <> caches), and (wfE : (wfBs <> wfs)))
168+
169+
Case s _ [(_, a)] ->
170+
let (cacheS, wfS) = go False s
171+
(cacheA, wfA) = go False a
172+
(caches, wfs) = unzip (fmap goArg args)
173+
in (foldr unionVarEnv cacheS (cacheA : caches), and (wfS : wfA : wfs))
174+
175+
Case e _ _ ->
176+
let (cache', wf) = go False e
177+
(caches, wfs) = unzip (fmap goArg args)
178+
in (foldr unionVarEnv cache' caches, and (wf : wfs))
179+
180+
Cast e _ _ ->
181+
let (cache', wf) = go False e
182+
(caches, wfs) = unzip (fmap goArg args)
183+
in (foldr unionVarEnv cache' caches, and (wf : wfs))
119184

120185
-- (Ty)App's and Ticks are removed by collectArgs
121186
Tick _ _ -> error "isWorkFree: unexpected Tick"
122187
App {} -> error "isWorkFree: unexpected App"
123188
TyApp {} -> error "isWorkFree: unexpected TyApp"
124189

125-
goArg e = eitherM (go False) (pure . const True) (pure e)
190+
goArg e = either (go False) (const (cache, True)) e
126191
isConstantArg = either isConstant (const True)
127192

128193
-- | Determine if a term represents a constant

0 commit comments

Comments
 (0)