| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612 |
- /**
- * @file ServerConnection.c
- * @author Ambroz Bizjak <ambrop7@gmail.com>
- *
- * @section LICENSE
- *
- * This file is part of BadVPN.
- *
- * BadVPN is free software: you can redistribute it and/or modify
- * it under the terms of the GNU General Public License version 2
- * as published by the Free Software Foundation.
- *
- * BadVPN is distributed in the hope that it will be useful,
- * but WITHOUT ANY WARRANTY; without even the implied warranty of
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- * GNU General Public License for more details.
- *
- * You should have received a copy of the GNU General Public License along
- * with this program; if not, write to the Free Software Foundation, Inc.,
- * 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
- */
- #include <stdio.h>
- #include <misc/debug.h>
- #include <base/BLog.h>
- #include <server_connection/ServerConnection.h>
- #include <generated/blog_channel_ServerConnection.h>
- #define STATE_CONNECTING 1
- #define STATE_WAITINIT 2
- #define STATE_COMPLETE 3
- static void report_error (ServerConnection *o);
- static void connector_handler (ServerConnection *o, int is_error);
- static void pending_handler (ServerConnection *o);
- static SECStatus client_auth_data_callback (ServerConnection *o, PRFileDesc *fd, CERTDistNames *caNames, CERTCertificate **pRetCert, SECKEYPrivateKey **pRetKey);
- static void connection_handler (ServerConnection *o, int event);
- static void sslcon_handler (ServerConnection *o, int event);
- static void decoder_handler_error (ServerConnection *o);
- static void input_handler_send (ServerConnection *o, uint8_t *data, int data_len);
- static void packet_hello (ServerConnection *o, uint8_t *data, int data_len);
- static void packet_newclient (ServerConnection *o, uint8_t *data, int data_len);
- static void packet_endclient (ServerConnection *o, uint8_t *data, int data_len);
- static void packet_inmsg (ServerConnection *o, uint8_t *data, int data_len);
- static int start_packet (ServerConnection *o, void **data, int len);
- static void end_packet (ServerConnection *o, uint8_t type);
- static void newclient_job_handler (ServerConnection *o);
- void report_error (ServerConnection *o)
- {
- DEBUGERROR(&o->d_err, o->handler_error(o->user))
- }
- void connector_handler (ServerConnection *o, int is_error)
- {
- DebugObject_Access(&o->d_obj);
- ASSERT(o->state == STATE_CONNECTING)
-
- // check connection attempt result
- if (is_error) {
- BLog(BLOG_ERROR, "connection failed");
- goto fail0;
- }
-
- BLog(BLOG_NOTICE, "connected");
-
- // init connection
- if (!BConnection_Init(&o->con, BCONNECTION_SOURCE_CONNECTOR(&o->connector), o->reactor, o, (BConnection_handler)connection_handler)) {
- BLog(BLOG_ERROR, "BConnection_Init failed");
- goto fail0;
- }
-
- // init connection interfaces
- BConnection_SendAsync_Init(&o->con);
- BConnection_RecvAsync_Init(&o->con);
-
- StreamPassInterface *send_iface = BConnection_SendAsync_GetIf(&o->con);
- StreamRecvInterface *recv_iface = BConnection_RecvAsync_GetIf(&o->con);
-
- if (o->have_ssl) {
- // create bottom NSPR file descriptor
- if (!BSSLConnection_MakeBackend(&o->bottom_prfd, send_iface, recv_iface)) {
- BLog(BLOG_ERROR, "BSSLConnection_MakeBackend failed");
- goto fail0a;
- }
-
- // create SSL file descriptor from the bottom NSPR file descriptor
- if (!(o->ssl_prfd = SSL_ImportFD(NULL, &o->bottom_prfd))) {
- BLog(BLOG_ERROR, "SSL_ImportFD failed");
- ASSERT_FORCE(PR_Close(&o->bottom_prfd) == PR_SUCCESS)
- goto fail0a;
- }
-
- // set client mode
- if (SSL_ResetHandshake(o->ssl_prfd, PR_FALSE) != SECSuccess) {
- BLog(BLOG_ERROR, "SSL_ResetHandshake failed");
- goto fail1;
- }
-
- // set server name
- if (SSL_SetURL(o->ssl_prfd, o->server_name) != SECSuccess) {
- BLog(BLOG_ERROR, "SSL_SetURL failed");
- goto fail1;
- }
-
- // set client certificate callback
- if (SSL_GetClientAuthDataHook(o->ssl_prfd, (SSLGetClientAuthData)client_auth_data_callback, o) != SECSuccess) {
- BLog(BLOG_ERROR, "SSL_GetClientAuthDataHook failed");
- goto fail1;
- }
-
- // init BSSLConnection
- BSSLConnection_Init(&o->sslcon, o->ssl_prfd, 0, BReactor_PendingGroup(o->reactor), o, (BSSLConnection_handler)sslcon_handler);
-
- send_iface = BSSLConnection_GetSendIf(&o->sslcon);
- recv_iface = BSSLConnection_GetRecvIf(&o->sslcon);
- }
-
- // init input chain
- PacketPassInterface_Init(&o->input_interface, SC_MAX_ENC, (PacketPassInterface_handler_send)input_handler_send, o, BReactor_PendingGroup(o->reactor));
- if (!PacketProtoDecoder_Init(&o->input_decoder, recv_iface, &o->input_interface, BReactor_PendingGroup(o->reactor), o, (PacketProtoDecoder_handler_error)decoder_handler_error)) {
- BLog(BLOG_ERROR, "PacketProtoDecoder_Init failed");
- goto fail2;
- }
-
- // set job to send hello
- // this needs to be in here because hello sending must be done after sending started (so we can write into the send buffer),
- // but before receiving started (so we don't get into conflict with the user sending packets)
- BPending_Init(&o->start_job, BReactor_PendingGroup(o->reactor), (BPending_handler)pending_handler, o);
- BPending_Set(&o->start_job);
-
- // init keepalive output branch
- SCKeepaliveSource_Init(&o->output_ka_zero, BReactor_PendingGroup(o->reactor));
- PacketProtoEncoder_Init(&o->output_ka_encoder, SCKeepaliveSource_GetOutput(&o->output_ka_zero), BReactor_PendingGroup(o->reactor));
-
- // init output common
-
- // init sender
- PacketStreamSender_Init(&o->output_sender, send_iface, PACKETPROTO_ENCLEN(SC_MAX_ENC), BReactor_PendingGroup(o->reactor));
-
- // init keepalives
- if (!KeepaliveIO_Init(&o->output_keepaliveio, o->reactor, PacketStreamSender_GetInput(&o->output_sender), PacketProtoEncoder_GetOutput(&o->output_ka_encoder), o->keepalive_interval)) {
- BLog(BLOG_ERROR, "KeepaliveIO_Init failed");
- goto fail3;
- }
-
- // init queue
- PacketPassPriorityQueue_Init(&o->output_queue, KeepaliveIO_GetInput(&o->output_keepaliveio), BReactor_PendingGroup(o->reactor), 0);
-
- // init output local flow
-
- // init queue flow
- PacketPassPriorityQueueFlow_Init(&o->output_local_qflow, &o->output_queue, 0);
-
- // init PacketProtoFlow
- if (!PacketProtoFlow_Init(&o->output_local_oflow, SC_MAX_ENC, o->buffer_size, PacketPassPriorityQueueFlow_GetInput(&o->output_local_qflow), BReactor_PendingGroup(o->reactor))) {
- BLog(BLOG_ERROR, "PacketProtoFlow_Init failed");
- goto fail4;
- }
- o->output_local_if = PacketProtoFlow_GetInput(&o->output_local_oflow);
-
- // have no output packet
- o->output_local_packet_len = -1;
-
- // init output user flow
- PacketPassPriorityQueueFlow_Init(&o->output_user_qflow, &o->output_queue, 1);
-
- // update state
- o->state = STATE_WAITINIT;
-
- return;
-
- fail4:
- PacketPassPriorityQueueFlow_Free(&o->output_local_qflow);
- PacketPassPriorityQueue_Free(&o->output_queue);
- KeepaliveIO_Free(&o->output_keepaliveio);
- fail3:
- PacketStreamSender_Free(&o->output_sender);
- PacketProtoEncoder_Free(&o->output_ka_encoder);
- SCKeepaliveSource_Free(&o->output_ka_zero);
- BPending_Free(&o->start_job);
- PacketProtoDecoder_Free(&o->input_decoder);
- fail2:
- PacketPassInterface_Free(&o->input_interface);
- if (o->have_ssl) {
- BSSLConnection_Free(&o->sslcon);
- fail1:
- ASSERT_FORCE(PR_Close(o->ssl_prfd) == PR_SUCCESS)
- }
- fail0a:
- BConnection_RecvAsync_Free(&o->con);
- BConnection_SendAsync_Free(&o->con);
- BConnection_Free(&o->con);
- fail0:
- // report error
- report_error(o);
- }
- void pending_handler (ServerConnection *o)
- {
- ASSERT(o->state == STATE_WAITINIT)
- DebugObject_Access(&o->d_obj);
-
- // send hello
- struct sc_client_hello *packet;
- if (!start_packet(o, (void **)&packet, sizeof(struct sc_client_hello))) {
- BLog(BLOG_ERROR, "no buffer for hello");
- report_error(o);
- return;
- }
- packet->version = htol16(SC_VERSION);
- end_packet(o, SCID_CLIENTHELLO);
- }
- SECStatus client_auth_data_callback (ServerConnection *o, PRFileDesc *fd, CERTDistNames *caNames, CERTCertificate **pRetCert, SECKEYPrivateKey **pRetKey)
- {
- ASSERT(o->have_ssl)
- DebugObject_Access(&o->d_obj);
-
- CERTCertificate *newcert;
- if (!(newcert = CERT_DupCertificate(o->client_cert))) {
- return SECFailure;
- }
-
- SECKEYPrivateKey *newkey;
- if (!(newkey = SECKEY_CopyPrivateKey(o->client_key))) {
- CERT_DestroyCertificate(newcert);
- return SECFailure;
- }
-
- *pRetCert = newcert;
- *pRetKey = newkey;
- return SECSuccess;
- }
- void connection_handler (ServerConnection *o, int event)
- {
- DebugObject_Access(&o->d_obj);
- ASSERT(o->state >= STATE_WAITINIT)
-
- if (event == BCONNECTION_EVENT_RECVCLOSED) {
- BLog(BLOG_INFO, "connection closed");
- } else {
- BLog(BLOG_INFO, "connection error");
- }
-
- report_error(o);
- return;
- }
- void sslcon_handler (ServerConnection *o, int event)
- {
- DebugObject_Access(&o->d_obj);
- ASSERT(o->have_ssl)
- ASSERT(o->state >= STATE_WAITINIT)
- ASSERT(event == BSSLCONNECTION_EVENT_ERROR)
-
- BLog(BLOG_ERROR, "SSL error");
-
- report_error(o);
- return;
- }
- void decoder_handler_error (ServerConnection *o)
- {
- DebugObject_Access(&o->d_obj);
- ASSERT(o->state >= STATE_WAITINIT)
-
- BLog(BLOG_ERROR, "decoder error");
-
- report_error(o);
- return;
- }
- void input_handler_send (ServerConnection *o, uint8_t *data, int data_len)
- {
- ASSERT(o->state >= STATE_WAITINIT)
- ASSERT(data_len >= 0)
- ASSERT(data_len <= SC_MAX_ENC)
- DebugObject_Access(&o->d_obj);
-
- // accept packet
- PacketPassInterface_Done(&o->input_interface);
-
- // parse header
- if (data_len < sizeof(struct sc_header)) {
- BLog(BLOG_ERROR, "packet too short (no sc header)");
- report_error(o);
- return;
- }
- struct sc_header *header = (struct sc_header *)data;
- data += sizeof(*header);
- data_len -= sizeof(*header);
- uint8_t type = ltoh8(header->type);
-
- // call appropriate handler based on packet type
- switch (type) {
- case SCID_SERVERHELLO:
- packet_hello(o, data, data_len);
- return;
- case SCID_NEWCLIENT:
- packet_newclient(o, data, data_len);
- return;
- case SCID_ENDCLIENT:
- packet_endclient(o, data, data_len);
- return;
- case SCID_INMSG:
- packet_inmsg(o, data, data_len);
- return;
- default:
- BLog(BLOG_ERROR, "unknown packet type %d", (int)type);
- report_error(o);
- return;
- }
- }
- void packet_hello (ServerConnection *o, uint8_t *data, int data_len)
- {
- if (o->state != STATE_WAITINIT) {
- BLog(BLOG_ERROR, "hello: not expected");
- report_error(o);
- return;
- }
-
- if (data_len != sizeof(struct sc_server_hello)) {
- BLog(BLOG_ERROR, "hello: invalid length");
- report_error(o);
- return;
- }
- struct sc_server_hello *msg = (struct sc_server_hello *)data;
- peerid_t id = ltoh16(msg->id);
-
- // change state
- o->state = STATE_COMPLETE;
-
- // report
- o->handler_ready(o->user, id, msg->clientAddr);
- return;
- }
- void packet_newclient (ServerConnection *o, uint8_t *data, int data_len)
- {
- if (o->state != STATE_COMPLETE) {
- BLog(BLOG_ERROR, "newclient: not expected");
- report_error(o);
- return;
- }
-
- if (data_len < sizeof(struct sc_server_newclient) || data_len > sizeof(struct sc_server_newclient) + SCID_NEWCLIENT_MAX_CERT_LEN) {
- BLog(BLOG_ERROR, "newclient: invalid length");
- report_error(o);
- return;
- }
-
- struct sc_server_newclient *msg = (struct sc_server_newclient *)data;
- peerid_t id = ltoh16(msg->id);
-
- // schedule reporting new client
- o->newclient_data = data;
- o->newclient_data_len = data_len;
- BPending_Set(&o->newclient_job);
-
- // send acceptpeer
- struct sc_client_acceptpeer *packet;
- if (!start_packet(o, (void **)&packet, sizeof(*packet))) {
- BLog(BLOG_ERROR, "newclient: out of buffer for acceptpeer");
- report_error(o);
- return;
- }
- packet->clientid = htol16(id);
- end_packet(o, SCID_ACCEPTPEER);
- }
- void packet_endclient (ServerConnection *o, uint8_t *data, int data_len)
- {
- if (o->state != STATE_COMPLETE) {
- BLog(BLOG_ERROR, "endclient: not expected");
- report_error(o);
- return;
- }
-
- if (data_len != sizeof(struct sc_server_endclient)) {
- BLog(BLOG_ERROR, "endclient: invalid length");
- report_error(o);
- return;
- }
-
- struct sc_server_endclient *msg = (struct sc_server_endclient *)data;
- peerid_t id = ltoh16(msg->id);
-
- // report
- o->handler_endclient(o->user, id);
- return;
- }
- void packet_inmsg (ServerConnection *o, uint8_t *data, int data_len)
- {
- if (o->state != STATE_COMPLETE) {
- BLog(BLOG_ERROR, "inmsg: not expected");
- report_error(o);
- return;
- }
-
- if (data_len < sizeof(struct sc_server_inmsg)) {
- BLog(BLOG_ERROR, "inmsg: missing header");
- report_error(o);
- return;
- }
-
- if (data_len > sizeof(struct sc_server_inmsg) + SC_MAX_MSGLEN) {
- BLog(BLOG_ERROR, "inmsg: too long");
- report_error(o);
- return;
- }
-
- struct sc_server_inmsg *msg = (struct sc_server_inmsg *)data;
- peerid_t peer_id = ltoh16(msg->clientid);
- uint8_t *payload = data + sizeof(struct sc_server_inmsg);
- int payload_len = data_len - sizeof(struct sc_server_inmsg);
-
- // report
- o->handler_message(o->user, peer_id, payload, payload_len);
- return;
- }
- int start_packet (ServerConnection *o, void **data, int len)
- {
- ASSERT(o->state >= STATE_WAITINIT)
- ASSERT(o->output_local_packet_len == -1)
- ASSERT(len >= 0)
- ASSERT(len <= SC_MAX_PAYLOAD)
- ASSERT(data || len == 0)
-
- // obtain memory location
- if (!BufferWriter_StartPacket(o->output_local_if, &o->output_local_packet)) {
- BLog(BLOG_ERROR, "out of buffer");
- return 0;
- }
-
- o->output_local_packet_len = len;
-
- if (data) {
- *data = o->output_local_packet + sizeof(struct sc_header);
- }
-
- return 1;
- }
- void end_packet (ServerConnection *o, uint8_t type)
- {
- ASSERT(o->state >= STATE_WAITINIT)
- ASSERT(o->output_local_packet_len >= 0)
- ASSERT(o->output_local_packet_len <= SC_MAX_PAYLOAD)
-
- // write header
- struct sc_header *header = (struct sc_header *)o->output_local_packet;
- header->type = htol8(type);
-
- // finish writing packet
- BufferWriter_EndPacket(o->output_local_if, sizeof(struct sc_header) + o->output_local_packet_len);
-
- o->output_local_packet_len = -1;
- }
- int ServerConnection_Init (
- ServerConnection *o,
- BReactor *reactor,
- BAddr addr,
- int keepalive_interval,
- int buffer_size,
- int have_ssl,
- CERTCertificate *client_cert,
- SECKEYPrivateKey *client_key,
- const char *server_name,
- void *user,
- ServerConnection_handler_error handler_error,
- ServerConnection_handler_ready handler_ready,
- ServerConnection_handler_newclient handler_newclient,
- ServerConnection_handler_endclient handler_endclient,
- ServerConnection_handler_message handler_message
- )
- {
- ASSERT(addr.type == BADDR_TYPE_IPV4 || addr.type == BADDR_TYPE_IPV6)
- ASSERT(keepalive_interval > 0)
- ASSERT(buffer_size > 0)
- ASSERT(have_ssl == 0 || have_ssl == 1)
-
- // init arguments
- o->reactor = reactor;
- o->keepalive_interval = keepalive_interval;
- o->buffer_size = buffer_size;
- o->have_ssl = have_ssl;
- if (have_ssl) {
- o->client_cert = client_cert;
- o->client_key = client_key;
- snprintf(o->server_name, sizeof(o->server_name), "%s", server_name);
- }
- o->user = user;
- o->handler_error = handler_error;
- o->handler_ready = handler_ready;
- o->handler_newclient = handler_newclient;
- o->handler_endclient = handler_endclient;
- o->handler_message = handler_message;
-
- // init connector
- if (!BConnector_Init(&o->connector, addr, o->reactor, o, (BConnector_handler)connector_handler)) {
- BLog(BLOG_ERROR, "BConnector_Init failed");
- goto fail0;
- }
-
- // init newclient job
- BPending_Init(&o->newclient_job, BReactor_PendingGroup(o->reactor), (BPending_handler)newclient_job_handler, o);
-
- // set state
- o->state = STATE_CONNECTING;
-
- DebugError_Init(&o->d_err, BReactor_PendingGroup(o->reactor));
- DebugObject_Init(&o->d_obj);
- return 1;
-
- fail0:
- return 0;
- }
- void ServerConnection_Free (ServerConnection *o)
- {
- DebugObject_Free(&o->d_obj);
- DebugError_Free(&o->d_err);
-
- if (o->state > STATE_CONNECTING) {
- // allow freeing queue flows
- PacketPassPriorityQueue_PrepareFree(&o->output_queue);
-
- // free output user flow
- PacketPassPriorityQueueFlow_Free(&o->output_user_qflow);
-
- // free output local flow
- PacketProtoFlow_Free(&o->output_local_oflow);
- PacketPassPriorityQueueFlow_Free(&o->output_local_qflow);
-
- // free output common
- PacketPassPriorityQueue_Free(&o->output_queue);
- KeepaliveIO_Free(&o->output_keepaliveio);
- PacketStreamSender_Free(&o->output_sender);
-
- // free output keep-alive branch
- PacketProtoEncoder_Free(&o->output_ka_encoder);
- SCKeepaliveSource_Free(&o->output_ka_zero);
-
- // free job
- BPending_Free(&o->start_job);
-
- // free input chain
- PacketProtoDecoder_Free(&o->input_decoder);
- PacketPassInterface_Free(&o->input_interface);
-
- // free SSL
- if (o->have_ssl) {
- BSSLConnection_Free(&o->sslcon);
- ASSERT_FORCE(PR_Close(o->ssl_prfd) == PR_SUCCESS)
- }
-
- // free connection interfaces
- BConnection_RecvAsync_Free(&o->con);
- BConnection_SendAsync_Free(&o->con);
-
- // free connection
- BConnection_Free(&o->con);
- }
-
- // free newclient job
- BPending_Free(&o->newclient_job);
-
- // free connector
- BConnector_Free(&o->connector);
- }
- int ServerConnection_IsReady (ServerConnection *o)
- {
- DebugObject_Access(&o->d_obj);
-
- return (o->state == STATE_COMPLETE);
- }
- PacketPassInterface * ServerConnection_GetSendInterface (ServerConnection *o)
- {
- ASSERT(o->state == STATE_COMPLETE)
- DebugError_AssertNoError(&o->d_err);
- DebugObject_Access(&o->d_obj);
-
- return PacketPassPriorityQueueFlow_GetInput(&o->output_user_qflow);
- }
- void newclient_job_handler (ServerConnection *o)
- {
- DebugObject_Access(&o->d_obj);
- ASSERT(o->state == STATE_COMPLETE)
-
- struct sc_server_newclient *msg = (struct sc_server_newclient *)o->newclient_data;
- peerid_t id = ltoh16(msg->id);
- int flags = ltoh16(msg->flags);
-
- uint8_t *cert_data = (uint8_t *)(msg + 1);
- int cert_len = o->newclient_data_len - sizeof(*msg);
-
- // report new client
- o->handler_newclient(o->user, id, flags, cert_data, cert_len);
- return;
- }
|