From 99e9b2dcab8651cb999967cd87846a6d3528e24b Mon Sep 17 00:00:00 2001 From: Stephen Paul Weber Date: Mon, 15 Feb 2021 21:51:41 -0500 Subject: [PATCH] Switch TLS to TLST to allow any Unexceptional base monad --- lib/Network/Protocol/TLS/GNU.hs | 64 ++++++++++++++++----------------- 1 file changed, 30 insertions(+), 34 deletions(-) diff --git a/lib/Network/Protocol/TLS/GNU.hs b/lib/Network/Protocol/TLS/GNU.hs index 0f99536..711e263 100644 --- a/lib/Network/Protocol/TLS/GNU.hs +++ b/lib/Network/Protocol/TLS/GNU.hs @@ -17,10 +17,9 @@ module Network.Protocol.TLS.GNU ( TLS + , TLST , Session , Error (..) - , throwE - , fromExceptT , runTLS , runTLS' @@ -46,7 +45,6 @@ import Control.Monad (when, foldM, foldM_) import Control.Monad.Trans.Class (lift) import qualified Control.Monad.Trans.Except as E import qualified Control.Monad.Trans.Reader as R -import Control.Monad.IO.Class (liftIO) import qualified Data.ByteString as B import qualified Data.ByteString.Lazy as BL import qualified Data.ByteString.Unsafe as B @@ -56,7 +54,7 @@ import qualified Foreign.C as F import Foreign.Concurrent as FC import qualified System.IO as IO import System.IO.Unsafe (unsafePerformIO) -import UnexceptionalIO.Trans (UIO, Unexceptional) +import UnexceptionalIO.Trans (Unexceptional) import qualified UnexceptionalIO.Trans as UIO import qualified Network.Protocol.TLS.GNU.Foreign as F @@ -68,10 +66,10 @@ globalInitMVar :: M.MVar () {-# NOINLINE globalInitMVar #-} globalInitMVar = unsafePerformIO $ M.newMVar () -globalInit :: E.ExceptT Error IO () +globalInit :: (Unexceptional m) => E.ExceptT Error m () globalInit = do let init_ = M.withMVar globalInitMVar $ \_ -> F.gnutls_global_init - F.ReturnCode rc <- liftIO init_ + F.ReturnCode rc <- UIO.unsafeFromIO init_ when (rc < 0) $ E.throwE $ mapError rc globalDeinit :: IO () @@ -90,33 +88,31 @@ data Session = Session , sessionCredentials :: IORef [F.ForeignPtr F.Credentials] } -type TLS a = E.ExceptT Error (R.ReaderT Session UIO) a +type TLS a = TLST IO a +type TLST m a = E.ExceptT Error (R.ReaderT Session m) a -throwE :: Error -> TLS a -throwE = fromExceptT . E.throwE - -fromExceptT :: E.ExceptT Error UIO a -> TLS a -fromExceptT = E.mapExceptT lift - -runTLS :: (Unexceptional m) => Session -> TLS a -> m (Either Error a) +runTLS :: (Unexceptional m) => Session -> TLST m a -> m (Either Error a) runTLS s = E.runExceptT . runTLS' s -runTLS' :: (Unexceptional m) => Session -> TLS a -> E.ExceptT Error m a -runTLS' s = E.mapExceptT (UIO.lift . flip R.runReaderT s) +runTLS' :: Session -> TLST m a -> E.ExceptT Error m a +runTLS' s = E.mapExceptT (flip R.runReaderT s) -runClient :: Transport -> TLS a -> IO (Either Error a) +runClient :: (Unexceptional m) => Transport -> TLST m a -> m (Either Error a) runClient transport tls = do eitherSession <- newSession transport (F.ConnectionEnd 2) case eitherSession of Left err -> return (Left err) Right session -> runTLS session tls -newSession :: Transport -> F.ConnectionEnd -> IO (Either Error Session) -newSession transport end = F.alloca $ \sPtr -> E.runExceptT $ do +newSession :: (Unexceptional m) => + Transport + -> F.ConnectionEnd + -> m (Either Error Session) +newSession transport end = UIO.unsafeFromIO . F.alloca $ \sPtr -> E.runExceptT $ do globalInit - F.ReturnCode rc <- liftIO $ F.gnutls_init sPtr end + F.ReturnCode rc <- UIO.unsafeFromIO $ F.gnutls_init sPtr end when (rc < 0) $ E.throwE $ mapError rc - liftIO $ do + UIO.unsafeFromIO $ do ptr <- F.peek sPtr let session = F.Session ptr push <- F.wrapTransportFunc (pushImpl transport) @@ -132,22 +128,22 @@ newSession transport end = F.alloca $ \sPtr -> E.runExceptT $ do F.freeHaskellFunPtr pull return (Session fp creds) -getSession :: TLS Session +getSession :: (Monad m) => TLST m Session getSession = lift R.ask -handshake :: TLS () +handshake :: (Unexceptional m) => TLST m () handshake = unsafeWithSession F.gnutls_handshake >>= checkRC -rehandshake :: TLS () +rehandshake :: (Unexceptional m) => TLST m () rehandshake = unsafeWithSession F.gnutls_rehandshake >>= checkRC -putBytes :: BL.ByteString -> TLS () +putBytes :: (Unexceptional m) => BL.ByteString -> TLST m () putBytes = putChunks . BL.toChunks where putChunks chunks = do maybeErr <- unsafeWithSession $ \s -> foldM (putChunk s) Nothing chunks case maybeErr of Nothing -> return () - Just err -> throwE $ mapError $ fromIntegral err + Just err -> E.throwE $ mapError $ fromIntegral err putChunk s Nothing chunk = B.unsafeUseAsCStringLen chunk $ uncurry loop where loop ptr len = do @@ -161,7 +157,7 @@ putBytes = putChunks . BL.toChunks where putChunk _ err _ = return err -getBytes :: Integer -> TLS BL.ByteString +getBytes :: (Unexceptional m) => Integer -> TLST m BL.ByteString getBytes count = do (mbytes, len) <- unsafeWithSession $ \s -> F.allocaBytes (fromInteger count) $ \ptr -> do @@ -175,9 +171,9 @@ getBytes count = do case mbytes of Just bytes -> return bytes - Nothing -> throwE $ mapError $ fromIntegral len + Nothing -> E.throwE $ mapError $ fromIntegral len -checkPending :: TLS Integer +checkPending :: (Unexceptional m) => TLST m Integer checkPending = unsafeWithSession $ \s -> do pending <- F.gnutls_record_check_pending s return $ toInteger pending @@ -209,7 +205,7 @@ handleTransport h = Transport (BL.hPut h) (BL.hGet h . fromInteger) data Credentials = Credentials F.CredentialsType (F.ForeignPtr F.Credentials) -setCredentials :: Credentials -> TLS () +setCredentials :: (Unexceptional m) => Credentials -> TLST m () setCredentials (Credentials ctype fp) = do rc <- unsafeWithSession $ \s -> F.withForeignPtr fp $ \ptr -> do @@ -220,7 +216,7 @@ setCredentials (Credentials ctype fp) = do then UIO.unsafeFromIO (atomicModifyIORef (sessionCredentials s) (\creds -> (fp:creds, ()))) else checkRC rc -certificateCredentials :: TLS Credentials +certificateCredentials :: (Unexceptional m) => TLST m Credentials certificateCredentials = do (rc, ptr) <- UIO.unsafeFromIO $ F.alloca $ \ptr -> do rc <- F.gnutls_certificate_allocate_credentials ptr @@ -233,13 +229,13 @@ certificateCredentials = do return $ Credentials (F.CredentialsType 1) fp -- | This must only be called with IO actions that do not throw NonPseudoException -unsafeWithSession :: (F.Session -> IO a) -> TLS a +unsafeWithSession :: (Unexceptional m) => (F.Session -> IO a) -> TLST m a unsafeWithSession io = do s <- getSession UIO.unsafeFromIO $ F.withForeignPtr (sessionPtr s) $ io . F.Session -checkRC :: F.ReturnCode -> TLS () -checkRC (F.ReturnCode x) = when (x < 0) $ throwE $ mapError x +checkRC :: (Monad m) => F.ReturnCode -> TLST m () +checkRC (F.ReturnCode x) = when (x < 0) $ E.throwE $ mapError x mapError :: F.CInt -> Error mapError = Error . toInteger -- 2.38.5