From 055a7bfb2e80fbf3e1c12d5ea22b65195cf991e3 Mon Sep 17 00:00:00 2001 From: John Millikin Date: Thu, 18 Jun 2009 05:12:45 +0000 Subject: [PATCH] Added TLS support --- Network/Protocol/XMPP/Client.hs | 53 ++++++++-------- Network/Protocol/XMPP/Stream.hs | 106 ++++++++++++++++++++++---------- 2 files changed, 99 insertions(+), 60 deletions(-) diff --git a/Network/Protocol/XMPP/Client.hs b/Network/Protocol/XMPP/Client.hs index c7ee86c..b408460 100644 --- a/Network/Protocol/XMPP/Client.hs +++ b/Network/Protocol/XMPP/Client.hs @@ -39,13 +39,12 @@ import qualified Network.Protocol.XMPP.Stream as S import Network.Protocol.XMPP.Stanzas (Stanza) import Network.Protocol.XMPP.Util (mkElement, mkQName) -data ConnectedClient = ConnectedClient JID S.Stream Handle +data ConnectedClient = ConnectedClient JID S.Stream data Client = Client { - clientJID :: JID - ,clientServerJID :: JID - ,clientStream :: S.Stream - ,clientHandle :: Handle + clientJID :: JID + ,clientServerJID :: JID + ,clientStream :: S.Stream } type Username = String @@ -57,34 +56,30 @@ clientConnect jid host port = do hSetBuffering handle NoBuffering stream <- S.beginStream jid handle - - -- TODO: TLS support - - return $ ConnectedClient jid stream handle + return $ ConnectedClient jid stream clientAuthenticate :: ConnectedClient -> JID -> Username -> Password -> IO Client -clientAuthenticate (ConnectedClient serverJID stream h) jid username password = let - mechanisms = (advertisedMechanisms . S.streamFeatures) stream - saslMechanism = case bestMechanism mechanisms of +clientAuthenticate (ConnectedClient serverJID stream) jid username password = do + let mechanisms = (advertisedMechanisms . S.streamFeatures) stream + let saslMechanism = case bestMechanism mechanisms of Nothing -> error "No supported SASL mechanism" Just m -> m - in do - -- TODO: use detected mechanism - - let saslText = concat [(show jid), "\x00", username, "\x00", password] - let b64Text = encode saslText - - S.putTree stream $ mkElement ("", "auth") - [ ("", "xmlns", "urn:ietf:params:xml:ns:xmpp-sasl") - ,("", "mechanism", "PLAIN")] - [XN.mkText b64Text] - - response <- S.getTree stream - - -- TODO: check if response is success or failure - - newStream <- S.beginStream serverJID h - return $ Client serverJID jid newStream h + + -- TODO: use detected mechanism + let saslText = concat [(show jid), "\x00", username, "\x00", password] + let b64Text = encode saslText + + S.putTree stream $ mkElement ("", "auth") + [ ("", "xmlns", "urn:ietf:params:xml:ns:xmpp-sasl") + ,("", "mechanism", "PLAIN")] + [XN.mkText b64Text] + + response <- S.getTree stream + + -- TODO: check if response is success or failure + + newStream <- S.restartStream stream + return $ Client serverJID jid newStream clientBind :: Client -> IO JID clientBind c = do diff --git a/Network/Protocol/XMPP/Stream.hs b/Network/Protocol/XMPP/Stream.hs index c067e1c..1f9c39e 100644 --- a/Network/Protocol/XMPP/Stream.hs +++ b/Network/Protocol/XMPP/Stream.hs @@ -28,6 +28,7 @@ module Network.Protocol.XMPP.Stream ( ,FeatureSession ) ,beginStream + ,restartStream ,getTree ,putTree ) where @@ -44,6 +45,11 @@ import Text.XML.HXT.Arrow ((>>>), (>>.)) import Data.Tree.NTree.TypeDefs (NTree(NTree)) import qualified Text.XML.HXT.Arrow as A +-- TLS support +import qualified Network.GnuTLS as GnuTLS +import Foreign (allocaBytes) +import Foreign.C (peekCAStringLen) + import Network.Protocol.XMPP.JID (JID) import Network.Protocol.XMPP.SASL (Mechanism, findMechanism) import Network.Protocol.XMPP.Util (eventsToTree, mkQName, mkElement) @@ -52,10 +58,11 @@ maxXMPPVersion = XMPPVersion 1 0 data Stream = Stream { - streamHandle :: IO.Handle - ,streamParser :: XML.Parser + streamHandle :: Handle + ,streamJID :: JID + ,streamParser :: XML.Parser ,streamLanguage :: XMLLanguage - ,streamVersion :: XMPPVersion + ,streamVersion :: XMPPVersion ,streamFeatures :: [StreamFeature] } @@ -75,47 +82,72 @@ newtype XMLLanguage = XMLLanguage String data XMPPVersion = XMPPVersion Int Int deriving (Show, Eq) -------------------------------------------------------------------------------- +data Handle = + PlainHandle IO.Handle + | SecureHandle (GnuTLS.Session GnuTLS.Client) + +------------------------------------------------------------------------------ + +restartStream :: Stream -> IO Stream +restartStream s = beginStream' (streamJID s) (streamHandle s) beginStream :: JID -> IO.Handle -> IO Stream -beginStream jid handle = do - parser <- XML.newParser +beginStream jid rawHandle = do + plainStream <- beginStream' jid (PlainHandle rawHandle) + + putTree plainStream $ mkElement ("", "starttls") + [("", "xmlns", "urn:ietf:params:xml:ns:xmpp-tls")] + [] + getTree plainStream + session <- GnuTLS.tlsClient [ + GnuTLS.handle GnuTLS.:= rawHandle + ,GnuTLS.priorities GnuTLS.:= [GnuTLS.CrtX509] + ,GnuTLS.credentials GnuTLS.:= GnuTLS.certificateCredentials + ] + GnuTLS.handshake session + beginStream' jid (SecureHandle session) + +beginStream' :: JID -> Handle -> IO Stream +beginStream' jid h = do -- Since only the opening tag should be written, normal XML -- serialization cannot be used. Be careful to escape any embedded -- attributes. - IO.hPutStr handle $ + let xmlHeader = "\n" ++ "" - IO.hFlush handle - [startStreamEvent] <- readEventsUntil startOfStream handle parser 1000 - featureTree <- getTree' handle parser - return $ beginStream' handle parser startStreamEvent featureTree + parser <- XML.newParser + hPutStr h xmlHeader + [startStreamEvent] <- readEventsUntil startOfStream h parser 1000 + featureTree <- getTree' h parser + + let (language, version) = parseStartStream startStreamEvent + let features = parseFeatures featureTree + + return $ Stream h jid parser language version features + where streamName = mkQName "http://etherx.jabber.org/streams" "stream" + startOfStream depth event = case (depth, event) of (1, (XML.BeginElement streamName _)) -> True otherwise -> False -beginStream' handle parser streamStart featureTree = let - -- TODO: parse from streamStart - language = XMLLanguage "en" - version = XMPPVersion 1 0 - - featuresName = mkQName "http://etherx.jabber.org/streams" "features" - - featureRoots = A.runLA ( - A.getChildren - >>> A.hasQName featuresName) featureTree - features = case featureRoots of - [] -> [] - (t:_) -> map parseFeature (A.runLA A.getChildren t) - - in Stream handle parser language version features +parseStartStream :: XML.Event -> (XMLLanguage, XMPPVersion) +parseStartStream e = (XMLLanguage "en", XMPPVersion 1 0) -- TODO + +parseFeatures :: XmlTree -> [StreamFeature] +parseFeatures t = + A.runLA (A.getChildren + >>> A.hasQName featuresName + >>> A.getChildren + >>> A.arrL (\t' -> [parseFeature t'])) t + where + featuresName = mkQName "http://etherx.jabber.org/streams" "features" parseFeature :: XmlTree -> StreamFeature parseFeature t = lookupDef FeatureUnknown qname [ @@ -149,7 +181,7 @@ parseFeatureSASL t = let getTree :: Stream -> IO XmlTree getTree s = getTree' (streamHandle s) (streamParser s) -getTree' :: IO.Handle -> XML.Parser -> IO XmlTree +getTree' :: Handle -> XML.Parser -> IO XmlTree getTree' h p = do events <- readEventsUntil finished h p 1000 return $ eventsToTree events @@ -164,14 +196,13 @@ putTree s t = do [text] <- A.runX (A.constA root >>> A.writeDocumentToString [ (A.a_no_xml_pi, "1") ]) - IO.hPutStr h text - IO.hFlush h + hPutStr h text ------------------------------------------------------------------------------- -readEventsUntil :: (Int -> XML.Event -> Bool) -> IO.Handle -> XML.Parser -> Int -> IO [XML.Event] +readEventsUntil :: (Int -> XML.Event -> Bool) -> Handle -> XML.Parser -> Int -> IO [XML.Event] readEventsUntil done h parser timeout = readEventsUntil' done 0 [] $ do - char <- IO.hGetChar h + char <- hGetChar h XML.incrementalParse parser [char] readEventsUntil' done depth accum getEvents = do @@ -190,3 +221,16 @@ readEventsStep done (e:es) depth accum = let accum' = accum ++ [e] in if done depth' e then (True, depth', accum') else readEventsStep done es depth' accum' + +------------------------------------------------------------------------------- + +hPutStr :: Handle -> String -> IO () +hPutStr (PlainHandle h) = IO.hPutStr h +hPutStr (SecureHandle h) = GnuTLS.tlsSendString h + +hGetChar :: Handle -> IO Char +hGetChar (PlainHandle h) = IO.hGetChar h +hGetChar (SecureHandle h) = allocaBytes 1 $ \ptr -> do + len <- GnuTLS.tlsRecv h ptr 1 + [char] <- peekCAStringLen (ptr, len) + return char -- 2.38.4