Module refinery.lib.asn1.cms

Expand source code Browse git
from __future__ import annotations

import hashlib
import re

from collections import OrderedDict
from datetime import datetime, timezone

from refinery.lib.asn1 import ASN1Reader
from refinery.lib.asn1.defs import ContentInfo, SignedContentInfo, SpcSpOpusInfo
from refinery.lib.structures import StructReader

_TIME_VALUED_ATTRIBUTES = {'signingTime'}


def _skip00(sr: StructReader[memoryview]):
    if (zero := sr.peek(2) == b'\0\0'):
        sr.skip(2)
    return zero


def _parse_asn1_time(value):
    if not isinstance(value, str):
        return value
    if m := re.fullmatch(r'(\d{2})(\d{2})(\d{2})(\d{2})(\d{2})(\d{2})Z', value):
        yy = int(m[1])
        year = 2000 + yy if yy < 50 else 1900 + yy
        try:
            dt = datetime(year, int(m[2]), int(m[3]), int(m[4]), int(m[5]), int(m[6]), tzinfo=timezone.utc)
            return dt.isoformat(sep=' ')
        except ValueError:
            return value
    if m := re.fullmatch(r'(\d{4})(\d{2})(\d{2})(\d{2})(\d{2})(\d{2})Z', value):
        try:
            dt = datetime(int(m[1]), int(m[2]), int(m[3]), int(m[4]), int(m[5]), int(m[6]), tzinfo=timezone.utc)
            return dt.isoformat(sep=' ')
        except ValueError:
            return value
    return value


def _flatten_name(name: list) -> OrderedDict:
    result: OrderedDict[str, str] = OrderedDict()
    if not isinstance(name, list):
        return result
    for rdn in name:
        items = rdn if isinstance(rdn, list) else [rdn]
        for atv in items:
            if not isinstance(atv, dict):
                continue
            oid = atv.get('type', '')
            val = atv.get('value', '')
            if isinstance(oid, str) and oid:
                result[oid] = val
    return result


def _flatten_generic_name(name) -> OrderedDict:
    result: OrderedDict[str, str] = OrderedDict()
    if not isinstance(name, list):
        return result
    for rdn in name:
        if isinstance(rdn, list) and len(rdn) == 2 and isinstance(rdn[0], str):
            result[rdn[0]] = rdn[1]
        elif isinstance(rdn, list):
            for atv in rdn:
                if isinstance(atv, list) and len(atv) == 2 and isinstance(atv[0], str):
                    result[atv[0]] = atv[1]
                elif isinstance(atv, dict):
                    oid = atv.get('type', '')
                    val = atv.get('value', '')
                    if isinstance(oid, str) and oid:
                        result[oid] = val
        elif isinstance(rdn, dict):
            oid = rdn.get('type', '')
            val = rdn.get('value', '')
            if isinstance(oid, str) and oid:
                result[oid] = val
    return result


def _interpret_spc_opus(value) -> OrderedDict:
    result: OrderedDict = OrderedDict()
    items = value if isinstance(value, list) else [value]
    for item in items:
        if not isinstance(item, dict):
            continue
        tag = item.get('tag', '')
        val = item.get('value')
        if tag == 'context-0':
            result['programName'] = _extract_spc_string(val)
        elif tag == 'context-1':
            result['moreInfo'] = _extract_spc_link(val)
    return result


def _extract_spc_string(value):
    if isinstance(value, str):
        return value
    if isinstance(value, (bytes, bytearray, memoryview)):
        try:
            return bytes(value).decode('utf-16-be')
        except Exception:
            return bytes(value).decode('latin-1')
    if isinstance(value, dict):
        inner = value.get('value', value)
        return _extract_spc_string(inner)
    if isinstance(value, list):
        for item in value:
            result = _extract_spc_string(item)
            if isinstance(result, str):
                return result
    return value


def _extract_spc_link(value):
    if isinstance(value, str):
        return value
    if isinstance(value, (bytes, bytearray, memoryview)):
        try:
            return bytes(value).decode('ascii')
        except Exception:
            return bytes(value).decode('latin-1')
    if isinstance(value, dict):
        tag = value.get('tag', '')
        inner = value.get('value', value)
        if tag == 'context-0':
            if isinstance(inner, (bytes, bytearray, memoryview)):
                try:
                    return bytes(inner).decode('ascii')
                except Exception:
                    return bytes(inner).decode('latin-1')
            return inner
        if tag == 'context-2':
            return _extract_spc_string(inner)
        if isinstance(inner, (bytes, bytearray, memoryview)):
            try:
                return bytes(inner).decode('ascii')
            except Exception:
                return bytes(inner).decode('latin-1')
        return _extract_spc_string(inner) if not isinstance(inner, dict) else inner
    return value


