Source code for twomemo.etree

import base64
from typing import Optional, Set, Tuple, cast
import xml.etree.ElementTree as ET

from omemo import DeviceList, EncryptedKeyMaterial, KeyExchange, Message, SignedLabel
import x3dh
try:
    import xmlschema
except ImportError as e:
    raise ImportError(
        "Optional dependency xmlschema not found. Please install xmlschema, or install this package using"
        " `pip install twomemo[xml]`, to use the ElementTree-based XML serialization/parser helpers."
    ) from e

from .twomemo import NAMESPACE, BundleImpl, ContentImpl, EncryptedKeyMaterialImpl, KeyExchangeImpl


__all__ = [
    "serialize_device_list",
    "parse_device_list",
    "serialize_bundle",
    "parse_bundle",
    "serialize_message",
    "parse_message"
]


NS = f"{{{NAMESPACE}}}"


DEVICE_LIST_SCHEMA = xmlschema.XMLSchema("""<?xml version='1.0' encoding='UTF-8'?>
<xs:schema xmlns:xs='http://www.w3.org/2001/XMLSchema'
           targetNamespace='urn:xmpp:omemo:2'
           xmlns='urn:xmpp:omemo:2'>

    <xs:element name='devices'>
        <xs:complexType>
            <xs:sequence minOccurs='0' maxOccurs='unbounded'>
                <xs:element ref='device'/>
            </xs:sequence>
        </xs:complexType>
    </xs:element>

    <xs:element name='device'>
        <xs:complexType>
            <xs:attribute name='id' type='xs:unsignedInt' use='required'/>
            <xs:attribute name='label' type='xs:string'/>
            <xs:attribute name='labelsig' type='xs:base64Binary'/>
        </xs:complexType>
    </xs:element>
</xs:schema>
""")


BUNDLE_SCHEMA = xmlschema.XMLSchema("""<?xml version='1.0' encoding='UTF-8'?>
<xs:schema xmlns:xs='http://www.w3.org/2001/XMLSchema'
           targetNamespace='urn:xmpp:omemo:2'
           xmlns='urn:xmpp:omemo:2'>

    <xs:element name='bundle'>
        <xs:complexType>
            <xs:all>
                <xs:element ref='spk'/>
                <xs:element ref='spks'/>
                <xs:element ref='ik'/>
                <xs:element ref='prekeys'/>
            </xs:all>
        </xs:complexType>
    </xs:element>

    <xs:element name='spk'>
        <xs:complexType>
            <xs:simpleContent>
                <xs:extension base='xs:base64Binary'>
                    <xs:attribute name='id' type='xs:unsignedInt' use='required'/>
                </xs:extension>
            </xs:simpleContent>
        </xs:complexType>
    </xs:element>

    <xs:element name='spks' type='xs:base64Binary'/>
    <xs:element name='ik' type='xs:base64Binary'/>

    <xs:element name='prekeys'>
        <xs:complexType>
            <xs:sequence maxOccurs='unbounded'>
                <xs:element ref='pk'/>
            </xs:sequence>
        </xs:complexType>
    </xs:element>

    <xs:element name='pk'>
        <xs:complexType>
            <xs:simpleContent>
                <xs:extension base='xs:base64Binary'>
                    <xs:attribute name='id' type='xs:unsignedInt' use='required'/>
                </xs:extension>
            </xs:simpleContent>
        </xs:complexType>
    </xs:element>
</xs:schema>
""")


