Module refinery.lib.un7z.bcj2

BCJ2 decoder for 7z archives.

Ported from 7-Zip's Bcj2.c (Igor Pavlov, public domain).

BCJ2 is a complex x86 branch filter with 4 input streams and 1 output stream. It uses a range coder to predict whether x86 CALL/JMP instructions should have their relative addresses converted to absolute.

Stream layout: - Stream 0 (MAIN): Non-branch bytes and branch instruction opcodes - Stream 1 (CALL): 4-byte big-endian CALL target addresses - Stream 2 (JUMP): 4-byte big-endian JMP target addresses - Stream 3 (RC): Range coder data for branch prediction

Expand source code Browse git
"""
BCJ2 decoder for 7z archives.

Ported from 7-Zip's Bcj2.c (Igor Pavlov, public domain).

BCJ2 is a complex x86 branch filter with 4 input streams and 1 output stream.
It uses a range coder to predict whether x86 CALL/JMP instructions should
have their relative addresses converted to absolute.

Stream layout:
  - Stream 0 (MAIN): Non-branch bytes and branch instruction opcodes
  - Stream 1 (CALL): 4-byte big-endian CALL target addresses
  - Stream 2 (JUMP): 4-byte big-endian JMP target addresses
  - Stream 3 (RC):   Range coder data for branch prediction
"""
from __future__ import annotations

from refinery.lib.un7z.headers import SzCorruptArchive

_TOP_VALUE = 1 << 24
_NUM_MODEL_BITS = 11
_BIT_MODEL_TOTAL = 1 << _NUM_MODEL_BITS
_NUM_MOVE_BITS = 5
_NUM_PROBS = 2 + 256

_MASK32 = 0xFFFFFFFF


def decode_bcj2(
    main_data: bytes | bytearray | memoryview,
    call_data: bytes | bytearray | memoryview,
    jump_data: bytes | bytearray | memoryview,
    rc_data: bytes | bytearray | memoryview,
    output_size: int,
) -> bytearray:
    main = memoryview(main_data)
    call = memoryview(call_data)
    jump = memoryview(jump_data)
    rc = memoryview(rc_data)

    main_pos = 0
    call_pos = 0
    jump_pos = 0
    rc_pos = 0

    probs = [_BIT_MODEL_TOTAL >> 1] * _NUM_PROBS

    if len(rc) < 5:
        raise SzCorruptArchive('BCJ2: range coder stream too short.')
    if rc[0] != 0:
        raise SzCorruptArchive('BCJ2: range coder stream must start with 0x00.')
    code = 0
    for i in range(1, 5):
        code = (code << 8) | rc[i]
    rc_pos = 5
    range_ = _MASK32

    output = bytearray(output_size)
    out_pos = 0
    ip = 0
    prev_byte = 0

    while out_pos < output_size:
        if range_ < _TOP_VALUE:
            if rc_pos >= len(rc):
                raise SzCorruptArchive('BCJ2: unexpected end of range coder stream.')
            range_ = (range_ << 8) & _MASK32
            code = ((code << 8) | rc[rc_pos]) & _MASK32
            rc_pos += 1

        found_branch = False

        while main_pos < len(main):
            b = main[main_pos]
            main_pos += 1
            if b == 0x0F and main_pos < len(main) and (main[main_pos] & 0xF0) == 0x80:
                output[out_pos] = b
                out_pos += 1
                b = main[main_pos]
                main_pos += 1
                output[out_pos] = b
                out_pos += 1
                ip += 2
                prev_byte = b
                continue
            if (b & 0xFE) == 0xE8:
                found_branch = True
                output[out_pos] = b
                out_pos += 1
                ip += 1
                break
            output[out_pos] = b
            out_pos += 1
            ip += 1
            prev_byte = b

        if not found_branch:
            break

        b = output[out_pos - 1]
        if b == 0xE8:
            prob_idx = 2 + prev_byte
        elif b == 0xE9:
            prob_idx = 1
        else:
            prob_idx = 0

        ttt = probs[prob_idx]
        bound = (range_ >> _NUM_MODEL_BITS) * ttt

        if (code & _MASK32) < bound:
            range_ = bound
            probs[prob_idx] = ttt + ((_BIT_MODEL_TOTAL - ttt) >> _NUM_MOVE_BITS)
            prev_byte = b
            continue

        range_ = (range_ - bound) & _MASK32
        code = (code - bound) & _MASK32
        probs[prob_idx] = ttt - (ttt >> _NUM_MOVE_BITS)

        if b == 0xE8:
            if call_pos + 4 > len(call):
                raise SzCorruptArchive('BCJ2: unexpected end of CALL stream.')
            val = int.from_bytes(call[call_pos:call_pos + 4], 'big')
            call_pos += 4
        else:
            if jump_pos + 4 > len(jump):
                raise SzCorruptArchive('BCJ2: unexpected end of JUMP stream.')
            val = int.from_bytes(jump[jump_pos:jump_pos + 4], 'big')
            jump_pos += 4

        ip += 4
        val = (val - ip) & _MASK32

        output[out_pos:out_pos + 4] = val.to_bytes(4, 'little')
        out_pos += 4
        prev_byte = (val >> 24) & 0xFF

    return output

