{- | Performs lambda-lifting of the program -} module Lift (liftCode) where import List import State import PosCode import SyntaxPos import Util.Extra (emptySet,unionSet,removeSet,noPos,strace,pair) import Maybe import Id import IntState import TokenId import IdKind import Building (Compiler(..),compiler) --------- =========== data LiftDown = LiftDown Bool -- strict ((TokenId,IdKind)->Id) -- tidFun TokenId -- current function data LiftThread = LiftThread [(Id,[Id])] -- translation from lifted identifier to new free variables [PosBinding] -- new top-level definitions IntState liftCode :: [(Id,PosLambda)] -> IntState -> ((TokenId,IdKind) -> Id) -> ([(Id,PosLambda)],IntState) liftCode code state tidFun = case (mapS liftTopBinding code) (LiftDown True tidFun tunknown) (LiftThread [] [] state) of (code,LiftThread _ _ state) -> (concat code,state) liftTopBinding d = liftSetTid (fst d) >=> liftBinding d >>>= \ d -> liftTop >>>= \ sc -> unitS (d:sc) liftScc pos bindingsIn down@(LiftDown strict tidFun ptid) up@(LiftThread transIn scIn stateIn) = let (declsInLift,declsInStay) = partition liftIt bindingsIn definedLift = map fst declsInLift envLift = foldr unionSet emptySet (map (expandEnv transIn) (removeSet (foldr (unionSet . map snd . getEnvs) emptySet declsInLift) definedLift)) transNew = map (`pair` envLift) definedLift ++ transIn (declsOutLift,LiftThread _ scInLift state1) = mapS liftBinding declsInLift (LiftDown True tidFun ptid) (LiftThread transNew [] stateIn) (declsOutStay,LiftThread _ scInStay state2) = mapS liftBinding declsInStay (LiftDown False tidFun ptid) (LiftThread transNew [] state1) scHere = map (addArgs envLift) declsOutLift newBindings = map (addEnvs transNew) declsOutStay newSC = scHere ++ scInLift++scInStay++scIn newState = foldr (updateInfo ptid) state2 scHere in (newBindings, LiftThread transNew newSC newState) updateInfo ptid (fun, PosLambda pos int envs args exp) state = let arity = length args in updateIS state fun (\info-> let tid = tidI info in (seq tid (InfoName fun tid arity (tidPos ptid pos) True))) --PHtprof addArgs newargs (fun, PosLambda pos int envs args exp) = (fun, PosLambda pos int [] (map (pair pos) newargs++args) exp) addEnvs trans (fun, PosLambda pos int envs args exp) = (fun, PosLambda pos int (map (pair pos) (foldr unionSet emptySet (map (expandEnv trans . snd) envs))) args exp) liftIt (fun, PosLambda pos int envs args exp) = not (null args) getEnvs (fun, PosLambda pos int envs args exp) = envs expandEnv trans f = case lookup f trans of Nothing -> [f] Just set -> set liftLambda pos int envs args exp down@(LiftDown strict tidFun ptid) up@(LiftThread transIn scIn stateIn) = let newEnvs = map (pair pos) (foldr unionSet emptySet (map (expandEnv transIn . snd) envs)) fl = if int then LamFLIntro else LamFLLambda scHere = (fun, PosLambda pos fl [] (newEnvs++args) exp) arity = length newEnvs + length args (fun,state2) = uniqueIS stateIn tid = (visible (reverse ("LAMBDA" ++ strId fun))) -- Not exported newSC = scHere:scIn newState = seq tid $ addIS fun (InfoName fun tid arity (tidPos ptid pos) True) --PHtprof state2 in (PosExpApp pos (PosVar pos fun:map (uncurry PosVar) newEnvs) ,LiftThread transIn newSC newState ) liftBinding (fun,PosLambda pos int envs args exp) = liftExp exp >>>= \ exp -> unitS (fun,PosLambda pos int envs args exp) liftBinding (fun,PosPrimitive pos fn) = unitS (fun,PosPrimitive pos fn) liftBinding (fun,PosForeign pos fn ar t c ie) = unitS (fun,PosForeign pos fn ar t c ie) liftExpLambda pos int envs args exp = liftStrict True (liftExp exp) >>>= liftLambda pos int envs args liftExp (PosExpLambda pos int envs args exp) = liftExpLambda pos int envs args exp liftExp (PosExpLet rec pos bindings exp) = liftScc pos bindings >>>= \ bindings -> liftExp exp >>>= \ exp -> unitS (PosExpLet rec pos bindings exp) liftExp e@(PosExpCase pos exp alts) = liftGetStrict >>>= \ strict -> if strict then unitS (PosExpCase pos) =>>> liftExp exp =>>> mapS liftAlt alts else strace "liftExp PosExpCase lazy!" $ liftExpLambda pos True [] [] e liftExp e@(PosExpFatBar b e1 e2) = liftGetStrict >>>= \ strict -> if strict then unitS (PosExpFatBar b) =>>> liftExp e1 =>>> liftExp e2 else strace "liftExp PosExpFatBar lazy!" $ liftExpLambda noPos True [] [] e liftExp (PosExpFail) = unitS PosExpFail liftExp e@(PosExpIf pos g c e1 e2) = liftGetStrict >>>= \ strict -> if strict then unitS (PosExpIf pos g) =>>> liftExp c =>>> liftExp e1 =>>> liftExp e2 else strace "liftExp PosExpIf lazy!" $ liftExpLambda pos True [] [] e liftExp e@(PosExpApp pos es) = -- hd es is not always strict !!! liftGetStrict >>>= \ strict -> if strict then --OLD: unitS (posExpApp pos) =>>> liftStrict False (mapS liftExp es) liftExp (head es) >>>= \ head_es -> liftStrict False (mapS liftExp (tail es)) >>>= \ tail_es -> unitS (posExpApp pos (head_es:tail_es)) else liftApply es >>>= liftExp liftExp (PosExpThunk pos ap (e:es)) = -- A primitive/con/apply with correct number of arguments liftExp e >>>= \ e -> liftStrict False (mapS liftExp es) >>>= \ es -> unitS (PosExpThunk pos ap (e:es)) liftExp (PosVar pos i) = liftIdent pos i liftExp a = unitS a liftAlt (PosAltCon pos con args exp) = unitS (PosAltCon pos con args) =>>> liftExp exp liftAlt (PosAltInt pos int b exp) = unitS (PosAltInt pos int b) =>>> liftExp exp -------------------- {- NOTE: the APPLY instruction in YHC works lazily anyway, so we don't need to bother lifting lazy applications. -} liftApply | compiler==Yhc = liftY | compiler==Nhc98 = lift where liftY (e:[]) = unitS e liftY (e:es) = unitS (PosExpThunk (getPos e) False (e:es)) lift (e1:[]) = unitS e1 lift es@(e1:e2:[]) = liftTidFun (t_apply1,Var) >>>= \ f -> unitS (PosExpThunk (getPos e1) True (f:es)) lift es@(e1:e2:e3:[]) = liftTidFun (t_apply2,Var) >>>= \ f -> unitS (PosExpThunk (getPos e1) True (f:es)) lift es@(e1:e2:e3:e4:[]) = liftTidFun (t_apply3,Var) >>>= \ f -> unitS (PosExpThunk (getPos e1) True (f:es)) lift (e1:e2:e3:e4:e5:es) = liftTidFun (t_apply4,Var) >>>= \ f -> lift (PosExpThunk (getPos e1) True (f:e1:e2:e3:e4:e5:[]):es) liftIdent pos ident down@(LiftDown strict tidFun ptid) up@(LiftThread trans sc state) = case lookup ident trans of Nothing -> (PosVar pos ident,up) Just env -> (PosExpApp pos (PosVar pos ident:map (PosVar pos) env),up) liftSetTid fun down@(LiftDown strict tidFun ptid) up@(LiftThread trans sc state) = (LiftDown strict tidFun ((profI . fromJust . lookupIS state) fun),up) liftTop down up@(LiftThread trans sc state) = (sc,LiftThread [] [] state) liftStrict strict lift down@(LiftDown _ tidFun ptid) up = lift (LiftDown strict tidFun ptid) up liftGetStrict down@(LiftDown strict tidFun ptid) up = (strict,up) liftTidFun tid down@(LiftDown strict tidFun ptid) up = (PosVar noPos (tidFun tid) ,up)