{-# LANGUAGE CPP, NamedFieldPuns, RecordWildCards, ScopedTypeVariables, RankNTypes, DeriveDataTypeable #-}
#if MIN_VERSION_monad_control(0,3,0)
{-# LANGUAGE FlexibleContexts #-}
#endif
#if !MIN_VERSION_base(4,3,0)
{-# LANGUAGE RankNTypes #-}
#endif
module Data.Pool
(
Pool(idleTime, maxResources, numStripes)
, LocalPool
, createPool
, withResource
, takeResource
, tryWithResource
, tryTakeResource
, destroyResource
, putResource
, destroyAllResources
) where
import Control.Applicative ((<$>))
import Control.Concurrent (ThreadId, forkIOWithUnmask, killThread, myThreadId, threadDelay)
import Control.Concurrent.STM
import Control.Exception (SomeException, onException, mask_)
import Control.Monad (forM_, forever, join, liftM3, unless, when)
import Data.Hashable (hash)
import Data.IORef (IORef, newIORef, mkWeakIORef)
import Data.List (partition)
import Data.Time.Clock (NominalDiffTime, UTCTime, diffUTCTime, getCurrentTime)
import Data.Typeable (Typeable)
import GHC.Conc.Sync (labelThread)
import qualified Control.Exception as E
import qualified Data.Vector as V
#if MIN_VERSION_monad_control(0,3,0)
import Control.Monad.Trans.Control (MonadBaseControl, control)
import Control.Monad.Base (liftBase)
#else
import Control.Monad.IO.Control (MonadControlIO, controlIO)
import Control.Monad.IO.Class (liftIO)
#define control controlIO
#define liftBase liftIO
#endif
#if MIN_VERSION_base(4,3,0)
import Control.Exception (mask)
#else
mask :: ((forall a. IO a -> IO a) -> IO b) -> IO b
mask f = f id
#endif
data Entry a = Entry {
entry :: a
, lastUse :: UTCTime
}
data LocalPool a = LocalPool {
inUse :: TVar Int
, entries :: TVar [Entry a]
, lfin :: IORef ()
} deriving (Typeable)
data Pool a = Pool {
create :: IO a
, destroy :: a -> IO ()
, numStripes :: Int
, idleTime :: NominalDiffTime
, maxResources :: Int
, localPools :: V.Vector (LocalPool a)
, fin :: IORef ()
} deriving (Typeable)
instance Show (Pool a) where
show Pool{..} = "Pool {numStripes = " ++ show numStripes ++ ", " ++
"idleTime = " ++ show idleTime ++ ", " ++
"maxResources = " ++ show maxResources ++ "}"
createPool
:: IO a
-> (a -> IO ())
-> Int
-> NominalDiffTime
-> Int
-> IO (Pool a)
createPool create destroy numStripes idleTime maxResources = do
when (numStripes < 1) $
modError "pool " $ "invalid stripe count " ++ show numStripes
when (idleTime < 0.5) $
modError "pool " $ "invalid idle time " ++ show idleTime
when (maxResources < 1) $
modError "pool " $ "invalid maximum resource count " ++ show maxResources
localPools <- V.replicateM numStripes $
liftM3 LocalPool (newTVarIO 0) (newTVarIO []) (newIORef ())
reaperId <- forkIOLabeledWithUnmask "resource-pool: reaper" $ \unmask ->
unmask $ reaper destroy idleTime localPools
fin <- newIORef ()
let p = Pool {
create
, destroy
, numStripes
, idleTime
, maxResources
, localPools
, fin
}
mkWeakIORef fin (killThread reaperId) >>
V.mapM_ (\lp -> mkWeakIORef (lfin lp) (purgeLocalPool destroy lp)) localPools
return p
forkIOLabeledWithUnmask :: String
-> ((forall a. IO a -> IO a) -> IO ())
-> IO ThreadId
forkIOLabeledWithUnmask label m = mask_ $ forkIOWithUnmask $ \unmask -> do
tid <- myThreadId
labelThread tid label
m unmask
reaper :: (a -> IO ()) -> NominalDiffTime -> V.Vector (LocalPool a) -> IO ()
reaper destroy idleTime pools = forever $ do
threadDelay (1 * 1000000)
now <- getCurrentTime
let isStale Entry{..} = now `diffUTCTime` lastUse > idleTime
V.forM_ pools $ \LocalPool{..} -> do
resources <- atomically $ do
(stale,fresh) <- partition isStale <$> readTVar entries
unless (null stale) $ do
writeTVar entries fresh
modifyTVar_ inUse (subtract (length stale))
return (map entry stale)
forM_ resources $ \resource -> do
destroy resource `E.catch` \(_::SomeException) -> return ()
purgeLocalPool :: (a -> IO ()) -> LocalPool a -> IO ()
purgeLocalPool destroy LocalPool{..} = do
resources <- atomically $ do
idle <- swapTVar entries []
modifyTVar_ inUse (subtract (length idle))
return (map entry idle)
forM_ resources $ \resource ->
destroy resource `E.catch` \(_::SomeException) -> return ()
withResource ::
#if MIN_VERSION_monad_control(0,3,0)
(MonadBaseControl IO m)
#else
(MonadControlIO m)
#endif
=> Pool a -> (a -> m b) -> m b
{-# SPECIALIZE withResource :: Pool a -> (a -> IO b) -> IO b #-}
withResource pool act = control $ \runInIO -> mask $ \restore -> do
(resource, local) <- takeResource pool
ret <- restore (runInIO (act resource)) `onException`
destroyResource pool local resource
putResource local resource
return ret
#if __GLASGOW_HASKELL__ >= 700
{-# INLINABLE withResource #-}
#endif
takeResource :: Pool a -> IO (a, LocalPool a)
takeResource pool@Pool{..} = do
local@LocalPool{..} <- getLocalPool pool
resource <- liftBase . join . atomically $ do
ents <- readTVar entries
case ents of
(Entry{..}:es) -> writeTVar entries es >> return (return entry)
[] -> do
used <- readTVar inUse
when (used == maxResources) retry
writeTVar inUse $! used + 1
return $
create `onException` atomically (modifyTVar_ inUse (subtract 1))
return (resource, local)
#if __GLASGOW_HASKELL__ >= 700
{-# INLINABLE takeResource #-}
#endif
tryWithResource :: forall m a b.
#if MIN_VERSION_monad_control(0,3,0)
(MonadBaseControl IO m)
#else
(MonadControlIO m)
#endif
=> Pool a -> (a -> m b) -> m (Maybe b)
tryWithResource pool act = control $ \runInIO -> mask $ \restore -> do
res <- tryTakeResource pool
case res of
Just (resource, local) -> do
ret <- restore (runInIO (Just <$> act resource)) `onException`
destroyResource pool local resource
putResource local resource
return ret
Nothing -> restore . runInIO $ return (Nothing :: Maybe b)
#if __GLASGOW_HASKELL__ >= 700
{-# INLINABLE tryWithResource #-}
#endif
tryTakeResource :: Pool a -> IO (Maybe (a, LocalPool a))
tryTakeResource pool@Pool{..} = do
local@LocalPool{..} <- getLocalPool pool
resource <- liftBase . join . atomically $ do
ents <- readTVar entries
case ents of
(Entry{..}:es) -> writeTVar entries es >> return (return . Just $ entry)
[] -> do
used <- readTVar inUse
if used == maxResources
then return (return Nothing)
else do
writeTVar inUse $! used + 1
return $ Just <$>
create `onException` atomically (modifyTVar_ inUse (subtract 1))
return $ (flip (,) local) <$> resource
#if __GLASGOW_HASKELL__ >= 700
{-# INLINABLE tryTakeResource #-}
#endif
getLocalPool :: Pool a -> IO (LocalPool a)
getLocalPool Pool{..} = do
i <- liftBase $ ((`mod` numStripes) . hash) <$> myThreadId
return $ localPools V.! i
#if __GLASGOW_HASKELL__ >= 700
{-# INLINABLE getLocalPool #-}
#endif
destroyResource :: Pool a -> LocalPool a -> a -> IO ()
destroyResource Pool{..} LocalPool{..} resource = do
destroy resource `E.catch` \(_::SomeException) -> return ()
atomically (modifyTVar_ inUse (subtract 1))
#if __GLASGOW_HASKELL__ >= 700
{-# INLINABLE destroyResource #-}
#endif
putResource :: LocalPool a -> a -> IO ()
putResource LocalPool{..} resource = do
now <- getCurrentTime
atomically $ modifyTVar_ entries (Entry resource now:)
#if __GLASGOW_HASKELL__ >= 700
{-# INLINABLE putResource #-}
#endif
destroyAllResources :: Pool a -> IO ()
destroyAllResources Pool{..} = V.forM_ localPools $ purgeLocalPool destroy
modifyTVar_ :: TVar a -> (a -> a) -> STM ()
modifyTVar_ v f = readTVar v >>= \a -> writeTVar v $! f a
modError :: String -> String -> a
modError func msg =
error $ "Data.Pool." ++ func ++ ": " ++ msg