MESSAGE_SCHEMA = xmlschema.XMLSchema("""<?xml version='1.0' encoding='UTF-8'?>
<xs:schema xmlns:xs='http://www.w3.org/2001/XMLSchema'
           targetNamespace='urn:xmpp:omemo:2'
           xmlns='urn:xmpp:omemo:2'>

    <xs:element name='encrypted'>
        <xs:complexType>
            <xs:all>
                <xs:element ref='header'/>
                <xs:element ref='payload' minOccurs='0' maxOccurs='1'/>
            </xs:all>
        </xs:complexType>
    </xs:element>

    <xs:element name='payload' type='xs:base64Binary'/>

    <xs:element name='header'>
        <xs:complexType>
            <xs:sequence maxOccurs='unbounded'>
                <xs:element ref='keys'/>
            </xs:sequence>
            <xs:attribute name='sid' type='xs:unsignedInt'/>
        </xs:complexType>
    </xs:element>

    <xs:element name='keys'>
        <xs:complexType>
            <xs:sequence maxOccurs='unbounded'>
                <xs:element ref='key'/>
            </xs:sequence>
            <xs:attribute name='jid' type='xs:string' use='required'/>
        </xs:complexType>
    </xs:element>

    <xs:element name='key'>
        <xs:complexType>
            <xs:simpleContent>
                <xs:extension base='xs:base64Binary'>
                    <xs:attribute name='rid' type='xs:unsignedInt' use='required'/>
                    <xs:attribute name='kex' type='xs:boolean' default='false'/>
                </xs:extension>
            </xs:simpleContent>
        </xs:complexType>
    </xs:element>
</xs:schema>
""")


