Module refinery.lib.unrar.unpack30
RAR 3.0 decompression algorithm.
Expand source code Browse git
"""
RAR 3.0 decompression algorithm.
"""
from __future__ import annotations
import struct
import zlib
from dataclasses import dataclass, field
from typing import Callable
from refinery.lib.unrar.filters import (
V3FilterType,
execute_v3_filter,
identify_v3_filter,
)
from refinery.lib.unrar.reader import BitInput
from refinery.lib.unrar.unpack import (
LBits,
LDecode,
RarUnpacker,
SDBits,
SDDecode,
)
from refinery.lib.unrar.unpack50 import (
BlockTables,
decode_number,
make_decode_tables,
)
NC30 = 299
DC30 = 60
LDC30 = 17
RC30 = 28
BC30 = 20
HUFF_TABLE_SIZE30 = NC30 + DC30 + RC30 + LDC30
LOW_DIST_REP_COUNT = 16
MAX3_UNPACK_FILTERS = 8192
BLOCK_LZ = 0
BLOCK_PPM = 1
MAX_O = 64
MAX_FREQ = 124
INT_BITS = 7
PERIOD_BITS = 7
TOT_BITS = INT_BITS + PERIOD_BITS
INTERVAL = 1 << INT_BITS
BIN_SCALE = 1 << TOT_BITS
TOP = 1 << 24
BOT = 1 << 15
_M32 = 0xFFFFFFFF
N1 = 4
N2 = 4
N3 = 4
N4 = (128 + 3 - 1 * N1 - 2 * N2 - 3 * N3) // 4
N_INDEXES = N1 + N2 + N3 + N4
UNIT_SIZE = 12
FIXED_UNIT_SIZE = 12
ExpEscape = [25, 14, 9, 7, 5, 5, 4, 4, 4, 3, 3, 3, 2, 2, 2, 2]
def _init_dist_tables():
"""
Lazy-initialize DDecode / DBits tables.
"""
dd = [0] * DC30
db = [0] * DC30
db_length_counts = [4, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 14, 0, 12]
dist = 0
bit_length = 0
slot = 0
for count in db_length_counts:
for _ in range(count):
if slot < DC30:
dd[slot] = dist
db[slot] = bit_length
slot += 1
dist += 1 << bit_length
bit_length += 1
return dd, db
_DDecode, _DBits = _init_dist_tables()
def vm_read_data(inp: BitInput) -> int:
"""
Read a variable-length data value from the VM code stream.
"""
data = inp.getbits()
flag = data & 0xC000
if flag == 0:
inp.addbits(6)
return (data >> 10) & 0xF
elif flag == 0x4000:
if (data & 0x3C00) == 0:
val = 0xFFFFFF00 | ((data >> 2) & 0xFF)
inp.addbits(14)
else:
val = (data >> 6) & 0xFF
inp.addbits(10)
return val & _M32
elif flag == 0x8000:
inp.addbits(2)
val = inp.getbits()
inp.addbits(16)
return val & 0xFFFF
else:
inp.addbits(2)
val = (inp.getbits() << 16) & _M32
inp.addbits(16)
val |= inp.getbits()
inp.addbits(16)
return val & _M32
class _SubAllocator:
"""
PPMd sub-allocator.
Uses a flat bytearray with integer offsets instead of pointers.
"""
def __init__(self):
self.heap: bytearray = bytearray()
self.heap_start = 0
self.heap_end = 0
self.p_text = 0
self.units_start = 0
self.lo_unit = 0
self.hi_unit = 0
self.fake_units_start = 0
self.glue_count = 0
self.sub_allocator_size = 0
self.free_list = [0] * N_INDEXES
self.indx2units = [0] * N_INDEXES
self.units2indx = [0] * 128
def start(self, mb_count: int) -> bool:
t = mb_count << 20
if self.sub_allocator_size == t:
return True
alloc_size = t // FIXED_UNIT_SIZE * UNIT_SIZE + 2 * UNIT_SIZE
self.heap = bytearray(alloc_size)
self.heap_start = 0
self.heap_end = alloc_size - UNIT_SIZE
self.sub_allocator_size = t
return True
def stop(self):
self.sub_allocator_size = 0
self.heap = bytearray()
def init(self):
self.free_list = [0] * N_INDEXES
self.p_text = self.heap_start
size2 = FIXED_UNIT_SIZE * (self.sub_allocator_size // 8 // FIXED_UNIT_SIZE * 7)
real_size2 = size2 // FIXED_UNIT_SIZE * UNIT_SIZE
size1 = self.sub_allocator_size - size2
real_size1 = size1 // FIXED_UNIT_SIZE * UNIT_SIZE + UNIT_SIZE
self.lo_unit = self.units_start = self.heap_start + real_size1
self.fake_units_start = self.heap_start + size1
self.hi_unit = self.lo_unit + real_size2
i = 0
k = 1
for _ in range(N1):
self.indx2units[i] = k
i += 1
k += 1
k += 1
for _ in range(N2):
self.indx2units[i] = k
i += 1
k += 2
k += 1
for _ in range(N3):
self.indx2units[i] = k
i += 1
k += 3
k += 1
for _ in range(N4):
self.indx2units[i] = k
i += 1
k += 4
self.glue_count = 0
k = 0
i = 0
for j in range(128):
if i < N_INDEXES and self.indx2units[i] < j + 1:
i += 1
self.units2indx[j] = min(i, N_INDEXES - 1)
def _u2b(self, nu: int) -> int:
return UNIT_SIZE * nu
def _insert_node(self, p: int, indx: int):
struct.pack_into('<I', self.heap, p, self.free_list[indx])
self.free_list[indx] = p
def _remove_node(self, indx: int) -> int:
ret = self.free_list[indx]
self.free_list[indx] = struct.unpack_from('<I', self.heap, ret)[0]
return ret
def _split_block(self, pv: int, old_indx: int, new_indx: int):
u_diff = self.indx2units[old_indx] - self.indx2units[new_indx]
p = pv + self._u2b(self.indx2units[new_indx])
i = self.units2indx[u_diff - 1] if u_diff > 0 else 0
if i < N_INDEXES and self.indx2units[i] != u_diff:
i -= 1
self._insert_node(p, i)
p += self._u2b(self.indx2units[i])
u_diff -= self.indx2units[i]
if u_diff > 0:
self._insert_node(p, self.units2indx[u_diff - 1])
def alloc_context(self) -> int:
if self.hi_unit != self.lo_unit:
self.hi_unit -= UNIT_SIZE
return self.hi_unit
if self.free_list[0]:
return self._remove_node(0)
return self._alloc_units_rare(0)
def alloc_units(self, nu: int) -> int:
indx = self.units2indx[min(nu - 1, 127)]
if self.free_list[indx]:
return self._remove_node(indx)
ret = self.lo_unit
self.lo_unit += self._u2b(self.indx2units[indx])
if self.lo_unit <= self.hi_unit:
return ret
self.lo_unit -= self._u2b(self.indx2units[indx])
return self._alloc_units_rare(indx)
def _alloc_units_rare(self, indx: int) -> int:
if not self.glue_count:
self.glue_count = 255
self._glue_free_blocks()
if self.free_list[indx]:
return self._remove_node(indx)
i = indx
while True:
i += 1
if i == N_INDEXES:
self.glue_count -= 1
i_bytes = self._u2b(self.indx2units[indx])
j = FIXED_UNIT_SIZE * self.indx2units[indx]
if self.fake_units_start - self.p_text > j:
self.fake_units_start -= j
self.units_start -= i_bytes
return self.units_start
return 0
if self.free_list[i]:
break
ret = self._remove_node(i)
self._split_block(ret, i, indx)
return ret
def _glue_free_blocks(self):
all_blocks: list[tuple[int, int]] = []
for i in range(N_INDEXES):
while self.free_list[i]:
p = self._remove_node(i)
all_blocks.append((p, self.indx2units[i]))
all_blocks.sort()
merged: list[tuple[int, int]] = []
for addr, nu in all_blocks:
if merged and merged[-1][0] + self._u2b(merged[-1][1]) == addr:
merged[-1] = (merged[-1][0], merged[-1][1] + nu)
else:
merged.append((addr, nu))
for addr, nu in merged:
while nu > 128:
self._insert_node(addr, N_INDEXES - 1)
addr += self._u2b(128)
nu -= 128
if nu > 0:
idx = self.units2indx[min(nu - 1, 127)]
if idx < N_INDEXES and self.indx2units[idx] != nu:
k = nu - self.indx2units[idx - 1] if idx > 0 else nu
if k > 0 and k <= 128:
p2 = addr + self._u2b(nu - k)
self._insert_node(p2, self.units2indx[min(k - 1, 127)])
if idx > 0:
idx -= 1
self._insert_node(addr, idx)
def expand_units(self, old_ptr: int, old_nu: int) -> int:
i0 = self.units2indx[min(old_nu - 1, 127)]
i1 = self.units2indx[min(old_nu, 127)]
if i0 == i1:
return old_ptr
ptr = self.alloc_units(old_nu + 1)
if ptr:
size = self._u2b(old_nu)
self.heap[ptr:ptr + size] = self.heap[old_ptr:old_ptr + size]
self._insert_node(old_ptr, i0)
return ptr
def shrink_units(self, old_ptr: int, old_nu: int, new_nu: int) -> int:
i0 = self.units2indx[min(old_nu - 1, 127)]
i1 = self.units2indx[min(new_nu - 1, 127)]
if i0 == i1:
return old_ptr
if self.free_list[i1]:
ptr = self._remove_node(i1)
size = self._u2b(new_nu)
self.heap[ptr:ptr + size] = self.heap[old_ptr:old_ptr + size]
self._insert_node(old_ptr, i0)
return ptr
else:
self._split_block(old_ptr, i0, i1)
return old_ptr
def free_units(self, ptr: int, old_nu: int):
self._insert_node(ptr, self.units2indx[min(old_nu - 1, 127)])
@property
def allocated(self) -> int:
return self.sub_allocator_size
_STATE_SIZE = 6
_CTX_SIZE = 12
class _PPMHeap:
"""
Accessor for PPM structures stored in sub-allocator heap.
"""
def __init__(self, sa: _SubAllocator):
self.sa = sa
self.h = sa.heap
def st_symbol(self, p: int) -> int:
return self.h[p]
def st_set_symbol(self, p: int, v: int):
self.h[p] = v & 0xFF
def st_freq(self, p: int) -> int:
return self.h[p + 1]
def st_set_freq(self, p: int, v: int):
self.h[p + 1] = v & 0xFF
def st_successor(self, p: int) -> int:
return struct.unpack_from('<I', self.h, p + 2)[0]
def st_set_successor(self, p: int, v: int):
struct.pack_into('<I', self.h, p + 2, v & _M32)
def st_copy(self, dst: int, src: int):
self.h[dst:dst + _STATE_SIZE] = self.h[src:src + _STATE_SIZE]
def st_swap(self, a: int, b: int):
t = self.h[a:a + _STATE_SIZE]
self.h[a:a + _STATE_SIZE] = self.h[b:b + _STATE_SIZE]
self.h[b:b + _STATE_SIZE] = t
def ctx_num_stats(self, c: int) -> int:
return struct.unpack_from('<H', self.h, c)[0]
def ctx_set_num_stats(self, c: int, v: int):
struct.pack_into('<H', self.h, c, v & 0xFFFF)
def ctx_summ_freq(self, c: int) -> int:
return struct.unpack_from('<H', self.h, c + 2)[0]
def ctx_set_summ_freq(self, c: int, v: int):
struct.pack_into('<H', self.h, c + 2, v & 0xFFFF)
def ctx_stats(self, c: int) -> int:
return struct.unpack_from('<I', self.h, c + 4)[0]
def ctx_set_stats(self, c: int, v: int):
struct.pack_into('<I', self.h, c + 4, v & _M32)
def ctx_one_state(self, c: int) -> int:
"""
Return offset of the OneState within context (starts at c+2).
"""
return c + 2
def ctx_suffix(self, c: int) -> int:
return struct.unpack_from('<I', self.h, c + 8)[0]
def ctx_set_suffix(self, c: int, v: int):
struct.pack_into('<I', self.h, c + 8, v & _M32)
def state_at(self, stats: int, i: int) -> int:
return stats + i * _STATE_SIZE
class _SEE2Context:
__slots__ = ('summ', 'shift', 'count')
def __init__(self, init_val: int = 0):
self.summ = 0
self.shift = 0
self.count = 0
if init_val:
self.init(init_val)
def init(self, init_val: int):
self.shift = PERIOD_BITS - 4
self.summ = (init_val << self.shift) & 0xFFFF
self.count = 4
def get_mean(self) -> int:
ret = (self.summ & 0xFFFF) >> self.shift
self.summ = (self.summ - ret) & 0xFFFF
return ret + (1 if ret == 0 else 0)
def update(self):
if self.shift < PERIOD_BITS:
self.count -= 1
if self.count == 0:
self.summ = (self.summ + self.summ) & 0xFFFF
self.count = 3 << self.shift
self.shift += 1
class _RangeCoder:
__slots__ = (
'low',
'code',
'range',
'low_count',
'high_count',
'scale',
'_reader',
)
_reader: Callable[[], int]
def __init__(self):
self.low = 0
self.code = 0
self.range = _M32
self.low_count = 0
self.high_count = 0
self.scale = 0
def init_decoder(self, reader: Callable[[], int]):
"""
Initialize from a byte reader (callable returning int).
"""
self._reader = reader
self.low = 0
self.code = 0
self.range = _M32
for _ in range(4):
self.code = ((self.code << 8) | (self._reader() & 0xFF)) & _M32
def _get_char(self) -> int:
return self._reader() & 0xFF
def get_current_count(self) -> int:
self.range = (self.range // self.scale) & _M32
if self.range == 0:
self.range = 1
return ((self.code - self.low) & _M32) // self.range
def get_current_shift_count(self, shift: int) -> int:
self.range = (self.range >> shift) & _M32
if self.range == 0:
self.range = 1
return ((self.code - self.low) & _M32) // self.range
def decode(self):
self.low = (self.low + self.range * self.low_count) & _M32
self.range = (self.range * (self.high_count - self.low_count)) & _M32
def normalize(self):
while True:
if ((self.low ^ ((self.low + self.range) & _M32)) & _M32) >= TOP:
if (self.range & _M32) >= BOT:
break
self.range = ((-(self.low & 0xFFFFFFFF)) & (BOT - 1)) & _M32
self.code = ((self.code << 8) | self._get_char()) & _M32
self.range = (self.range << 8) & _M32
self.low = (self.low << 8) & _M32
class _ModelPPM:
"""
PPMd Model H for RAR3 decompression.
"""
def __init__(self):
self.sa = _SubAllocator()
self.hp: _PPMHeap = _PPMHeap(self.sa)
self.coder = _RangeCoder()
self.min_context = 0
self.max_context = 0
self.found_state = 0
self.num_masked = 0
self.init_esc = 0
self.order_fall = 0
self.max_order = 0
self.run_length = 0
self.init_rl = 0
self.esc_count = 0
self.prev_success = 0
self.hi_bits_flag = 0
self.char_mask = [0] * 256
self.ns2indx = [0] * 256
self.ns2bs_indx = [0] * 256
self.hb2flag = [0] * 256
self.bin_summ = [[0] * 64 for _ in range(128)]
self.see2_cont = [[_SEE2Context() for _ in range(16)] for _ in range(25)]
self.dummy_see2 = _SEE2Context()
def _restart_model_rare(self):
hp = self.hp
sa = self.sa
self.char_mask = [0] * 256
sa.init()
self.init_rl = -(min(self.max_order, 12)) - 1
self.min_context = self.max_context = sa.alloc_context()
if not self.min_context:
return
hp.ctx_set_suffix(self.min_context, 0)
self.order_fall = self.max_order
hp.ctx_set_num_stats(self.min_context, 256)
hp.ctx_set_summ_freq(self.min_context, 257)
stats = sa.alloc_units(128)
if not stats:
return
hp.ctx_set_stats(self.min_context, stats)
self.found_state = stats
self.run_length = self.init_rl
self.prev_success = 0
for i in range(256):
s = hp.state_at(stats, i)
hp.st_set_symbol(s, i)
hp.st_set_freq(s, 1)
hp.st_set_successor(s, 0)
init_bin_esc = [0x3CDD, 0x1F3F, 0x59BF, 0x48F3, 0x64A1, 0x5ABC, 0x6632, 0x6051]
for i in range(128):
for k in range(8):
for m in range(0, 64, 8):
self.bin_summ[i][k + m] = BIN_SCALE - init_bin_esc[k] // (i + 2)
for i in range(25):
for k in range(16):
self.see2_cont[i][k].init(5 * i + 10)
def _start_model_rare(self, max_order: int):
self.esc_count = 1
self.max_order = max_order
self._restart_model_rare()
self.ns2bs_indx[0] = 0
self.ns2bs_indx[1] = 2
for i in range(2, 11):
self.ns2bs_indx[i] = 4
for i in range(11, 256):
self.ns2bs_indx[i] = 6
for i in range(3):
self.ns2indx[i] = i
m = 3
step = 1
k = step
for i in range(3, 256):
self.ns2indx[i] = m
k -= 1
if k == 0:
step += 1
k = step
m += 1
for i in range(0x40):
self.hb2flag[i] = 0
for i in range(0x40, 0x100):
self.hb2flag[i] = 0x08
self.dummy_see2.shift = PERIOD_BITS
def decode_init(self, byte_reader: Callable[[], int], esc_char_ref: list) -> bool:
"""
Initialize PPM decoding. esc_char_ref is [esc_char] mutable.
"""
max_order = byte_reader()
if max_order < 0:
return False
reset = bool(max_order & 0x20)
max_mb = 0
if reset:
max_mb = byte_reader()
if max_mb < 0:
return False
elif self.sa.allocated == 0:
return False
if max_order & 0x40:
ch = byte_reader()
if ch < 0:
return False
esc_char_ref[0] = ch
self.coder.init_decoder(byte_reader)
if reset:
order = (max_order & 0x1F) + 1
if order > 16:
order = 16 + (order - 16) * 3
if order == 1:
self.sa.stop()
return False
self.sa.start(max_mb + 1)
self.hp = _PPMHeap(self.sa)
self._start_model_rare(order)
return self.min_context != 0
def decode_char(self) -> int:
"""
Decode one character. Returns -1 on error.
"""
hp = self.hp
sa = self.sa
if not hp or not self.min_context:
return -1
if self.min_context >= sa.heap_end or self.min_context < sa.p_text:
return -1
if hp.ctx_num_stats(self.min_context) != 1:
stats = hp.ctx_stats(self.min_context)
if stats < sa.p_text or stats >= sa.heap_end:
return -1
if not self._decode_symbol1(self.min_context):
return -1
else:
self._decode_bin_symbol(self.min_context)
self.coder.decode()
while not self.found_state:
self.coder.normalize()
while True:
self.order_fall += 1
self.min_context = hp.ctx_suffix(self.min_context)
if not self.min_context or self.min_context < sa.p_text or self.min_context >= sa.heap_end:
return -1
if hp.ctx_num_stats(self.min_context) != self.num_masked:
break
if not self._decode_symbol2(self.min_context):
return -1
self.coder.decode()
symbol = hp.st_symbol(self.found_state)
if not self.order_fall and hp.st_successor(self.found_state) > sa.p_text:
self.min_context = self.max_context = hp.st_successor(self.found_state)
else:
self._update_model()
if self.esc_count == 0:
self._clear_mask()
self.coder.normalize()
return symbol
def _decode_bin_symbol(self, ctx: int):
hp = self.hp
os = hp.ctx_one_state(ctx)
rs_freq = hp.st_freq(os)
rs_symbol = hp.st_symbol(os)
suffix = hp.ctx_suffix(ctx)
self.hi_bits_flag = self.hb2flag[hp.st_symbol(self.found_state)] if self.found_state else 0
ns = hp.ctx_num_stats(suffix) if suffix else 1
bs_idx = (rs_freq - 1) & 0x7F
bs_off = (
self.prev_success + self.ns2bs_indx[min(ns - 1, 255)]
+ self.hi_bits_flag + 2 * self.hb2flag[rs_symbol]
+ ((self.run_length >> 26) & 0x20)
)
bs_off = min(bs_off, 63)
bs = self.bin_summ[bs_idx][bs_off]
if self.coder.get_current_shift_count(TOT_BITS) < bs:
self.found_state = os
if rs_freq < 128:
hp.st_set_freq(os, rs_freq + 1)
self.coder.low_count = 0
self.coder.high_count = bs
self.bin_summ[bs_idx][bs_off] = min(bs + INTERVAL - self._get_mean(bs, PERIOD_BITS, 2), 0xFFFF) & 0xFFFF
self.prev_success = 1
self.run_length += 1
else:
self.coder.low_count = bs
bs = max(bs - self._get_mean(bs, PERIOD_BITS, 2), 0) & 0xFFFF
self.bin_summ[bs_idx][bs_off] = bs
self.coder.high_count = BIN_SCALE
self.init_esc = ExpEscape[min(bs >> 10, 15)]
self.num_masked = 1
self.char_mask[rs_symbol] = self.esc_count
self.prev_success = 0
self.found_state = 0
@staticmethod
def _get_mean(summ: int, shift: int, rnd: int) -> int:
return (summ + (1 << (shift - rnd))) >> shift
def _decode_symbol1(self, ctx: int) -> bool:
hp = self.hp
self.coder.scale = hp.ctx_summ_freq(ctx)
stats = hp.ctx_stats(ctx)
count = self.coder.get_current_count()
if count >= self.coder.scale:
return False
p = stats
hi_cnt = hp.st_freq(p)
if count < hi_cnt:
self.prev_success = 1 if (2 * hi_cnt > self.coder.scale) else 0
self.coder.high_count = hi_cnt
self.run_length += self.prev_success
self.found_state = p
hi_cnt += 4
hp.st_set_freq(p, min(hi_cnt, 255))
summ_freq = hp.ctx_summ_freq(ctx) + 4
hp.ctx_set_summ_freq(ctx, min(summ_freq, 0xFFFF))
if hi_cnt > MAX_FREQ:
self._rescale(ctx)
self.coder.low_count = 0
return True
elif not self.found_state:
return False
self.prev_success = 0
num_stats = hp.ctx_num_stats(ctx)
i = num_stats - 1
while i > 0:
p += _STATE_SIZE
hi_cnt += hp.st_freq(p)
if hi_cnt > count:
break
i -= 1
else:
if hi_cnt <= count:
self.hi_bits_flag = self.hb2flag[hp.st_symbol(self.found_state)] if self.found_state else 0
self.coder.low_count = hi_cnt
self.char_mask[hp.st_symbol(p)] = self.esc_count
self.num_masked = num_stats
i = num_stats - 1
self.found_state = 0
while i > 0:
p -= _STATE_SIZE
self.char_mask[hp.st_symbol(p)] = self.esc_count
i -= 1
self.coder.high_count = self.coder.scale
return True
self.coder.low_count = hi_cnt - hp.st_freq(p)
self.coder.high_count = hi_cnt
self._update1(ctx, p)
return True
def _update1(self, ctx: int, p: int):
hp = self.hp
self.found_state = p
freq = hp.st_freq(p) + 4
hp.st_set_freq(p, min(freq, 255))
summ = hp.ctx_summ_freq(ctx) + 4
hp.ctx_set_summ_freq(ctx, min(summ, 0xFFFF))
stats = hp.ctx_stats(ctx)
if p > stats and hp.st_freq(p) > hp.st_freq(p - _STATE_SIZE):
hp.st_swap(p, p - _STATE_SIZE)
self.found_state = p - _STATE_SIZE
p = self.found_state
if hp.st_freq(p) > MAX_FREQ:
self._rescale(ctx)
def _decode_symbol2(self, ctx: int) -> bool:
hp = self.hp
num_stats = hp.ctx_num_stats(ctx)
diff = num_stats - self.num_masked
see2c = self._make_esc_freq2(ctx, diff)
stats = hp.ctx_stats(ctx)
ps = []
hi_cnt = 0
p = stats - _STATE_SIZE
i = diff
while i > 0:
p += _STATE_SIZE
while self.char_mask[hp.st_symbol(p)] == self.esc_count:
p += _STATE_SIZE
hi_cnt += hp.st_freq(p)
ps.append(p)
i -= 1
self.coder.scale += hi_cnt
count = self.coder.get_current_count()
if count >= self.coder.scale:
return False
if count < hi_cnt:
hi2 = 0
idx = 0
p = ps[0]
while True:
hi2 += hp.st_freq(p)
if hi2 > count:
break
idx += 1
if idx >= len(ps):
return False
p = ps[idx]
self.coder.low_count = hi2 - hp.st_freq(p)
self.coder.high_count = hi2
if see2c is not self.dummy_see2:
see2c.update()
self._update2(ctx, p)
else:
self.coder.low_count = hi_cnt
self.coder.high_count = self.coder.scale
i = diff
for pp in ps:
self.char_mask[hp.st_symbol(pp)] = self.esc_count
if see2c is not self.dummy_see2:
see2c.summ = (see2c.summ + self.coder.scale) & 0xFFFF
self.num_masked = num_stats
return True
def _update2(self, ctx: int, p: int):
hp = self.hp
self.found_state = p
freq = hp.st_freq(p) + 4
hp.st_set_freq(p, min(freq, 255))
summ = hp.ctx_summ_freq(ctx) + 4
hp.ctx_set_summ_freq(ctx, min(summ, 0xFFFF))
if hp.st_freq(p) > MAX_FREQ:
self._rescale(ctx)
self.esc_count += 1
self.run_length = self.init_rl
def _make_esc_freq2(self, ctx: int, diff: int) -> _SEE2Context:
hp = self.hp
num_stats = hp.ctx_num_stats(ctx)
if num_stats != 256:
ns_idx = self.ns2indx[min(diff - 1, 255)]
suffix = hp.ctx_suffix(ctx)
suffix_ns = hp.ctx_num_stats(suffix) if suffix else 1
off = (
int(diff < suffix_ns - num_stats)
+ 2 * int(hp.ctx_summ_freq(ctx) < 11 * num_stats)
+ 4 * int(self.num_masked > diff)
+ self.hi_bits_flag
)
off = min(off, 15)
psee2c = self.see2_cont[min(ns_idx, 24)][off]
self.coder.scale = psee2c.get_mean()
return psee2c
else:
self.coder.scale = 1
return self.dummy_see2
def _rescale(self, ctx: int):
hp = self.hp
sa = self.sa
num_stats = hp.ctx_num_stats(ctx)
stats = hp.ctx_stats(ctx)
if self.found_state and self.found_state != stats:
p = self.found_state
while p > stats:
hp.st_swap(p, p - _STATE_SIZE)
p -= _STATE_SIZE
freq0 = hp.st_freq(stats)
hp.st_set_freq(stats, min(freq0 + 4, 255))
summ_freq = hp.ctx_summ_freq(ctx) + 4
esc_freq = summ_freq - hp.st_freq(stats)
adder = 1 if self.order_fall != 0 else 0
new_freq = ((hp.st_freq(stats) + adder) >> 1)
hp.st_set_freq(stats, max(new_freq, 1))
new_summ = hp.st_freq(stats)
i = num_stats - 1
p = stats + _STATE_SIZE
while i > 0:
esc_freq -= hp.st_freq(p)
f = ((hp.st_freq(p) + adder) >> 1)
f = max(f, 0)
hp.st_set_freq(p, f)
new_summ += f
if f > hp.st_freq(p - _STATE_SIZE):
tmp = hp.h[p:p + _STATE_SIZE]
q = p
while q > stats and f > hp.st_freq(q - _STATE_SIZE):
hp.st_copy(q, q - _STATE_SIZE)
q -= _STATE_SIZE
hp.h[q:q + _STATE_SIZE] = tmp
p += _STATE_SIZE
i -= 1
p = stats + (num_stats - 1) * _STATE_SIZE
zero_count = 0
while p > stats and hp.st_freq(p) == 0:
zero_count += 1
p -= _STATE_SIZE
if zero_count > 0:
esc_freq += zero_count
num_stats -= zero_count
hp.ctx_set_num_stats(ctx, num_stats)
if num_stats == 1:
tmp_sym = hp.st_symbol(stats)
tmp_freq = hp.st_freq(stats)
tmp_succ = hp.st_successor(stats)
while tmp_freq > 1 and esc_freq > 1:
tmp_freq -= tmp_freq >> 1
esc_freq >>= 1
old_nu = ((num_stats + zero_count) + 1) >> 1
sa.free_units(stats, old_nu)
os = hp.ctx_one_state(ctx)
hp.st_set_symbol(os, tmp_sym)
hp.st_set_freq(os, tmp_freq)
hp.st_set_successor(os, tmp_succ)
self.found_state = os
return
else:
old_n = ((num_stats + zero_count) + 1) >> 1
new_n = (num_stats + 1) >> 1
if old_n != new_n:
new_stats = sa.shrink_units(stats, old_n, new_n)
hp.ctx_set_stats(ctx, new_stats)
stats = new_stats
esc_freq -= esc_freq >> 1
new_summ += max(esc_freq, 1)
hp.ctx_set_summ_freq(ctx, min(new_summ, 0xFFFF))
self.found_state = hp.ctx_stats(ctx)
def _create_successors(self, skip: bool, p1: int) -> int:
hp = self.hp
sa = self.sa
pc = self.min_context
up_branch = hp.st_successor(self.found_state)
ps = []
if not skip:
ps.append(self.found_state)
if not hp.ctx_suffix(pc):
p1 = 0
if p1:
p = p1
pc = hp.ctx_suffix(pc)
if hp.st_successor(p) != up_branch:
pc = hp.st_successor(p)
else:
if len(ps) >= MAX_O:
return 0
ps.append(p)
while hp.ctx_suffix(pc):
pc = hp.ctx_suffix(pc)
if hp.ctx_num_stats(pc) != 1:
p = hp.ctx_stats(pc)
if hp.st_symbol(p) != hp.st_symbol(self.found_state):
while hp.st_symbol(p) != hp.st_symbol(self.found_state):
p += _STATE_SIZE
else:
p = hp.ctx_one_state(pc)
if hp.st_successor(p) != up_branch:
pc = hp.st_successor(p)
break
if len(ps) >= MAX_O:
return 0
ps.append(p)
elif hp.ctx_suffix(pc):
pc = hp.ctx_suffix(pc)
while True:
if hp.ctx_num_stats(pc) != 1:
p = hp.ctx_stats(pc)
if hp.st_symbol(p) != hp.st_symbol(self.found_state):
while hp.st_symbol(p) != hp.st_symbol(self.found_state):
p += _STATE_SIZE
else:
p = hp.ctx_one_state(pc)
if hp.st_successor(p) != up_branch:
pc = hp.st_successor(p)
break
if len(ps) >= MAX_O:
return 0
ps.append(p)
if not hp.ctx_suffix(pc):
break
pc = hp.ctx_suffix(pc)
if len(ps) == 0:
return pc
if up_branch >= len(hp.h) or up_branch < 1:
return 0
up_symbol = hp.h[up_branch] if up_branch < len(hp.h) else 0
up_successor = up_branch + 1
if hp.ctx_num_stats(pc) != 1:
if pc <= sa.p_text:
return 0
stats = hp.ctx_stats(pc)
pp = stats
if hp.st_symbol(pp) != up_symbol:
while hp.st_symbol(pp) != up_symbol:
pp += _STATE_SIZE
cf = hp.st_freq(pp) - 1
s0 = hp.ctx_summ_freq(pc) - hp.ctx_num_stats(pc) - cf
if s0 <= 0:
up_freq = 1
elif 2 * cf <= s0:
up_freq = 1 + (1 if 5 * cf > s0 else 0)
else:
up_freq = 1 + min((2 * cf + 3 * s0 - 1) // (2 * s0), 255)
else:
up_freq = hp.st_freq(hp.ctx_one_state(pc))
while ps:
p_state = ps.pop()
new_ctx = sa.alloc_context()
if not new_ctx:
return 0
hp.ctx_set_num_stats(new_ctx, 1)
os = hp.ctx_one_state(new_ctx)
hp.st_set_symbol(os, up_symbol)
hp.st_set_freq(os, up_freq)
hp.st_set_successor(os, up_successor)
hp.ctx_set_suffix(new_ctx, pc)
hp.st_set_successor(p_state, new_ctx)
pc = new_ctx
return pc
def _update_model(self):
hp = self.hp
sa = self.sa
fs_symbol = hp.st_symbol(self.found_state)
fs_freq = hp.st_freq(self.found_state)
fs_successor = hp.st_successor(self.found_state)
p = 0
suffix = hp.ctx_suffix(self.min_context)
if fs_freq < MAX_FREQ // 4 and suffix:
if hp.ctx_num_stats(suffix) != 1:
stats = hp.ctx_stats(suffix)
p = stats
if hp.st_symbol(p) != fs_symbol:
while hp.st_symbol(p) != fs_symbol:
p += _STATE_SIZE
if hp.st_freq(p) >= hp.st_freq(p - _STATE_SIZE):
hp.st_swap(p, p - _STATE_SIZE)
p -= _STATE_SIZE
if hp.st_freq(p) < MAX_FREQ - 9:
hp.st_set_freq(p, hp.st_freq(p) + 2)
summ = hp.ctx_summ_freq(suffix) + 2
hp.ctx_set_summ_freq(suffix, min(summ, 0xFFFF))
else:
p = hp.ctx_one_state(suffix)
if hp.st_freq(p) < 32:
hp.st_set_freq(p, hp.st_freq(p) + 1)
if not self.order_fall:
successor = self._create_successors(True, p)
if not successor:
self._restart_model_rare()
self.esc_count = 0
return
self.min_context = self.max_context = successor
hp.st_set_successor(self.found_state, successor)
return
if sa.p_text < len(sa.heap):
sa.heap[sa.p_text] = fs_symbol & 0xFF
sa.p_text += 1
successor_ptr = sa.p_text
if sa.p_text >= sa.fake_units_start:
self._restart_model_rare()
self.esc_count = 0
return
if fs_successor:
if fs_successor <= sa.p_text:
new_succ = self._create_successors(False, p)
if not new_succ:
self._restart_model_rare()
self.esc_count = 0
return
hp.st_set_successor(self.found_state, new_succ)
fs_successor = new_succ
self.order_fall -= 1
if not self.order_fall:
successor_ptr = fs_successor
sa.p_text -= int(self.max_context != self.min_context)
else:
hp.st_set_successor(self.found_state, successor_ptr)
fs_successor = self.min_context
ns = hp.ctx_num_stats(self.min_context)
s0 = hp.ctx_summ_freq(self.min_context) - ns - (fs_freq - 1)
if s0 < 1:
s0 = 1
pc = self.max_context
while pc != self.min_context:
ns1 = hp.ctx_num_stats(pc)
if ns1 != 1:
if (ns1 & 1) == 0:
new_stats = sa.expand_units(hp.ctx_stats(pc), ns1 >> 1)
if not new_stats:
self._restart_model_rare()
self.esc_count = 0
return
hp.ctx_set_stats(pc, new_stats)
summ = hp.ctx_summ_freq(pc)
summ += int(2 * ns1 < ns) + 2 * int((4 * ns1 <= ns) and (summ <= 8 * ns1))
hp.ctx_set_summ_freq(pc, min(summ, 0xFFFF))
else:
pp = sa.alloc_units(1)
if not pp:
self._restart_model_rare()
self.esc_count = 0
return
os = hp.ctx_one_state(pc)
hp.h[pp:pp + _STATE_SIZE] = hp.h[os:os + _STATE_SIZE]
hp.ctx_set_stats(pc, pp)
f = hp.st_freq(pp)
if f < MAX_FREQ // 4 - 1:
hp.st_set_freq(pp, min(f + f, 255))
else:
hp.st_set_freq(pp, MAX_FREQ - 4)
summ = hp.st_freq(pp) + self.init_esc + int(ns > 3)
hp.ctx_set_summ_freq(pc, min(summ, 0xFFFF))
cf = 2 * fs_freq * (hp.ctx_summ_freq(pc) + 6)
sf = s0 + hp.ctx_summ_freq(pc)
if cf < 6 * sf:
new_cf = 1 + int(cf > sf) + int(cf >= 4 * sf)
summ = hp.ctx_summ_freq(pc) + 3
hp.ctx_set_summ_freq(pc, min(summ, 0xFFFF))
else:
new_cf = 4 + int(cf >= 9 * sf) + int(cf >= 12 * sf) + int(cf >= 15 * sf)
summ = hp.ctx_summ_freq(pc) + new_cf
hp.ctx_set_summ_freq(pc, min(summ, 0xFFFF))
new_state = hp.state_at(hp.ctx_stats(pc), ns1)
hp.st_set_successor(new_state, successor_ptr)
hp.st_set_symbol(new_state, fs_symbol)
hp.st_set_freq(new_state, new_cf)
hp.ctx_set_num_stats(pc, ns1 + 1)
pc = hp.ctx_suffix(pc)
self.max_context = self.min_context = fs_successor
def _clear_mask(self):
self.esc_count = 1
self.char_mask = [0] * 256
def cleanup(self):
self.sa.stop()
self.sa.start(1)
self.hp = _PPMHeap(self.sa)
self._start_model_rare(2)
@dataclass
class _UnpackFilter30:
block_start: int = 0
block_length: int = 0
next_window: bool = False
parent_filter: int = 0
prg_type: V3FilterType = V3FilterType.VMSF_NONE
init_r: list[int] = field(default_factory=lambda: [0] * 8)
class Unpack30(RarUnpacker):
"""
RAR 3.0 decompression engine.
"""
def __init__(
self,
data: bytes | memoryview,
unp_size: int,
win_size: int,
solid: bool = False,
):
self._inp = BitInput(data)
self._dest_size = unp_size
self._win_size = max(win_size, 0x40000)
self._win_mask = self._win_size - 1
self._window = bytearray(self._win_size)
self._solid = solid
self._old_dist = [0, 0, 0, 0]
self._last_length = 0
self._unp_ptr = 0
self._wr_ptr = 0
self._written = 0
self._output = bytearray()
self._tables_read = False
self._block_type = BLOCK_LZ
self._block_tables = BlockTables()
self._unp_old_table = bytearray(HUFF_TABLE_SIZE30)
self._ppm = _ModelPPM()
self._ppm_esc_char = 2
self._prev_low_dist = 0
self._low_dist_rep_count = 0
self._filters: list[_UnpackFilter30] = []
self._old_filter_lengths: list[int] = []
self._prgstack: list[_UnpackFilter30 | None] = []
self._last_filter = 0
def _get_char(self) -> int:
inp = self._inp
if inp.in_addr >= len(inp.buf):
return 0
ch = inp.buf[inp.in_addr]
inp.in_addr += 1
return ch
def _write_buf(self):
written_border = self._wr_ptr
mask = self._win_mask
write_size = (self._unp_ptr - written_border) & mask
for i, flt in enumerate(self._prgstack):
if flt is None:
continue
if flt.next_window:
flt.next_window = False
continue
block_start = flt.block_start
block_length = flt.block_length
if ((block_start - written_border) & mask) < write_size:
if written_border != block_start:
self._write_area(written_border, block_start)
written_border = block_start
write_size = (self._unp_ptr - written_border) & mask
if block_length <= write_size:
block_end = (block_start + block_length) & mask
mem = bytearray(block_length)
if block_start < block_end or block_end == 0:
mem[:] = self._window[block_start:block_start + block_length]
else:
first = self._win_size - block_start
mem[:first] = self._window[block_start:]
mem[first:] = self._window[:block_end]
flt.init_r[6] = self._written & _M32
if flt.prg_type != V3FilterType.VMSF_NONE:
out_mem = execute_v3_filter(flt.prg_type, mem, block_length, flt.init_r)
else:
out_mem = mem
self._prgstack[i] = None
while i + 1 < len(self._prgstack):
nf = self._prgstack[i + 1]
if nf is None or nf.block_start != block_start or nf.block_length != len(out_mem) or nf.next_window:
break
nf.init_r[6] = self._written & _M32
if nf.prg_type != V3FilterType.VMSF_NONE:
out_mem = execute_v3_filter(nf.prg_type, out_mem, len(out_mem), nf.init_r)
i += 1
self._prgstack[i] = None
self._write_data(out_mem)
written_border = block_end
write_size = (self._unp_ptr - written_border) & mask
else:
for j in range(i, len(self._prgstack)):
f2 = self._prgstack[j]
if f2 is not None and f2.next_window:
f2.next_window = False
self._wr_ptr = written_border
return
self._write_area(written_border, self._unp_ptr)
self._wr_ptr = self._unp_ptr
def _read_end_of_block(self) -> bool:
bit_field = self._inp.getbits()
if bit_field & 0x8000:
new_table = True
new_file = False
self._inp.addbits(1)
else:
new_file = True
new_table = bool(bit_field & 0x4000)
self._inp.addbits(2)
self._tables_read = not new_table
if new_file:
return False
return self._read_tables()
def _read_vm_code(self) -> bool:
inp = self._inp
first_byte = inp.getbits() >> 8
inp.addbits(8)
length = (first_byte & 7) + 1
if length == 7:
length = (inp.getbits() >> 8) + 7
inp.addbits(8)
elif length == 8:
length = inp.getbits()
inp.addbits(16)
if length == 0:
return False
vm_code = bytearray(length)
for ii in range(length):
vm_code[ii] = (inp.getbits() >> 8) & 0xFF
inp.addbits(8)
return self._add_vm_code(first_byte, vm_code)
def _read_vm_code_ppm(self) -> bool:
first_byte = self._ppm.decode_char()
if first_byte < 0:
return False
length = (first_byte & 7) + 1
if length == 7:
b1 = self._ppm.decode_char()
if b1 < 0:
return False
length = b1 + 7
elif length == 8:
b1 = self._ppm.decode_char()
if b1 < 0:
return False
b2 = self._ppm.decode_char()
if b2 < 0:
return False
length = b1 * 256 + b2
if length == 0:
return False
vm_code = bytearray(length)
for ii in range(length):
ch = self._ppm.decode_char()
if ch < 0:
return False
vm_code[ii] = ch & 0xFF
return self._add_vm_code(first_byte, vm_code)
def _add_vm_code(self, first_byte: int, code: bytearray) -> bool:
vm_inp = BitInput(code)
if first_byte & 0x80:
filt_pos = vm_read_data(vm_inp)
if filt_pos == 0:
self._init_filters(False)
filt_pos = 0
else:
filt_pos -= 1
else:
filt_pos = self._last_filter
if filt_pos > len(self._filters) or filt_pos > len(self._old_filter_lengths):
return False
self._last_filter = filt_pos
new_filter = (filt_pos == len(self._filters))
stack_filter = _UnpackFilter30()
if new_filter:
if filt_pos > MAX3_UNPACK_FILTERS:
return False
parent = _UnpackFilter30()
self._filters.append(parent)
stack_filter.parent_filter = len(self._filters) - 1
self._old_filter_lengths.append(0)
else:
stack_filter.parent_filter = filt_pos
empty_count = sum(1 for x in self._prgstack if x is None)
self._prgstack = [x for x in self._prgstack if x is not None]
if not empty_count:
if len(self._prgstack) > MAX3_UNPACK_FILTERS:
return False
self._prgstack.append(stack_filter)
block_start = vm_read_data(vm_inp)
if first_byte & 0x40:
block_start += 258
stack_filter.block_start = (block_start + self._unp_ptr) & self._win_mask
if first_byte & 0x20:
stack_filter.block_length = vm_read_data(vm_inp)
if filt_pos < len(self._old_filter_lengths):
self._old_filter_lengths[filt_pos] = stack_filter.block_length
else:
stack_filter.block_length = self._old_filter_lengths[filt_pos] if filt_pos < len(self._old_filter_lengths) else 0
stack_filter.next_window = (self._wr_ptr != self._unp_ptr
and ((self._wr_ptr - self._unp_ptr) & self._win_mask) <= block_start)
stack_filter.init_r = [0] * 8
stack_filter.init_r[4] = stack_filter.block_length
if first_byte & 0x10:
init_mask = vm_inp.getbits() >> 9
vm_inp.addbits(7)
for ii in range(7):
if init_mask & (1 << ii):
stack_filter.init_r[ii] = vm_read_data(vm_inp)
if new_filter:
vm_code_size = vm_read_data(vm_inp)
if vm_code_size >= 0x10000 or vm_code_size == 0:
return False
if vm_inp.in_addr + vm_code_size > len(vm_inp.buf):
return False
vm_code_data = bytearray(vm_code_size)
for ii in range(vm_code_size):
vm_code_data[ii] = (vm_inp.getbits() >> 8) & 0xFF
vm_inp.addbits(8)
code_crc = zlib.crc32(vm_code_data) & 0xFFFFFFFF
parent = self._filters[filt_pos]
parent.prg_type = identify_v3_filter(code_crc)
stack_filter.prg_type = self._filters[stack_filter.parent_filter].prg_type
return True
def _init_filters(self, solid: bool):
if not solid:
self._old_filter_lengths.clear()
self._last_filter = 0
self._filters.clear()
self._prgstack.clear()
def _read_tables(self) -> bool:
inp = self._inp
inp.addbits((8 - inp.in_bit) & 7)
bit_field = inp.getbits()
if bit_field & 0x8000:
self._block_type = BLOCK_PPM
esc_ref = [self._ppm_esc_char]
result = self._ppm.decode_init(self._get_char, esc_ref)
self._ppm_esc_char = esc_ref[0]
return result
self._block_type = BLOCK_LZ
self._prev_low_dist = 0
self._low_dist_rep_count = 0
if not (bit_field & 0x4000):
self._unp_old_table = bytearray(HUFF_TABLE_SIZE30)
inp.addbits(2)
bit_length = bytearray(BC30)
i = 0
while i < BC30:
length = (inp.getbits() >> 12) & 0xF
inp.addbits(4)
if length == 15:
zero_count = (inp.getbits() >> 12) & 0xF
inp.addbits(4)
if zero_count == 0:
bit_length[i] = 15
else:
zero_count += 2
while zero_count > 0 and i < BC30:
bit_length[i] = 0
i += 1
zero_count -= 1
continue
else:
bit_length[i] = length
i += 1
make_decode_tables(bit_length, self._block_tables.BD, BC30)
table = bytearray(HUFF_TABLE_SIZE30)
i = 0
while i < HUFF_TABLE_SIZE30:
number = decode_number(inp, self._block_tables.BD)
if number < 16:
table[i] = (number + self._unp_old_table[i]) & 0xF
i += 1
elif number < 18:
if number == 16:
n = ((inp.getbits() >> 13) & 7) + 3
inp.addbits(3)
else:
n = ((inp.getbits() >> 9) & 0x7F) + 11
inp.addbits(7)
if i == 0:
return False
while n > 0 and i < HUFF_TABLE_SIZE30:
table[i] = table[i - 1]
i += 1
n -= 1
else:
if number == 18:
n = ((inp.getbits() >> 13) & 7) + 3
inp.addbits(3)
else:
n = ((inp.getbits() >> 9) & 0x7F) + 11
inp.addbits(7)
while n > 0 and i < HUFF_TABLE_SIZE30:
table[i] = 0
i += 1
n -= 1
self._tables_read = True
make_decode_tables(table, self._block_tables.LD, NC30)
off = NC30
make_decode_tables(table[off:], self._block_tables.DD, DC30)
off += DC30
make_decode_tables(table[off:], self._block_tables.LDD, LDC30)
off += LDC30
make_decode_tables(table[off:], self._block_tables.RD, RC30)
self._unp_old_table[:] = table
return True
def init_solid(self, data: bytes | memoryview, dest_size: int):
"""
Reinitialize for the next file in a solid archive chain.
"""
super().init_solid(data, dest_size)
self._init_filters(True)
def decompress(self) -> bytearray:
"""
Run the RAR3 decompression and return the extracted data.
"""
inp = self._inp
mask = self._win_mask
win = self._window
if not self._solid:
self._tables_read = False
self._unp_old_table = bytearray(HUFF_TABLE_SIZE30)
self._ppm_esc_char = 2
self._block_type = BLOCK_LZ
self._init_filters(False)
if (not self._solid or not self._tables_read) and not self._read_tables():
return self._output
tbl = self._block_tables
inp_len = len(inp.buf)
while True:
self._unp_ptr &= mask
if inp.in_addr >= inp_len:
break
if ((self._wr_ptr - self._unp_ptr) & mask) < 260 and self._wr_ptr != self._unp_ptr:
self._write_buf()
if self._written > self._dest_size:
return self._output
if self._block_type == BLOCK_PPM:
ch = self._ppm.decode_char()
if ch < 0:
self._ppm.cleanup()
self._block_type = BLOCK_LZ
break
if ch == self._ppm_esc_char:
next_ch = self._ppm.decode_char()
if next_ch < 0:
break
if next_ch == 0:
if not self._read_tables():
break
continue
if next_ch == 2:
break
if next_ch == 3:
if not self._read_vm_code_ppm():
break
continue
if next_ch == 4:
distance = 0
failed = False
length = 0
for ii in range(4):
c = self._ppm.decode_char()
if c < 0:
failed = True
break
if ii == 3:
length = c & 0xFF
else:
distance = (distance << 8) + (c & 0xFF)
if failed:
break
self._copy_string(length + 32, distance + 2)
continue
if next_ch == 5:
ll = self._ppm.decode_char()
if ll < 0:
break
self._copy_string(ll + 4, 1)
continue
win[self._unp_ptr] = ch & 0xFF
self._unp_ptr = (self._unp_ptr + 1) & mask
continue
number = decode_number(inp, tbl.LD)
if number < 256:
win[self._unp_ptr] = number & 0xFF
self._unp_ptr = (self._unp_ptr + 1) & mask
continue
if number >= 271:
num = number - 271
length = LDecode[num] + 3
bits = LBits[num]
if bits > 0:
length += inp.getbits() >> (16 - bits)
inp.addbits(bits)
dist_number = decode_number(inp, tbl.DD)
distance = _DDecode[dist_number] + 1
d_bits = _DBits[dist_number]
if d_bits > 0:
if dist_number > 9:
if d_bits > 4:
distance += (inp.getbits() >> (20 - d_bits)) << 4
inp.addbits(d_bits - 4)
if self._low_dist_rep_count > 0:
self._low_dist_rep_count -= 1
distance += self._prev_low_dist
else:
low_dist = decode_number(inp, tbl.LDD)
if low_dist == 16:
self._low_dist_rep_count = LOW_DIST_REP_COUNT - 1
distance += self._prev_low_dist
else:
distance += low_dist
self._prev_low_dist = low_dist
else:
distance += inp.getbits() >> (16 - d_bits)
inp.addbits(d_bits)
if distance >= 0x2000:
length += 1
if distance >= 0x40000:
length += 1
self._insert_old_dist(distance)
self._last_length = length
self._copy_string(length, distance)
continue
if number == 256:
if not self._read_end_of_block():
break
continue
if number == 257:
if not self._read_vm_code():
break
continue
if number == 258:
if self._last_length != 0:
self._copy_string(self._last_length, self._old_dist[0])
continue
if number < 263:
dist_num = number - 259
distance = self._old_dist[dist_num]
for idx in range(dist_num, 0, -1):
self._old_dist[idx] = self._old_dist[idx - 1]
self._old_dist[0] = distance
length_number = decode_number(inp, tbl.RD)
length = LDecode[length_number] + 2
bits = LBits[length_number]
if bits > 0:
length += inp.getbits() >> (16 - bits)
inp.addbits(bits)
self._last_length = length
self._copy_string(length, distance)
continue
if number < 272:
num = number - 263
distance = SDDecode[num] + 1
bits = SDBits[num]
if bits > 0:
distance += inp.getbits() >> (16 - bits)
inp.addbits(bits)
self._insert_old_dist(distance)
self._last_length = 2
self._copy_string(2, distance)
continue
self._write_buf()
return self._output
Functions
def vm_read_data(inp)-
Read a variable-length data value from the VM code stream.
Expand source code Browse git
def vm_read_data(inp: BitInput) -> int: """ Read a variable-length data value from the VM code stream. """ data = inp.getbits() flag = data & 0xC000 if flag == 0: inp.addbits(6) return (data >> 10) & 0xF elif flag == 0x4000: if (data & 0x3C00) == 0: val = 0xFFFFFF00 | ((data >> 2) & 0xFF) inp.addbits(14) else: val = (data >> 6) & 0xFF inp.addbits(10) return val & _M32 elif flag == 0x8000: inp.addbits(2) val = inp.getbits() inp.addbits(16) return val & 0xFFFF else: inp.addbits(2) val = (inp.getbits() << 16) & _M32 inp.addbits(16) val |= inp.getbits() inp.addbits(16) return val & _M32
Classes
class Unpack30 (data, unp_size, win_size, solid=False)-
RAR 3.0 decompression engine.
Expand source code Browse git
class Unpack30(RarUnpacker): """ RAR 3.0 decompression engine. """ def __init__( self, data: bytes | memoryview, unp_size: int, win_size: int, solid: bool = False, ): self._inp = BitInput(data) self._dest_size = unp_size self._win_size = max(win_size, 0x40000) self._win_mask = self._win_size - 1 self._window = bytearray(self._win_size) self._solid = solid self._old_dist = [0, 0, 0, 0] self._last_length = 0 self._unp_ptr = 0 self._wr_ptr = 0 self._written = 0 self._output = bytearray() self._tables_read = False self._block_type = BLOCK_LZ self._block_tables = BlockTables() self._unp_old_table = bytearray(HUFF_TABLE_SIZE30) self._ppm = _ModelPPM() self._ppm_esc_char = 2 self._prev_low_dist = 0 self._low_dist_rep_count = 0 self._filters: list[_UnpackFilter30] = [] self._old_filter_lengths: list[int] = [] self._prgstack: list[_UnpackFilter30 | None] = [] self._last_filter = 0 def _get_char(self) -> int: inp = self._inp if inp.in_addr >= len(inp.buf): return 0 ch = inp.buf[inp.in_addr] inp.in_addr += 1 return ch def _write_buf(self): written_border = self._wr_ptr mask = self._win_mask write_size = (self._unp_ptr - written_border) & mask for i, flt in enumerate(self._prgstack): if flt is None: continue if flt.next_window: flt.next_window = False continue block_start = flt.block_start block_length = flt.block_length if ((block_start - written_border) & mask) < write_size: if written_border != block_start: self._write_area(written_border, block_start) written_border = block_start write_size = (self._unp_ptr - written_border) & mask if block_length <= write_size: block_end = (block_start + block_length) & mask mem = bytearray(block_length) if block_start < block_end or block_end == 0: mem[:] = self._window[block_start:block_start + block_length] else: first = self._win_size - block_start mem[:first] = self._window[block_start:] mem[first:] = self._window[:block_end] flt.init_r[6] = self._written & _M32 if flt.prg_type != V3FilterType.VMSF_NONE: out_mem = execute_v3_filter(flt.prg_type, mem, block_length, flt.init_r) else: out_mem = mem self._prgstack[i] = None while i + 1 < len(self._prgstack): nf = self._prgstack[i + 1] if nf is None or nf.block_start != block_start or nf.block_length != len(out_mem) or nf.next_window: break nf.init_r[6] = self._written & _M32 if nf.prg_type != V3FilterType.VMSF_NONE: out_mem = execute_v3_filter(nf.prg_type, out_mem, len(out_mem), nf.init_r) i += 1 self._prgstack[i] = None self._write_data(out_mem) written_border = block_end write_size = (self._unp_ptr - written_border) & mask else: for j in range(i, len(self._prgstack)): f2 = self._prgstack[j] if f2 is not None and f2.next_window: f2.next_window = False self._wr_ptr = written_border return self._write_area(written_border, self._unp_ptr) self._wr_ptr = self._unp_ptr def _read_end_of_block(self) -> bool: bit_field = self._inp.getbits() if bit_field & 0x8000: new_table = True new_file = False self._inp.addbits(1) else: new_file = True new_table = bool(bit_field & 0x4000) self._inp.addbits(2) self._tables_read = not new_table if new_file: return False return self._read_tables() def _read_vm_code(self) -> bool: inp = self._inp first_byte = inp.getbits() >> 8 inp.addbits(8) length = (first_byte & 7) + 1 if length == 7: length = (inp.getbits() >> 8) + 7 inp.addbits(8) elif length == 8: length = inp.getbits() inp.addbits(16) if length == 0: return False vm_code = bytearray(length) for ii in range(length): vm_code[ii] = (inp.getbits() >> 8) & 0xFF inp.addbits(8) return self._add_vm_code(first_byte, vm_code) def _read_vm_code_ppm(self) -> bool: first_byte = self._ppm.decode_char() if first_byte < 0: return False length = (first_byte & 7) + 1 if length == 7: b1 = self._ppm.decode_char() if b1 < 0: return False length = b1 + 7 elif length == 8: b1 = self._ppm.decode_char() if b1 < 0: return False b2 = self._ppm.decode_char() if b2 < 0: return False length = b1 * 256 + b2 if length == 0: return False vm_code = bytearray(length) for ii in range(length): ch = self._ppm.decode_char() if ch < 0: return False vm_code[ii] = ch & 0xFF return self._add_vm_code(first_byte, vm_code) def _add_vm_code(self, first_byte: int, code: bytearray) -> bool: vm_inp = BitInput(code) if first_byte & 0x80: filt_pos = vm_read_data(vm_inp) if filt_pos == 0: self._init_filters(False) filt_pos = 0 else: filt_pos -= 1 else: filt_pos = self._last_filter if filt_pos > len(self._filters) or filt_pos > len(self._old_filter_lengths): return False self._last_filter = filt_pos new_filter = (filt_pos == len(self._filters)) stack_filter = _UnpackFilter30() if new_filter: if filt_pos > MAX3_UNPACK_FILTERS: return False parent = _UnpackFilter30() self._filters.append(parent) stack_filter.parent_filter = len(self._filters) - 1 self._old_filter_lengths.append(0) else: stack_filter.parent_filter = filt_pos empty_count = sum(1 for x in self._prgstack if x is None) self._prgstack = [x for x in self._prgstack if x is not None] if not empty_count: if len(self._prgstack) > MAX3_UNPACK_FILTERS: return False self._prgstack.append(stack_filter) block_start = vm_read_data(vm_inp) if first_byte & 0x40: block_start += 258 stack_filter.block_start = (block_start + self._unp_ptr) & self._win_mask if first_byte & 0x20: stack_filter.block_length = vm_read_data(vm_inp) if filt_pos < len(self._old_filter_lengths): self._old_filter_lengths[filt_pos] = stack_filter.block_length else: stack_filter.block_length = self._old_filter_lengths[filt_pos] if filt_pos < len(self._old_filter_lengths) else 0 stack_filter.next_window = (self._wr_ptr != self._unp_ptr and ((self._wr_ptr - self._unp_ptr) & self._win_mask) <= block_start) stack_filter.init_r = [0] * 8 stack_filter.init_r[4] = stack_filter.block_length if first_byte & 0x10: init_mask = vm_inp.getbits() >> 9 vm_inp.addbits(7) for ii in range(7): if init_mask & (1 << ii): stack_filter.init_r[ii] = vm_read_data(vm_inp) if new_filter: vm_code_size = vm_read_data(vm_inp) if vm_code_size >= 0x10000 or vm_code_size == 0: return False if vm_inp.in_addr + vm_code_size > len(vm_inp.buf): return False vm_code_data = bytearray(vm_code_size) for ii in range(vm_code_size): vm_code_data[ii] = (vm_inp.getbits() >> 8) & 0xFF vm_inp.addbits(8) code_crc = zlib.crc32(vm_code_data) & 0xFFFFFFFF parent = self._filters[filt_pos] parent.prg_type = identify_v3_filter(code_crc) stack_filter.prg_type = self._filters[stack_filter.parent_filter].prg_type return True def _init_filters(self, solid: bool): if not solid: self._old_filter_lengths.clear() self._last_filter = 0 self._filters.clear() self._prgstack.clear() def _read_tables(self) -> bool: inp = self._inp inp.addbits((8 - inp.in_bit) & 7) bit_field = inp.getbits() if bit_field & 0x8000: self._block_type = BLOCK_PPM esc_ref = [self._ppm_esc_char] result = self._ppm.decode_init(self._get_char, esc_ref) self._ppm_esc_char = esc_ref[0] return result self._block_type = BLOCK_LZ self._prev_low_dist = 0 self._low_dist_rep_count = 0 if not (bit_field & 0x4000): self._unp_old_table = bytearray(HUFF_TABLE_SIZE30) inp.addbits(2) bit_length = bytearray(BC30) i = 0 while i < BC30: length = (inp.getbits() >> 12) & 0xF inp.addbits(4) if length == 15: zero_count = (inp.getbits() >> 12) & 0xF inp.addbits(4) if zero_count == 0: bit_length[i] = 15 else: zero_count += 2 while zero_count > 0 and i < BC30: bit_length[i] = 0 i += 1 zero_count -= 1 continue else: bit_length[i] = length i += 1 make_decode_tables(bit_length, self._block_tables.BD, BC30) table = bytearray(HUFF_TABLE_SIZE30) i = 0 while i < HUFF_TABLE_SIZE30: number = decode_number(inp, self._block_tables.BD) if number < 16: table[i] = (number + self._unp_old_table[i]) & 0xF i += 1 elif number < 18: if number == 16: n = ((inp.getbits() >> 13) & 7) + 3 inp.addbits(3) else: n = ((inp.getbits() >> 9) & 0x7F) + 11 inp.addbits(7) if i == 0: return False while n > 0 and i < HUFF_TABLE_SIZE30: table[i] = table[i - 1] i += 1 n -= 1 else: if number == 18: n = ((inp.getbits() >> 13) & 7) + 3 inp.addbits(3) else: n = ((inp.getbits() >> 9) & 0x7F) + 11 inp.addbits(7) while n > 0 and i < HUFF_TABLE_SIZE30: table[i] = 0 i += 1 n -= 1 self._tables_read = True make_decode_tables(table, self._block_tables.LD, NC30) off = NC30 make_decode_tables(table[off:], self._block_tables.DD, DC30) off += DC30 make_decode_tables(table[off:], self._block_tables.LDD, LDC30) off += LDC30 make_decode_tables(table[off:], self._block_tables.RD, RC30) self._unp_old_table[:] = table return True def init_solid(self, data: bytes | memoryview, dest_size: int): """ Reinitialize for the next file in a solid archive chain. """ super().init_solid(data, dest_size) self._init_filters(True) def decompress(self) -> bytearray: """ Run the RAR3 decompression and return the extracted data. """ inp = self._inp mask = self._win_mask win = self._window if not self._solid: self._tables_read = False self._unp_old_table = bytearray(HUFF_TABLE_SIZE30) self._ppm_esc_char = 2 self._block_type = BLOCK_LZ self._init_filters(False) if (not self._solid or not self._tables_read) and not self._read_tables(): return self._output tbl = self._block_tables inp_len = len(inp.buf) while True: self._unp_ptr &= mask if inp.in_addr >= inp_len: break if ((self._wr_ptr - self._unp_ptr) & mask) < 260 and self._wr_ptr != self._unp_ptr: self._write_buf() if self._written > self._dest_size: return self._output if self._block_type == BLOCK_PPM: ch = self._ppm.decode_char() if ch < 0: self._ppm.cleanup() self._block_type = BLOCK_LZ break if ch == self._ppm_esc_char: next_ch = self._ppm.decode_char() if next_ch < 0: break if next_ch == 0: if not self._read_tables(): break continue if next_ch == 2: break if next_ch == 3: if not self._read_vm_code_ppm(): break continue if next_ch == 4: distance = 0 failed = False length = 0 for ii in range(4): c = self._ppm.decode_char() if c < 0: failed = True break if ii == 3: length = c & 0xFF else: distance = (distance << 8) + (c & 0xFF) if failed: break self._copy_string(length + 32, distance + 2) continue if next_ch == 5: ll = self._ppm.decode_char() if ll < 0: break self._copy_string(ll + 4, 1) continue win[self._unp_ptr] = ch & 0xFF self._unp_ptr = (self._unp_ptr + 1) & mask continue number = decode_number(inp, tbl.LD) if number < 256: win[self._unp_ptr] = number & 0xFF self._unp_ptr = (self._unp_ptr + 1) & mask continue if number >= 271: num = number - 271 length = LDecode[num] + 3 bits = LBits[num] if bits > 0: length += inp.getbits() >> (16 - bits) inp.addbits(bits) dist_number = decode_number(inp, tbl.DD) distance = _DDecode[dist_number] + 1 d_bits = _DBits[dist_number] if d_bits > 0: if dist_number > 9: if d_bits > 4: distance += (inp.getbits() >> (20 - d_bits)) << 4 inp.addbits(d_bits - 4) if self._low_dist_rep_count > 0: self._low_dist_rep_count -= 1 distance += self._prev_low_dist else: low_dist = decode_number(inp, tbl.LDD) if low_dist == 16: self._low_dist_rep_count = LOW_DIST_REP_COUNT - 1 distance += self._prev_low_dist else: distance += low_dist self._prev_low_dist = low_dist else: distance += inp.getbits() >> (16 - d_bits) inp.addbits(d_bits) if distance >= 0x2000: length += 1 if distance >= 0x40000: length += 1 self._insert_old_dist(distance) self._last_length = length self._copy_string(length, distance) continue if number == 256: if not self._read_end_of_block(): break continue if number == 257: if not self._read_vm_code(): break continue if number == 258: if self._last_length != 0: self._copy_string(self._last_length, self._old_dist[0]) continue if number < 263: dist_num = number - 259 distance = self._old_dist[dist_num] for idx in range(dist_num, 0, -1): self._old_dist[idx] = self._old_dist[idx - 1] self._old_dist[0] = distance length_number = decode_number(inp, tbl.RD) length = LDecode[length_number] + 2 bits = LBits[length_number] if bits > 0: length += inp.getbits() >> (16 - bits) inp.addbits(bits) self._last_length = length self._copy_string(length, distance) continue if number < 272: num = number - 263 distance = SDDecode[num] + 1 bits = SDBits[num] if bits > 0: distance += inp.getbits() >> (16 - bits) inp.addbits(bits) self._insert_old_dist(distance) self._last_length = 2 self._copy_string(2, distance) continue self._write_buf() return self._outputAncestors
Methods
def decompress(self)-
Run the RAR3 decompression and return the extracted data.
Expand source code Browse git
def decompress(self) -> bytearray: """ Run the RAR3 decompression and return the extracted data. """ inp = self._inp mask = self._win_mask win = self._window if not self._solid: self._tables_read = False self._unp_old_table = bytearray(HUFF_TABLE_SIZE30) self._ppm_esc_char = 2 self._block_type = BLOCK_LZ self._init_filters(False) if (not self._solid or not self._tables_read) and not self._read_tables(): return self._output tbl = self._block_tables inp_len = len(inp.buf) while True: self._unp_ptr &= mask if inp.in_addr >= inp_len: break if ((self._wr_ptr - self._unp_ptr) & mask) < 260 and self._wr_ptr != self._unp_ptr: self._write_buf() if self._written > self._dest_size: return self._output if self._block_type == BLOCK_PPM: ch = self._ppm.decode_char() if ch < 0: self._ppm.cleanup() self._block_type = BLOCK_LZ break if ch == self._ppm_esc_char: next_ch = self._ppm.decode_char() if next_ch < 0: break if next_ch == 0: if not self._read_tables(): break continue if next_ch == 2: break if next_ch == 3: if not self._read_vm_code_ppm(): break continue if next_ch == 4: distance = 0 failed = False length = 0 for ii in range(4): c = self._ppm.decode_char() if c < 0: failed = True break if ii == 3: length = c & 0xFF else: distance = (distance << 8) + (c & 0xFF) if failed: break self._copy_string(length + 32, distance + 2) continue if next_ch == 5: ll = self._ppm.decode_char() if ll < 0: break self._copy_string(ll + 4, 1) continue win[self._unp_ptr] = ch & 0xFF self._unp_ptr = (self._unp_ptr + 1) & mask continue number = decode_number(inp, tbl.LD) if number < 256: win[self._unp_ptr] = number & 0xFF self._unp_ptr = (self._unp_ptr + 1) & mask continue if number >= 271: num = number - 271 length = LDecode[num] + 3 bits = LBits[num] if bits > 0: length += inp.getbits() >> (16 - bits) inp.addbits(bits) dist_number = decode_number(inp, tbl.DD) distance = _DDecode[dist_number] + 1 d_bits = _DBits[dist_number] if d_bits > 0: if dist_number > 9: if d_bits > 4: distance += (inp.getbits() >> (20 - d_bits)) << 4 inp.addbits(d_bits - 4) if self._low_dist_rep_count > 0: self._low_dist_rep_count -= 1 distance += self._prev_low_dist else: low_dist = decode_number(inp, tbl.LDD) if low_dist == 16: self._low_dist_rep_count = LOW_DIST_REP_COUNT - 1 distance += self._prev_low_dist else: distance += low_dist self._prev_low_dist = low_dist else: distance += inp.getbits() >> (16 - d_bits) inp.addbits(d_bits) if distance >= 0x2000: length += 1 if distance >= 0x40000: length += 1 self._insert_old_dist(distance) self._last_length = length self._copy_string(length, distance) continue if number == 256: if not self._read_end_of_block(): break continue if number == 257: if not self._read_vm_code(): break continue if number == 258: if self._last_length != 0: self._copy_string(self._last_length, self._old_dist[0]) continue if number < 263: dist_num = number - 259 distance = self._old_dist[dist_num] for idx in range(dist_num, 0, -1): self._old_dist[idx] = self._old_dist[idx - 1] self._old_dist[0] = distance length_number = decode_number(inp, tbl.RD) length = LDecode[length_number] + 2 bits = LBits[length_number] if bits > 0: length += inp.getbits() >> (16 - bits) inp.addbits(bits) self._last_length = length self._copy_string(length, distance) continue if number < 272: num = number - 263 distance = SDDecode[num] + 1 bits = SDBits[num] if bits > 0: distance += inp.getbits() >> (16 - bits) inp.addbits(bits) self._insert_old_dist(distance) self._last_length = 2 self._copy_string(2, distance) continue self._write_buf() return self._output
Inherited members