Module refinery.lib.decompression
Expand source code Browse git
from __future__ import annotations
from typing import List, Optional, Union
from collections import Counter
from itertools import repeat
from refinery.lib.structures import StructReader
DECODE_TABLE_SYMBOL_SHIFT = 4
DECODE_TABLE_MAX_SYMBOL = ((1 << (16 - DECODE_TABLE_SYMBOL_SHIFT)) - 1)
DECODE_TABLE_MAX_LENGTH = ((1 << DECODE_TABLE_SYMBOL_SHIFT) - 1)
DECODE_TABLE_LENGTH_MASK = DECODE_TABLE_MAX_LENGTH
def MAKE_DECODE_TABLE_ENTRY(symbol, length):
v = ((symbol << DECODE_TABLE_SYMBOL_SHIFT) | length)
assert v & 0xFFFF == v
return v
def s32shift(k: int, shift: int):
"""
This helper function implements a signed left shift for 32bit integers.
"""
M = 1 << 32
shift %= 32
k = k * (1 << shift) % (1 << 32)
return k - M if k >> 31 else k
def make_huffman_decode_table(
table_data: bytearray,
table_bits: int,
max_codeword_len: int,
) -> List[int]:
remainder = 1
codeword_length = 1
entry_pos = 0
decode_table: List[int] = [0] * (1 << table_bits)
sym_count = len(table_data)
len_counts = Counter(table_data)
for sym in range(1, max_codeword_len + 1):
remainder = (remainder << 1) - len_counts[sym]
if remainder < 0:
raise OverflowError('Lengths have overflowed the code space.')
if remainder:
if remainder != 1 << max_codeword_len:
raise RuntimeError('Incomplete & nonempty code encountered.')
return decode_table
offsets = [0]
for sym in range(max_codeword_len):
offsets.append(offsets[sym] + len_counts[sym])
sorted_syms = {}
for i, sym in enumerate(table_data):
offset = offsets[sym]
offsets[sym] += 1
sorted_syms[offset] = i
sym_index = offsets[0]
stores_per_loop = 1 << (table_bits - codeword_length)
while stores_per_loop:
end_sym_idx = sym_index + len_counts[codeword_length]
for k in range(sym_index, end_sym_idx):
entry_end = entry_pos + stores_per_loop
decode_table[entry_pos:entry_end] = repeat(
MAKE_DECODE_TABLE_ENTRY(sorted_syms[k], codeword_length),
stores_per_loop)
entry_pos = entry_end
codeword_length += 1
sym_index = end_sym_idx
stores_per_loop >>= 1
assert sym_index <= sym_count
if sym_index == sym_count:
return decode_table
codeword = entry_pos * 2
subtable_pos = 1 << table_bits
subtable_bits = table_bits
subtable_prefix = -1
while sym_index < sym_count:
while len_counts[codeword_length] == 0:
if codeword_length > sym_count:
raise IndexError('Error computing codeword')
codeword_length += 1
codeword <<= 1
prefix = codeword >> (codeword_length - table_bits)
if prefix != subtable_prefix:
subtable_prefix = prefix
subtable_bits = codeword_length - table_bits
remainder = s32shift(1, subtable_bits)
while True:
remainder -= len_counts[table_bits + subtable_bits]
if remainder <= 0:
break
subtable_bits += 1
remainder <<= 1
decode_table[subtable_prefix] = MAKE_DECODE_TABLE_ENTRY(subtable_pos, subtable_bits)
entry = MAKE_DECODE_TABLE_ENTRY(sorted_syms[sym_index], codeword_length - table_bits)
count = 1 << (table_bits + subtable_bits - codeword_length)
end = subtable_pos + count
decode_table[subtable_pos:end] = repeat(entry, count)
subtable_pos = end
len_counts[codeword_length] -= 1
codeword += 1
sym_index += 1
return decode_table
class BitBufferedReader:
"""
A helper class to read bitwise from the compressed input stream.
"""
def __init__(self, buffer: Union[bytearray, StructReader], bits_per_read: int = 32):
if not isinstance(buffer, StructReader):
buffer = StructReader(memoryview(buffer), bigendian=False)
self._reader: StructReader[memoryview] = buffer
self._bit_buffer_data: int = 0
self._bit_buffer_size: int = 0
self._bits_per_read = bits_per_read
def variable_length_integer(self) -> int:
value = 1
while True:
chunk = self.read(2)
value = (value << 1) + (chunk >> 1)
if not chunk & 1:
return value
@property
def overshoot(self) -> int:
return self._bit_buffer_size // 8
def __getattr__(self, k):
return getattr(self._reader, k)
def __next__(self) -> int:
return self.read(1)
def next(self):
return self.read(1)
def peek(self, count: int):
return self._bit_buffer_data >> self.collect(count)
def __len__(self):
return self._bit_buffer_size
def __getitem__(self, k: int):
if k not in range(self._bit_buffer_size):
raise IndexError(k)
offset = self._bit_buffer_size - k
return (self._bit_buffer_data >> offset) & 1
def __enter__(self):
return self
def __exit__(self, *_):
return False
def read(self, count: int) -> int:
offset = self.collect(count)
bits = self._bit_buffer_data >> offset
self._bit_buffer_data ^= bits << offset
self._bit_buffer_size -= count
assert self._bit_buffer_data.bit_length() <= self._bit_buffer_size
assert bits.bit_length() <= count
return bits
def collect(self, count: Optional[int] = None) -> int:
if count is None:
count = self._bits_per_read
offset = self._bit_buffer_size - count
if offset < 0:
more = count - self._bit_buffer_size
reads, _r = divmod(more, self._bits_per_read)
reads += int(bool(_r))
reads *= self._bits_per_read
self._bit_buffer_data <<= reads
self._bit_buffer_data |= self._reader.read_integer(reads)
self._bit_buffer_size += reads
offset += reads
assert offset >= 0
return offset
def align(self):
self._bit_buffer_size = 0
self._bit_buffer_data = 0
def read_huffman_symbol(reader: BitBufferedReader, decode_table: List[int], table_bits: int, max_codeword_len: int):
reader.collect(max_codeword_len)
entry = decode_table[reader.peek(table_bits)]
symbol = entry >> DECODE_TABLE_SYMBOL_SHIFT
length = entry & DECODE_TABLE_LENGTH_MASK
if max_codeword_len > table_bits and entry >= (1 << (table_bits + DECODE_TABLE_SYMBOL_SHIFT)):
reader.read(table_bits)
entry = decode_table[symbol + reader.peek(length)]
symbol = entry >> DECODE_TABLE_SYMBOL_SHIFT
length = entry & DECODE_TABLE_LENGTH_MASK
reader.read(length)
return symbol
Functions
def MAKE_DECODE_TABLE_ENTRY(symbol, length)
-
Expand source code Browse git
def MAKE_DECODE_TABLE_ENTRY(symbol, length): v = ((symbol << DECODE_TABLE_SYMBOL_SHIFT) | length) assert v & 0xFFFF == v return v
def s32shift(k, shift)
-
This helper function implements a signed left shift for 32bit integers.
Expand source code Browse git
def s32shift(k: int, shift: int): """ This helper function implements a signed left shift for 32bit integers. """ M = 1 << 32 shift %= 32 k = k * (1 << shift) % (1 << 32) return k - M if k >> 31 else k
def make_huffman_decode_table(table_data, table_bits, max_codeword_len)
-
Expand source code Browse git
def make_huffman_decode_table( table_data: bytearray, table_bits: int, max_codeword_len: int, ) -> List[int]: remainder = 1 codeword_length = 1 entry_pos = 0 decode_table: List[int] = [0] * (1 << table_bits) sym_count = len(table_data) len_counts = Counter(table_data) for sym in range(1, max_codeword_len + 1): remainder = (remainder << 1) - len_counts[sym] if remainder < 0: raise OverflowError('Lengths have overflowed the code space.') if remainder: if remainder != 1 << max_codeword_len: raise RuntimeError('Incomplete & nonempty code encountered.') return decode_table offsets = [0] for sym in range(max_codeword_len): offsets.append(offsets[sym] + len_counts[sym]) sorted_syms = {} for i, sym in enumerate(table_data): offset = offsets[sym] offsets[sym] += 1 sorted_syms[offset] = i sym_index = offsets[0] stores_per_loop = 1 << (table_bits - codeword_length) while stores_per_loop: end_sym_idx = sym_index + len_counts[codeword_length] for k in range(sym_index, end_sym_idx): entry_end = entry_pos + stores_per_loop decode_table[entry_pos:entry_end] = repeat( MAKE_DECODE_TABLE_ENTRY(sorted_syms[k], codeword_length), stores_per_loop) entry_pos = entry_end codeword_length += 1 sym_index = end_sym_idx stores_per_loop >>= 1 assert sym_index <= sym_count if sym_index == sym_count: return decode_table codeword = entry_pos * 2 subtable_pos = 1 << table_bits subtable_bits = table_bits subtable_prefix = -1 while sym_index < sym_count: while len_counts[codeword_length] == 0: if codeword_length > sym_count: raise IndexError('Error computing codeword') codeword_length += 1 codeword <<= 1 prefix = codeword >> (codeword_length - table_bits) if prefix != subtable_prefix: subtable_prefix = prefix subtable_bits = codeword_length - table_bits remainder = s32shift(1, subtable_bits) while True: remainder -= len_counts[table_bits + subtable_bits] if remainder <= 0: break subtable_bits += 1 remainder <<= 1 decode_table[subtable_prefix] = MAKE_DECODE_TABLE_ENTRY(subtable_pos, subtable_bits) entry = MAKE_DECODE_TABLE_ENTRY(sorted_syms[sym_index], codeword_length - table_bits) count = 1 << (table_bits + subtable_bits - codeword_length) end = subtable_pos + count decode_table[subtable_pos:end] = repeat(entry, count) subtable_pos = end len_counts[codeword_length] -= 1 codeword += 1 sym_index += 1 return decode_table
def read_huffman_symbol(reader, decode_table, table_bits, max_codeword_len)
-
Expand source code Browse git
def read_huffman_symbol(reader: BitBufferedReader, decode_table: List[int], table_bits: int, max_codeword_len: int): reader.collect(max_codeword_len) entry = decode_table[reader.peek(table_bits)] symbol = entry >> DECODE_TABLE_SYMBOL_SHIFT length = entry & DECODE_TABLE_LENGTH_MASK if max_codeword_len > table_bits and entry >= (1 << (table_bits + DECODE_TABLE_SYMBOL_SHIFT)): reader.read(table_bits) entry = decode_table[symbol + reader.peek(length)] symbol = entry >> DECODE_TABLE_SYMBOL_SHIFT length = entry & DECODE_TABLE_LENGTH_MASK reader.read(length) return symbol
Classes
class BitBufferedReader (buffer, bits_per_read=32)
-
A helper class to read bitwise from the compressed input stream.
Expand source code Browse git
class BitBufferedReader: """ A helper class to read bitwise from the compressed input stream. """ def __init__(self, buffer: Union[bytearray, StructReader], bits_per_read: int = 32): if not isinstance(buffer, StructReader): buffer = StructReader(memoryview(buffer), bigendian=False) self._reader: StructReader[memoryview] = buffer self._bit_buffer_data: int = 0 self._bit_buffer_size: int = 0 self._bits_per_read = bits_per_read def variable_length_integer(self) -> int: value = 1 while True: chunk = self.read(2) value = (value << 1) + (chunk >> 1) if not chunk & 1: return value @property def overshoot(self) -> int: return self._bit_buffer_size // 8 def __getattr__(self, k): return getattr(self._reader, k) def __next__(self) -> int: return self.read(1) def next(self): return self.read(1) def peek(self, count: int): return self._bit_buffer_data >> self.collect(count) def __len__(self): return self._bit_buffer_size def __getitem__(self, k: int): if k not in range(self._bit_buffer_size): raise IndexError(k) offset = self._bit_buffer_size - k return (self._bit_buffer_data >> offset) & 1 def __enter__(self): return self def __exit__(self, *_): return False def read(self, count: int) -> int: offset = self.collect(count) bits = self._bit_buffer_data >> offset self._bit_buffer_data ^= bits << offset self._bit_buffer_size -= count assert self._bit_buffer_data.bit_length() <= self._bit_buffer_size assert bits.bit_length() <= count return bits def collect(self, count: Optional[int] = None) -> int: if count is None: count = self._bits_per_read offset = self._bit_buffer_size - count if offset < 0: more = count - self._bit_buffer_size reads, _r = divmod(more, self._bits_per_read) reads += int(bool(_r)) reads *= self._bits_per_read self._bit_buffer_data <<= reads self._bit_buffer_data |= self._reader.read_integer(reads) self._bit_buffer_size += reads offset += reads assert offset >= 0 return offset def align(self): self._bit_buffer_size = 0 self._bit_buffer_data = 0
Instance variables
var overshoot
-
Expand source code Browse git
@property def overshoot(self) -> int: return self._bit_buffer_size // 8
Methods
def variable_length_integer(self)
-
Expand source code Browse git
def variable_length_integer(self) -> int: value = 1 while True: chunk = self.read(2) value = (value << 1) + (chunk >> 1) if not chunk & 1: return value
def next(self)
-
Expand source code Browse git
def next(self): return self.read(1)
def peek(self, count)
-
Expand source code Browse git
def peek(self, count: int): return self._bit_buffer_data >> self.collect(count)
def read(self, count)
-
Expand source code Browse git
def read(self, count: int) -> int: offset = self.collect(count) bits = self._bit_buffer_data >> offset self._bit_buffer_data ^= bits << offset self._bit_buffer_size -= count assert self._bit_buffer_data.bit_length() <= self._bit_buffer_size assert bits.bit_length() <= count return bits
def collect(self, count=None)
-
Expand source code Browse git
def collect(self, count: Optional[int] = None) -> int: if count is None: count = self._bits_per_read offset = self._bit_buffer_size - count if offset < 0: more = count - self._bit_buffer_size reads, _r = divmod(more, self._bits_per_read) reads += int(bool(_r)) reads *= self._bits_per_read self._bit_buffer_data <<= reads self._bit_buffer_data |= self._reader.read_integer(reads) self._bit_buffer_size += reads offset += reads assert offset >= 0 return offset
def align(self)
-
Expand source code Browse git
def align(self): self._bit_buffer_size = 0 self._bit_buffer_data = 0