Functions

def decode_bcj2(main_data, call_data, jump_data, rc_data, output_size)
Expand source code Browse git
def decode_bcj2(
    main_data: bytes | bytearray | memoryview,
    call_data: bytes | bytearray | memoryview,
    jump_data: bytes | bytearray | memoryview,
    rc_data: bytes | bytearray | memoryview,
    output_size: int,
) -> bytearray:
    main = memoryview(main_data)
    call = memoryview(call_data)
    jump = memoryview(jump_data)
    rc = memoryview(rc_data)

    main_pos = 0
    call_pos = 0
    jump_pos = 0
    rc_pos = 0

    probs = [_BIT_MODEL_TOTAL >> 1] * _NUM_PROBS

    if len(rc) < 5:
        raise SzCorruptArchive('BCJ2: range coder stream too short.')
    if rc[0] != 0:
        raise SzCorruptArchive('BCJ2: range coder stream must start with 0x00.')
    code = 0
    for i in range(1, 5):
        code = (code << 8) | rc[i]
    rc_pos = 5
    range_ = _MASK32

    output = bytearray(output_size)
    out_pos = 0
    ip = 0
    prev_byte = 0

    while out_pos < output_size:
        if range_ < _TOP_VALUE:
            if rc_pos >= len(rc):
                raise SzCorruptArchive('BCJ2: unexpected end of range coder stream.')
            range_ = (range_ << 8) & _MASK32
            code = ((code << 8) | rc[rc_pos]) & _MASK32
            rc_pos += 1

        found_branch = False

        while main_pos < len(main):
            b = main[main_pos]
            main_pos += 1
            if b == 0x0F and main_pos < len(main) and (main[main_pos] & 0xF0) == 0x80:
                output[out_pos] = b
                out_pos += 1
                b = main[main_pos]
                main_pos += 1
                output[out_pos] = b
                out_pos += 1
                ip += 2
                prev_byte = b
                continue
            if (b & 0xFE) == 0xE8:
                found_branch = True
                output[out_pos] = b
                out_pos += 1
                ip += 1
                break
            output[out_pos] = b
            out_pos += 1
            ip += 1
            prev_byte = b

        if not found_branch:
            break

        b = output[out_pos - 1]
        if b == 0xE8:
            prob_idx = 2 + prev_byte
        elif b == 0xE9:
            prob_idx = 1
        else:
            prob_idx = 0

        ttt = probs[prob_idx]
        bound = (range_ >> _NUM_MODEL_BITS) * ttt

        if (code & _MASK32) < bound:
            range_ = bound
            probs[prob_idx] = ttt + ((_BIT_MODEL_TOTAL - ttt) >> _NUM_MOVE_BITS)
            prev_byte = b
            continue

        range_ = (range_ - bound) & _MASK32
        code = (code - bound) & _MASK32
        probs[prob_idx] = ttt - (ttt >> _NUM_MOVE_BITS)

        if b == 0xE8:
            if call_pos + 4 > len(call):
                raise SzCorruptArchive('BCJ2: unexpected end of CALL stream.')
            val = int.from_bytes(call[call_pos:call_pos + 4], 'big')
            call_pos += 4
        else:
            if jump_pos + 4 > len(jump):
                raise SzCorruptArchive('BCJ2: unexpected end of JUMP stream.')
            val = int.from_bytes(jump[jump_pos:jump_pos + 4], 'big')
            jump_pos += 4

        ip += 4
        val = (val - ip) & _MASK32

        output[out_pos:out_pos + 4] = val.to_bytes(4, 'little')
        out_pos += 4
        prev_byte = (val >> 24) & 0xFF

    return output