~singpolyma/network-protocol-xmpp

055a7bfb2e80fbf3e1c12d5ea22b65195cf991e3 — John Millikin 15 years ago 5055a1d
Added TLS support
2 files changed, 99 insertions(+), 60 deletions(-)

M Network/Protocol/XMPP/Client.hs
M Network/Protocol/XMPP/Stream.hs
M Network/Protocol/XMPP/Client.hs => Network/Protocol/XMPP/Client.hs +24 -29
@@ 39,13 39,12 @@ import qualified Network.Protocol.XMPP.Stream as S
import Network.Protocol.XMPP.Stanzas (Stanza)
import Network.Protocol.XMPP.Util (mkElement, mkQName)

data ConnectedClient = ConnectedClient JID S.Stream Handle
data ConnectedClient = ConnectedClient JID S.Stream

data Client = Client {
	 clientJID       :: JID
	,clientServerJID :: JID
	,clientStream    :: S.Stream
	,clientHandle    :: Handle
	 clientJID        :: JID
	,clientServerJID  :: JID
	,clientStream     :: S.Stream
	}

type Username = String


@@ 57,34 56,30 @@ clientConnect jid host port = do
	hSetBuffering handle NoBuffering
	
	stream <- S.beginStream jid handle
	
	-- TODO: TLS support
	
	return $ ConnectedClient jid stream handle
	return $ ConnectedClient jid stream

clientAuthenticate :: ConnectedClient -> JID -> Username -> Password -> IO Client
clientAuthenticate (ConnectedClient serverJID stream h) jid username password = let
	mechanisms = (advertisedMechanisms . S.streamFeatures) stream
	saslMechanism = case bestMechanism mechanisms of
clientAuthenticate (ConnectedClient serverJID stream) jid username password = do
	let mechanisms = (advertisedMechanisms . S.streamFeatures) stream
	let saslMechanism = case bestMechanism mechanisms of
		Nothing -> error "No supported SASL mechanism"
		Just m -> m
	in do
		-- TODO: use detected mechanism
		
		let saslText = concat [(show jid), "\x00", username, "\x00", password]
		let b64Text = encode saslText
		
		S.putTree stream $ mkElement ("", "auth")
			[ ("", "xmlns", "urn:ietf:params:xml:ns:xmpp-sasl")
			 ,("", "mechanism", "PLAIN")]
			[XN.mkText b64Text]
		
		response <- S.getTree stream
		
		-- TODO: check if response is success or failure
		
		newStream <- S.beginStream serverJID h
		return $ Client serverJID jid newStream h
	
	-- TODO: use detected mechanism
	let saslText = concat [(show jid), "\x00", username, "\x00", password]
	let b64Text = encode saslText
	
	S.putTree stream $ mkElement ("", "auth")
		[ ("", "xmlns", "urn:ietf:params:xml:ns:xmpp-sasl")
		 ,("", "mechanism", "PLAIN")]
		[XN.mkText b64Text]
	
	response <- S.getTree stream
	
	-- TODO: check if response is success or failure
	
	newStream <- S.restartStream stream
	return $ Client serverJID jid newStream

clientBind :: Client -> IO JID
clientBind c = do

M Network/Protocol/XMPP/Stream.hs => Network/Protocol/XMPP/Stream.hs +75 -31
@@ 28,6 28,7 @@ module Network.Protocol.XMPP.Stream (
		,FeatureSession
		)
	,beginStream
	,restartStream
	,getTree
	,putTree
	) where


@@ 44,6 45,11 @@ import Text.XML.HXT.Arrow ((>>>), (>>.))
import Data.Tree.NTree.TypeDefs (NTree(NTree))
import qualified Text.XML.HXT.Arrow as A

-- TLS support
import qualified Network.GnuTLS as GnuTLS
import Foreign (allocaBytes)
import Foreign.C (peekCAStringLen)

import Network.Protocol.XMPP.JID (JID)
import Network.Protocol.XMPP.SASL (Mechanism, findMechanism)
import Network.Protocol.XMPP.Util (eventsToTree, mkQName, mkElement)


@@ 52,10 58,11 @@ maxXMPPVersion = XMPPVersion 1 0

data Stream = Stream
	{
		 streamHandle :: IO.Handle
		,streamParser :: XML.Parser
		 streamHandle   :: Handle
		,streamJID      :: JID
		,streamParser   :: XML.Parser
		,streamLanguage :: XMLLanguage
		,streamVersion :: XMPPVersion
		,streamVersion  :: XMPPVersion
		,streamFeatures :: [StreamFeature]
	}



@@ 75,47 82,72 @@ newtype XMLLanguage = XMLLanguage String
data XMPPVersion = XMPPVersion Int Int
	deriving (Show, Eq)

-------------------------------------------------------------------------------
data Handle =
	  PlainHandle IO.Handle
	| SecureHandle (GnuTLS.Session GnuTLS.Client)

------------------------------------------------------------------------------

restartStream :: Stream -> IO Stream
restartStream s = beginStream' (streamJID s) (streamHandle s)

beginStream :: JID -> IO.Handle -> IO Stream
beginStream jid handle = do
	parser <- XML.newParser
