Make the instance generator select proposal topics from a Zipf distribution, and
[match/match.git] / program / RandomizedMonad.hs
CommitLineData
967c39ef
MM
1module RandomizedMonad (
2 Randomized,
0df1b3e1
MM
3 msplit,
4 runRandom1, runRandom, runRandomStd, runRandomNewStd,
967c39ef 5 mrandomR, mrandom,
8723ed6a 6 withProb, withWeight,
066d7f53 7 filterRandomized,
0df1b3e1 8 indReplicateRandom, indRepeatRandom, indRandomArray
967c39ef
MM
9) where
10import System.Random
066d7f53
MM
11import Data.Array.IArray
12import Data.Ix
967c39ef
MM
13
14-- Needs -XRank2Types
0df1b3e1 15newtype Randomized a = Randomized (forall g. RandomGen g => (g -> (a, g)))
967c39ef 16
0df1b3e1
MM
17-- This implementation threads a single RandomGen through the whole process in
18-- order to satisfy the monad laws.
967c39ef
MM
19instance Monad Randomized where
20 ma >>= amb = Randomized (\g -> let
967c39ef 21 Randomized fa = ma
0df1b3e1 22 (a, g2) = fa g
967c39ef
MM
23 Randomized fb = amb a
24 in fb g2
25 )
0df1b3e1
MM
26 return x = Randomized (\g -> (x, g))
27
28-- Splits the generator and runs the argument on the left generator while
29-- threading the right generator on. C.f. unsaveInterleaveIO. Use this to
30-- make a sub-calculation parallelizable and evolvable without breaking
31-- same-seed reproducibility of the whole calculation.
32msplit :: Randomized a -> Randomized a
33msplit (Randomized fa) = Randomized
34 (\g -> let (g1, g2) = split g in (fst (fa g1), g2))
35
36runRandom1 :: RandomGen g => g -> Randomized a -> (a, g)
37runRandom1 g (Randomized fa) = fa g
967c39ef
MM
38
39runRandom :: RandomGen g => g -> Randomized a -> a
0df1b3e1 40runRandom g (Randomized fa) = fst (fa g)
967c39ef
MM
41
42-- Conveniences
43runRandomStd :: Randomized a -> IO a
44runRandomStd ra = do
45 g <- getStdGen
46 return $ runRandom g ra
47
48runRandomNewStd :: Randomized a -> IO a
49runRandomNewStd ra = do
50 g <- newStdGen
51 return $ runRandom g ra
52
53-- Monadic versions of random and randomR (to generate primitive-ish values)
967c39ef 54mrandom :: Random a => Randomized a
0df1b3e1
MM
55mrandom = Randomized random
56mrandomR :: Random a => (a, a) -> Randomized a
e42ffb75
MM
57-- Eta-expand this one to keep GHC 6.6.1 on birdy happy.
58mrandomR lohi = Randomized (\g -> randomR lohi g)
967c39ef
MM
59
60chooseCase :: Double -> [(Double, a)] -> a -> a
61chooseCase val ifCs elseR = case ifCs of
62 [] -> elseR
63 (cutoff, theR):ifCt -> if val < cutoff
64 then theR
65 else chooseCase (val - cutoff) ifCt elseR
66
0df1b3e1 67-- An if-elsif...else-style construct where each "if" has a probability.
967c39ef
MM
68withProb :: [(Double, Randomized a)] -> Randomized a -> Randomized a
69withProb ifCs elseR = do
70 val <- mrandom
71 chooseCase val ifCs elseR
72
8723ed6a
MM
73-- Like withProb, but without an else case and with the "probabilities" scaled
74-- so that they sum to 1.
75withWeight :: [(Double, Randomized a)] -> Randomized a
76withWeight ifCs = do
77 val <- mrandomR (0, sum (map fst ifCs))
78 chooseCase val (tail ifCs) (snd (head ifCs))
79
967c39ef
MM
80-- Keep trying until we get what we want.
81filterRandomized :: (a -> Bool) -> Randomized a -> Randomized a
82filterRandomized f ra = do
83 a <- ra
84 if f a then return a else filterRandomized f ra
066d7f53 85
0df1b3e1
MM
86-- A randomized list of elements chosen independently from a distribution.
87-- Each element is under msplit for parallelizability.
88indReplicateRandom :: Int -> Randomized a -> Randomized [a]
89indReplicateRandom n ra = sequence $ replicate n $ msplit ra
90
91-- An infinite randomized list of elements chosen independently from a
92-- distribution. The list is under msplit to avoid an infinite loop when it is
93-- bound.
94indRepeatRandom :: Randomized a -> Randomized [a]
95indRepeatRandom ra = msplit $ sequence $ repeat $ msplit ra
96
97-- Produces an array of elements chosen independently from a distribution.
066d7f53
MM
98indRandomArray :: (IArray a e, Ix i) =>
99 (i, i) -> Randomized e -> Randomized (a i e)
0df1b3e1
MM
100indRandomArray bds re = do
101 list <- indReplicateRandom (rangeSize bds) re
066d7f53 102 return (listArray bds list)