[docs] def serialize_device_list(device_list: DeviceList) -> ET.Element: """ Args: device_list: The device list to serialize. The first entry of each tuple is the device id, and the second entry is the optional signed label. Returns: The serialized device list as an XML element. """ devices_elt = ET.Element(f"{NS}devices") for device_id, signed_label in device_list.items(): device_elt = ET.SubElement(devices_elt, f"{NS}device") device_elt.set("id", str(device_id)) if signed_label is not None: device_elt.set("label", signed_label.label) device_elt.set("labelsig", base64.b64encode(signed_label.signature).decode("ASCII")) return devices_elt
[docs] def parse_device_list(element: ET.Element) -> DeviceList: """ Args: element: The XML element to parse the device list from. Returns: The extracted device list. The first entry of each tuple is the device id, and the second entry is the optional signed label. Raises: xmlschema.XMLSchemaValidationError: in case the element does not conform to the XML schema given in the specification. """ DEVICE_LIST_SCHEMA.validate(element) device_list: DeviceList = {} for device_elt in element.iter(f"{NS}device"): device_id = int(cast(str, device_elt.get("id"))) label = device_elt.get("label", None) signature = device_elt.get("labelsig", None) device_list[device_id] = None if label is None or signature is None else SignedLabel( label=label, signature=base64.b64decode(signature.encode("ASCII")) ) return device_list
[docs] def serialize_bundle(bundle: BundleImpl) -> ET.Element: """ Args: bundle: The bundle to serialize. Returns: The serialized bundle as an XML element. """ bundle_elt = ET.Element(f"{NS}bundle") ET.SubElement( bundle_elt, f"{NS}spk", attrib={ "id": str(bundle.signed_pre_key_id) } ).text = base64.b64encode(bundle.bundle.signed_pre_key).decode("ASCII") ET.SubElement( bundle_elt, f"{NS}spks" ).text = base64.b64encode(bundle.bundle.signed_pre_key_sig).decode("ASCII") ET.SubElement( bundle_elt, f"{NS}ik" ).text = base64.b64encode(bundle.bundle.identity_key).decode("ASCII") prekeys_elt = ET.SubElement(bundle_elt, f"{NS}prekeys") for pre_key in bundle.bundle.pre_keys: ET.SubElement( prekeys_elt, f"{NS}pk", attrib={ "id": str(bundle.pre_key_ids[pre_key]) } ).text = base64.b64encode(pre_key).decode("ASCII") return bundle_elt
[docs] def parse_bundle(element: ET.Element, bare_jid: str, device_id: int) -> BundleImpl: """ Args: element: The XML element to parse the bundle from. bare_jid: The bare JID this bundle belongs to. device_id: The device id of the specific device this bundle belongs to. Returns: The extracted bundle. Raises: xmlschema.XMLSchemaValidationError: in case the element does not conform to the XML schema given in the specification. """ BUNDLE_SCHEMA.validate(element) spk_elt = cast(ET.Element, element.find(f"{NS}spk")) pk_elts = list(element.iter(f"{NS}pk")) return BundleImpl( bare_jid, device_id, x3dh.Bundle( base64.b64decode(cast(str, cast(ET.Element, element.find(f"{NS}ik")).text)), base64.b64decode(cast(str, spk_elt.text)), base64.b64decode(cast(str, cast(ET.Element, element.find(f"{NS}spks")).text)), frozenset(base64.b64decode(cast(str, pk_elt.text)) for pk_elt in pk_elts) ), int(cast(str, spk_elt.get("id"))), { base64.b64decode(cast(str, pk_elt.text)): int(cast(str, pk_elt.get("id"))) for pk_elt in pk_elts } )
[docs] def serialize_message(message: Message) -> ET.Element: """ Args: message: The message to serialize. Returns: The serialized message as an XML element. """ assert isinstance(message.content, ContentImpl) encrypted_elt = ET.Element(f"{NS}encrypted") header_elt = ET.SubElement(encrypted_elt, f"{NS}header", attrib={ "sid": str(message.device_id) }) for bare_jid in frozenset(encrypted_key_material.bare_jid for encrypted_key_material, _ in message.keys): keys_elt = ET.SubElement(header_elt, f"{NS}keys", attrib={ "jid": bare_jid }) keys = frozenset(key for key in message.keys if key[0].bare_jid == bare_jid) for encrypted_key_material, key_exchange in keys: assert isinstance(encrypted_key_material, EncryptedKeyMaterialImpl) key_elt = ET.SubElement( keys_elt, f"{NS}key", attrib={ "rid": str(encrypted_key_material.device_id) } ) authenticated_message = encrypted_key_material.serialize() if key_exchange is None: key_elt.text = base64.b64encode(authenticated_message).decode("ASCII") else: assert isinstance(key_exchange, KeyExchangeImpl) key_elt.set("kex", "true") key_elt.text = base64.b64encode(key_exchange.serialize(authenticated_message)).decode("ASCII") if not message.content.empty: ET.SubElement( encrypted_elt, f"{NS}payload" ).text = base64.b64encode(message.content.ciphertext).decode("ASCII") return encrypted_elt
[docs] def parse_message(element: ET.Element, bare_jid: str) -> Message: """ Args: element: The XML element to parse the message from. bare_jid: The bare JID of the sender. Returns: The extracted message. Raises: ValueError: in case there is malformed data not caught be the XML schema validation. xmlschema.XMLSchemaValidationError: in case the element does not conform to the XML schema given in the specification. """ MESSAGE_SCHEMA.validate(element) payload_elt = element.find(f"{NS}payload") keys: Set[Tuple[EncryptedKeyMaterial, Optional[KeyExchange]]] = set() for keys_elt in element.iter(f"{NS}keys"): recipient_bare_jid = cast(str, keys_elt.get("jid")) for key_elt in keys_elt.iter(f"{NS}key"): recipient_device_id = int(cast(str, key_elt.get("rid"))) content = base64.b64decode(cast(str, key_elt.text)) key_exchange: Optional[KeyExchangeImpl] = None authenticated_message: bytes if key_elt.get("kex", "false") in [ "true", "1" ]: key_exchange, authenticated_message = KeyExchangeImpl.parse(content) else: authenticated_message = content encrypted_key_material = EncryptedKeyMaterialImpl.parse( authenticated_message, recipient_bare_jid, recipient_device_id ) keys.add((encrypted_key_material, key_exchange)) return Message( NAMESPACE, bare_jid, int(cast(str, cast(ET.Element, element.find(f"{NS}header")).get("sid"))), ( ContentImpl.make_empty() if payload_elt is None else ContentImpl(base64.b64decode(cast(str, payload_elt.text))) ), frozenset(keys) )