beginStream jid rawHandle = do
	plainStream <- beginStream' jid (PlainHandle rawHandle)
	
	putTree plainStream $ mkElement ("", "starttls")
		[("", "xmlns", "urn:ietf:params:xml:ns:xmpp-tls")]
		[]
	getTree plainStream
	
	session <- GnuTLS.tlsClient [
		 GnuTLS.handle GnuTLS.:= rawHandle
		,GnuTLS.priorities GnuTLS.:= [GnuTLS.CrtX509]
		,GnuTLS.credentials GnuTLS.:= GnuTLS.certificateCredentials
		]
	GnuTLS.handshake session
	beginStream' jid (SecureHandle session)

beginStream' :: JID -> Handle -> IO Stream
beginStream' jid h = do
	-- Since only the opening tag should be written, normal XML
	-- serialization cannot be used. Be careful to escape any embedded
	-- attributes.
	IO.hPutStr handle $
	let xmlHeader =
		"<?xml version='1.0'?>\n" ++
		"<stream:stream xmlns='jabber:client'" ++
		" to='" ++ (attrEscapeXml . show) jid ++ "'" ++
		" version='1.0'" ++
		" xmlns:stream='http://etherx.jabber.org/streams'>"
	IO.hFlush handle
	
	[startStreamEvent] <- readEventsUntil startOfStream handle parser 1000
	featureTree <- getTree' handle parser
	return $ beginStream' handle parser startStreamEvent featureTree
	parser <- XML.newParser
	hPutStr h xmlHeader
	[startStreamEvent] <- readEventsUntil startOfStream h parser 1000
	featureTree <- getTree' h parser
	
	let (language, version) = parseStartStream startStreamEvent
	let features = parseFeatures featureTree
	
	return $ Stream h jid parser language version features
	
	where
		streamName = mkQName "http://etherx.jabber.org/streams" "stream"
		
		startOfStream depth event = case (depth, event) of
			(1, (XML.BeginElement streamName _)) -> True
			otherwise -> False

beginStream' handle parser streamStart featureTree = let
	-- TODO: parse from streamStart
	language = XMLLanguage "en"
	version = XMPPVersion 1 0
	
	featuresName = mkQName "http://etherx.jabber.org/streams" "features"
	
	featureRoots = A.runLA (
		A.getChildren
		>>> A.hasQName featuresName) featureTree
	features = case featureRoots of
		[] -> []
		(t:_) -> map parseFeature (A.runLA A.getChildren t)
	
	in Stream handle parser language version features
parseStartStream :: XML.Event -> (XMLLanguage, XMPPVersion)
parseStartStream e = (XMLLanguage "en", XMPPVersion 1 0) -- TODO

parseFeatures :: XmlTree -> [StreamFeature]
parseFeatures t =
	A.runLA (A.getChildren
		>>> A.hasQName featuresName
		>>> A.getChildren
		>>> A.arrL (\t' -> [parseFeature t'])) t
	where
		featuresName = mkQName "http://etherx.jabber.org/streams" "features"

parseFeature :: XmlTree -> StreamFeature
parseFeature t = lookupDef FeatureUnknown qname [


@@ 149,7 181,7 @@ parseFeatureSASL t = let
getTree :: Stream -> IO XmlTree
getTree s = getTree' (streamHandle s) (streamParser s)

getTree' :: IO.Handle -> XML.Parser -> IO XmlTree
getTree' :: Handle -> XML.Parser -> IO XmlTree
getTree' h p = do
	events <- readEventsUntil finished h p 1000
	return $ eventsToTree events


@@ 164,14 196,13 @@ putTree s t = do
	[text] <- A.runX (A.constA root >>> A.writeDocumentToString [
		(A.a_no_xml_pi, "1")
		])
	IO.hPutStr h text
	IO.hFlush h
	hPutStr h text

-------------------------------------------------------------------------------

readEventsUntil :: (Int -> XML.Event -> Bool) -> IO.Handle -> XML.Parser -> Int -> IO [XML.Event]
readEventsUntil :: (Int -> XML.Event -> Bool) -> Handle -> XML.Parser -> Int -> IO [XML.Event]
readEventsUntil done h parser timeout = readEventsUntil' done 0 [] $ do
	char <- IO.hGetChar h
	char <- hGetChar h
	XML.incrementalParse parser [char]

readEventsUntil' done depth accum getEvents = do


@@ 190,3 221,16 @@ readEventsStep done (e:es) depth accum = let
	accum' = accum ++ [e]
	in if done depth' e then (True, depth', accum')
	else readEventsStep done es depth' accum'

-------------------------------------------------------------------------------

hPutStr :: Handle -> String -> IO ()
hPutStr (PlainHandle h) = IO.hPutStr h
hPutStr (SecureHandle h) = GnuTLS.tlsSendString h

hGetChar :: Handle -> IO Char
hGetChar (PlainHandle h) = IO.hGetChar h
hGetChar (SecureHandle h) = allocaBytes 1 $ \ptr -> do
	len <- GnuTLS.tlsRecv h ptr 1
	[char] <- peekCAStringLen (ptr, len)
	return char