diff --git a/wsvpn.c b/wsvpn.c new file mode 100644 index 0000000..2f11ecb --- /dev/null +++ b/wsvpn.c @@ -0,0 +1,286 @@ +// x-run: ~/scripts/runc.sh % -lmongoose -Wall -Wextra +#include +#include +#include +#include +#include +#include + +#define MAX_CLIENTS 256 +#define MAX_OPEN_CHANNELS 128 +#define MAX_PACKET_SIZE 32767 + +struct client { + struct mg_connection *connection; + uint16_t open_channels[MAX_OPEN_CHANNELS]; + int next_open_channel_index; + bool receive_all; +}; + +struct client clients[MAX_CLIENTS] = { 0 }; + +static void handle_client(struct mg_connection *connection, int event_type, void *ev_data, void *fn_data); +static void on_ws_connect(struct mg_connection *connection, struct mg_http_message *message, void *data); +static void on_ws_message(struct mg_connection *connection, struct mg_ws_message *message, void *data); +static void on_ws_disconnect(struct mg_connection *connection, void *data); + +static void modem_open(struct client *client, uint16_t request_id, uint16_t channel); +static void modem_isOpen(struct client *client, uint16_t request_id, uint16_t channel); +static void modem_close(struct client *client, uint16_t request_id, uint16_t channel); +static void modem_closeAll(struct client *client, uint16_t request_id); +static void modem_transmit(struct client *client, uint16_t request_id, uint16_t channel, uint16_t reply_channel, void *data, uint16_t size); + +struct metrics { + uint64_t sent_bytes; + uint64_t sent_messages; + uint64_t received_bytes; + uint64_t received_messages; + uint64_t errors; +} metrics; + +int main(void) { + const char *address = "ws://0.0.0.0:8667"; + struct mg_mgr manager; + mg_mgr_init(&manager); + mg_http_listen(&manager, address, handle_client, NULL); + printf("Listening on %s\n", address); + while (1) mg_mgr_poll(&manager, 1000); + mg_mgr_free(&manager); +} + +static void handle_client(struct mg_connection *connection, int event_type, void *event_data, void *fn_data) { + if (event_type == MG_EV_OPEN) { + if (connection->rem.port == 0) return; + memset(connection->data, 0, 32); + } else if (event_type == MG_EV_HTTP_MSG) { + printf("http: %p\n", connection); + struct mg_http_message *http_message = (struct mg_http_message *) event_data; + if (mg_http_match_uri(http_message, "/open")) { + mg_ws_upgrade(connection, http_message, NULL); + } else if (mg_http_match_uri(http_message, "/stat")) { + mg_printf(connection, "HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n"); + static char addr_str[256]; + for (struct mg_connection *conn = connection->mgr->conns; conn != NULL; conn = conn->next) { + if (conn->is_websocket) { + struct client *client = *(struct client **)&conn->data[0]; + if (client->connection == conn) { + mg_snprintf(addr_str, 256, "%M", mg_print_ip_port, &conn->rem); + mg_http_printf_chunk(connection, "%s - %d ports {", addr_str, client->next_open_channel_index); + for (int i = 0; i < MAX_OPEN_CHANNELS; i++) { + if (i == client->next_open_channel_index) { + mg_http_printf_chunk(connection, "/*end of used*/"); + } + mg_http_printf_chunk(connection, "%d,", client->open_channels[i]); + } + mg_http_printf_chunk(connection, "}\n"); + } + } + } + mg_http_printf_chunk(connection, ""); + } else { + mg_http_reply(connection, 404, "", "uwu"); + } + } else if (event_type == MG_EV_WS_OPEN) { + struct mg_http_message *http_message = (struct mg_http_message *) event_data; + printf("ws open: %p\n", connection); + on_ws_connect(connection, http_message, fn_data); + } else if (event_type == MG_EV_WS_MSG) { + printf("websocket: %p\n", connection); + struct mg_ws_message *ws_message = (struct mg_ws_message *)event_data; + on_ws_message(connection, ws_message, fn_data); + } else if (event_type == MG_EV_CLOSE) { + printf("closed: %p\n", connection); + if (connection->is_websocket) { + on_ws_disconnect(connection, fn_data); + } + } +} + +void ws_send_error(struct mg_connection *connection, uint16_t request_id, const char *fmt, ...) { + static char buffer[1024]; + memset(buffer, 0, 1024); + + va_list args; + va_start(args, fmt); + int text_size = vsnprintf(&buffer[5], 1019, fmt, args); + va_end(args); + if (text_size < 0) return; + + buffer[0] = 'E'; + buffer[1] = (request_id >> 8) & 0xFF; + buffer[2] = request_id & 0xFF; + buffer[3] = (text_size >> 8) & 0xFF; + buffer[4] = text_size & 0xFF; + + metrics.sent_bytes += 5 + text_size; + mg_ws_send(connection, buffer, 5 + text_size, WEBSOCKET_OP_BINARY); +} + +void ws_send_info(struct mg_connection *connection, const char *fmt, ...) { + static char buffer[1024]; + memset(buffer, 0, 1024); + + va_list args; + va_start(args, fmt); + int text_size = vsnprintf(&buffer[3], 1021, fmt, args); + va_end(args); + if (text_size < 0) return; + + buffer[0] = 'I'; + buffer[1] = (text_size >> 8) & 0xFF; + buffer[2] = text_size & 0xFF; + + metrics.sent_bytes += 3 + text_size; + mg_ws_send(connection, buffer, 3 + text_size, WEBSOCKET_OP_BINARY); +} + +void ws_respond(struct mg_connection *connection, uint16_t request_id, void *data, uint32_t size) { + static char buffer[MAX_PACKET_SIZE]; + assert(size < MAX_PACKET_SIZE); + buffer[0] = 'R'; + buffer[1] = (request_id >> 8) & 0xFF; + buffer[2] = request_id & 0xFF; + if (size != 0) memcpy(&buffer[3], data, size); + metrics.sent_bytes += 3 + size; + mg_ws_send(connection, buffer, size + 3, WEBSOCKET_OP_BINARY); +} + +static void on_ws_connect(struct mg_connection *connection, struct mg_http_message *message, void *data) { + static char addr_str[256]; + mg_snprintf(addr_str, 256, "%M", mg_print_ip_port, &connection->rem); + ws_send_info(connection, "Hello, %s", addr_str); + struct client *client = malloc(sizeof(struct client)); + memcpy(&connection->data[0], &client, sizeof(struct client *)); + client->connection = connection; +} + +static void on_ws_message(struct mg_connection *connection, struct mg_ws_message *message, void *data) { + if ((message->flags & 15) != WEBSOCKET_OP_BINARY) { + ws_send_error(connection, -1, "This text could've been a binary."); + connection->is_draining = 1; + return; + } + + struct client *client = *(struct client **)&connection->data[0]; + assert(client->connection == connection); + + metrics.received_bytes += message->data.len; + + if (message->data.len == 0) return; + + uint16_t request_id = ntohs(*(uint16_t*)&message->data.ptr[1]); + + switch (message->data.ptr[0]) { + case 'I': // info. We can safely ignore that channel + break; + case 'O': // open + { + uint16_t channel = ntohs(*(uint16_t*)&message->data.ptr[3]); + printf("%p[%04x] modem.open(%d)\n", client, request_id, channel); + modem_open(client, request_id, channel); + } + return; + case 'o': // isOpen + { + uint16_t channel = ntohs(*(uint16_t*)&message->data.ptr[3]); + printf("%p[%04x] modem.isOpen(%d)\n", client, request_id, channel); + modem_isOpen(client, request_id, channel); + } + return; + case 'c': // close + { + uint16_t channel = ntohs(*(uint16_t*)&message->data.ptr[3]); + printf("%p[%04x] modem.close(%d)\n", client, request_id, channel); + modem_close(client, request_id, channel); + } + return; + case 'C': // closeAll + { + printf("%p[%04x] modem.closeAll()\n", client, request_id); + modem_closeAll(client, request_id); + } + return; + case 'T': // transmit + { + uint16_t channel = ntohs(*(uint16_t*)&message->data.ptr[3]); + uint16_t reply_channel = ntohs(*(uint16_t*)&message->data.ptr[5]); + uint16_t data_length = ntohs(*(uint16_t*)&message->data.ptr[7]); + modem_transmit(client, request_id, channel, reply_channel, data, data_length); + } + return; + default: + ws_send_error(connection, request_id, "Unknown opcode: 0x%02x", message->data.ptr[0]); + connection->is_draining = 1; + return; + } +} + +static void on_ws_disconnect(struct mg_connection *connection, void *data) { + struct client *client = *(struct client **)&connection->data[0]; + if (client->connection == connection) { + free(client); + } +} + + +static void modem_open(struct client *client, uint16_t request_id, uint16_t channel) { + for (int i = 0; i < client->next_open_channel_index; i++) { + if (client->open_channels[i] == channel) { + ws_respond(client->connection, request_id, NULL, 0); + return; + } + } + if (client->next_open_channel_index == MAX_OPEN_CHANNELS) { + ws_send_error(client->connection, request_id, "Too many open channels"); + return; + } + client->open_channels[client->next_open_channel_index] = channel; + client->next_open_channel_index++; + ws_respond(client->connection, request_id, NULL, 0); +} + +static void modem_isOpen(struct client *client, uint16_t request_id, uint16_t channel) { + unsigned char is_open = 0; + for (int i = 0; i < client->next_open_channel_index; i++) { + if (client->open_channels[i] == channel) { + is_open = 42; + break; + } + } + ws_respond(client->connection, request_id, &is_open, 1); +} + +static void modem_close(struct client *client, uint16_t request_id, uint16_t channel) { + for (int i = 0; i < client->next_open_channel_index; i++) { + if (client->open_channels[i] == channel) { + client->open_channels[i] = client->open_channels[client->next_open_channel_index - 1]; + client->next_open_channel_index--; + break; + } + } + ws_respond(client->connection, request_id, NULL, 0); +} + +static void modem_closeAll(struct client *client, uint16_t request_id) { + client->next_open_channel_index = 0; + memset(client->open_channels, 0, sizeof(uint16_t) * MAX_OPEN_CHANNELS); + ws_respond(client->connection, request_id, NULL, 0); +} + +static void modem_transmit(struct client *client, uint16_t request_id, uint16_t channel, uint16_t reply_channel, void *data, uint16_t size) { + static uint8_t buffer[MAX_PACKET_SIZE + 7]; + + if (size > MAX_PACKET_SIZE) { + ws_send_error(client->connection, request_id, "Packet too big: %d > %d", size, MAX_PACKET_SIZE); + return; + } + + buffer[0] = 'M'; + buffer[1] = (channel >> 8) & 0xFF; + buffer[2] = channel & 0xFF; + buffer[3] = (reply_channel >> 8) & 0xFF; + buffer[4] = reply_channel & 0xFF; + buffer[5] = (size >> 8) & 0xFF; + buffer[6] = size & 0xFF; + memcpy(&buffer[7], data, size); +} diff --git a/wsvpn.lua b/wsvpn.lua new file mode 100644 index 0000000..c79bb38 --- /dev/null +++ b/wsvpn.lua @@ -0,0 +1,143 @@ +local expect = require("cc.expect") + +local WSModem = { + open = function(self, channel) + expect.expect(1, channel, "number") + expect.range(channel, 0, 65535) + self._request(0x4f, { + bit.band(0xFF, bit.brshift(channel, 8)), + bit.band(0xFF, channel) + }) + end, + isOpen = function(self, channel) + expect.expect(1, channel, "number") + expect.range(channel, 0, 65535) + return self._request(0x6f, { + bit.band(0xFF, bit.brshift(channel, 8)), + bit.band(0xFF, channel) + })[1] ~= 0 + end, + close = function(self, channel) + expect.expect(1, channel, "number") + expect.range(channel, 0, 65535) + self._request(0x63, { + bit.band(0xFF, bit.brshift(channel, 8)), + bit.band(0xFF, channel) + }) + end, + closeAll = function(self) + self._request(0x43) + end, + transmit = function(self, channel, replyChannel, data) + expect.expect(1, channel, "number") + expect.expect(2, replyChannel, "number") + expect.expect(3, data, "nil", "string", "number", "table") + expect.range(channel, 0, 65535) + expect.range(replyChannel, 0, 65535) + + local serialized = textutils.serializeJSON(data) + expect.range(#serialized, 0, 65535) + serialized = { serialized:byte(1, 65536) } + return self._request(0x54, { + bit.band(0xFF, bit.brshift(channel, 8)), + bit.band(0xFF, channel), + bit.band(0xFF, bit.brshift(replyChannel, 8)), + bit.band(0xFF, replyChannel), + table.unpack(serialized, 1, #serialized) + }) + end, + isWireless = function(self) return true end, + run = function(self) + while true do + local data, binary = self._socket.receive() + if not data then return true end + if binary == false then return false, "Not a binary message" end + data = { string.byte(data, 1, #data) } + local opcode = table.remove(data, 1) + if opcode == 0x49 then -- info + local len, msg = self._read_u16ne(data) + msg = string.char(table.unpack(msg)) + os.queueEvent("wsvpn:info", msg) + elseif opcode == 0x45 then -- Error + local request_id, error_length + request_id, data = self._read_u16ne(data) + error_length, data = self._read_u16ne(data) + local message = string.char(table.unpack(data, 1, error_length)) + os.queueEvent("wsvpn:response", false, request_id, message) + elseif opcode == 0x52 then -- Response + local request_id, response = self._read_u16ne(data) + os.queueEvent("wsvpn:response", true, request_id, response) + else + return false, string.format("Invalid opcode 0x%02x", opcode) + end + os.sleep(0) + end + end, + + -- low-level part + + _read_u16ne = function(self, data) + local v = bit.blshift(table.remove(data, 1), 8) + v = bit.bor(v, table.remove(data, 1)) + return v, data + end, + + _wait_response = function(self, request_id) + while true do + local ev, status, id, data = os.pullEvent("wsvpn:response") + if ev == "wsvpn:response" and id == request_id then + return status, data + end + end + end, + + _request = function(self, opcode, data) + local request_id = self._get_id() + self._socket.send( + string.char( + opcode, + bit.band(0xFF, bit.brshift(request_id, 8)), + bit.band(0xFF, request_id), + table.unpack(data or {}) + ), + true + ) + local status, response = self._wait_response(request_id) + if not status then + error(response) + end + return response + end, + + _get_id = function(self) + self._req_id = bit.band(0xFFFF, self._req_id + 1) + return self._req_id + end, + + _send_text = function(self, code, fmt, ...) + local msg = { fmt:format(...):byte(1, 1020) } + self._socket.send( + string.char( + code, + bit.band(0xFF, bit.brshift(#msg, 8)), + bit.band(0xFF, #msg), + table.unpack(msg, 1, #msg) + ), + true + ) + end, + + _init = function(self) + self._send_text(0x49, "Hello! I'm computer %d", os.getComputerID()) + end, +} + +return function(addr) + local ws = assert(http.websocket(addr)) + local sock = setmetatable({ _socket = ws, _req_id = 0 }, { __index = WSModem }) + for name, method in pairs(WSModem) do + sock[name] = function(...) return method(sock, ...) end + end + sock._init() + return sock +end