Rewrite Bellman-Ford and min-cost flow, especially to stop the latter from crashing.
[match/match.git] / program / NaiveMinCostFlow.hs
1 module NaiveMinCostFlow (minCostFlow) where
2 import BellmanFord
3 import MonadStuff
4 import Data.Array.IArray
5 import Data.Array.ST
6 import Control.Monad
7 import Control.Monad.ST
8 import Data.Graph.Inductive.Graph
9 import Data.Graph.Inductive.Internal.RootPath
10 import Data.List
11
12 data MCFEdge i f c = MCFEdge {
13         edgeIdx   :: i,
14         edgeCap   :: f,
15         edgeCost  :: c,
16         edgeIsRev :: Bool
17 }
18 data MCFState s gr a b i f c = MCFState {
19         mcfGraph  :: gr a (MCFEdge i f c),
20         mcfSource :: Node,
21         mcfSink   :: Node,
22         mcfFlow   :: STArray s i f
23 }
24
25 edgeCapLeft :: (Graph gr, Ix i, Real f, Real c) => MCFState s gr a b i f c ->
26         MCFEdge i f c -> ST s f
27 edgeCapLeft state (MCFEdge i cap _ isRev) = do
28         fwdFlow <- readArray (mcfFlow state) i
29         return (if isRev then fwdFlow else cap - fwdFlow)
30
31 edgePush :: (Graph gr, Ix i, Real f, Real c) => MCFState s gr a b i f c ->
32         MCFEdge i f c -> f -> ST s ()
33 edgePush state (MCFEdge i _ _ isRev) nf = do
34         oldFwdFlow <- readArray (mcfFlow state) i
35         let newFwdFlow = if isRev then oldFwdFlow - nf else oldFwdFlow + nf
36         writeArray (mcfFlow state) i newFwdFlow
37
38 pathCapLeft :: (Graph gr, Ix i, Real f, Real c) => MCFState s gr a b i f c ->
39         (MCFEdge i f c, BFPath (MCFEdge i f c) c) -> ST s f
40 pathCapLeft state (lastEdge, BFPath _ _ mFrom) = do
41         lastCL <- edgeCapLeft state lastEdge
42         case mFrom of
43                 Nothing -> return lastCL
44                 Just from -> do
45                         fromCL <- pathCapLeft state from
46                         return (min lastCL fromCL)
47
48 augment :: (Graph gr, Ix i, Real f, Real c) => MCFState s gr a b i f c ->
49         f -> BFPath (MCFEdge i f c) c -> ST s ()
50 augment state augAmt (BFPath _ _ mFrom) = case mFrom of
51         Nothing -> nop
52         Just (lastEdge, path1) -> do
53                 edgePush state lastEdge augAmt
54                 augment state augAmt path1
55
56 doFlow :: forall s gr a b i f c. (Graph gr, Ix i, Real f, Real c) => MCFState s gr a b i f c ->
57         ST s ()
58 doFlow state = do
59         filteredEdges <- filterM (\(_, _, l) -> do
60                         ecl <- edgeCapLeft state l
61                         return (ecl /= 0)
62                 ) (labEdges (mcfGraph state))
63         let filteredGraph = mkGraph (labNodes (mcfGraph state)) filteredEdges :: gr a (MCFEdge i f c)
64         -- Why won't we get a negative cycle?  The original graph from the
65         -- proposal matcher is acyclic (so no negative cycle), and if we
66         -- created a negative cycle, that would contradict the fact that we
67         -- always augment along the lowest-cost path.
68         let mAugPath = bellmanFord edgeCost (mcfSource state) filteredGraph
69                 ! (mcfSink state)
70         case mAugPath of
71                 Nothing -> nop -- Done.
72                 -- source /= sink, so augPasth is not empty.
73                 Just augPath@(BFPath _ _ (Just from)) -> do
74                         augAmt <- pathCapLeft state from
75                         augment state augAmt augPath
76                         doFlow state
77
78 minCostFlow :: forall s gr a b i f c. (Graph gr, Ix i, Real f, Real c) =>
79         (i, i)       -> -- Range of edge indices
80         (b -> i)     -> -- Edge label -> unique edge index
81         (b -> f)     -> -- Edge label -> flow capacity
82         (b -> c)     -> -- Edge label -> cost per unit of flow
83         gr a b       -> -- Graph
84         (Node, Node) -> -- (source, sink)
85         Array i f       -- ! edge index -> flow value
86 minCostFlow idxBounds edgeIdx edgeCap edgeCost theGraph (source, sink) = runSTArray (do
87                 let ourFlipF isRev l =
88                         MCFEdge (edgeIdx l) (edgeCap l)
89                                 (if isRev then -(edgeCost l) else edgeCost l)
90                                 isRev
91                 let graph2 = mkGraph (labNodes theGraph) (concatMap
92                         (\(n1, n2, l) -> [ -- Capacity of reverse edge is never used.
93                                 (n1, n2, MCFEdge (edgeIdx l) (edgeCap l) (  edgeCost l ) False),
94                                 (n2, n1, MCFEdge (edgeIdx l)  undefined  (-(edgeCost l)) True )
95                         ]) $ labEdges theGraph) :: gr a (MCFEdge i f c)
96                 flow <- newArray idxBounds 0
97                 let state = MCFState graph2 source sink flow
98                 doFlow state
99                 return (mcfFlow state)
100         )