NOW IT'S ALL ASYNC

This commit is contained in:
Casey 2023-08-25 21:49:10 +03:00
parent 6116030f66
commit de25182868
Signed by: hkc
GPG Key ID: F0F6CFE11CDB0960
10 changed files with 146 additions and 176 deletions

0
bta_proxy/__init__.py Normal file
View File

View File

@ -1,47 +1,21 @@
# x-run: cd .. && python -m bta_proxy '201:4f8c:4ea:0:71ec:6d7:6f1b:a4f9'
import asyncio
import socket
from asyncio.streams import StreamReader, StreamWriter
from argparse import ArgumentParser
from sys import argv
from .debug import debug_client, debug_server
from bta_proxy.proxy import BTAProxy
MAX_SIZE = 0x400000
async def handle_server(writer: StreamWriter, server: socket.socket, fp):
try:
while (packet := await loop.sock_recv(server, MAX_SIZE)):
try:
debug_server(packet, fp)
except Exception as e:
print(f'[S] error: {e}')
writer.write(packet)
await writer.drain()
except Exception as e:
print(f'handle_server(): {e}')
async def handle_client(reader: StreamReader, writer: StreamWriter):
print(reader, writer)
try:
server = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
server.connect(('201:4f8c:4ea:0:71ec:6d7:6f1b:a4f9', 25565))
server.setblocking(False)
with open("packets.txt", "w") as fp:
loop.create_task(handle_server(writer, server, fp))
while (packet := await reader.read(MAX_SIZE)):
try:
debug_client(packet, fp)
except Exception as e:
print(f'[C] error: {e}')
await loop.sock_sendall(server, packet)
except Exception as e:
print(f'handle_client(): {e}')
loop = asyncio.get_event_loop()
def main():
loop.create_task(asyncio.start_server(handle_client, 'localhost', 25565))
loop.run_forever()
async def main(args):
loop = asyncio.get_running_loop()
server = await asyncio.start_server(BTAProxy(args[0], 25565, loop).handle_client, "localhost", 25565)
async with server:
await server.serve_forever()
if __name__ == '__main__':
main()
asyncio.run(main(argv[1:]))

View File

@ -1,65 +1,55 @@
from asyncio.queues import Queue
import struct
class DataInputStream:
def __init__(self, buffer: bytes):
self._buffer = buffer
self._cursor = 0
class AsyncDataInputStream:
def __init__(self, queue: Queue):
self._queue = queue
self._last = b''
def read_bytes(self, n: int) -> bytes:
if self._cursor + n > len(self._buffer):
raise EOFError('stream overread')
blob = self._buffer[self._cursor : self._cursor + n]
self._cursor += n
return blob
async def read_bytes(self, n: int) -> bytes:
if len(self._last) < n:
self._last += await self._queue.get()
out, self._last = self._last[:n], self._last[n:]
return out
def empty(self):
return self._cursor >= len(self._buffer) - 1
async def read(self) -> int:
return (await self.read_bytes(1))[0]
def end(self) -> bytes:
return self.read_bytes(len(self._buffer) - self._cursor)
async def read_boolean(self) -> bool:
return (await self.read()) != 0
def read_byte(self) -> int:
if self._cursor >= len(self._buffer):
print(f'\033[91mstream overread in {self._buffer}\033[0m')
raise EOFError('stream overread')
self._cursor += 1
return self._buffer[self._cursor - 1]
async def read_short(self) -> int:
return struct.unpack('>h', await self.read_bytes(2))[0]
def read_boolean(self) -> bool:
return self.read_byte() != 0
async def read_ushort(self) -> int:
return struct.unpack('>H', await self.read_bytes(2))[0]
def read_short(self) -> int:
return struct.unpack('>h', self.read_bytes(2))[0]
async def read_int(self) -> int:
return struct.unpack('>i', await self.read_bytes(4))[0]
def read_ushort(self) -> int:
return struct.unpack('>H', self.read_bytes(2))[0]
async def read_uint(self) -> int:
return struct.unpack('>I', await self.read_bytes(4))[0]
def read_int(self) -> int:
return struct.unpack('>i', self.read_bytes(4))[0]
async def read_long(self) -> int:
return struct.unpack('>q', await self.read_bytes(8))[0]
def read_uint(self) -> int:
return struct.unpack('>I', self.read_bytes(4))[0]
async def read_ulong(self) -> int:
return struct.unpack('>Q', await self.read_bytes(8))[0]
def read_long(self) -> int:
return struct.unpack('>q', self.read_bytes(8))[0]
async def read_float(self) -> float:
return struct.unpack('>f', await self.read_bytes(4))[0]
def read_ulong(self) -> int:
return struct.unpack('>Q', self.read_bytes(8))[0]
async def read_double(self) -> float:
return struct.unpack('>d', await self.read_bytes(8))[0]
def read_float(self) -> float:
return struct.unpack('>f', self.read_bytes(4))[0]
async def read_char(self) -> str:
return chr(await self.read_ushort())
def read_double(self) -> float:
return struct.unpack('>d', self.read_bytes(8))[0]
def read_char(self) -> str:
return chr(self.read_ushort())
def read_varint(self, bits: int = 32) -> int:
async def read_varint(self, bits: int = 32) -> int:
value: int = 0
position: int = 0
while True:
byte = self.read_byte()
byte = await self.read()
value |= (byte & 0x7F) << position
if ((byte & 0x80) == 0):
break
@ -68,7 +58,7 @@ class DataInputStream:
raise ValueError('varint too big')
return value
def read_string(self) -> str:
size = self.read_short()
return self.read_bytes(size).decode('utf-8')
async def read_string(self) -> str:
size = await self.read_short()
return (await self.read_bytes(size)).decode('utf-8')

