{-# OPTIONS_GHC -funbox-strict-fields #-}
module Data.UnionFind.ST
( Point, fresh, repr, union, union', equivalent, redundant,
descriptor, setDescriptor, modifyDescriptor )
where
import Control.Applicative
import Control.Monad ( when )
import Control.Monad.ST
import Data.STRef
newtype Point s a = Pt (STRef s (Link s a)) deriving Point s a -> Point s a -> Bool
(Point s a -> Point s a -> Bool)
-> (Point s a -> Point s a -> Bool) -> Eq (Point s a)
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall s a. Point s a -> Point s a -> Bool
$c== :: forall s a. Point s a -> Point s a -> Bool
== :: Point s a -> Point s a -> Bool
$c/= :: forall s a. Point s a -> Point s a -> Bool
/= :: Point s a -> Point s a -> Bool
Eq
data Link s a
= Info {-# UNPACK #-} !(STRef s (Info a))
| Link {-# UNPACK #-} !(Point s a)
deriving Link s a -> Link s a -> Bool
(Link s a -> Link s a -> Bool)
-> (Link s a -> Link s a -> Bool) -> Eq (Link s a)
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
forall s a. Link s a -> Link s a -> Bool
$c== :: forall s a. Link s a -> Link s a -> Bool
== :: Link s a -> Link s a -> Bool
$c/= :: forall s a. Link s a -> Link s a -> Bool
/= :: Link s a -> Link s a -> Bool
Eq
data Info a = MkInfo
{ forall a. Info a -> Int
weight :: {-# UNPACK #-} !Int
, forall a. Info a -> a
descr :: a
} deriving Info a -> Info a -> Bool
(Info a -> Info a -> Bool)
-> (Info a -> Info a -> Bool) -> Eq (Info a)
forall a. Eq a => Info a -> Info a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall a. Eq a => Info a -> Info a -> Bool
== :: Info a -> Info a -> Bool
$c/= :: forall a. Eq a => Info a -> Info a -> Bool
/= :: Info a -> Info a -> Bool
Eq
fresh :: a -> ST s (Point s a)
fresh :: forall a s. a -> ST s (Point s a)
fresh a
desc = do
info <- Info a -> ST s (STRef s (Info a))
forall a s. a -> ST s (STRef s a)
newSTRef (MkInfo { weight :: Int
weight = Int
1, descr :: a
descr = a
desc })
l <- newSTRef (Info info)
return (Pt l)
repr :: Point s a -> ST s (Point s a)
repr :: forall s a. Point s a -> ST s (Point s a)
repr point :: Point s a
point@(Pt STRef s (Link s a)
l) = do
link <- STRef s (Link s a) -> ST s (Link s a)
forall s a. STRef s a -> ST s a
readSTRef STRef s (Link s a)
l
case link of
Info STRef s (Info a)
_ -> Point s a -> ST s (Point s a)
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return Point s a
point
Link pt' :: Point s a
pt'@(Pt STRef s (Link s a)
l') -> do
pt'' <- Point s a -> ST s (Point s a)
forall s a. Point s a -> ST s (Point s a)
repr Point s a
pt'
when (pt'' /= pt') $ do
link' <- readSTRef l'
writeSTRef l link'
return pt''
descrRef :: Point s a -> ST s (STRef s (Info a))
descrRef :: forall s a. Point s a -> ST s (STRef s (Info a))
descrRef point :: Point s a
point@(Pt STRef s (Link s a)
link_ref) = do
link <- STRef s (Link s a) -> ST s (Link s a)
forall s a. STRef s a -> ST s a
readSTRef STRef s (Link s a)
link_ref
case link of
Info STRef s (Info a)
info -> STRef s (Info a) -> ST s (STRef s (Info a))
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return STRef s (Info a)
info
Link (Pt STRef s (Link s a)
link'_ref) -> do
link' <- STRef s (Link s a) -> ST s (Link s a)
forall s a. STRef s a -> ST s a
readSTRef STRef s (Link s a)
link'_ref
case link' of
Info STRef s (Info a)
info -> STRef s (Info a) -> ST s (STRef s (Info a))
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return STRef s (Info a)
info
Link s a
_ -> Point s a -> ST s (STRef s (Info a))
forall s a. Point s a -> ST s (STRef s (Info a))
descrRef (Point s a -> ST s (STRef s (Info a)))
-> ST s (Point s a) -> ST s (STRef s (Info a))
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Point s a -> ST s (Point s a)
forall s a. Point s a -> ST s (Point s a)
repr Point s a
point
descriptor :: Point s a -> ST s a
descriptor :: forall s a. Point s a -> ST s a
descriptor Point s a
point = do
Info a -> a
forall a. Info a -> a
descr (Info a -> a) -> ST s (Info a) -> ST s a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (STRef s (Info a) -> ST s (Info a)
forall s a. STRef s a -> ST s a
readSTRef (STRef s (Info a) -> ST s (Info a))
-> ST s (STRef s (Info a)) -> ST s (Info a)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Point s a -> ST s (STRef s (Info a))
forall s a. Point s a -> ST s (STRef s (Info a))
descrRef Point s a
point)
setDescriptor :: Point s a -> a -> ST s ()
setDescriptor :: forall s a. Point s a -> a -> ST s ()
setDescriptor Point s a
point a
new_descr = do
r <- Point s a -> ST s (STRef s (Info a))
forall s a. Point s a -> ST s (STRef s (Info a))
descrRef Point s a
point
modifySTRef r $ \Info a
i -> Info a
i { descr = new_descr }
modifyDescriptor :: Point s a -> (a -> a) -> ST s ()
modifyDescriptor :: forall s a. Point s a -> (a -> a) -> ST s ()
modifyDescriptor Point s a
point a -> a
f = do
r <- Point s a -> ST s (STRef s (Info a))
forall s a. Point s a -> ST s (STRef s (Info a))
descrRef Point s a
point
modifySTRef r $ \Info a
i -> Info a
i { descr = f (descr i) }
union :: Point s a -> Point s a -> ST s ()
union :: forall s a. Point s a -> Point s a -> ST s ()
union Point s a
p1 Point s a
p2 = Point s a -> Point s a -> (a -> a -> ST s a) -> ST s ()
forall s a. Point s a -> Point s a -> (a -> a -> ST s a) -> ST s ()
union' Point s a
p1 Point s a
p2 (\a
_ a
d2 -> a -> ST s a
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return a
d2)
union' :: Point s a -> Point s a -> (a -> a -> ST s a) -> ST s ()
union' :: forall s a. Point s a -> Point s a -> (a -> a -> ST s a) -> ST s ()
union' Point s a
p1 Point s a
p2 a -> a -> ST s a
update = do
point1@(Pt link_ref1) <- Point s a -> ST s (Point s a)
forall s a. Point s a -> ST s (Point s a)
repr Point s a
p1
point2@(Pt link_ref2) <- repr p2
when (point1 /= point2) $ do
ir1 <- readSTRef link_ref1
let Info info_ref1 = ir1
ir2 <- readSTRef link_ref2
let Info info_ref2 = ir2
i1 <- readSTRef info_ref1
let MkInfo w1 d1 = i1
i2 <- readSTRef info_ref2
let MkInfo w2 d2 = i2
d2' <- update d1 d2
if w1 >= w2 then do
writeSTRef link_ref2 (Link point1)
writeSTRef info_ref1 (MkInfo (w1 + w2) d2')
else do
writeSTRef link_ref1 (Link point2)
writeSTRef info_ref2 (MkInfo (w1 + w2) d2')
equivalent :: Point s a -> Point s a -> ST s Bool
equivalent :: forall s a. Point s a -> Point s a -> ST s Bool
equivalent Point s a
p1 Point s a
p2 = Point s a -> Point s a -> Bool
forall a. Eq a => a -> a -> Bool
(==) (Point s a -> Point s a -> Bool)
-> ST s (Point s a) -> ST s (Point s a -> Bool)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Point s a -> ST s (Point s a)
forall s a. Point s a -> ST s (Point s a)
repr Point s a
p1 ST s (Point s a -> Bool) -> ST s (Point s a) -> ST s Bool
forall a b. ST s (a -> b) -> ST s a -> ST s b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Point s a -> ST s (Point s a)
forall s a. Point s a -> ST s (Point s a)
repr Point s a
p2
redundant :: Point s a -> ST s Bool
redundant :: forall s a. Point s a -> ST s Bool
redundant (Pt STRef s (Link s a)
link_r) = do
link <- STRef s (Link s a) -> ST s (Link s a)
forall s a. STRef s a -> ST s a
readSTRef STRef s (Link s a)
link_r
case link of
Info STRef s (Info a)
_ -> Bool -> ST s Bool
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
Link Point s a
_ -> Bool -> ST s Bool
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True