From 5a786ce8edf8516bb4751c47b1cde57c7f3b6551 Mon Sep 17 00:00:00 2001 From: Stephen Paul Weber Date: Sat, 13 Feb 2021 20:59:00 -0500 Subject: [PATCH] Switch monad transformer stack to a type alias Since we already allowed injecting any Session via runTLS or throwing any Error via throwE, this does not reduce safety at all but improves ergonomics considerably. The only downside here is that we must say goodbye to our transitional MonadIO instance. --- lib/Network/Protocol/TLS/GNU.hs | 59 +++++++++++---------------------- 1 file changed, 20 insertions(+), 39 deletions(-) diff --git a/lib/Network/Protocol/TLS/GNU.hs b/lib/Network/Protocol/TLS/GNU.hs index 3bf92b5..a9ab31a 100644 --- a/lib/Network/Protocol/TLS/GNU.hs +++ b/lib/Network/Protocol/TLS/GNU.hs @@ -20,7 +20,6 @@ module Network.Protocol.TLS.GNU , Session , Error (..) , throwE - , catchE , fromExceptT , runTLS @@ -41,13 +40,12 @@ module Network.Protocol.TLS.GNU , certificateCredentials ) where -import Control.Applicative (Applicative, pure, (<*>)) import qualified Control.Concurrent.MVar as M -import Control.Monad (ap, when, foldM, foldM_) +import Control.Monad (when, foldM, foldM_) import Control.Monad.Trans.Class (lift) import qualified Control.Monad.Trans.Except as E import qualified Control.Monad.Trans.Reader as R -import Control.Monad.IO.Class (MonadIO, liftIO) +import Control.Monad.IO.Class (liftIO) import qualified Data.ByteString as B import qualified Data.ByteString.Lazy as BL import qualified Data.ByteString.Unsafe as B @@ -62,7 +60,7 @@ import qualified UnexceptionalIO.Trans as UIO import qualified Network.Protocol.TLS.GNU.Foreign as F -data Error = Error Integer | IOError IOError +data Error = Error Integer deriving (Show) globalInitMVar :: M.MVar () @@ -91,34 +89,16 @@ data Session = Session , sessionCredentials :: IORef [F.ForeignPtr F.Credentials] } -newtype TLS a = TLS { unTLS :: E.ExceptT Error (R.ReaderT Session UIO) a } - -instance Functor TLS where - fmap f = TLS . fmap f . unTLS - -instance Applicative TLS where - pure = TLS . return - (<*>) = ap - -instance Monad TLS where - return = TLS . return - m >>= f = TLS $ unTLS m >>= unTLS . f - --- | This is a transitional instance and may be deprecated in the future -instance MonadIO TLS where - liftIO = TLS . E.withExceptT IOError . UIO.fromIO' (userError . show) +type TLS a = E.ExceptT Error (R.ReaderT Session UIO) a throwE :: Error -> TLS a throwE = fromExceptT . E.throwE -catchE :: TLS a -> (Error -> TLS a) -> TLS a -catchE inner handler = TLS $ unTLS inner `E.catchE` (unTLS . handler) - fromExceptT :: E.ExceptT Error UIO a -> TLS a -fromExceptT = TLS . E.mapExceptT lift +fromExceptT = E.mapExceptT lift runTLS :: (Unexceptional m) => Session -> TLS a -> m (Either Error a) -runTLS s tls = UIO.lift $ R.runReaderT (E.runExceptT (unTLS tls)) s +runTLS s tls = UIO.lift $ R.runReaderT (E.runExceptT tls) s runClient :: Transport -> TLS a -> IO (Either Error a) runClient transport tls = do @@ -149,18 +129,18 @@ newSession transport end = F.alloca $ \sPtr -> E.runExceptT $ do return (Session fp creds) getSession :: TLS Session -getSession = TLS $ lift R.ask +getSession = lift R.ask handshake :: TLS () -handshake = withSession F.gnutls_handshake >>= checkRC +handshake = unsafeWithSession F.gnutls_handshake >>= checkRC rehandshake :: TLS () -rehandshake = withSession F.gnutls_rehandshake >>= checkRC +rehandshake = unsafeWithSession F.gnutls_rehandshake >>= checkRC putBytes :: BL.ByteString -> TLS () putBytes = putChunks . BL.toChunks where putChunks chunks = do - maybeErr <- withSession $ \s -> foldM (putChunk s) Nothing chunks + maybeErr <- unsafeWithSession $ \s -> foldM (putChunk s) Nothing chunks case maybeErr of Nothing -> return () Just err -> throwE $ mapError $ fromIntegral err @@ -179,7 +159,7 @@ putBytes = putChunks . BL.toChunks where getBytes :: Integer -> TLS BL.ByteString getBytes count = do - (mbytes, len) <- withSession $ \s -> + (mbytes, len) <- unsafeWithSession $ \s -> F.allocaBytes (fromInteger count) $ \ptr -> do len <- F.gnutls_record_recv s ptr (fromInteger count) bytes <- if len >= 0 @@ -194,7 +174,7 @@ getBytes count = do Nothing -> throwE $ mapError $ fromIntegral len checkPending :: TLS Integer -checkPending = withSession $ \s -> do +checkPending = unsafeWithSession $ \s -> do pending <- F.gnutls_record_check_pending s return $ toInteger pending @@ -227,31 +207,32 @@ data Credentials = Credentials F.CredentialsType (F.ForeignPtr F.Credentials) setCredentials :: Credentials -> TLS () setCredentials (Credentials ctype fp) = do - rc <- withSession $ \s -> + rc <- unsafeWithSession $ \s -> F.withForeignPtr fp $ \ptr -> do F.gnutls_credentials_set s ctype ptr s <- getSession if F.unRC rc == 0 - then liftIO (atomicModifyIORef (sessionCredentials s) (\creds -> (fp:creds, ()))) + then UIO.unsafeFromIO (atomicModifyIORef (sessionCredentials s) (\creds -> (fp:creds, ()))) else checkRC rc certificateCredentials :: TLS Credentials certificateCredentials = do - (rc, ptr) <- liftIO $ F.alloca $ \ptr -> do + (rc, ptr) <- UIO.unsafeFromIO $ F.alloca $ \ptr -> do rc <- F.gnutls_certificate_allocate_credentials ptr ptr' <- if F.unRC rc < 0 then return F.nullPtr else F.peek ptr return (rc, ptr') checkRC rc - fp <- liftIO $ F.newForeignPtr F.gnutls_certificate_free_credentials_funptr ptr + fp <- UIO.unsafeFromIO $ F.newForeignPtr F.gnutls_certificate_free_credentials_funptr ptr return $ Credentials (F.CredentialsType 1) fp -withSession :: (F.Session -> IO a) -> TLS a -withSession io = do +-- | This must only be called with IO actions that do not throw NonPseudoException +unsafeWithSession :: (F.Session -> IO a) -> TLS a +unsafeWithSession io = do s <- getSession - liftIO $ F.withForeignPtr (sessionPtr s) $ io . F.Session + UIO.unsafeFromIO $ F.withForeignPtr (sessionPtr s) $ io . F.Session checkRC :: F.ReturnCode -> TLS () checkRC (F.ReturnCode x) = when (x < 0) $ throwE $ mapError x -- 2.38.5