ChocolateBirdData/reference_implementation.py

284 lines
9.3 KiB
Python
Raw Normal View History

2024-06-26 20:30:36 +09:30
# This is a reference implementation of parsing, serializing, and deserializing ChocolateBird's struct definitions.
# It is ported from the original GDscript to Python.
import logging
from io import TextIOWrapper
from struct import pack_into, unpack_from, calcsize
from collections.abc import Buffer
class LeftoverBits:
number_bits = 0
bit_buffer = 0 # Stored as an integer
class ReadBuffer:
backing_buffer: Buffer
position: int = 0
def __init__(self, backing_buffer, position: int = 0) -> None:
self.backing_buffer = backing_buffer
self.position = position
def get(self, format) -> tuple:
values = unpack_from(format, self.backing_buffer, self.position)
self.position += calcsize(format) # TODO: cache this
return values
class WriteBuffer:
backing_buffer: Buffer
position: int = 0
def __init__(self, backing_buffer, position: int = 0) -> None:
self.backing_buffer = backing_buffer
self.position = position
def put(self, format, *values) -> None:
pack_into(format, self.backing_buffer, self.position, *values)
self.position += calcsize(format) # TODO: cache this
class StructType:
name: str = None
def __repr__(self) -> str:
if self.name:
return self.name
return super().__repr__()
def get_value(self, buffer: ReadBuffer, leftover_bits: LeftoverBits):
raise NotImplementedError('Deserialization not implemented')
def put_value(self, buffer: WriteBuffer, value, leftover_bits: LeftoverBits):
raise NotImplementedError('Serialization not implemented')
class SimpleStruct(StructType):
format: str
def __init__(self, format, name=None) -> None:
self.format = format
self.name = name
def get_value(self, buffer: ReadBuffer, leftover_bits: LeftoverBits):
return buffer.get(self.format)[0]
def put_value(self, buffer: WriteBuffer, value, leftover_bits: LeftoverBits):
buffer.put(self.format, value)
class U24(StructType):
name = 'u24'
def get_value(self, buffer: ReadBuffer, leftover_bits: LeftoverBits):
u16, u8 = buffer.get('<HB')
return u16 | (u8 << 16)
def put_value(self, buffer: WriteBuffer, value, leftover_bits: LeftoverBits):
buffer.put('<HB', value & 0xFFFF, value >> 16)
class S24(StructType):
name = 's24'
def get_value(self, buffer: ReadBuffer, leftover_bits: LeftoverBits):
u16, u8 = buffer.get('<HB')
unsigned = u16 | (u8 << 16)
return unsigned - (2 * (unsigned & 0x800000))
def put_value(self, buffer: WriteBuffer, value, leftover_bits: LeftoverBits):
unsigned = value % 0x1000000
buffer.put('<HB', unsigned & 0xFFFF, unsigned >> 16)
class UBits(StructType):
bits = 8
def __init__(self, bits: int):
self.bits = bits
self.name = f'u{bits:d}'
def get_value(self, buffer: ReadBuffer, leftover_bits: LeftoverBits):
while leftover_bits.number_bits < self.bits:
leftover_bits.bit_buffer |= buffer.get('<B')[0] << leftover_bits.number_bits
leftover_bits.number_bits += 8
value = leftover_bits.bit_buffer & ((1 << self.bits)-1)
leftover_bits.bit_buffer = leftover_bits.bit_buffer >> self.bits
leftover_bits.number_bits -= self.bits
return value
def put_value(self, buffer: WriteBuffer, value, leftover_bits: LeftoverBits):
leftover_bits.bit_buffer |= value << leftover_bits.number_bits
leftover_bits.number_bits += self.bits
while leftover_bits.number_bits >= 8:
buffer.put('<B', leftover_bits.bit_buffer & 0xFF)
leftover_bits.number_bits -= 8
leftover_bits.bit_buffer = leftover_bits.bit_buffer >> 8
class Struct(StructType):
members: list # Array of [name, StructType]
def __init__(self, name=None) -> None:
self.name = name
self.members = []
def get_value(self, buffer: ReadBuffer, leftover_bits: LeftoverBits):
return {key: struct_type.get_value(buffer, leftover_bits) for key, struct_type in self.members}
def put_value(self, buffer: WriteBuffer, value, leftover_bits: LeftoverBits):
for key, struct_type in self.members:
if not (key in value):
logging.error(f'Key "{key}" missing from value supplied')
return
struct_type.put_value(buffer, value[key], leftover_bits)
class StructArrayType(StructType):
count: int
struct_type: StructType
def __init__(self, count, struct_type, name=None) -> None:
self.count = count
self.struct_type = struct_type
self.name = name
def get_value(self, buffer: ReadBuffer, leftover_bits: LeftoverBits) -> list:
return [self.struct_type.get_value(buffer, leftover_bits) for i in range(self.count)]
def put_value(self, buffer: WriteBuffer, value, leftover_bits: LeftoverBits) -> None:
if len(value) < self.count:
logging.error('Not enough values supplied')
return
for i in range(self.count):
self.struct_type.put_value(buffer, value[i], leftover_bits)
def get_base_structarraytypes() -> dict:
return {
'u8': SimpleStruct('<B', 'u8'),
's8': SimpleStruct('<b', 's8'),
'u16': SimpleStruct('<H', 'u16'),
's16': SimpleStruct('<h', 's16'),
'u24': U24(),
's24': S24(),
'u32': SimpleStruct('<I', 'u32'),
's32': SimpleStruct('<i', 's32'),
}
def get_structarraytype(type: str, existing_structs: dict):
tokens = type.split(' ')
t: str = tokens[-1]
inner_type: Struct = None
if t in existing_structs:
inner_type = existing_structs[t]
elif t[0] == 'u':
b: int = int(t[1:])
if b > 0:
inner_type = UBits(b)
existing_structs[f'u{b:d}'] = inner_type # Cache it for future use
if not inner_type:
logging.error(f'typestring "{type}" has no matches for "{t}" in existing structs')
return
l: int = len(tokens)
if l == 1:
return inner_type
# Our parsing goal is to turn 'a of b of c of d' into StructArrayType<StructArrayType<StructArrayType<d, c>, b>, a>
# Our strategy is to parse backwards over the tokens, changing inner_type at each point
# a of b of c of (d)
# a of b of (c of d)
# a of (b of c of d)
# (a of b of c of d)
# done
i: int = l-2
while i > -1:
match tokens[i]:
case 'of':
i -= 1
l1: int = int(tokens[i])
if l1 > 1:
inner_type = StructArrayType(l1, inner_type, name=type) # Might be worth caching these later on if we use them more
i -= 1
case k:
logging.error(f'Invalid keyword used in type designator: "{k}"')
return
return inner_type
def parse_struct_definitions_from_tsv_file(tsv_file: TextIOWrapper, existing_structs: dict) -> None:
current_struct: Struct
lines = tsv_file.read().rstrip().split('\n')
for line in lines:
# logging.debug(line)
tokens = line.split('\t')
size = len(tokens)
if size < 2:
continue
# Size is at least 2
type, label = tokens[:2]
if type == 'struct':
# New struct declaration
current_struct = Struct(name=label)
existing_structs[label] = current_struct
elif type and label:
current_struct.members.append([label, get_structarraytype(type, existing_structs)])
# TODO: Maybe store the trailing comments somewhere?
def parse_struct_definitions_from_tsv_filename(filename: str, existing_structs: dict) -> None:
with open(filename, 'r') as file:
parse_struct_definitions_from_tsv_file(file, existing_structs)
def load_ff5_snes_struct_definitions() -> dict:
existing_structs = get_base_structarraytypes()
parse_struct_definitions_from_tsv_filename('structs_SNES_stubs.tsv', existing_structs)
parse_struct_definitions_from_tsv_filename('5/structs/SNES_stubs.tsv', existing_structs)
parse_struct_definitions_from_tsv_filename('5/structs/SNES.tsv', existing_structs)
parse_struct_definitions_from_tsv_filename('5/structs/SNES_save.tsv', existing_structs)
return existing_structs
# Basic TSV dumper
# This is mostly unrelated, but helpful for debugging.
def flatten_keys(d: dict, prefix: str = '') -> dict:
output = {}
for k, v in d.items():
if isinstance(v, dict):
flat = flatten_keys(v, f'{prefix}{k}.')
for k2, v2 in flat.items():
output[k2] = v2
else:
output[f'{prefix}{k}'] = v
return output
def dump_tsv(filename, table):
table_flat = [flatten_keys(d) for d in table]
hex_digits = len(f'{len(table_flat)-1:X}') # See how long the hex representation of the last number will be, so we can zero-pad the rest to match.
hex_format = f'0{hex_digits}X'
with open(filename, 'w') as file:
headers = list(table_flat[0].keys())
file.write('\t'.join(['ID'] + headers) + '\n')
for i, entry in enumerate(table_flat):
file.write('\t'.join([f'0x{i:{hex_format}}'] + [str(entry[key]) for key in headers]) + '\n')
# Example usage: run this script with a positional filename argument for a FF5 .sfc and it will load EnemyStats from it.
if __name__ == '__main__':
existing_structs = load_ff5_snes_struct_definitions()
leftover_bits = LeftoverBits()
from sys import argv
if len(argv) > 1:
rom_filename = argv[1]
if rom_filename.endswith('.sfc'):
with open(rom_filename, 'rb') as file:
rom_data = file.read()
buffer = ReadBuffer(rom_data, 0x100000)
# enemy_stats = existing_structs['EnemyStats'].get_value(buffer, leftover_bits)
enemy_stats = get_structarraytype('384 of EnemyStats', existing_structs).get_value(buffer, leftover_bits)
print('Loaded EnemyStats table to enemy_stats! If you ran the python interpreter with -i flag, you can now examine it in the REPL.')
# Also add a basic dumper
if len(argv) > 2 and argv[2] == 'dump_tsv':
dump_tsv('enemy_stats.tsv', enemy_stats)
print('Dumped enemy_stats table to enemy_stats.tsv!')
else:
print(f'Argument "{rom_filename}" doesn\'t end in ".sfc", so it was not parsed as a FF5 SNES ROM.')
else:
print('Example usage: run this script with a positional filename argument for a FF5 .sfc and it will load EnemyStats from it.')