from typing import Any, ClassVar, Type from bta_proxy.entitydata import EntityData from bta_proxy.itemstack import ItemStack from ..datainputstream import AsyncDataInputStream class Packet: REGISTRY: ClassVar[dict[int, Type['Packet']]] = {} FIELDS: ClassVar[list[tuple[str, Any]]] = [] packet_id: int def __init__(self, **params): for k, v in params.items(): setattr(self, k, v) @classmethod async def read_data_from(cls, stream: AsyncDataInputStream) -> 'Packet': fields: dict = {} for key, datatype in cls.FIELDS: fields[key] = await cls.read_field(stream, datatype, fields) return cls(**fields) @staticmethod async def read_field(stream: AsyncDataInputStream, datatype: Any, fields: dict[str, Any] = {}): match datatype: case 'uint': return await stream.read_uint() case 'int': return await stream.read_int() case 'str': return await stream.read_string() case 'str', length: return (await stream.read_string())[:length] case 'string': return await stream.read_string() case 'string', length: return (await stream.read_string())[:length] case 'ulong': return await stream.read_ulong() case 'long': return await stream.read_long() case 'ushort': return await stream.read_ushort() case 'short': return await stream.read_short() case 'byte': return await stream.read() case 'float': return await stream.read_float() case 'double': return await stream.read_double() case 'bool': return await stream.read_boolean() case 'bytes', length_or_key: if isinstance(length_or_key, int): return await stream.read_bytes(length_or_key) elif isinstance(length_or_key, str): if length_or_key not in fields: raise KeyError(f'failed to find {length_or_key} in {fields} to read bytes length') return await stream.read_bytes(fields[length_or_key]) raise ValueError(f'invalid type for bytes length_or_key: {length_or_key!r}') case 'itemstack': return await ItemStack.read_from(stream) case 'itemstack_optional': if (item_id := await stream.read_short()) >= 0: count = await stream.read() data = await stream.read_short() return ItemStack(item_id, count, data) return None case 'entitydata': return await EntityData.read_from(stream) case _: raise ValueError(f'unknown type {datatype}') def __init_subclass__(cls, packet_id: int, **kwargs) -> None: Packet.REGISTRY[packet_id] = cls cls.packet_id = packet_id super().__init_subclass__(**kwargs) @classmethod async def read_packet(cls, stream: AsyncDataInputStream) -> 'Packet': packet_id: int = await stream.read() if packet_id not in cls.REGISTRY: raise ValueError(f'invalid packet 0x{packet_id:02x} ({packet_id})') pkt = await cls.REGISTRY[packet_id].read_data_from(stream) pkt.packet_id = packet_id return pkt def __repr__(self): pkt_name = self.REGISTRY[self.packet_id].__name__ fields = [] for name, _ in self.FIELDS: fields.append(f'{name}={getattr(self, name)!r}') return f'<{pkt_name} {str.join(", ", fields)}>'