def _interpret_counter_signature(value) -> OrderedDict:
    result: OrderedDict = OrderedDict()
    if not isinstance(value, list) or len(value) < 5:
        return result
    sid_raw = value[1]
    if isinstance(sid_raw, list) and len(sid_raw) == 2:
        sid = OrderedDict()
        sid['issuer'] = _flatten_generic_name(sid_raw[0])
        sid['serialNumber'] = sid_raw[1]
        result['sid'] = sid
    for item in value:
        if isinstance(item, dict) and item.get('tag') == 'context-0':
            attrs_raw = item.get('value', [])
            if isinstance(attrs_raw, list):
                result['signedAttrs'] = [
                    _interpret_generic_attribute(a) for a in attrs_raw]
            break
    return result


def _interpret_generic_attribute(value) -> OrderedDict:
    result: OrderedDict = OrderedDict()
    if isinstance(value, list) and len(value) >= 2:
        result['type'] = value[0]
        vals = value[1] if isinstance(value[1], list) else [value[1]]
        if len(vals) == 1:
            result['value'] = vals[0]
        else:
            result['values'] = vals
    elif isinstance(value, dict):
        result['type'] = value.get('attrType', value.get('type', ''))
        result['value'] = value.get('attrValues', value.get('value', ''))
    return result


def _decode_attribute_value(oid: str, values: list) -> list:
    if oid == 'spcSpOpusInfo':
        decoded = []
        for v in values:
            if isinstance(v, (bytes, bytearray, memoryview)):
                try:
                    reader = ASN1Reader(memoryview(v), bigendian=True)
                    decoded.append(reader.decode_with_schema(SpcSpOpusInfo))
                except Exception:
                    decoded.append(v)
            else:
                decoded.append(_interpret_spc_opus(v))
        return decoded
    if oid == 'microsoftNestedSignature':
        decoded = []
        for v in values:
            if isinstance(v, (bytes, bytearray, memoryview)):
                try:
                    parsed = parse_content_info(v)
                    decoded.append(parsed)
                except Exception:
                    decoded.append(v)
            else:
                decoded.append(v)
        return decoded
    if oid == 'counterSignature':
        return [_interpret_counter_signature(v) for v in values]
    return values


def _unsign(data):
    if isinstance(data, int):
        size = data.bit_length()
        if data < 0:
            data = (1 << (size + 1)) - ~data - 1
        if data > 0xFFFFFFFF_FFFFFFFF:
            size, r = divmod(size, 8)
            size += bool(r)
            data = data.to_bytes(size, 'big').hex()
        return data
    elif isinstance(data, dict):
        for key in list(data):
            data[key] = _unsign(data[key])
    elif isinstance(data, list):
        return [_unsign(x) for x in data]
    return data


def _postprocess(obj, raw: bytes | memoryview | None = None):
    if isinstance(obj, OrderedDict):
        for key in list(obj.keys()):
            value = obj[key]
            if key in ('issuer', 'subject'):
                if isinstance(value, list):
                    obj[key] = _flatten_name(value)
                elif isinstance(value, dict):
                    obj[key] = _postprocess(value)
            elif key == 'validity' and isinstance(value, dict):
                obj[key] = OrderedDict(
                    (k, _parse_asn1_time(_postprocess(v)))
                    for k, v in value.items()
                )
            elif key in ('signedAttrs', 'unsignedAttrs') and isinstance(value, list):
                obj[key] = [_postprocess_attribute(attr) for attr in value]
            elif key == 'certificates' and isinstance(value, list):
                obj[key] = [_postprocess(cert) for cert in value]
            elif key == 'signerInfos' and isinstance(value, list):
                obj[key] = [_postprocess(si) for si in value]
            elif key == 'content' and isinstance(value, dict):
                obj[key] = _postprocess(value, raw)
            elif key == 'tbsCertificate' and isinstance(value, dict):
                obj[key] = _postprocess(value)
            elif key == 'sid' and isinstance(value, dict):
                obj[key] = _postprocess(value)
            else:
                obj[key] = _postprocess(value)
        return obj
    if isinstance(obj, list):
        return [_postprocess(item) for item in obj]
    if isinstance(obj, bytes):
        return obj.hex().upper()
    return obj


