from typing import Any, ClassVar, Type from bta_proxy.entitydata import EntityData from ..datainputstream import AsyncDataInputStream class Packet: REGISTRY: ClassVar[dict[int, Type['Packet']]] = {} FIELDS: ClassVar[list[tuple[str, Any]]] = [] packet_id: int def __init__(self, **params): for k, v in params.items(): setattr(self, k, v) @classmethod async def read_data_from(cls, stream: AsyncDataInputStream) -> 'Packet': fields: dict = {} for key, datatype in cls.FIELDS: fields[key] = await cls.read_field(stream, datatype) return cls(**fields) @staticmethod async def read_field(stream: AsyncDataInputStream, datatype: Any): match datatype: case 'uint': return await stream.read_uint() case 'int': return await stream.read_int() case 'str': return await stream.read_string() case 'str', length: return (await stream.read_string())[:length] case 'ulong': return await stream.read_ulong() case 'long': return await stream.read_long() case 'ushort': return await stream.read_ushort() case 'short': return await stream.read_short() case 'byte': return await stream.read() case 'float': return await stream.read_float() case 'double': return await stream.read_double() case 'bool': return await stream.read_boolean() case 'bytes', length: return await stream.read_bytes(length) case 'entitydata': return await EntityData.read_from(stream) case _: raise ValueError(f'unknown type {datatype}') def __init_subclass__(cls, packet_id: int, **kwargs) -> None: Packet.REGISTRY[packet_id] = cls cls.packet_id = packet_id super().__init_subclass__(**kwargs) @classmethod async def read_packet(cls, stream: AsyncDataInputStream) -> 'Packet': packet_id: int = await stream.read() if packet_id not in cls.REGISTRY: raise ValueError(f'invalid packet 0x{packet_id:02x}') pkt = await cls.REGISTRY[packet_id].read_data_from(stream) pkt.packet_id = packet_id return pkt def __repr__(self): pkt_name = self.REGISTRY[self.packet_id].__name__ fields = [] for name, _ in self.FIELDS: fields.append(f'{name}={getattr(self, name)!r}') return f'<{pkt_name} {str.join(", ", fields)}>'