{- Copyright (C) 2009 John Millikin <jmillikin@gmail.com>
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
-}
module Network.Protocol.XMPP.Stream (
Stream (
streamLanguage
,streamVersion
,streamFeatures
)
,StreamFeature (
FeatureStartTLS
,FeatureSASL
,FeatureRegister
,FeatureBind
,FeatureSession
)
,beginStream
,restartStream
,getTree
,putTree
) where
import qualified System.IO as IO
import Data.AssocList (lookupDef)
import Data.Char (toUpper)
-- XML Parsing
import Text.XML.HXT.Arrow ((>>>))
import qualified Text.XML.HXT.Arrow as A
import qualified Text.XML.HXT.DOM.Interface as DOM
import qualified Text.XML.HXT.DOM.XmlNode as XN
import qualified Text.XML.LibXML.SAX as SAX
-- TLS support
import qualified Network.GnuTLS as GnuTLS
import Foreign (allocaBytes)
import Foreign.C (peekCAStringLen)
import Network.Protocol.XMPP.JID (JID, jidFormat)
import qualified Network.Protocol.XMPP.Util as Util
maxXMPPVersion :: XMPPVersion
maxXMPPVersion = XMPPVersion 1 0
data Stream = Stream
{
streamHandle :: Handle
,streamJID :: JID
,streamParser :: SAX.Parser
,streamLanguage :: XMLLanguage
,streamVersion :: XMPPVersion
,streamFeatures :: [StreamFeature]
}
data StreamFeature =
FeatureStartTLS Bool
| FeatureSASL [String]
| FeatureRegister
| FeatureBind
| FeatureSession
| FeatureUnknown DOM.XmlTree
deriving (Show, Eq)
newtype XMLLanguage = XMLLanguage String
deriving (Show, Eq)
data XMPPVersion = XMPPVersion Int Int
deriving (Show, Eq)
data Handle =
PlainHandle IO.Handle
| SecureHandle IO.Handle (GnuTLS.Session GnuTLS.Client)
------------------------------------------------------------------------------
restartStream :: Stream -> IO Stream
restartStream s = beginStream' (streamJID s) (streamHandle s)
beginStream :: JID -> IO.Handle -> IO Stream
beginStream jid rawHandle = do
IO.hSetBuffering rawHandle IO.NoBuffering
plainStream <- beginStream' jid (PlainHandle rawHandle)
putTree plainStream $ Util.mkElement ("", "starttls")
[("", "xmlns", "urn:ietf:params:xml:ns:xmpp-tls")]
[]
getTree plainStream
session <- GnuTLS.tlsClient [
GnuTLS.handle GnuTLS.:= rawHandle
,GnuTLS.priorities GnuTLS.:= [GnuTLS.CrtX509]
,GnuTLS.credentials GnuTLS.:= GnuTLS.certificateCredentials
]
GnuTLS.handshake session
beginStream' jid (SecureHandle rawHandle session)
beginStream' :: JID -> Handle -> IO Stream
beginStream' jid h = do
-- Since only the opening tag should be written, normal XML
-- serialization cannot be used. Be careful to escape any embedded
-- attributes.
let xmlHeader =
"<?xml version='1.0'?>\n" ++
"<stream:stream xmlns='jabber:client'" ++
" to='" ++ (DOM.attrEscapeXml . jidFormat) jid ++ "'" ++
" version='1.0'" ++
" xmlns:stream='http://etherx.jabber.org/streams'>"
parser <- SAX.mkParser
hPutStr h xmlHeader
initialEvents <- readEventsUntil startOfStream h parser
featureTree <- getTree' h parser
let startStreamEvent = last initialEvents
let (language, version) = parseStartStream startStreamEvent
let features = parseFeatures featureTree
return $ Stream h jid parser language version features
where
streamName = Util.mkQName "http://etherx.jabber.org/streams" "stream"
startOfStream depth event = case (depth, event) of
(1, (SAX.BeginElement elemName _)) ->
streamName == Util.convertQName elemName
_ -> False
parseStartStream :: SAX.Event -> (XMLLanguage, XMPPVersion)
parseStartStream e = (XMLLanguage "en", XMPPVersion 1 0) -- TODO
parseFeatures :: DOM.XmlTree -> [StreamFeature]
parseFeatures t =
A.runLA (A.getChildren
>>> A.hasQName featuresName
>>> A.getChildren
>>> A.arrL (\t' -> [parseFeature t'])) t
where
featuresName = Util.mkQName "http://etherx.jabber.org/streams" "features"
parseFeature :: DOM.XmlTree -> StreamFeature
parseFeature t = lookupDef FeatureUnknown qname [
(("urn:ietf:params:xml:ns:xmpp-tls", "starttls"), parseFeatureTLS)
,(("urn:ietf:params:xml:ns:xmpp-sasl", "mechanisms"), parseFeatureSASL)
,(("http://jabber.org/features/iq-register", "register"), (\_ -> FeatureRegister))
,(("urn:ietf:params:xml:ns:xmpp-bind", "bind"), (\_ -> FeatureBind))
,(("urn:ietf:params:xml:ns:xmpp-session", "session"), (\_ -> FeatureSession))
] t
where
qname = maybe ("", "") (\n -> (DOM.namespaceUri n, DOM.localPart n)) (XN.getName t)
parseFeatureTLS :: DOM.XmlTree -> StreamFeature
parseFeatureTLS t = FeatureStartTLS True -- TODO: detect whether or not required
parseFeatureSASL :: DOM.XmlTree -> StreamFeature
parseFeatureSASL t = let
mechName = Util.mkQName "urn:ietf:params:xml:ns:xmpp-sasl" "mechanism"
mechanisms = A.runLA (
A.getChildren
>>> A.hasQName mechName
>>> A.getChildren
>>> A.getText) t
in FeatureSASL $ map (map toUpper) mechanisms
-------------------------------------------------------------------------------
getTree :: Stream -> IO DOM.XmlTree
getTree s = getTree' (streamHandle s) (streamParser s)
getTree' :: Handle -> SAX.Parser -> IO DOM.XmlTree
getTree' h p = do
events <- readEventsUntil finished h p
return $ Util.eventsToTree events
where
finished 0 (SAX.EndElement _) = True
finished _ _ = False
putTree :: Stream -> DOM.XmlTree -> IO ()
putTree s t = do
let root = XN.mkRoot [] [t]
let h = streamHandle s
[text] <- A.runX (A.constA root >>> A.writeDocumentToString [
(A.a_no_xml_pi, "1")
])
hPutStr h text
-------------------------------------------------------------------------------
readEventsUntil :: (Int -> SAX.Event -> Bool) -> Handle -> SAX.Parser -> IO [SAX.Event]
readEventsUntil done h parser = readEventsUntil' done 0 [] $ do
char <- hGetChar h
SAX.parse parser [char] False
readEventsUntil' :: (Int -> SAX.Event -> Bool) -> Int -> [SAX.Event] -> IO [SAX.Event] -> IO [SAX.Event]
readEventsUntil' done depth accum getEvents = do
events <- getEvents
let (done', depth', accum') = readEventsStep done events depth accum
if done'
then return accum'
else readEventsUntil' done depth' accum' getEvents
readEventsStep :: (Int -> SAX.Event -> Bool) -> [SAX.Event] -> Int -> [SAX.Event] -> (Bool, Int, [SAX.Event])
readEventsStep _ [] depth accum = (False, depth, accum)
readEventsStep done (e:es) depth accum = let
depth' = depth + case e of
(SAX.BeginElement _ _) -> 1
(SAX.EndElement _) -> (- 1)
_ -> 0
accum' = accum ++ [e]
in if done depth' e then (True, depth', accum')
else readEventsStep done es depth' accum'
-------------------------------------------------------------------------------
hPutStr :: Handle -> String -> IO ()
hPutStr (PlainHandle h) = IO.hPutStr h
hPutStr (SecureHandle _ session) = GnuTLS.tlsSendString session
hGetChar :: Handle -> IO Char
hGetChar (PlainHandle h) = IO.hGetChar h
hGetChar (SecureHandle h session) = allocaBytes 1 $ \ptr -> do
pending <- GnuTLS.tlsCheckPending session
if pending == 0
then do
IO.hWaitForInput h (-1)
return ()
else return ()
len <- GnuTLS.tlsRecv session ptr 1
[char] <- peekCAStringLen (ptr, len)
return char