View File

@ -1,52 +0,0 @@
from collections.abc import Iterable
from bta_proxy.packets.base import Packet
from bta_proxy.packets import *
from .datainputstream import DataInputStream
from typing import Generator, TextIO, TypeVar
T = TypeVar('T')
def chunks(gen: Iterable[T], size: int) -> Generator[list[T], None, None]:
bucket: list[T] = []
for item in gen:
bucket.append(item)
if len(bucket) >= size:
yield bucket
bucket.clear()
if bucket:
yield bucket
def debug_client(buffer: bytes, tmpfile: TextIO):
stream = DataInputStream(buffer)
while not stream.empty():
try:
packet = Packet.parse_packet(stream)
match packet.packet_id:
case _:
print('[C]', packet)
except ValueError:
# print(f'[C:rest] {stream.end()}')
buf = stream.end()
print(f"[C] {buf[0]=} {len(buf)=}, {buf=}", file=tmpfile)
def debug_server(buffer: bytes, tmpfile: TextIO):
stream = DataInputStream(buffer)
while not stream.empty():
try:
packet = Packet.parse_packet(stream)
match packet.packet_id:
case Packet50PreChunk.packet_id:
continue
case Packet38EntityStatus.packet_id:
continue
case _:
print('[S]', packet)
except ValueError:
# print(f'[S:rest] {stream.end()}')
buf = stream.end()
print(f"[S] {buf[0]=} {len(buf)=}, {buf=}", file=tmpfile)

18
bta_proxy/dpi.py Normal file
View File

@ -0,0 +1,18 @@
from asyncio.queues import Queue
from bta_proxy.datainputstream import AsyncDataInputStream
from bta_proxy.packets.base import Packet
async def inspect_client(queue: Queue, addr: tuple[str, int]):
dis = AsyncDataInputStream(queue)
while True:
pkt = await Packet.read_packet(dis)
print("C", pkt)
async def inspect_server(queue: Queue, addr: tuple[str, int]):
dis = AsyncDataInputStream(queue)
while True:
pkt = await Packet.read_packet(dis)
print("S", pkt)

View File

@ -1,6 +1,6 @@
from typing import Any
from bta_proxy.datainputstream import DataInputStream
from bta_proxy.datainputstream import AsyncDataInputStream
from enum import Enum
from dataclasses import dataclass
@ -23,26 +23,26 @@ class DataItem:
class EntityData:
@classmethod
def read_from(cls, dis: DataInputStream) -> list[DataItem]:
async def read_from(cls, dis: AsyncDataInputStream) -> list[DataItem]:
items = []
while (data := dis.read_byte()) != 0x7F:
while (data := await dis.read()) != 0x7F:
item_type = DataItemType((data & 0xE0) >> 5)
item_id: int = data & 0x1F
match item_type:
case DataItemType.BYTE:
items.append(DataItem(item_type, item_id, dis.read_byte()))
items.append(DataItem(item_type, item_id, await dis.read()))
case DataItemType.SHORT:
items.append(DataItem(item_type, item_id, dis.read_short()))
items.append(DataItem(item_type, item_id, await dis.read_short()))
case DataItemType.FLOAT:
items.append(DataItem(item_type, item_id, dis.read_float()))
items.append(DataItem(item_type, item_id, await dis.read_float()))
case DataItemType.STRING:
items.append(DataItem(item_type, item_id, dis.read_string()))
items.append(DataItem(item_type, item_id, await dis.read_string()))
case DataItemType.ITEMSTACK:
items.append(DataItem(item_type, item_id, ItemStack.read_from(dis)))
items.append(DataItem(item_type, item_id, await ItemStack.read_from(dis)))
case DataItemType.CHUNK_COORDINATES:
x = dis.read_float()
y = dis.read_float()
z = dis.read_float()
x = await dis.read_float()
y = await dis.read_float()
z = await dis.read_float()
items.append(DataItem(item_type, item_id, (x, y, z)))
return items

View File

