/** * @file ServerConnection.c * @author Ambroz Bizjak * * @section LICENSE * * This file is part of BadVPN. * * BadVPN is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License version 2 * as published by the Free Software Foundation. * * BadVPN is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License along * with this program; if not, write to the Free Software Foundation, Inc., * 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. */ #include #include #include #include #include #define STATE_CONNECTING 1 #define STATE_WAITINIT 2 #define STATE_COMPLETE 3 static void report_error (ServerConnection *o); static void connector_handler (ServerConnection *o, int is_error); static void pending_handler (ServerConnection *o); static SECStatus client_auth_data_callback (ServerConnection *o, PRFileDesc *fd, CERTDistNames *caNames, CERTCertificate **pRetCert, SECKEYPrivateKey **pRetKey); static void connection_handler (ServerConnection *o, int event); static void sslcon_handler (ServerConnection *o, int event); static void decoder_handler_error (ServerConnection *o); static void input_handler_send (ServerConnection *o, uint8_t *data, int data_len); static void packet_hello (ServerConnection *o, uint8_t *data, int data_len); static void packet_newclient (ServerConnection *o, uint8_t *data, int data_len); static void packet_endclient (ServerConnection *o, uint8_t *data, int data_len); static void packet_inmsg (ServerConnection *o, uint8_t *data, int data_len); static int start_packet (ServerConnection *o, void **data, int len); static void end_packet (ServerConnection *o, uint8_t type); static void newclient_job_handler (ServerConnection *o); void report_error (ServerConnection *o) { DEBUGERROR(&o->d_err, o->handler_error(o->user)) } void connector_handler (ServerConnection *o, int is_error) { DebugObject_Access(&o->d_obj); ASSERT(o->state == STATE_CONNECTING) // check connection attempt result if (is_error) { BLog(BLOG_ERROR, "connection failed"); goto fail0; } BLog(BLOG_NOTICE, "connected"); // init connection if (!BConnection_Init(&o->con, BCONNECTION_SOURCE_CONNECTOR(&o->connector), o->reactor, o, (BConnection_handler)connection_handler)) { BLog(BLOG_ERROR, "BConnection_Init failed"); goto fail0; } // init connection interfaces BConnection_SendAsync_Init(&o->con); BConnection_RecvAsync_Init(&o->con); StreamPassInterface *send_iface = BConnection_SendAsync_GetIf(&o->con); StreamRecvInterface *recv_iface = BConnection_RecvAsync_GetIf(&o->con); if (o->have_ssl) { // create bottom NSPR file descriptor if (!BSSLConnection_MakeBackend(&o->bottom_prfd, send_iface, recv_iface)) { BLog(BLOG_ERROR, "BSSLConnection_MakeBackend failed"); goto fail0a; } // create SSL file descriptor from the bottom NSPR file descriptor if (!(o->ssl_prfd = SSL_ImportFD(NULL, &o->bottom_prfd))) { BLog(BLOG_ERROR, "SSL_ImportFD failed"); ASSERT_FORCE(PR_Close(&o->bottom_prfd) == PR_SUCCESS) goto fail0a; } // set client mode if (SSL_ResetHandshake(o->ssl_prfd, PR_FALSE) != SECSuccess) { BLog(BLOG_ERROR, "SSL_ResetHandshake failed"); goto fail1; } // set server name if (SSL_SetURL(o->ssl_prfd, o->server_name) != SECSuccess) { BLog(BLOG_ERROR, "SSL_SetURL failed"); goto fail1; } // set client certificate callback if (SSL_GetClientAuthDataHook(o->ssl_prfd, (SSLGetClientAuthData)client_auth_data_callback, o) != SECSuccess) { BLog(BLOG_ERROR, "SSL_GetClientAuthDataHook failed"); goto fail1; } // init BSSLConnection BSSLConnection_Init(&o->sslcon, o->ssl_prfd, 0, BReactor_PendingGroup(o->reactor), o, (BSSLConnection_handler)sslcon_handler); send_iface = BSSLConnection_GetSendIf(&o->sslcon); recv_iface = BSSLConnection_GetRecvIf(&o->sslcon); } // init input chain PacketPassInterface_Init(&o->input_interface, SC_MAX_ENC, (PacketPassInterface_handler_send)input_handler_send, o, BReactor_PendingGroup(o->reactor)); if (!PacketProtoDecoder_Init(&o->input_decoder, recv_iface, &o->input_interface, BReactor_PendingGroup(o->reactor), o, (PacketProtoDecoder_handler_error)decoder_handler_error)) { BLog(BLOG_ERROR, "PacketProtoDecoder_Init failed"); goto fail2; } // set job to send hello // this needs to be in here because hello sending must be done after sending started (so we can write into the send buffer), // but before receiving started (so we don't get into conflict with the user sending packets) BPending_Init(&o->start_job, BReactor_PendingGroup(o->reactor), (BPending_handler)pending_handler, o); BPending_Set(&o->start_job); // init keepalive output branch SCKeepaliveSource_Init(&o->output_ka_zero, BReactor_PendingGroup(o->reactor)); PacketProtoEncoder_Init(&o->output_ka_encoder, SCKeepaliveSource_GetOutput(&o->output_ka_zero), BReactor_PendingGroup(o->reactor)); // init output common // init sender PacketStreamSender_Init(&o->output_sender, send_iface, PACKETPROTO_ENCLEN(SC_MAX_ENC), BReactor_PendingGroup(o->reactor)); // init keepalives if (!KeepaliveIO_Init(&o->output_keepaliveio, o->reactor, PacketStreamSender_GetInput(&o->output_sender), PacketProtoEncoder_GetOutput(&o->output_ka_encoder), o->keepalive_interval)) { BLog(BLOG_ERROR, "KeepaliveIO_Init failed"); goto fail3; } // init queue PacketPassPriorityQueue_Init(&o->output_queue, KeepaliveIO_GetInput(&o->output_keepaliveio), BReactor_PendingGroup(o->reactor), 0); // init output local flow // init queue flow PacketPassPriorityQueueFlow_Init(&o->output_local_qflow, &o->output_queue, 0); // init PacketProtoFlow if (!PacketProtoFlow_Init(&o->output_local_oflow, SC_MAX_ENC, o->buffer_size, PacketPassPriorityQueueFlow_GetInput(&o->output_local_qflow), BReactor_PendingGroup(o->reactor))) { BLog(BLOG_ERROR, "PacketProtoFlow_Init failed"); goto fail4; } o->output_local_if = PacketProtoFlow_GetInput(&o->output_local_oflow); // have no output packet o->output_local_packet_len = -1; // init output user flow PacketPassPriorityQueueFlow_Init(&o->output_user_qflow, &o->output_queue, 1); // update state o->state = STATE_WAITINIT; return; fail4: PacketPassPriorityQueueFlow_Free(&o->output_local_qflow); PacketPassPriorityQueue_Free(&o->output_queue); KeepaliveIO_Free(&o->output_keepaliveio); fail3: PacketStreamSender_Free(&o->output_sender); PacketProtoEncoder_Free(&o->output_ka_encoder); SCKeepaliveSource_Free(&o->output_ka_zero); BPending_Free(&o->start_job); PacketProtoDecoder_Free(&o->input_decoder); fail2: PacketPassInterface_Free(&o->input_interface); if (o->have_ssl) { BSSLConnection_Free(&o->sslcon); fail1: ASSERT_FORCE(PR_Close(o->ssl_prfd) == PR_SUCCESS) } fail0a: BConnection_RecvAsync_Free(&o->con); BConnection_SendAsync_Free(&o->con); BConnection_Free(&o->con); fail0: // report error report_error(o); } void pending_handler (ServerConnection *o) { ASSERT(o->state == STATE_WAITINIT) DebugObject_Access(&o->d_obj); // send hello struct sc_client_hello *packet; if (!start_packet(o, (void **)&packet, sizeof(struct sc_client_hello))) { BLog(BLOG_ERROR, "no buffer for hello"); report_error(o); return; } packet->version = htol16(SC_VERSION); end_packet(o, SCID_CLIENTHELLO); } SECStatus client_auth_data_callback (ServerConnection *o, PRFileDesc *fd, CERTDistNames *caNames, CERTCertificate **pRetCert, SECKEYPrivateKey **pRetKey) { ASSERT(o->have_ssl) DebugObject_Access(&o->d_obj); CERTCertificate *newcert; if (!(newcert = CERT_DupCertificate(o->client_cert))) { return SECFailure; } SECKEYPrivateKey *newkey; if (!(newkey = SECKEY_CopyPrivateKey(o->client_key))) { CERT_DestroyCertificate(newcert); return SECFailure; } *pRetCert = newcert; *pRetKey = newkey; return SECSuccess; } void connection_handler (ServerConnection *o, int event) { DebugObject_Access(&o->d_obj); ASSERT(o->state >= STATE_WAITINIT) if (event == BCONNECTION_EVENT_RECVCLOSED) { BLog(BLOG_INFO, "connection closed"); } else { BLog(BLOG_INFO, "connection error"); } report_error(o); return; } void sslcon_handler (ServerConnection *o, int event) { DebugObject_Access(&o->d_obj); ASSERT(o->have_ssl) ASSERT(o->state >= STATE_WAITINIT) ASSERT(event == BSSLCONNECTION_EVENT_ERROR) BLog(BLOG_ERROR, "SSL error"); report_error(o); return; } void decoder_handler_error (ServerConnection *o) { DebugObject_Access(&o->d_obj); ASSERT(o->state >= STATE_WAITINIT) BLog(BLOG_ERROR, "decoder error"); report_error(o); return; } void input_handler_send (ServerConnection *o, uint8_t *data, int data_len) { ASSERT(o->state >= STATE_WAITINIT) ASSERT(data_len >= 0) ASSERT(data_len <= SC_MAX_ENC) DebugObject_Access(&o->d_obj); // accept packet PacketPassInterface_Done(&o->input_interface); // parse header if (data_len < sizeof(struct sc_header)) { BLog(BLOG_ERROR, "packet too short (no sc header)"); report_error(o); return; } struct sc_header *header = (struct sc_header *)data; data += sizeof(*header); data_len -= sizeof(*header); uint8_t type = ltoh8(header->type); // call appropriate handler based on packet type switch (type) { case SCID_SERVERHELLO: packet_hello(o, data, data_len); return; case SCID_NEWCLIENT: packet_newclient(o, data, data_len); return; case SCID_ENDCLIENT: packet_endclient(o, data, data_len); return; case SCID_INMSG: packet_inmsg(o, data, data_len); return; default: BLog(BLOG_ERROR, "unknown packet type %d", (int)type); report_error(o); return; } } void packet_hello (ServerConnection *o, uint8_t *data, int data_len) { if (o->state != STATE_WAITINIT) { BLog(BLOG_ERROR, "hello: not expected"); report_error(o); return; } if (data_len != sizeof(struct sc_server_hello)) { BLog(BLOG_ERROR, "hello: invalid length"); report_error(o); return; } struct sc_server_hello *msg = (struct sc_server_hello *)data; peerid_t id = ltoh16(msg->id); // change state o->state = STATE_COMPLETE; // report o->handler_ready(o->user, id, msg->clientAddr); return; } void packet_newclient (ServerConnection *o, uint8_t *data, int data_len) { if (o->state != STATE_COMPLETE) { BLog(BLOG_ERROR, "newclient: not expected"); report_error(o); return; } if (data_len < sizeof(struct sc_server_newclient) || data_len > sizeof(struct sc_server_newclient) + SCID_NEWCLIENT_MAX_CERT_LEN) { BLog(BLOG_ERROR, "newclient: invalid length"); report_error(o); return; } struct sc_server_newclient *msg = (struct sc_server_newclient *)data; peerid_t id = ltoh16(msg->id); // schedule reporting new client o->newclient_data = data; o->newclient_data_len = data_len; BPending_Set(&o->newclient_job); // send acceptpeer struct sc_client_acceptpeer *packet; if (!start_packet(o, (void **)&packet, sizeof(*packet))) { BLog(BLOG_ERROR, "newclient: out of buffer for acceptpeer"); report_error(o); return; } packet->clientid = htol16(id); end_packet(o, SCID_ACCEPTPEER); } void packet_endclient (ServerConnection *o, uint8_t *data, int data_len) { if (o->state != STATE_COMPLETE) { BLog(BLOG_ERROR, "endclient: not expected"); report_error(o); return; } if (data_len != sizeof(struct sc_server_endclient)) { BLog(BLOG_ERROR, "endclient: invalid length"); report_error(o); return; } struct sc_server_endclient *msg = (struct sc_server_endclient *)data; peerid_t id = ltoh16(msg->id); // report o->handler_endclient(o->user, id); return; } void packet_inmsg (ServerConnection *o, uint8_t *data, int data_len) { if (o->state != STATE_COMPLETE) { BLog(BLOG_ERROR, "inmsg: not expected"); report_error(o); return; } if (data_len < sizeof(struct sc_server_inmsg)) { BLog(BLOG_ERROR, "inmsg: missing header"); report_error(o); return; } if (data_len > sizeof(struct sc_server_inmsg) + SC_MAX_MSGLEN) { BLog(BLOG_ERROR, "inmsg: too long"); report_error(o); return; } struct sc_server_inmsg *msg = (struct sc_server_inmsg *)data; peerid_t peer_id = ltoh16(msg->clientid); uint8_t *payload = data + sizeof(struct sc_server_inmsg); int payload_len = data_len - sizeof(struct sc_server_inmsg); // report o->handler_message(o->user, peer_id, payload, payload_len); return; } int start_packet (ServerConnection *o, void **data, int len) { ASSERT(o->state >= STATE_WAITINIT) ASSERT(o->output_local_packet_len == -1) ASSERT(len >= 0) ASSERT(len <= SC_MAX_PAYLOAD) ASSERT(data || len == 0) // obtain memory location if (!BufferWriter_StartPacket(o->output_local_if, &o->output_local_packet)) { BLog(BLOG_ERROR, "out of buffer"); return 0; } o->output_local_packet_len = len; if (data) { *data = o->output_local_packet + sizeof(struct sc_header); } return 1; } void end_packet (ServerConnection *o, uint8_t type) { ASSERT(o->state >= STATE_WAITINIT) ASSERT(o->output_local_packet_len >= 0) ASSERT(o->output_local_packet_len <= SC_MAX_PAYLOAD) // write header struct sc_header *header = (struct sc_header *)o->output_local_packet; header->type = htol8(type); // finish writing packet BufferWriter_EndPacket(o->output_local_if, sizeof(struct sc_header) + o->output_local_packet_len); o->output_local_packet_len = -1; } int ServerConnection_Init ( ServerConnection *o, BReactor *reactor, BAddr addr, int keepalive_interval, int buffer_size, int have_ssl, CERTCertificate *client_cert, SECKEYPrivateKey *client_key, const char *server_name, void *user, ServerConnection_handler_error handler_error, ServerConnection_handler_ready handler_ready, ServerConnection_handler_newclient handler_newclient, ServerConnection_handler_endclient handler_endclient, ServerConnection_handler_message handler_message ) { ASSERT(keepalive_interval > 0) ASSERT(buffer_size > 0) ASSERT(have_ssl == 0 || have_ssl == 1) // init arguments o->reactor = reactor; o->keepalive_interval = keepalive_interval; o->buffer_size = buffer_size; o->have_ssl = have_ssl; if (have_ssl) { o->client_cert = client_cert; o->client_key = client_key; snprintf(o->server_name, sizeof(o->server_name), "%s", server_name); } o->user = user; o->handler_error = handler_error; o->handler_ready = handler_ready; o->handler_newclient = handler_newclient; o->handler_endclient = handler_endclient; o->handler_message = handler_message; if (!BConnection_AddressSupported(addr)) { BLog(BLOG_ERROR, "BConnection_AddressSupported failed"); goto fail0; } // init connector if (!BConnector_Init(&o->connector, addr, o->reactor, o, (BConnector_handler)connector_handler)) { BLog(BLOG_ERROR, "BConnector_Init failed"); goto fail0; } // init newclient job BPending_Init(&o->newclient_job, BReactor_PendingGroup(o->reactor), (BPending_handler)newclient_job_handler, o); // set state o->state = STATE_CONNECTING; DebugError_Init(&o->d_err, BReactor_PendingGroup(o->reactor)); DebugObject_Init(&o->d_obj); return 1; fail0: return 0; } void ServerConnection_Free (ServerConnection *o) { DebugObject_Free(&o->d_obj); DebugError_Free(&o->d_err); if (o->state > STATE_CONNECTING) { // allow freeing queue flows PacketPassPriorityQueue_PrepareFree(&o->output_queue); // free output user flow PacketPassPriorityQueueFlow_Free(&o->output_user_qflow); // free output local flow PacketProtoFlow_Free(&o->output_local_oflow); PacketPassPriorityQueueFlow_Free(&o->output_local_qflow); // free output common PacketPassPriorityQueue_Free(&o->output_queue); KeepaliveIO_Free(&o->output_keepaliveio); PacketStreamSender_Free(&o->output_sender); // free output keep-alive branch PacketProtoEncoder_Free(&o->output_ka_encoder); SCKeepaliveSource_Free(&o->output_ka_zero); // free job BPending_Free(&o->start_job); // free input chain PacketProtoDecoder_Free(&o->input_decoder); PacketPassInterface_Free(&o->input_interface); // free SSL if (o->have_ssl) { BSSLConnection_Free(&o->sslcon); ASSERT_FORCE(PR_Close(o->ssl_prfd) == PR_SUCCESS) } // free connection interfaces BConnection_RecvAsync_Free(&o->con); BConnection_SendAsync_Free(&o->con); // free connection BConnection_Free(&o->con); } // free newclient job BPending_Free(&o->newclient_job); // free connector BConnector_Free(&o->connector); } PacketPassInterface * ServerConnection_GetSendInterface (ServerConnection *o) { ASSERT(o->state == STATE_COMPLETE) DebugError_AssertNoError(&o->d_err); DebugObject_Access(&o->d_obj); return PacketPassPriorityQueueFlow_GetInput(&o->output_user_qflow); } void newclient_job_handler (ServerConnection *o) { DebugObject_Access(&o->d_obj); ASSERT(o->state == STATE_COMPLETE) struct sc_server_newclient *msg = (struct sc_server_newclient *)o->newclient_data; peerid_t id = ltoh16(msg->id); int flags = ltoh16(msg->flags); uint8_t *cert_data = (uint8_t *)(msg + 1); int cert_len = o->newclient_data_len - sizeof(*msg); // report new client o->handler_newclient(o->user, id, flags, cert_data, cert_len); return; }