diff --git a/data/reference_implementation.py b/data/reference_implementation.py new file mode 100644 index 0000000..7ebac67 --- /dev/null +++ b/data/reference_implementation.py @@ -0,0 +1,283 @@ +# This is a reference implementation of parsing, serializing, and deserializing ChocolateBird's struct definitions. +# It is ported from the original GDscript to Python. +import logging +from io import TextIOWrapper +from struct import pack_into, unpack_from, calcsize +from collections.abc import Buffer + + +class LeftoverBits: + number_bits = 0 + bit_buffer = 0 # Stored as an integer + + +class ReadBuffer: + backing_buffer: Buffer + position: int = 0 + + def __init__(self, backing_buffer, position: int = 0) -> None: + self.backing_buffer = backing_buffer + self.position = position + + def get(self, format) -> tuple: + values = unpack_from(format, self.backing_buffer, self.position) + self.position += calcsize(format) # TODO: cache this + return values + + +class WriteBuffer: + backing_buffer: Buffer + position: int = 0 + + def __init__(self, backing_buffer, position: int = 0) -> None: + self.backing_buffer = backing_buffer + self.position = position + + def put(self, format, *values) -> None: + pack_into(format, self.backing_buffer, self.position, *values) + self.position += calcsize(format) # TODO: cache this + + +class StructType: + name: str = None + def __repr__(self) -> str: + if self.name: + return self.name + return super().__repr__() + + def get_value(self, buffer: ReadBuffer, leftover_bits: LeftoverBits): + raise NotImplementedError('Deserialization not implemented') + def put_value(self, buffer: WriteBuffer, value, leftover_bits: LeftoverBits): + raise NotImplementedError('Serialization not implemented') + + +class SimpleStruct(StructType): + format: str + def __init__(self, format, name=None) -> None: + self.format = format + self.name = name + def get_value(self, buffer: ReadBuffer, leftover_bits: LeftoverBits): + return buffer.get(self.format)[0] + def put_value(self, buffer: WriteBuffer, value, leftover_bits: LeftoverBits): + buffer.put(self.format, value) + + +class U24(StructType): + name = 'u24' + def get_value(self, buffer: ReadBuffer, leftover_bits: LeftoverBits): + u16, u8 = buffer.get('> 16) + + +class S24(StructType): + name = 's24' + def get_value(self, buffer: ReadBuffer, leftover_bits: LeftoverBits): + u16, u8 = buffer.get('> 16) + + +class UBits(StructType): + bits = 8 + + def __init__(self, bits: int): + self.bits = bits + self.name = f'u{bits:d}' + + def get_value(self, buffer: ReadBuffer, leftover_bits: LeftoverBits): + while leftover_bits.number_bits < self.bits: + leftover_bits.bit_buffer |= buffer.get('> self.bits + leftover_bits.number_bits -= self.bits + return value + + def put_value(self, buffer: WriteBuffer, value, leftover_bits: LeftoverBits): + leftover_bits.bit_buffer |= value << leftover_bits.number_bits + leftover_bits.number_bits += self.bits + while leftover_bits.number_bits >= 8: + buffer.put('> 8 + + +class Struct(StructType): + members: list # Array of [name, StructType] + + def __init__(self, name=None) -> None: + self.name = name + self.members = [] + + def get_value(self, buffer: ReadBuffer, leftover_bits: LeftoverBits): + return {key: struct_type.get_value(buffer, leftover_bits) for key, struct_type in self.members} + + def put_value(self, buffer: WriteBuffer, value, leftover_bits: LeftoverBits): + for key, struct_type in self.members: + if not (key in value): + logging.error(f'Key "{key}" missing from value supplied') + return + struct_type.put_value(buffer, value[key], leftover_bits) + + +class StructArrayType(StructType): + count: int + struct_type: StructType + + def __init__(self, count, struct_type, name=None) -> None: + self.count = count + self.struct_type = struct_type + self.name = name + + def get_value(self, buffer: ReadBuffer, leftover_bits: LeftoverBits) -> list: + return [self.struct_type.get_value(buffer, leftover_bits) for i in range(self.count)] + + def put_value(self, buffer: WriteBuffer, value, leftover_bits: LeftoverBits) -> None: + if len(value) < self.count: + logging.error('Not enough values supplied') + return + for i in range(self.count): + self.struct_type.put_value(buffer, value[i], leftover_bits) + + +def get_base_structarraytypes() -> dict: + return { + 'u8': SimpleStruct(' 0: + inner_type = UBits(b) + existing_structs[f'u{b:d}'] = inner_type # Cache it for future use + if not inner_type: + logging.error(f'typestring "{type}" has no matches for "{t}" in existing structs') + return + + l: int = len(tokens) + if l == 1: + return inner_type + # Our parsing goal is to turn 'a of b of c of d' into StructArrayType, b>, a> + # Our strategy is to parse backwards over the tokens, changing inner_type at each point + # a of b of c of (d) + # a of b of (c of d) + # a of (b of c of d) + # (a of b of c of d) + # done + i: int = l-2 + while i > -1: + match tokens[i]: + case 'of': + i -= 1 + l1: int = int(tokens[i]) + if l1 > 1: + inner_type = StructArrayType(l1, inner_type, name=type) # Might be worth caching these later on if we use them more + i -= 1 + case k: + logging.error(f'Invalid keyword used in type designator: "{k}"') + return + return inner_type + + +def parse_struct_definitions_from_tsv_file(tsv_file: TextIOWrapper, existing_structs: dict) -> None: + current_struct: Struct + lines = tsv_file.read().rstrip().split('\n') + for line in lines: + # logging.debug(line) + tokens = line.split('\t') + size = len(tokens) + if size < 2: + continue + # Size is at least 2 + type, label = tokens[:2] + if type == 'struct': + # New struct declaration + current_struct = Struct(name=label) + existing_structs[label] = current_struct + elif type and label: + current_struct.members.append([label, get_structarraytype(type, existing_structs)]) + # TODO: Maybe store the trailing comments somewhere? + + +def parse_struct_definitions_from_tsv_filename(filename: str, existing_structs: dict) -> None: + with open(filename, 'r') as file: + parse_struct_definitions_from_tsv_file(file, existing_structs) + + +def load_ff5_snes_struct_definitions() -> dict: + existing_structs = get_base_structarraytypes() + parse_struct_definitions_from_tsv_filename('structs_SNES_stubs.tsv', existing_structs) + parse_struct_definitions_from_tsv_filename('5/structs_SNES_stubs.tsv', existing_structs) + parse_struct_definitions_from_tsv_filename('5/structs_SNES.tsv', existing_structs) + parse_struct_definitions_from_tsv_filename('5/structs_SNES_save.tsv', existing_structs) + return existing_structs + + +# Basic TSV dumper +# This is mostly unrelated, but helpful for debugging. +def flatten_keys(d: dict, prefix: str = '') -> dict: + output = {} + for k, v in d.items(): + if isinstance(v, dict): + flat = flatten_keys(v, f'{prefix}{k}.') + for k2, v2 in flat.items(): + output[k2] = v2 + else: + output[f'{prefix}{k}'] = v + return output + + +def dump_tsv(filename, table): + table_flat = [flatten_keys(d) for d in table] + hex_digits = len(f'{len(table_flat)-1:X}') # See how long the hex representation of the last number will be, so we can zero-pad the rest to match. + hex_format = f'0{hex_digits}X' + + with open(filename, 'w') as file: + headers = list(table_flat[0].keys()) + file.write('\t'.join(['ID'] + headers) + '\n') + for i, entry in enumerate(table_flat): + file.write('\t'.join([f'0x{i:{hex_format}}'] + [str(entry[key]) for key in headers]) + '\n') + + +# Example usage: run this script with a positional filename argument for a FF5 .sfc and it will load EnemyStats from it. +if __name__ == '__main__': + existing_structs = load_ff5_snes_struct_definitions() + leftover_bits = LeftoverBits() + + from sys import argv + if len(argv) > 1: + rom_filename = argv[1] + if rom_filename.endswith('.sfc'): + with open(rom_filename, 'rb') as file: + rom_data = file.read() + buffer = ReadBuffer(rom_data, 0x100000) + # enemy_stats = existing_structs['EnemyStats'].get_value(buffer, leftover_bits) + enemy_stats = get_structarraytype('384 of EnemyStats', existing_structs).get_value(buffer, leftover_bits) + print('Loaded EnemyStats table to enemy_stats! If you ran the python interpreter with -i flag, you can now examine it in the REPL.') + # Also add a basic dumper + if len(argv) > 2 and argv[2] == 'dump_tsv': + dump_tsv('enemy_stats.tsv', enemy_stats) + print('Dumped enemy_stats table to enemy_stats.tsv!') + else: + print(f'Argument "{rom_filename}" doesn\'t end in ".sfc", so it was not parsed as a FF5 SNES ROM.') + else: + print('Example usage: run this script with a positional filename argument for a FF5 .sfc and it will load EnemyStats from it.')