def _postprocess_attribute(attr) -> OrderedDict:
    if not isinstance(attr, dict):
        return attr
    if 'type' in attr and 'attrType' not in attr:
        result = OrderedDict()
        oid = attr['type']
        result['type'] = oid
        if 'value' in attr:
            v = _postprocess(attr['value'])
            if oid in _TIME_VALUED_ATTRIBUTES:
                v = _parse_asn1_time(v)
            result['value'] = v
        elif 'values' in attr:
            vals = [_postprocess(v) for v in attr['values']]
            if oid in _TIME_VALUED_ATTRIBUTES:
                vals = [_parse_asn1_time(v) for v in vals]
            result['values'] = vals
        return result
    result = OrderedDict()
    oid = attr.get('attrType', '')
    values = attr.get('attrValues', [])
    result['type'] = oid
    decoded = _decode_attribute_value(oid, values)
    if len(decoded) == 1:
        v = _postprocess(decoded[0])
        if oid in _TIME_VALUED_ATTRIBUTES:
            v = _parse_asn1_time(v)
        result['value'] = v
    else:
        vals = [_postprocess(v) for v in decoded]
        if oid in _TIME_VALUED_ATTRIBUTES:
            vals = [_parse_asn1_time(v) for v in vals]
        result['values'] = vals
    return result


def _read_tag_length(sr: StructReader) -> tuple[int, int, int]:
    b = sr.u8()
    tag_class = (b >> 6) & 3
    tag_number = b & 0x1F
    if tag_number == 0x1F:
        tag_number = 0
        while True:
            b = sr.u8()
            tag_number = (tag_number << 7) | (b & 0x7F)
            if not (b & 0x80):
                break
    b = sr.u8()
    if b < 0x80:
        length = b
    elif b == 0x80:
        length = -1
    else:
        n = b & 0x7F
        length = 0
        for _ in range(n):
            length = (length << 8) | sr.u8()
    return tag_class, tag_number, length


def _skip_tlv_complete(sr: StructReader) -> int:
    start = sr.tell()
    b = sr.u8()
    constructed = bool(b & 0x20)
    tag_number = b & 0x1F
    if tag_number == 0x1F:
        while sr.u8() & 0x80:
            pass
    b = sr.u8()
    if b < 0x80:
        length = b
    elif b == 0x80:
        length = -1
    else:
        n = b & 0x7F
        length = 0
        for _ in range(n):
            length = (length << 8) | sr.u8()
    if length >= 0:
        sr.seekrel(length)
    elif constructed:
        while not _skip00(sr):
            _skip_tlv_complete(sr)
    return sr.tell() - start


def _find_implicit_set_of_certificates(
    sr: StructReader,
) -> list[tuple[int, int]]:
    _read_tag_length(sr)
    _skip_tlv_complete(sr)
    tc, tn, outer_len = _read_tag_length(sr)
    if tc != 2 or tn != 0:
        return []
    if outer_len < 0:
        outer_end = sr.tell() + sr.remaining_bytes
    else:
        outer_end = sr.tell() + outer_len
    _read_tag_length(sr)
    _skip_tlv_complete(sr)
    _skip_tlv_complete(sr)
    _skip_tlv_complete(sr)
    positions = []
    while sr.tell() < outer_end:
        if sr.remaining_bytes >= 2 and _skip00(sr):
            break
        tc, tn, length = _read_tag_length(sr)
        if tc == 2 and tn == 0:
            if length < 0:
                cert_set_end = sr.tell() + sr.remaining_bytes
            else:
                cert_set_end = sr.tell() + length
            while sr.tell() < cert_set_end:
                if sr.remaining_bytes >= 2 and _skip00(sr):
                    break
                cert_start = sr.tell()
                _skip_tlv_complete(sr)
                positions.append((cert_start, sr.tell()))
            break
        elif tc == 2 and tn == 1:
            if length >= 0:
                sr.seekrel(length)
            else:
                while not _skip00(sr):
                    _skip_tlv_complete(sr)
        else:
            if length >= 0:
                sr.seekrel(length)
            break
    return positions


