Skip to content

Commit 64d0ef2

Browse files
author
Alex McKenna
committed
Store NormalizeState values in MVars
1 parent fb128e5 commit 64d0ef2

File tree

8 files changed

+271
-187
lines changed

8 files changed

+271
-187
lines changed

clash-lib/clash-lib.cabal

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,9 @@ Library
155155
hint >= 0.7 && < 0.10,
156156
interpolate >= 0.2.0 && < 1.0,
157157
lens >= 4.10 && < 5.1.0,
158+
-- TODO bounds
159+
lifted-base,
160+
monad-control,
158161
mtl >= 2.1.2 && < 2.3,
159162
ordered-containers >= 0.2 && < 0.3,
160163
prettyprinter >= 1.2.0.1 && < 1.8,
@@ -166,6 +169,7 @@ Library
166169
text >= 1.2.2 && < 2.1,
167170
time >= 1.4.0.1 && < 1.14,
168171
transformers >= 0.5.2.0 && < 0.7,
172+
transformers-base,
169173
trifecta >= 1.7.1.1 && < 2.2,
170174
vector >= 0.11 && < 1.0,
171175
vector-binary-instances >= 0.2.3.5 && < 0.3,

clash-lib/src/Clash/Normalize.hs

Lines changed: 38 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
module Clash.Normalize where
1818

19+
import qualified Control.Concurrent.MVar.Lifted as MVar
1920
import Control.Concurrent.Supply (Supply)
2021
import Control.Exception (throw)
2122
import qualified Control.Lens as Lens
@@ -113,34 +114,35 @@ runNormalization
113114
-> NormalizeSession a
114115
-- ^ NormalizeSession to run
115116
-> IO a
116-
runNormalization env supply globals typeTrans peEval eval rcsMap topEnts =
117-
runRewriteSession rwEnv rwState
118-
where
119-
-- TODO The RewriteEnv should just take ClashOpts.
120-
rwEnv = RewriteEnv
121-
env
122-
typeTrans
123-
peEval
124-
eval
125-
(mkVarSet topEnts)
126-
127-
rwState = RewriteState
128-
mempty -- transformCounters Map
129-
globals
130-
supply
131-
(error $ $(curLoc) ++ "Report as bug: no curFun",noSrcSpan)
132-
0
133-
(IntMap.empty, 0)
134-
emptyVarEnv
135-
normState
136-
137-
normState = NormalizeState
138-
emptyVarEnv
139-
Map.empty
140-
emptyVarEnv
141-
emptyVarEnv
142-
Map.empty
143-
rcsMap
117+
runNormalization env supply globals typeTrans peEval eval rcsMap topEntities session = do
118+
normState <- NormalizeState
119+
<$> MVar.newMVar emptyVarEnv
120+
<*> MVar.newMVar Map.empty
121+
<*> MVar.newMVar emptyVarEnv
122+
<*> MVar.newMVar emptyVarEnv
123+
<*> MVar.newMVar Map.empty
124+
<*> MVar.newMVar rcsMap
125+
126+
runRewriteSession rwEnv (rwState normState) session
127+
where
128+
rwEnv = RewriteEnv
129+
{ _clashEnv = env
130+
, _typeTranslator = typeTrans
131+
, _peEvaluator = peEval
132+
, _evaluator = eval
133+
, _topEntities = mkVarSet topEntities
134+
}
135+
136+
rwState s = RewriteState
137+
{ _transformCounters = mempty
138+
, _bindings = globals
139+
, _uniqSupply = supply
140+
, _curFun = (error $ $(curLoc) ++ "Report as bug: no curFun", noSrcSpan)
141+
, _nameCounter = 0
142+
, _globalHeap = (IntMap.empty, 0)
143+
, _workFreeBinders = emptyVarEnv
144+
, _extra = s
145+
}
144146