@ -1,5 +1,5 @@
from bta_proxy.datainputstream import DataInputStream
from bta_proxy.datainputstream import AsyncDataInputStream
class ItemStack:
@ -10,8 +10,8 @@ class ItemStack:
self.data = data
@classmethod
def read_from(cls, stream: DataInputStream) -> 'ItemStack':
item_id = stream.read_short()
count = stream.read_byte()
data = stream.read_ushort()
async def read_from(cls, stream: AsyncDataInputStream) -> 'ItemStack':
item_id = await stream.read_short()
count = await stream.read()
data = await stream.read_ushort()
return cls(item_id, count, data)

View File

@ -1,7 +1,7 @@
from typing import Any, ClassVar, Type
from bta_proxy.entitydata import EntityData
from ..datainputstream import DataInputStream
from ..datainputstream import AsyncDataInputStream
class Packet:
REGISTRY: ClassVar[dict[int, Type['Packet']]] = {}
@ -13,43 +13,43 @@ class Packet:
setattr(self, k, v)
@classmethod
def read_from(cls, stream: DataInputStream) -> 'Packet':
async def read_data_from(cls, stream: AsyncDataInputStream) -> 'Packet':
fields: dict = {}
for key, datatype in cls.FIELDS:
fields[key] = cls.read_field(stream, datatype)
fields[key] = await cls.read_field(stream, datatype)
return cls(**fields)
@staticmethod
def read_field(stream: DataInputStream, datatype: Any):
async def read_field(stream: AsyncDataInputStream, datatype: Any):
match datatype:
case 'uint':
return stream.read_uint()
return await stream.read_uint()
case 'int':
return stream.read_int()
return await stream.read_int()
case 'str':
return stream.read_string()
return await stream.read_string()
case 'str', length:
return stream.read_string()[:length]
return (await stream.read_string())[:length]
case 'ulong':
return stream.read_ulong()
return await stream.read_ulong()
case 'long':
return stream.read_long()
return await stream.read_long()
case 'ushort':
return stream.read_ushort()
return await stream.read_ushort()
case 'short':
return stream.read_short()
return await stream.read_short()
case 'byte':
return stream.read_byte()
return await stream.read()
case 'float':
return stream.read_float()
return await stream.read_float()
case 'double':
return stream.read_double()
return await stream.read_double()
case 'bool':
return stream.read_boolean()
return await stream.read_boolean()
case 'bytes', length:
return stream.read_bytes(length)
return await stream.read_bytes(length)
case 'entitydata':
return EntityData.read_from(stream)
return await EntityData.read_from(stream)
case _:
raise ValueError(f'unknown type {datatype}')
@ -59,16 +59,17 @@ class Packet:
super().__init_subclass__(**kwargs)
@classmethod
def parse_packet(cls, stream: DataInputStream) -> 'Packet':
packet_id: int = stream.read_byte()
async def read_packet(cls, stream: AsyncDataInputStream) -> 'Packet':
packet_id: int = await stream.read()
if packet_id not in cls.REGISTRY:
stream._cursor -= 1
raise ValueError(f'invalid packet 0x{packet_id:02x}')
return cls.REGISTRY[packet_id].read_from(stream)
pkt = await cls.REGISTRY[packet_id].read_data_from(stream)
pkt.packet_id = packet_id
return pkt
def __repr__(self):
pkt_name = self.REGISTRY[self.packet_id].__name__
fields = []
for name, _ in self.FIELDS:
fields.append(f'{name}={getattr(self, name)!r}')
return f'<{pkt_name} {str.join(", ", fields)}'
return f'<{pkt_name} {str.join(", ", fields)}>'

39
bta_proxy/proxy.py Normal file
View File

@ -0,0 +1,39 @@
from asyncio.protocols import Protocol
from asyncio.queues import Queue
from asyncio import AbstractEventLoop, get_event_loop
from asyncio.streams import StreamReader, StreamWriter, open_connection
from typing import Optional
from bta_proxy.dpi import inspect_client, inspect_server
class BTAProxy:
def __init__(self, host: str, port: int, loop: Optional[AbstractEventLoop] = None):
self.host = host
self.port = port
self.loop = loop or get_event_loop()
@staticmethod
async def pipe(reader: StreamReader, writer: StreamWriter, queue: Queue):
try:
while not reader.at_eof():
packet = await reader.read(0x400000)
queue.put_nowait(packet)
writer.write(packet)
finally:
writer.close()
async def handle_client(self, cli_reader: StreamReader, cli_writer: StreamWriter):
try:
peername = cli_writer.get_extra_info("peername")
srv_reader, srv_writer = await open_connection(self.host, self.port)
queue_srv: Queue = Queue()
queue_cli: Queue = Queue()
self.loop.create_task(inspect_client(queue_cli, peername))
self.loop.create_task(inspect_server(queue_srv, peername))
self.loop.create_task(self.pipe(cli_reader, srv_writer, queue_cli))
self.loop.create_task(self.pipe(srv_reader, cli_writer, queue_srv))
except Exception as e:
print(f"oopsie whoopsie {e}")