bta-proxy/bta_proxy/packets/base.py

203 lines
8.1 KiB
Python
Raw Normal View History

2023-08-26 01:12:19 +03:00
from typing import Any, ClassVar, Optional, Type, Union
2023-08-24 17:40:23 +03:00
2023-08-26 16:45:22 +03:00
import gzip
2023-08-24 17:40:23 +03:00
from bta_proxy.entitydata import EntityData
2023-08-25 23:11:36 +03:00
from bta_proxy.itemstack import ItemStack
2023-08-25 21:49:10 +03:00
from ..datainputstream import AsyncDataInputStream
2023-08-28 22:07:33 +03:00
from logging import getLogger
2023-08-24 15:29:57 +03:00
2023-08-28 22:07:33 +03:00
logger = getLogger(__name__)
2023-08-26 01:12:19 +03:00
def try_int(v: str) -> Union[str, int]:
try:
return int(v)
except ValueError:
return v
2023-08-24 15:29:57 +03:00
class Packet:
2023-08-26 01:12:19 +03:00
REGISTRY: ClassVar[dict[int, Type["Packet"]]] = {}
2023-08-24 15:29:57 +03:00
FIELDS: ClassVar[list[tuple[str, Any]]] = []
2023-08-24 17:40:23 +03:00
packet_id: int
2023-08-24 15:29:57 +03:00
def __init__(self, **params):
for k, v in params.items():
setattr(self, k, v)
@classmethod
2023-08-26 01:12:19 +03:00
async def read_data_from(cls, stream: AsyncDataInputStream) -> "Packet":
2023-08-28 22:07:33 +03:00
logger.debug("Packet.read_data_from(%r)", stream)
2023-08-24 15:29:57 +03:00
fields: dict = {}
for key, datatype in cls.FIELDS:
2023-08-26 01:12:19 +03:00
if "?" in key:
key, cond = key.split("?", 1)
if "==" in cond:
k, v = cond.split("==")
if fields[k] != try_int(v):
continue
elif not fields[cond]:
continue
try:
2023-08-28 22:07:33 +03:00
logger.debug(f"reading {key=} of type {datatype!r} ({fields=})")
2023-08-26 01:12:19 +03:00
fields[key] = await cls.read_field(stream, datatype, fields)
except Exception as e:
raise ValueError(f"Failed getting key {key} on {cls}") from e
2023-08-24 15:29:57 +03:00
return cls(**fields)
@staticmethod
2023-08-26 01:12:19 +03:00
async def read_field(
stream: AsyncDataInputStream,
datatype: Any,
fields: dict[str, Any] = {},
):
2023-08-28 22:07:33 +03:00
logger.debug(f"Packet.read_field(_, {datatype=}, {fields=})")
2023-08-24 15:29:57 +03:00
match datatype:
2023-08-26 01:12:19 +03:00
case "list", sizekey, *args:
2023-08-26 12:33:31 +03:00
args = args[0] if len(args) == 1 else tuple(args)
2023-08-28 22:07:33 +03:00
size = try_int(sizekey)
length = size if isinstance(size, int) else fields[sizekey]
2023-08-26 01:12:19 +03:00
return [
await Packet.read_field(stream, args, fields)
2023-08-26 12:33:31 +03:00
for _ in range(length)
2023-08-26 01:12:19 +03:00
]
2023-09-04 22:03:31 +03:00
case "tuple", *tuples:
out = []
for tup in tuples:
out.append(await Packet.read_field(stream, tup, fields))
return tuple(out)
2023-08-26 01:12:19 +03:00
case "uint":
2023-08-25 21:49:10 +03:00
return await stream.read_uint()
2023-08-26 01:12:19 +03:00
case "int":
2023-08-25 21:49:10 +03:00
return await stream.read_int()
2023-08-26 01:12:19 +03:00
case "str":
2023-08-25 21:49:10 +03:00
return await stream.read_string()
2023-08-26 01:12:19 +03:00
case "str", length:
2023-08-25 21:49:10 +03:00
return (await stream.read_string())[:length]
2023-08-26 01:12:19 +03:00
case "string":
2023-08-25 23:11:36 +03:00
return await stream.read_string()
2023-08-26 01:12:19 +03:00
case "string", length:
2023-08-25 23:11:36 +03:00
return (await stream.read_string())[:length]
2023-08-26 01:12:19 +03:00
case "ulong":
2023-08-25 21:49:10 +03:00
return await stream.read_ulong()
2023-08-26 01:12:19 +03:00
case "long":
2023-08-25 21:49:10 +03:00
return await stream.read_long()
2023-08-26 01:12:19 +03:00
case "ushort":
2023-08-25 21:49:10 +03:00
return await stream.read_ushort()
2023-08-26 01:12:19 +03:00
case "short":
2023-08-25 21:49:10 +03:00
return await stream.read_short()
2023-08-26 01:12:19 +03:00
case "byte":
return await stream.read_byte()
case "ubyte":
return await stream.read_ubyte()
case "float":
2023-08-25 21:49:10 +03:00
return await stream.read_float()
2023-08-26 01:12:19 +03:00
case "double":
2023-08-25 21:49:10 +03:00
return await stream.read_double()
2023-08-26 01:12:19 +03:00
case "bool":
2023-08-25 21:49:10 +03:00
return await stream.read_boolean()
2023-08-26 01:12:19 +03:00
case "bytes", length_or_key:
2023-08-25 23:11:36 +03:00
if isinstance(length_or_key, int):
return await stream.read_bytes(length_or_key)
elif isinstance(length_or_key, str):
2023-08-26 01:12:19 +03:00
if length_or_key == ".rest":
return stream.read_rest()
2023-08-25 23:11:36 +03:00
if length_or_key not in fields:
2023-08-26 01:12:19 +03:00
raise KeyError(
f"failed to find {length_or_key} in {fields} to read bytes length"
)
2023-08-25 23:11:36 +03:00
return await stream.read_bytes(fields[length_or_key])
2023-08-26 01:12:19 +03:00
raise ValueError(
f"invalid type for bytes length_or_key: {length_or_key!r}"
)
case "itemstack":
2023-08-25 23:11:36 +03:00
return await ItemStack.read_from(stream)
2023-08-26 01:12:19 +03:00
case "itemstack", length_or_key:
2023-08-26 12:33:31 +03:00
items: list[Optional[ItemStack]] = []
2023-08-26 01:12:19 +03:00
if isinstance(length_or_key, int):
for _ in range(length_or_key):
if (item_id := await stream.read_short()) >= 0:
count = await stream.read()
data = await stream.read_short()
items.append(ItemStack(item_id, count, data))
else:
items.append(None)
return items
elif isinstance(length_or_key, str):
if fields[length_or_key] <= 0:
return []
if length_or_key not in fields:
raise KeyError(
f"failed to find {length_or_key} in {fields} to read number of itemstacks"
)
for _ in range(fields[length_or_key]):
if (item_id := await stream.read_short()) >= 0:
count = await stream.read()
data = await stream.read_short()
items.append(ItemStack(item_id, count, data))
else:
items.append(None)
return items
raise ValueError(
f"invalid type for itemstack length_or_key: {length_or_key!r}"
)
case "itemstack_optional":
2023-08-25 23:11:36 +03:00
if (item_id := await stream.read_short()) >= 0:
count = await stream.read()
data = await stream.read_short()
return ItemStack(item_id, count, data)
return None
2023-08-26 01:12:19 +03:00
case "extendeditemstack_optional":
if (item_id := await stream.read_short()) >= 0:
count = await stream.read()
data = await stream.read_short()
tag_size = await stream.read_short()
tag = await stream.read_bytes(tag_size)
return ItemStack(item_id, count, data, tag)
return None
case "entitydata":
2023-08-25 21:49:10 +03:00
return await EntityData.read_from(stream)
2023-08-26 16:45:22 +03:00
case "nbt":
size = await stream.read_short()
if size < 0:
raise ValueError("Received tag length is less than zero! Weird tag!")
if size == 0:
return None
return gzip.decompress(await stream.read_bytes(size))
2023-08-24 15:29:57 +03:00
case _:
2023-08-26 01:12:19 +03:00
raise ValueError(f"unknown type {datatype}")
2023-08-24 15:29:57 +03:00
2023-08-24 17:40:23 +03:00
def __init_subclass__(cls, packet_id: int, **kwargs) -> None:
2023-08-28 22:07:33 +03:00
logger.debug(f"registered packet {cls} with id = {packet_id}")
2023-08-24 17:40:23 +03:00
Packet.REGISTRY[packet_id] = cls
cls.packet_id = packet_id
super().__init_subclass__(**kwargs)
2023-08-24 15:29:57 +03:00
2023-08-26 01:12:19 +03:00
def post_creation(self):
pass
2023-08-24 15:29:57 +03:00
@classmethod
2023-08-26 01:12:19 +03:00
async def read_packet(cls, stream: AsyncDataInputStream) -> "Packet":
2023-08-25 21:49:10 +03:00
packet_id: int = await stream.read()
2023-08-28 22:07:33 +03:00
logger.debug(f"incoming {packet_id=}")
2023-08-24 15:29:57 +03:00
if packet_id not in cls.REGISTRY:
2023-08-26 01:12:19 +03:00
raise ValueError(
2023-08-28 22:07:33 +03:00
f"invalid packet 0x{packet_id:02x} ({packet_id}) (rest: {stream.peek_rest()[:16]}...)"
2023-08-26 01:12:19 +03:00
)
2023-08-25 21:49:10 +03:00
pkt = await cls.REGISTRY[packet_id].read_data_from(stream)
pkt.packet_id = packet_id
2023-08-26 01:12:19 +03:00
pkt.post_creation()
2023-08-28 22:07:33 +03:00
logger.debug(f"received {pkt}")
2023-08-25 21:49:10 +03:00
return pkt
2023-08-24 15:29:57 +03:00
def __repr__(self):
2023-08-24 17:40:23 +03:00
pkt_name = self.REGISTRY[self.packet_id].__name__
2023-08-24 15:29:57 +03:00
fields = []
2023-08-26 01:12:19 +03:00
for key, _ in self.FIELDS:
if "?" in key:
key, cond = key.split("?", 1)
2023-09-04 22:03:31 +03:00
fields.append(f"{key}={getattr(self, key, None)!r} depending on {cond}")
2023-08-26 01:12:19 +03:00
else:
fields.append(f"{key}={getattr(self, key)!r}")
2023-08-25 21:49:10 +03:00
return f'<{pkt_name} {str.join(", ", fields)}>'
2023-09-04 22:03:31 +03:00