145147
normalize
146148
:: [Id]
@@ -191,10 +193,14 @@ normalize' nm = do
191193
, ") remains recursive after normalization:\n"
192194
, showPpr (bindingTerm tmNorm) ])
193195
(return ())
194-
prevNorm <- mapVarEnv bindingId <$> Lens.use (extra.normalized)
195-
let toNormalize = filter (`notElemVarSet` topEnts)
196-
$ filter (`notElemVarEnv` (extendVarEnv nm nm prevNorm)) usedBndrs
197-
return (toNormalize,(nm,tmNorm))
196+
197+
normV <- Lens.use (extra.normalized)
198+
199+
MVar.withMVar normV $ \norm ->
200+
let prevNorm = mapVarEnv bindingId norm
201+
toNormalize = filter (`notElemVarSet` topEnts)
202+
$ filter (`notElemVarEnv` extendVarEnv nm nm prevNorm) usedBndrs
203+
in return (toNormalize,(nm,tmNorm))
198204
else
199205
do
200206
-- Throw an error for unrepresentable topEntities and functions

clash-lib/src/Clash/Normalize/Transformations/Specialize.hs

Lines changed: 74 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@ module Clash.Normalize.Transformations.Specialize
2828
) where
2929

3030
import Control.Arrow ((***), (&&&))
31-
import Control.DeepSeq (deepseq)
31+
import qualified Control.Concurrent.MVar.Lifted as MVar
32+
import Control.DeepSeq (force)
3233
import Control.Exception (throw)
33-
import Control.Lens ((%=))
3434
import qualified Control.Lens as Lens
3535
import qualified Control.Monad as Monad
3636
import Control.Monad.Extra (orM)
@@ -387,74 +387,78 @@ specialize' (TransformContext is0 _) e (Var f, args, ticks) specArgIn = do
387387
specAbs :: Either Term Type
388388
specAbs = either (Left . stripAllTicks . (`mkAbstraction` specBndrs)) (Right . id) specArg
389389
-- Determine if 'f' has already been specialized on (a type-normalized) 'specArg'
390-
specM <- Map.lookup (f,argLen,specAbs) <$> Lens.use (extra.specialisationCache)
391-
case specM of
392-
-- Use previously specialized function
393-
Just f' ->
394-
traceIf (hasTransformationInfo AppliedTerm opts)
395-
("Using previous specialization of " ++ showPpr (varName f) ++ " on " ++
396-
(either showPpr showPpr) specAbs ++ ": " ++ showPpr (varName f')) $
397-
changed $ mkApps (mkTicks (Var f') ticks) (args ++ specVars)
398-
-- Create new specialized function
399-
Nothing -> do
400-
-- Determine if we can specialize f
401-
bodyMaybe <- fmap (lookupUniqMap (varName f)) $ Lens.use bindings
402-
case bodyMaybe of
403-
Just (Binding _ sp inl _ bodyTm _) -> do
404-
-- Determine if we see a sequence of specializations on a growing argument
405-
specHistM <- lookupUniqMap f <$> Lens.use (extra.specialisationHistory)
406-
specLim <- Lens.view specializationLimit
407-
if maybe False (> specLim) specHistM
408-
then throw (ClashException
409-
sp
410-
(unlines [ "Hit specialization limit " ++ show specLim ++ " on function `" ++ showPpr (varName f) ++ "'.\n"
411-
, "The function `" ++ showPpr f ++ "' is most likely recursive, and looks like it is being indefinitely specialized on a growing argument.\n"
412-
, "Body of `" ++ showPpr f ++ "':\n" ++ showPpr bodyTm ++ "\n"
413-
, "Argument (in position: " ++ show argLen ++ ") that triggered termination:\n" ++ (either showPpr showPpr) specArg
414-
, "Run with '-fclash-spec-limit=N' to increase the specialization limit to N."
415-
])
416-
Nothing)
417-
else do
418-
let existingNames = collectBndrsMinusApps bodyTm
419-
newNames = [ mkUnsafeInternalName ("pTS" `Text.append` Text.pack (show n)) n
420-
| n <- [(0::Int)..]
421-
]
422-
-- Make new binders for existing arguments
423-
(boundArgs,argVars) <- fmap (unzip . map (either (Left &&& Left . Var) (Right &&& Right . VarTy))) $
424-
Monad.zipWithM
425-
(mkBinderFor is0 tcm)
426-
(existingNames ++ newNames)
427-
args
428-
-- Determine name the resulting specialized function, and the
429-
-- form of the specialized-on argument
430-
(fId,inl',specArg') <- case specArg of
431-
Left a@(collectArgsTicks -> (Var g,gArgs,_gTicks)) -> if isPolyFun tcm a
432-
then do
433-
-- In case we are specialising on an argument that is a
434-
-- global function then we use that function's name as the
435-
-- name of the specialized higher-order function.
436-
-- Additionally, we will return the body of the global
437-
-- function, instead of a variable reference to the
438-
-- global function.
439-
--
440-
-- This will turn things like @mealy g k@ into a new
441-
-- binding @g'@ where both the body of @mealy@ and @g@
442-
-- are inlined, meaning the state-transition-function
443-
-- and the memory element will be in a single function.
444-
gTmM <- fmap (lookupUniqMap (varName g)) $ Lens.use bindings
445-
return (g,maybe inl bindingSpec gTmM, maybe specArg (Left . (`mkApps` gArgs) . bindingTerm) gTmM)
446-
else return (f,inl,specArg)
447-
_ -> return (f,inl,specArg)
448-
-- Create specialized functions
449-
let newBody = mkAbstraction (mkApps bodyTm (argVars ++ [specArg'])) (boundArgs ++ specBndrs)
450-
newf <- mkFunction (varName fId) sp inl' newBody
451-
-- Remember specialization
452-
(extra.specialisationHistory) %= extendUniqMapWith f 1 (+)
453-
(extra.specialisationCache) %= Map.insert (f,argLen,specAbs) newf
454-
-- use specialized function
455-
let newExpr = mkApps (mkTicks (Var newf) ticks) (args ++ specVars)
456-
newf `deepseq` changed newExpr
457-
Nothing -> return e
390+
specCacheV <- Lens.use (extra.specialisationCache)
391+
392+
MVar.modifyMVar specCacheV $ \specCache ->
393+
case Map.lookup (f, argLen, specAbs) specCache of
394+
-- Use previously specialized function
395+
Just f' ->
396+
traceIf (hasTransformationInfo AppliedTerm opts)
397+
("Using previous specialization of " ++ showPpr (varName f) ++ " on " ++
398+
(either showPpr showPpr) specAbs ++ ": " ++ showPpr (varName f')) $
399+
changed (specCache, mkApps (mkTicks (Var f') ticks) (args ++ specVars))
400+
-- Create new specialized function
401+
Nothing -> do
402+
-- Determine if we can specialize f
403+
bodyMaybe <- fmap (lookupUniqMap (varName f)) $ Lens.use bindings
404+
case bodyMaybe of
405+
Just (Binding _ sp inl _ bodyTm _) -> do
406+
-- Determine if we see a sequence of specializations on a growing argument
407+
specHistMV <- Lens.use (extra.specialisationHistory)
408+
specHist <- MVar.takeMVar specHistMV
409+
let specHistM = lookupUniqMap f specHist
410+
specLim <- Lens.view specializationLimit
411+
if maybe False (> specLim) specHistM
412+
then throw (ClashException
413+
sp
414+
(unlines [ "Hit specialization limit " ++ show specLim ++ " on function `" ++ showPpr (varName f) ++ "'.\n"
415+
, "The function `" ++ showPpr f ++ "' is most likely recursive, and looks like it is being indefinitely specialized on a growing argument.\n"
416+
, "Body of `" ++ showPpr f ++ "':\n" ++ showPpr bodyTm ++ "\n"
417+
, "Argument (in position: " ++ show argLen ++ ") that triggered termination:\n" ++ (either showPpr showPpr) specArg
418+
, "Run with '-fclash-spec-limit=N' to increase the specialization limit to N."
419+
])
420+
Nothing)
421+
else do
422+
let existingNames = collectBndrsMinusApps bodyTm
423+
newNames = [ mkUnsafeInternalName ("pTS" `Text.append` Text.pack (show n)) n
424+
| n <- [(0::Int)..]
425+
]
426+
-- Make new binders for existing arguments
427+
(boundArgs,argVars) <- fmap (unzip . map (either (Left &&& Left . Var) (Right &&& Right . VarTy))) $
428+
Monad.zipWithM
429+
(mkBinderFor is0 tcm)
430+
(existingNames ++ newNames)
431+
args
432+
-- Determine name the resulting specialized function, and the
433+
-- form of the specialized-on argument
434+
(fId,inl',specArg') <- case specArg of
435+
Left a@(collectArgsTicks -> (Var g,gArgs,_gTicks)) -> if isPolyFun tcm a
436+
then do
437+
-- In case we are specialising on an argument that is a
438+
-- global function then we use that function's name as the
439+
-- name of the specialized higher-order function.
440+
-- Additionally, we will return the body of the global
441+
-- function, instead of a variable reference to the
442+
-- global function.
443+
--
444+
-- This will turn things like @mealy g k@ into a new
445+
-- binding @g'@ where both the body of @mealy@ and @g@
446+
-- are inlined, meaning the state-transition-function
447+
-- and the memory element will be in a single function.
448+
gTmM <- fmap (lookupUniqMap (varName g)) $ Lens.use bindings
449+
return (g,maybe inl bindingSpec gTmM, maybe specArg (Left . (`mkApps` gArgs) . bindingTerm) gTmM)
450+
else return (f,inl,specArg)
451+
_ -> return (f,inl,specArg)
452+
-- Create specialized functions
453+
let newBody = mkAbstraction (mkApps bodyTm (argVars ++ [specArg'])) (boundArgs ++ specBndrs)
454+
newf <- force <$> mkFunction (varName fId) sp inl' newBody
455+
-- Remember specialization
456+
MVar.putMVar specHistMV (extendUniqMapWith f 1 (+) specHist)
457+
-- use specialized function
458+
let newCache = Map.insert (f, argLen, specAbs) newf specCache
459+
let newExpr = mkApps (mkTicks (Var newf) ticks) (args ++ specVars)
460+
changed (newCache, newExpr)
461+
Nothing -> return (specCache, e)
458462
where
459463
collectBndrsMinusApps :: Term -> [Name a]
460464
collectBndrsMinusApps = reverse . go []

clash-lib/src/Clash/Normalize/Types.hs

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@
1212

1313
module Clash.Normalize.Types where
1414

15+
import Control.Concurrent.MVar (MVar)
1516
import qualified Control.Lens as Lens
16-
import Control.Monad.State.Strict (State)
17+
import Control.Monad.State.Strict (StateT)
1718
import Data.Map (Map)
1819
import Data.Set (Set)
1920
import Data.Text (Text)
@@ -28,25 +29,25 @@ import Clash.Rewrite.Types (Rewrite, RewriteMonad)
2829
-- | State of the 'NormalizeMonad'
2930
data NormalizeState
3031
= NormalizeState
31-
{ _normalized :: BindingMap
32+
{ _normalized :: MVar BindingMap
3233
-- ^ Global binders
33-
, _specialisationCache :: Map (Id,Int,Either Term Type) Id
34+
, _specialisationCache :: MVar (Map (Id,Int,Either Term Type) Id)
3435
-- ^ Cache of previously specialized functions:
3536
--
3637
-- * Key: (name of the original function, argument position, specialized term/type)
3738
--
3839
-- * Elem: (name of specialized function,type of specialized function)
39-
, _specialisationHistory :: VarEnv Int
40+
, _specialisationHistory :: MVar (VarEnv Int)
4041
-- ^ Cache of how many times a function was specialized
41-
, _inlineHistory :: VarEnv (VarEnv Int)
42+
, _inlineHistory :: MVar (VarEnv (VarEnv Int))
4243
-- ^ Cache of function where inlining took place:
4344
--
4445
-- * Key: function where inlining took place
4546
--
4647
-- * Elem: (functions which were inlined, number of times inlined)
47-
, _primitiveArgs :: Map Text (Set Int)
48+
, _primitiveArgs :: MVar (Map Text (Set Int))
4849
-- ^ Cache for looking up constantness of blackbox arguments
49-
, _recursiveComponents :: VarEnv Bool
50+
, _recursiveComponents :: MVar (VarEnv Bool)
5051
-- ^ Map telling whether a components is recursively defined.
5152
--
5253
-- NB: there are only no mutually-recursive component, only self-recursive
@@ -56,7 +57,7 @@ data NormalizeState
5657
Lens.makeLenses ''NormalizeState
5758

5859
-- | State monad that stores specialisation and inlining information
59-
type NormalizeMonad = State NormalizeState
60+
type NormalizeMonad = StateT NormalizeState IO
6061

6162
-- | RewriteSession with extra Normalisation information
6263
type NormalizeSession = RewriteMonad NormalizeState

0 commit comments

Comments
 (0)