def _find_certificate_hashes(raw: bytes | memoryview) -> list[str]:
    hashes: list[str] = []
    try:
        sr = StructReader(memoryview(raw), bigendian=True)
        positions = _find_implicit_set_of_certificates(sr)
        for cert_start, cert_end in positions:
            cert_bytes = bytes(raw[cert_start:cert_end])
            hashes.append(hashlib.sha1(cert_bytes).hexdigest())
    except Exception:
        pass
    return hashes


def parse_content_info(data: bytes | bytearray | memoryview) -> OrderedDict:
    """
    Parse a DER-encoded PKCS#7/CMS ContentInfo structure and return a fully post-processed
    OrderedDict ready for JSON serialization. Names are flattened, times are formatted,
    attribute values are decoded, and negative ASN.1 integers are converted to unsigned
    representation.
    """
    mv = memoryview(data)
    best_result = None
    best_remaining = len(mv) + 1
    for schema in (SignedContentInfo, ContentInfo):
        try:
            reader = ASN1Reader(mv, bigendian=True)
            result = reader.decode_with_schema(schema)
            remaining = reader.remaining_bytes
            if remaining < best_remaining:
                best_result = result
                best_remaining = remaining
            if remaining == 0:
                break
        except Exception:
            continue
    if best_result is not None:
        result = _unsign(_postprocess(best_result, mv))
    else:
        reader = ASN1Reader(mv, bigendian=True)
        result = reader.read_tlv()
    if not isinstance(result, OrderedDict):
        raise RuntimeError('The ContentInfo data did not parse as a dictionary.')
    return result


def compute_certificate_fingerprints(
    result,
    raw: bytes | memoryview,
) -> None:
    """
    Compute SHA-1 fingerprints for each certificate by locating their DER boundaries in the raw
    data and add them in-place to the result dict.
    """
    if not isinstance(result, dict):
        return
    content = result.get('content')
    if not isinstance(content, dict):
        return
    certs = content.get('certificates')
    if not isinstance(certs, list):
        return
    mv = memoryview(raw)
    cert_hashes = _find_certificate_hashes(mv)
    for i, cert in enumerate(certs):
        if isinstance(cert, dict) and i < len(cert_hashes):
            cert['fingerprint'] = cert_hashes[i]

Functions

def parse_content_info(data)

Parse a DER-encoded PKCS#7/CMS ContentInfo structure and return a fully post-processed OrderedDict ready for JSON serialization. Names are flattened, times are formatted, attribute values are decoded, and negative ASN.1 integers are converted to unsigned representation.

Expand source code Browse git
def parse_content_info(data: bytes | bytearray | memoryview) -> OrderedDict:
    """
    Parse a DER-encoded PKCS#7/CMS ContentInfo structure and return a fully post-processed
    OrderedDict ready for JSON serialization. Names are flattened, times are formatted,
    attribute values are decoded, and negative ASN.1 integers are converted to unsigned
    representation.
    """
    mv = memoryview(data)
    best_result = None
    best_remaining = len(mv) + 1
    for schema in (SignedContentInfo, ContentInfo):
        try:
            reader = ASN1Reader(mv, bigendian=True)
            result = reader.decode_with_schema(schema)
            remaining = reader.remaining_bytes
            if remaining < best_remaining:
                best_result = result
                best_remaining = remaining
            if remaining == 0:
                break
        except Exception:
            continue
    if best_result is not None:
        result = _unsign(_postprocess(best_result, mv))
    else:
        reader = ASN1Reader(mv, bigendian=True)
        result = reader.read_tlv()
    if not isinstance(result, OrderedDict):
        raise RuntimeError('The ContentInfo data did not parse as a dictionary.')
    return result
def compute_certificate_fingerprints(result, raw)

Compute SHA-1 fingerprints for each certificate by locating their DER boundaries in the raw data and add them in-place to the result dict.

Expand source code Browse git
def compute_certificate_fingerprints(
    result,
    raw: bytes | memoryview,
) -> None:
    """
    Compute SHA-1 fingerprints for each certificate by locating their DER boundaries in the raw
    data and add them in-place to the result dict.
    """
    if not isinstance(result, dict):
        return
    content = result.get('content')
    if not isinstance(content, dict):
        return
    certs = content.get('certificates')
    if not isinstance(certs, list):
        return
    mv = memoryview(raw)
    cert_hashes = _find_certificate_hashes(mv)
    for i, cert in enumerate(certs):
        if isinstance(cert, dict) and i < len(cert_hashes):
            cert['fingerprint'] = cert_hashes[i]