cc-stuff/wsvpn.c

382 lines
13 KiB
C

// x-run: ~/scripts/runc.sh % -lmongoose -Wall -Wextra
#include <mongoose.h>
#include <netinet/in.h>
#include <stdarg.h>
#include <stdint.h>
#include <stdio.h>
#include <string.h>
#define MAX_CLIENTS 256
#define MAX_OPEN_CHANNELS 128
#define MAX_PACKET_SIZE 32767
/*
```c
// ALL integers are network-endian
struct msg_info { // side: both. safe to ignore
uint8_t code; // 0x49, 'I'
uint16_t len;
char msg[1024]; // $.len bytes sent
};
struct msg_res_error {
uint8_t code; // 0x45, 'E'
uint16_t req_id; // request ID that caused that error
uint16_t len;
char msg[1024]; // $.len bytes sent
};
struct msg_address { // side: server
uint8_t code; // 0x41, 'A'
uint8_t size;
char name[256]; // $.size long
};
struct msg_res_success { // side: server
uint8_t code; // 0x52, 'R'
uint16_t req_id; // request ID we're replying to
void *data; // packet-specific
};
struct msg_transmission { // side: server
uint8_t code; // 0x54, 'T'
uint16_t channel;
uint16_t replyChannel;
uint16_t size;
void *data;
};
struct msg_req_open { // side: client
uint8_t code; // 0x4f, 'O'
uint16_t req_id; // incremental request ID
uint16_t channel; // channel to be open
};
```
*/
struct client {
struct mg_connection *connection;
uint16_t open_channels[MAX_OPEN_CHANNELS];
int next_open_channel_index;
bool receive_all;
};
struct client clients[MAX_CLIENTS] = { 0 };
static void handle_client(struct mg_connection *connection, int event_type, void *ev_data);
static void on_ws_connect(struct mg_connection *connection, struct mg_http_message *message);
static void on_ws_message(struct mg_connection *connection, struct mg_ws_message *message);
static void on_ws_disconnect(struct mg_connection *connection);
bool client_is_open(struct client *client, uint16_t channel);
static void modem_open(struct client *client, uint16_t request_id, uint16_t channel);
static void modem_isOpen(struct client *client, uint16_t request_id, uint16_t channel);
static void modem_close(struct client *client, uint16_t request_id, uint16_t channel);
static void modem_closeAll(struct client *client, uint16_t request_id);
static void modem_transmit(struct client *client, uint16_t request_id, uint16_t channel, uint16_t reply_channel, void *data, uint16_t size);
struct metrics {
uint64_t sent_bytes;
uint64_t sent_messages;
uint64_t received_bytes;
uint64_t received_messages;
uint64_t errors;
uint64_t method_calls[5];
} metrics = { 0 };
const char method_names[5][8] = {
"open", "isOpen", "close", "closeAll", "transmit"
};
int main(void) {
const char *address = "ws://0.0.0.0:8667";
struct mg_mgr manager;
mg_mgr_init(&manager);
mg_http_listen(&manager, address, handle_client, NULL);
printf("Listening on %s\n", address);
while (1) mg_mgr_poll(&manager, 1000);
mg_mgr_free(&manager);
}
static void handle_client(struct mg_connection *connection, int event_type, void *event_data) {
if (event_type == MG_EV_OPEN) {
if (connection->rem.port == 0) return;
memset(connection->data, 0, 32);
} else if (event_type == MG_EV_HTTP_MSG) {
struct mg_http_message *http_message = (struct mg_http_message *) event_data;
if (mg_match(http_message->uri, mg_str_s("/open"), 0)) {
mg_ws_upgrade(connection, http_message, NULL);
} else if (mg_match(http_message->uri, mg_str_s("/metrics"), 0)) {
mg_printf(connection, "HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n");
mg_http_printf_chunk(connection, "# HELP ws_bytes_sent_total Number of bytes sent to clients\n");
mg_http_printf_chunk(connection, "# TYPE ws_bytes_sent_total counter\n");
mg_http_printf_chunk(connection, "ws_bytes_sent_total %ld\n", metrics.sent_bytes);
mg_http_printf_chunk(connection, "# HELP ws_bytes_received_total Number of bytes received to clients\n");
mg_http_printf_chunk(connection, "# TYPE ws_bytes_received_total counter\n");
mg_http_printf_chunk(connection, "ws_bytes_received_total %ld\n", metrics.received_bytes);
mg_http_printf_chunk(connection, "# HELP ws_messages_sent_total Number of messages sent to clients\n");
mg_http_printf_chunk(connection, "# TYPE ws_messages_sent_total counter\n");
mg_http_printf_chunk(connection, "ws_messages_sent_total %ld\n", metrics.sent_messages);
mg_http_printf_chunk(connection, "# HELP ws_messages_received_total Number of messages received to clients\n");
mg_http_printf_chunk(connection, "# TYPE ws_messages_received_total counter\n");
mg_http_printf_chunk(connection, "ws_messages_received_total %ld\n", metrics.received_messages);
mg_http_printf_chunk(connection, "# HELP ws_clients Number of active websocket clients\n");
mg_http_printf_chunk(connection, "# TYPE ws_clients gauge\n");
{
int n = 0;
for (struct mg_connection *conn = connection->mgr->conns; conn != NULL; conn = conn->next) {
if (conn->is_websocket) { n++; }
}
mg_http_printf_chunk(connection, "ws_clients %d\n", n);
}
mg_http_printf_chunk(connection, "# HELP method_calls Times each method was called\n");
mg_http_printf_chunk(connection, "# TYPE method_calls counter\n");
for (int i = 0; i < 5; i++) {
mg_http_printf_chunk(connection, "method_calls{method=\"%s\"} %ld\n", method_names[i], metrics.method_calls[i]);
}
mg_http_printf_chunk(connection, "");
} else {
mg_http_reply(connection, 404, "", "uwu");
}
} else if (event_type == MG_EV_WS_OPEN) {
struct mg_http_message *http_message = (struct mg_http_message *) event_data;
on_ws_connect(connection, http_message);
} else if (event_type == MG_EV_WS_MSG) {
struct mg_ws_message *ws_message = (struct mg_ws_message *)event_data;
on_ws_message(connection, ws_message);
} else if (event_type == MG_EV_CLOSE) {
if (connection->is_websocket) {
on_ws_disconnect(connection);
}
}
}
void ws_send_error(struct client *client, uint16_t request_id, const char *fmt, ...) {
static char buffer[1024];
memset(buffer, 0, 1024);
va_list args;
va_start(args, fmt);
int text_size = vsnprintf(&buffer[5], 1019, fmt, args);
va_end(args);
if (text_size < 0) return;
buffer[0] = 'E';
buffer[1] = (request_id >> 8) & 0xFF;
buffer[2] = request_id & 0xFF;
buffer[3] = (text_size >> 8) & 0xFF;
buffer[4] = text_size & 0xFF;
metrics.sent_bytes += 5 + text_size;
metrics.sent_messages++;
metrics.errors++;
mg_ws_send(client->connection, buffer, 5 + text_size, WEBSOCKET_OP_BINARY);
}
void ws_send_info(struct client *client, const char *fmt, ...) {
static char buffer[1024];
memset(buffer, 0, 1024);
va_list args;
va_start(args, fmt);
int text_size = vsnprintf(&buffer[3], 1021, fmt, args);
va_end(args);
if (text_size < 0) return;
buffer[0] = 'I';
buffer[1] = (text_size >> 8) & 0xFF;
buffer[2] = text_size & 0xFF;
metrics.sent_bytes += 3 + text_size;
metrics.sent_messages++;
mg_ws_send(client->connection, buffer, 3 + text_size, WEBSOCKET_OP_BINARY);
}
void ws_respond(struct client *client, uint16_t request_id, void *data, uint32_t size) {
static char buffer[MAX_PACKET_SIZE];
assert(size < MAX_PACKET_SIZE);
buffer[0] = 'R';
buffer[1] = (request_id >> 8) & 0xFF;
buffer[2] = request_id & 0xFF;
if (size != 0) memcpy(&buffer[3], data, size);
metrics.sent_bytes += 3 + size;
metrics.sent_messages++;
mg_ws_send(client->connection, buffer, size + 3, WEBSOCKET_OP_BINARY);
}
static void on_ws_connect(struct mg_connection *connection, struct mg_http_message *message) {
(void)message;
struct client *client = malloc(sizeof(struct client));
memcpy(&connection->data[0], &client, sizeof(struct client *));
client->connection = connection;
static char buffer[256];
buffer[0] = 'A';
buffer[1] = snprintf(&buffer[2], 250, "wsvpn_%ld", connection->id);
metrics.sent_bytes += 2 + buffer[1];
metrics.sent_messages++;
mg_ws_send(connection, buffer, 2 + buffer[1], WEBSOCKET_OP_BINARY);
}
static void on_ws_message(struct mg_connection *connection, struct mg_ws_message *message) {
if ((message->flags & 15) != WEBSOCKET_OP_BINARY) {
const char *err_str = "This server only works in binary mode. Sorry!";
mg_ws_send(connection, err_str, strlen(err_str), WEBSOCKET_OP_TEXT);
connection->is_draining = 1;
return;
}
struct client *client = *(struct client **)&connection->data[0];
assert(client->connection == connection);
metrics.received_bytes += message->data.len;
metrics.received_messages++;
if (message->data.len == 0) return;
uint16_t request_id = ntohs(*(uint16_t*)&message->data.buf[1]);
switch (message->data.buf[0]) {
case 'I': // info. We can safely ignore that message
break;
case 'O': // open
{
metrics.method_calls[0]++;
uint16_t channel = ntohs(*(uint16_t*)&message->data.buf[3]);
printf("%p[%04x] modem.open(%d)\n", (void*)client, request_id, channel);
modem_open(client, request_id, channel);
}
return;
case 'o': // isOpen
{
metrics.method_calls[1]++;
uint16_t channel = ntohs(*(uint16_t*)&message->data.buf[3]);
printf("%p[%04x] modem.isOpen(%d)\n", (void*)client, request_id, channel);
modem_isOpen(client, request_id, channel);
}
return;
case 'c': // close
{
metrics.method_calls[2]++;
uint16_t channel = ntohs(*(uint16_t*)&message->data.buf[3]);
printf("%p[%04x] modem.close(%d)\n", (void*)client, request_id, channel);
modem_close(client, request_id, channel);
}
return;
case 'C': // closeAll
{
metrics.method_calls[3]++;
printf("%p[%04x] modem.closeAll()\n", (void*)client, request_id);
modem_closeAll(client, request_id);
}
return;
case 'T': // transmit
{
metrics.method_calls[4]++;
uint16_t channel = ntohs(*(uint16_t*)&message->data.buf[3]);
uint16_t reply_channel = ntohs(*(uint16_t*)&message->data.buf[5]);
uint16_t data_length = ntohs(*(uint16_t*)&message->data.buf[7]);
modem_transmit(client, request_id, channel, reply_channel, (void*)&message->data.buf[9], data_length);
}
return;
default:
ws_send_error(client, request_id, "Unknown opcode: 0x%02x", message->data.buf[0]);
connection->is_draining = 1;
return;
}
}
static void on_ws_disconnect(struct mg_connection *connection) {
struct client *client = *(struct client **)&connection->data[0];
if (client->connection == connection) {
free(client);
}
}
static void modem_open(struct client *client, uint16_t request_id, uint16_t channel) {
if (client_is_open(client, channel)) {
ws_respond(client, request_id, NULL, 0);
}
if (client->next_open_channel_index == MAX_OPEN_CHANNELS) {
ws_send_error(client, request_id, "Too many open channels");
return;
}
client->open_channels[client->next_open_channel_index] = channel;
client->next_open_channel_index++;
ws_respond(client, request_id, NULL, 0);
}
static void modem_isOpen(struct client *client, uint16_t request_id, uint16_t channel) {
unsigned char is_open = client_is_open(client, channel) ? 42 : 0;
ws_respond(client, request_id, &is_open, 1);
}
static void modem_close(struct client *client, uint16_t request_id, uint16_t channel) {
for (int i = 0; i < client->next_open_channel_index; i++) {
if (client->open_channels[i] == channel) {
client->open_channels[i] = client->open_channels[client->next_open_channel_index - 1];
client->next_open_channel_index--;
break;
}
}
ws_respond(client, request_id, NULL, 0);
}
static void modem_closeAll(struct client *client, uint16_t request_id) {
client->next_open_channel_index = 0;
memset(client->open_channels, 0, sizeof(uint16_t) * MAX_OPEN_CHANNELS);
ws_respond(client, request_id, NULL, 0);
}
static void modem_transmit(struct client *client, uint16_t request_id, uint16_t channel, uint16_t reply_channel, void *data, uint16_t size) {
static uint8_t buffer[MAX_PACKET_SIZE + 7];
if (size > MAX_PACKET_SIZE) {
ws_send_error(client, request_id, "Packet too big: %d > %d", size, MAX_PACKET_SIZE);
return;
}
buffer[0] = 'T';
buffer[1] = (channel >> 8) & 0xFF;
buffer[2] = channel & 0xFF;
buffer[3] = (reply_channel >> 8) & 0xFF;
buffer[4] = reply_channel & 0xFF;
buffer[5] = (size >> 8) & 0xFF;
buffer[6] = size & 0xFF;
memcpy(&buffer[7], data, size);
for (struct mg_connection *conn = client->connection->mgr->conns; conn != NULL; conn = conn->next) {
if (conn->is_websocket) {
struct client *other_client = *(struct client **)&conn->data[0];
if (other_client->connection == conn && other_client->connection != client->connection) {
if (client_is_open(other_client, channel)) {
metrics.sent_bytes += size + 7;
metrics.sent_messages++;
mg_ws_send(other_client->connection, buffer, size + 7, WEBSOCKET_OP_BINARY);
}
}
}
}
ws_respond(client, request_id, NULL, 0);
}
bool client_is_open(struct client *client, uint16_t channel) {
for (int i = 0; i < client->next_open_channel_index; i++) {
if (client->open_channels[i] == channel) {
return true;
}
}
return false;
}