ソースを参照

Fixes and refactoring for SOCKS5 UDP.

Ambroz Bizjak 6 年 前
コミット
6241fc2978

+ 2 - 0
compile-tun2socks.sh

@@ -61,6 +61,7 @@ system/BConnection_common.c
 system/BTime.c
 system/BUnixSignal.c
 system/BNetwork.c
+system/BDatagram_common.c
 system/BDatagram_unix.c
 flow/StreamRecvInterface.c
 flow/PacketRecvInterface.c
@@ -109,6 +110,7 @@ base/BPending.c
 flowextra/PacketPassInactivityMonitor.c
 tun2socks/SocksUdpGwClient.c
 udpgw_client/UdpGwClient.c
+socks_udp_client/SocksUdpClient.c
 "
 
 set -e

+ 1 - 1
socks_udp_client/CMakeLists.txt

@@ -1 +1 @@
-badvpn_add_library(socks_udp_client "system;flow;flowextra" "" SocksUdpClient.c)
+badvpn_add_library(socks_udp_client "base;system;flow;flowextra;socksclient" "" SocksUdpClient.c)

+ 301 - 209
socks_udp_client/SocksUdpClient.c

@@ -1,5 +1,6 @@
 /*
  * Copyright (C) 2018 Jigsaw Operations LLC
+ * Copyright (C) 2019 Ambroz Bizjak (modifications)
  * 
  * Redistribution and use in source and binary forms, with or without
  * modification, are permitted provided that the following conditions are met:
@@ -24,6 +25,8 @@
  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  */
 
+#include <stddef.h>
+#include <stdint.h>
 #include <stdlib.h>
 #include <string.h>
 
@@ -31,38 +34,41 @@
 #include <misc/offset.h>
 #include <misc/byteorder.h>
 #include <misc/compare.h>
+#include <misc/socks_proto.h>
+#include <misc/debug.h>
+#include <misc/bsize.h>
 #include <base/BLog.h>
+#include <system/BAddr.h>
 
 #include <socks_udp_client/SocksUdpClient.h>
 
 #include <generated/blog_channel_SocksUdpClient.h>
 
-#define DNS_PORT 53
+static const int DnsPort = 53;
 
 static int addr_comparator (void *unused, BAddr *v1, BAddr *v2);
-static struct SocksUdpClient_connection * find_connection_by_addr (SocksUdpClient *o, BAddr addr);
-static void init_localhost4(uint32_t *ip4);
-static void init_localhost6(uint8_t ip6[16]);
-static void socks_state_handler(struct SocksUdpClient_connection *con, int event);
-static void datagram_state_handler(struct SocksUdpClient_connection *con, int event);
+static struct SocksUdpClient_connection * find_connection (SocksUdpClient *o, BAddr addr);
+static void socks_state_handler (struct SocksUdpClient_connection *con, int event);
+static void datagram_state_handler (struct SocksUdpClient_connection *con, int event);
 static void send_monitor_handler (struct SocksUdpClient_connection *con);
-static void recv_if_handler_send (struct SocksUdpClient_connection *con, uint8_t *data, int data_len);
-static struct SocksUdpClient_connection * connection_init(SocksUdpClient *o, BAddr local_addr,
-                                                          BAddr first_remote_addr,
-                                                          const uint8_t *first_data,
-                                                          int first_data_len);
+static void recv_if_handler_send (
+    struct SocksUdpClient_connection *con, uint8_t *data, int data_len);
+static struct SocksUdpClient_connection * connection_init (
+    SocksUdpClient *o, BAddr local_addr, BAddr first_remote_addr,
+    const uint8_t *first_data, int first_data_len);
 static void connection_free (struct SocksUdpClient_connection *con);
-static void connection_send (struct SocksUdpClient_connection *con, BAddr remote_addr, const uint8_t *data, int data_len);
-static void first_job_handler(struct SocksUdpClient_connection *con);
-static int compute_mtu(int udp_mtu);
-static int get_dns_id(BAddr *remote_addr, const uint8_t *data, int data_len);
+static void connection_send (struct SocksUdpClient_connection *con,
+    BAddr remote_addr, const uint8_t *data, int data_len);
+static void first_job_handler (struct SocksUdpClient_connection *con);
+static int compute_socks_mtu (int udp_mtu);
+static int get_dns_id (BAddr *remote_addr, const uint8_t *data, int data_len);
 
 int addr_comparator (void *unused, BAddr *v1, BAddr *v2)
 {
     return BAddr_CompareOrder(v1, v2);
 }
 
-struct SocksUdpClient_connection * find_connection_by_addr (SocksUdpClient *o, BAddr addr)
+struct SocksUdpClient_connection * find_connection (SocksUdpClient *o, BAddr addr)
 {
     BAVLNode *tree_node = BAVL_LookupExact(&o->connections_tree, &addr);
     if (!tree_node) {
@@ -72,195 +78,229 @@ struct SocksUdpClient_connection * find_connection_by_addr (SocksUdpClient *o, B
     return UPPER_OBJECT(tree_node, struct SocksUdpClient_connection, connections_tree_node);
 }
 
-void init_localhost4(uint32_t *ip4)
+void socks_state_handler (struct SocksUdpClient_connection *con, int event)
 {
-    *ip4 = 1<<24 | 127;
-}
-
-void init_localhost6(uint8_t ip6[16])
-{
-    memset(ip6, 0, 16);
-    ip6[15] = 1;
-}
+    DebugObject_Access(&con->client->d_obj);
 
-void socks_state_handler(struct SocksUdpClient_connection *con, int event)
-{
     switch (event) {
         case BSOCKSCLIENT_EVENT_UP: {
+            // Figure out the localhost address.
             BIPAddr localhost;
-            localhost.type = con->client->server_addr.type;
-            if (localhost.type == BADDR_TYPE_IPV4) {
-                init_localhost4(&localhost.ipv4);
-            } else if (localhost.type == BADDR_TYPE_IPV6) {
-                init_localhost6(localhost.ipv6);
-            } else {
-                BLog(BLOG_ERROR, "Bad address type");
-            }
-            // This will unblock the queue of pending packets.
-            BDatagram_SetSendAddrs(&con->socket, con->socks.bind_addr, localhost);
+            BIPAddr_InitLocalhost(&localhost, con->client->server_addr.type);
+
+            // Get the address to send datagrams to from BSocksClient.
+            BAddr remote_addr = BSocksClient_GetBindAddr(&con->socks);
+
+            // Set the local/remote send address for BDatagram.
+            // This will unblock the queue of outgoing packets.
+            BDatagram_SetSendAddrs(&con->socket, remote_addr, localhost);
         } break;
+
         case BSOCKSCLIENT_EVENT_ERROR: {
-            BLog(BLOG_ERROR, "Socks error event");
-        } // Fallthrough
+            char local_buffer[BADDR_MAX_PRINT_LEN];
+            BAddr_Print(&con->local_addr, local_buffer);
+            BLog(BLOG_ERROR,
+                "SOCKS error event for %s, removing connection.", local_buffer);
+
+            connection_free(con);
+        } break;
+
         case BSOCKSCLIENT_EVENT_ERROR_CLOSED: {
+            char local_buffer[BADDR_MAX_PRINT_LEN];
+            BAddr_Print(&con->local_addr, local_buffer);
+            BLog(BLOG_WARNING,
+                "SOCKS closed event for %s, removing connection.", local_buffer);
+
             connection_free(con);
         } break;
+
         default: {
-            BLog(BLOG_ERROR, "Unknown event");
-        }
+            BLog(BLOG_ERROR, "Unknown SOCKS event");
+        } break;
     }
 }
 
-void datagram_state_handler(struct SocksUdpClient_connection *con, int event)
+void datagram_state_handler (struct SocksUdpClient_connection *con, int event)
 {
+    DebugObject_Access(&con->client->d_obj);
+
     if (event == BDATAGRAM_EVENT_ERROR) {
         char local_buffer[BADDR_MAX_PRINT_LEN];
         BAddr_Print(&con->local_addr, local_buffer);
-        BLog(BLOG_ERROR, "Failing connection for %s due to a datagram send error", local_buffer);
+        BLog(BLOG_ERROR,
+            "Low-level datagram error %s, removing connection.", local_buffer);
+
+        // Remove the connection. Note that BDatagram requires that we free
+        // the BDatagram after an error is reported.
         connection_free(con);
     }
 }
 
 void send_monitor_handler (struct SocksUdpClient_connection *con)
 {
-    // The connection has passed its idle timeout.  Remove it.
+    DebugObject_Access(&con->client->d_obj);
+    
+    char local_buffer[BADDR_MAX_PRINT_LEN];
+    BAddr_Print(&con->local_addr, local_buffer);
+    BLog(BLOG_INFO,
+        "Removing connection for %s due to inactivity.", local_buffer);
+
+    // The connection has passed its idle timeout. Remove it.
     connection_free(con);
 }
 
-void recv_if_handler_send(struct SocksUdpClient_connection *con, uint8_t *data, int data_len)
+void recv_if_handler_send (
+    struct SocksUdpClient_connection *con, uint8_t *data, int data_len)
 {
-    SocksUdpClient *o = con->client;
     DebugObject_Access(&con->client->d_obj);
+    SocksUdpClient *o = con->client;
     ASSERT(data_len >= 0)
-    ASSERT(data_len <= compute_mtu(o->udp_mtu))
+    ASSERT(data_len <= o->socks_mtu)
     
     // accept packet
     PacketPassInterface_Done(&con->recv_if);
     
     // check header
-    if (data_len < sizeof(struct socks_udp_header)) {
-        BLog(BLOG_ERROR, "missing header");
+    struct socks_udp_header header;
+    if (data_len < sizeof(header)) {
+        BLog(BLOG_ERROR, "Missing SOCKS-UDP header.");
         return;
     }
-    struct socks_udp_header *header = (struct socks_udp_header *)data;
-    uint8_t *addr_data = data + sizeof(struct socks_udp_header);
+    memcpy(&header, data, sizeof(header));
+    data += sizeof(header);
+    data_len -= sizeof(header);
     
     // parse address
     BAddr remote_addr;
-    size_t addr_size;
-    switch (header->atyp) {
+    switch (header.atyp) {
         case SOCKS_ATYP_IPV4: {
+            struct socks_addr_ipv4 addr_ipv4;
+            if (data_len < sizeof(addr_ipv4)) {
+                BLog(BLOG_ERROR, "Missing IPv4 address.");
+                return;
+            }
+            memcpy(&addr_ipv4, data, sizeof(addr_ipv4));
+            data += sizeof(addr_ipv4);
+            data_len -= sizeof(addr_ipv4);
             remote_addr.type = BADDR_TYPE_IPV4;
-            struct socks_addr_ipv4 *addr_ipv4 = (struct socks_addr_ipv4 *)addr_data;
-            remote_addr.ipv4.ip = addr_ipv4->addr;
-            remote_addr.ipv4.port = addr_ipv4->port;
-            addr_size = sizeof(*addr_ipv4);
+            remote_addr.ipv4.ip = addr_ipv4.addr;
+            remote_addr.ipv4.port = addr_ipv4.port;
         } break;
         case SOCKS_ATYP_IPV6: {
+            struct socks_addr_ipv6 addr_ipv6;
+            if (data_len < sizeof(addr_ipv6)) {
+                BLog(BLOG_ERROR, "Missing IPv6 address.");
+                return;
+            }
+            memcpy(&addr_ipv6, data, sizeof(addr_ipv6));
+            data += sizeof(addr_ipv6);
+            data_len -= sizeof(addr_ipv6);
             remote_addr.type = BADDR_TYPE_IPV6;
-            struct socks_addr_ipv6 *addr_ipv6 = (struct socks_addr_ipv6 *)addr_data;
-            memcpy(remote_addr.ipv6.ip, addr_ipv6->addr, sizeof(remote_addr.ipv6.ip));
-            remote_addr.ipv6.port = addr_ipv6->port;
-            addr_size = sizeof(*addr_ipv6);
+            memcpy(remote_addr.ipv6.ip, addr_ipv6.addr, sizeof(remote_addr.ipv6.ip));
+            remote_addr.ipv6.port = addr_ipv6.port;
         } break;
         default: {
             BLog(BLOG_ERROR, "Bad address type");
             return;
-        }
+        } break;
     }
     
-    uint8_t *body_data = addr_data + addr_size;
-    size_t body_len = data_len - (body_data - data);
-    
     // check remaining data
-    if (body_len > o->udp_mtu) {
+    if (data_len > o->udp_mtu) {
         BLog(BLOG_ERROR, "too much data");
         return;
     }
     
     // pass packet to user
     SocksUdpClient *client = con->client;
-    client->handler_received(client->user, con->local_addr, remote_addr, body_data, body_len);
+    client->handler_received(client->user, con->local_addr, remote_addr, data, data_len);
 
+    // Was this connection used for a DNS query?
     if (con->dns_id >= 0) {
-        // This connection has only been used for a single DNS query.
-        int recv_dns_id = get_dns_id(&remote_addr, body_data, body_len);
+        // Get the DNS transaction ID of the response.
+        int recv_dns_id = get_dns_id(&remote_addr, data, data_len);
+
+        // Does the transaction ID matche that of the request?
         if (recv_dns_id == con->dns_id) {
             // We have now forwarded the response, so this connection is no longer needed.
+            char local_buffer[BADDR_MAX_PRINT_LEN];
+            BAddr_Print(&con->local_addr, local_buffer);
+            BLog(BLOG_DEBUG,
+                "Removing connection for %s after the DNS response.", local_buffer);
+
             connection_free(con);
         } else {
-            BLog(BLOG_INFO,
-                 "DNS client port received an unexpected non-DNS packet.  "
-                 "Disabling DNS optimization.");
+            BLog(BLOG_INFO, "DNS client port received an unexpected non-DNS packet, "
+                "disabling DNS optimization.");
+            
             con->dns_id = -1;
         }
     }
 }
 
-struct SocksUdpClient_connection *connection_init(SocksUdpClient *o, BAddr local_addr,
-                                                  BAddr first_remote_addr,
-                                                  const uint8_t *first_data,
-                                                  int first_data_len)
+struct SocksUdpClient_connection * connection_init (
+    SocksUdpClient *o, BAddr local_addr, BAddr first_remote_addr,
+    const uint8_t *first_data, int first_data_len)
 {
-    DebugObject_Access(&o->d_obj);
     ASSERT(o->num_connections <= o->max_connections)
-    ASSERT(!find_connection_by_addr(o, local_addr))
+    ASSERT(!find_connection(o, local_addr))
     
-    char buffer[BADDR_MAX_PRINT_LEN];
-    BAddr_Print(&local_addr, buffer);
-    BLog(BLOG_DEBUG, "Creating new connection for %s", buffer);
+    char local_buffer[BADDR_MAX_PRINT_LEN];
+    BAddr_Print(&local_addr, local_buffer);
+    BLog(BLOG_DEBUG, "Creating connection for %s.", local_buffer);
     
     // allocate structure
-    struct SocksUdpClient_connection *con = (struct SocksUdpClient_connection *)malloc(sizeof(*con));
+    struct SocksUdpClient_connection *con =
+        (struct SocksUdpClient_connection *)BAlloc(sizeof(*con));
     if (!con) {
-        BLog(BLOG_ERROR, "malloc failed");
+        BLog(BLOG_ERROR, "BAlloc connection failed");
         goto fail0;
     }
     
-    // init arguments
+    // set basic things
     con->client = o;
     con->local_addr = local_addr;
+
+    // store first outgoing packet
     con->first_data = BAlloc(first_data_len);
+    if (!con->first_data) {
+        BLog(BLOG_ERROR, "BAlloc first data failed");
+        goto fail1;
+    }
+    memcpy(con->first_data, first_data, first_data_len);
     con->first_data_len = first_data_len;
     con->first_remote_addr = first_remote_addr;
-    memcpy(con->first_data, first_data, first_data_len);
     
+    // Get the DNS transaction ID from the packet, if any.
     con->dns_id = get_dns_id(&first_remote_addr, first_data, first_data_len);
     
     BPendingGroup *pg = BReactor_PendingGroup(o->reactor);
     
-    // init first job, to send the first packet asynchronously.  This has to happen asynchronously
-    // because con->send_writer (a BufferWriter) cannot accept writes until after it is linked with
-    // its PacketBuffer (con->send_buffer), which happens asynchronously.
+    // Init first job, to send the first packet asynchronously. This has to happen
+    // asynchronously because con->send_writer (a BufferWriter) cannot accept writes until
+    // after it is linked with its PacketBuffer (con->send_buffer), which happens
+    // asynchronously.
     BPending_Init(&con->first_job, pg, (BPending_handler)first_job_handler, con);
-    // Add the first job to the pending set.  BPending acts as a LIFO stack, and first_job_handler
-    // needs to run after async actions that occur in PacketBuffer_Init, so we need to put first_job
-    // on the stack first.
+    // Add the first job to the pending set. BPending acts as a LIFO stack, and
+    // first_job_handler needs to run after async actions that occur in PacketBuffer_Init,
+    // so we need to put first_job on the stack first.
     BPending_Set(&con->first_job);
     
     // Create a datagram socket
     if (!BDatagram_Init(&con->socket, con->local_addr.type, o->reactor, con,
-                        (BDatagram_handler)datagram_state_handler)) {
+                        (BDatagram_handler)datagram_state_handler))
+    {
         BLog(BLOG_ERROR, "Failed to create a UDP socket");
-        goto fail1;
-    }
-    
-    // Bind to 127.0.0.1:0 (or [::1]:0).  Port 0 signals the kernel to choose an open port.
-    BAddr socket_addr;
-    socket_addr.type = local_addr.type;
-    if (local_addr.type == BADDR_TYPE_IPV4) {
-        init_localhost4(&socket_addr.ipv4.ip);
-        socket_addr.ipv4.port = 0;
-    } else if (local_addr.type == BADDR_TYPE_IPV6) {
-        init_localhost6(socket_addr.ipv6.ip);
-        socket_addr.ipv6.port = 0;
-    } else {
-        BLog(BLOG_ERROR, "Unknown local address type");
         goto fail2;
     }
+    
+    // Bind to localhost, port 0 signals the kernel to choose an open port.
+    BIPAddr localhost;
+    BIPAddr_InitLocalhost(&localhost, local_addr.type);
+    BAddr socket_addr = BAddr_MakeFromIpaddrAndPort(localhost, 0);
     if (!BDatagram_Bind(&con->socket, socket_addr)) {
         BLog(BLOG_ERROR, "Bind to localhost failed");
-        goto fail2;
+        goto fail3;
     }
     
     // Bind succeeded, so the kernel has found an open port.
@@ -268,7 +308,7 @@ struct SocksUdpClient_connection *connection_init(SocksUdpClient *o, BAddr local
     uint16_t port;
     if (!BDatagram_GetLocalPort(&con->socket, &port)) {
         BLog(BLOG_ERROR, "Failed to get bound port");
-        goto fail2;
+        goto fail3;
     }
     if (socket_addr.type == BADDR_TYPE_IPV4) {
         socket_addr.ipv4.port = port;
@@ -277,60 +317,66 @@ struct SocksUdpClient_connection *connection_init(SocksUdpClient *o, BAddr local
     }
     
     // Initiate connection to socks server
-    if (!BSocksClient_Init(&con->socks, o->server_addr, o->auth_info, o->num_auth_info, socket_addr,
-                           true, (BSocksClient_handler)socks_state_handler, con, o->reactor)) {
+    if (!BSocksClient_Init(&con->socks, o->server_addr, o->auth_info, o->num_auth_info,
+        socket_addr, true, (BSocksClient_handler)socks_state_handler, con, o->reactor))
+    {
         BLog(BLOG_ERROR, "Failed to initialize SOCKS client");
-        goto fail2;
+        goto fail3;
     }
     
-    // Ensure that the UDP handling pipeline can handle queries big enough to include
-    // all data plus the SOCKS-UDP header.
-    int socks_mtu = compute_mtu(o->udp_mtu);
-    
+    // Since we use o->socks_mtu for send and receive pipelines, we can handle maximally
+    // sized packets (o->udp_mtu) including the SOCKS-UDP header.
+
     // Send pipeline: send_writer -> send_buffer -> send_monitor -> send_if -> socket.
-    BDatagram_SendAsync_Init(&con->socket, socks_mtu);
-    PacketPassInterface *send_if = BDatagram_SendAsync_GetIf(&con->socket);
-    PacketPassInactivityMonitor_Init(&con->send_monitor, send_if, o->reactor, o->keepalive_time,
-                                     (PacketPassInactivityMonitor_handler)send_monitor_handler, con);
-    BufferWriter_Init(&con->send_writer, compute_mtu(o->udp_mtu), pg);
+    BDatagram_SendAsync_Init(&con->socket, o->socks_mtu);
+    PacketPassInactivityMonitor_Init(&con->send_monitor,
+        BDatagram_SendAsync_GetIf(&con->socket), o->reactor, o->keepalive_time,
+        (PacketPassInactivityMonitor_handler)send_monitor_handler, con);
+    BufferWriter_Init(&con->send_writer, o->socks_mtu, pg);
     if (!PacketBuffer_Init(&con->send_buffer, BufferWriter_GetOutput(&con->send_writer),
-                           PacketPassInactivityMonitor_GetInput(&con->send_monitor),
-                           SOCKS_UDP_SEND_BUFFER_PACKETS, pg)) {
+        PacketPassInactivityMonitor_GetInput(&con->send_monitor), o->send_buf_size, pg))
+    {
         BLog(BLOG_ERROR, "Send buffer init failed");
-        goto fail3;
+        goto fail4;
     }
     
     // Receive pipeline: socket -> recv_buffer -> recv_if
-    BDatagram_RecvAsync_Init(&con->socket, socks_mtu);
-    PacketPassInterface_Init(&con->recv_if, socks_mtu,
-                            (PacketPassInterface_handler_send)recv_if_handler_send, con, pg);
-    if (!SinglePacketBuffer_Init(&con->recv_buffer, BDatagram_RecvAsync_GetIf(&con->socket),
-                                &con->recv_if, pg)) {
+    BDatagram_RecvAsync_Init(&con->socket, o->socks_mtu);
+    PacketPassInterface_Init(&con->recv_if, o->socks_mtu,
+        (PacketPassInterface_handler_send)recv_if_handler_send, con, pg);
+    if (!SinglePacketBuffer_Init(&con->recv_buffer,
+        BDatagram_RecvAsync_GetIf(&con->socket), &con->recv_if, pg))
+    {
         BLog(BLOG_ERROR, "Receive buffer init failed");
-        goto fail4;
+        goto fail5;
     }
     
-    // insert to connections tree
-    ASSERT_EXECUTE(BAVL_Insert(&o->connections_tree, &con->connections_tree_node, NULL))
+    // Insert to connections tree, it must succeed because of the assert.
+    int inserted = BAVL_Insert(&o->connections_tree, &con->connections_tree_node, NULL);
+    ASSERT(inserted)
+    B_USE(inserted)
     
+    // increment number of connections
     o->num_connections++;
     
     return con;
     
-fail4:
+fail5:
     PacketPassInterface_Free(&con->recv_if);
     BDatagram_RecvAsync_Free(&con->socket);
     PacketBuffer_Free(&con->send_buffer);
-fail3:
+fail4:
     BufferWriter_Free(&con->send_writer);
     PacketPassInactivityMonitor_Free(&con->send_monitor);
     BDatagram_SendAsync_Free(&con->socket);
-fail2:
+    BSocksClient_Free(&con->socks);
+fail3:
     BDatagram_Free(&con->socket);
-fail1:
+fail2:
     BPending_Free(&con->first_job);
     BFree(con->first_data);
-    free(con);
+fail1:
+    BFree(con);
 fail0:
     return NULL;
 }
@@ -338,44 +384,44 @@ fail0:
 void connection_free (struct SocksUdpClient_connection *con)
 {
     SocksUdpClient *o = con->client;
-    DebugObject_Access(&o->d_obj);
     
     // decrement number of connections
+    ASSERT(o->num_connections > 0)
     o->num_connections--;
     
     // remove from connections tree
     BAVL_Remove(&o->connections_tree, &con->connections_tree_node);
     
+    // Free UDP receive pipeline components
+    SinglePacketBuffer_Free(&con->recv_buffer);
+    PacketPassInterface_Free(&con->recv_if);
+    BDatagram_RecvAsync_Free(&con->socket);
+    
     // Free UDP send pipeline components
     PacketBuffer_Free(&con->send_buffer);
     BufferWriter_Free(&con->send_writer);
     PacketPassInactivityMonitor_Free(&con->send_monitor);
     BDatagram_SendAsync_Free(&con->socket);
     
-    // Free UDP receive pipeline components
-    SinglePacketBuffer_Free(&con->recv_buffer);
-    PacketPassInterface_Free(&con->recv_if);
-    BDatagram_RecvAsync_Free(&con->socket);
+    // Free SOCKS client
+    BSocksClient_Free(&con->socks);
     
     // Free UDP socket
     BDatagram_Free(&con->socket);
     
-    // Free SOCKS client
-    BSocksClient_Free(&con->socks);
-    
+    // Free first job
     BPending_Free(&con->first_job);
-    if (con->first_data) {
-      BFree(con->first_data);
-    }
-    // free structure
-    free(con);
+
+    // Free first outgoing packet
+    BFree(con->first_data);
+
+    // Free structure
+    BFree(con);
 }
 
-void connection_send (struct SocksUdpClient_connection *con, BAddr remote_addr,
-                      const uint8_t *data, int data_len)
+void connection_send (struct SocksUdpClient_connection *con,
+    BAddr remote_addr, const uint8_t *data, int data_len)
 {
-    SocksUdpClient *o = con->client;
-    DebugObject_Access(&o->d_obj);
     ASSERT(data_len >= 0)
     ASSERT(data_len <= o->udp_mtu)
     
@@ -383,7 +429,7 @@ void connection_send (struct SocksUdpClient_connection *con, BAddr remote_addr,
         // So far, this connection has only sent a single DNS query.
         int new_dns_id = get_dns_id(&remote_addr, data, data_len);
         if (new_dns_id != con->dns_id) {
-            BLog(BLOG_DEBUG, "Client reused DNS query port.  Disabling DNS optimization.");
+            BLog(BLOG_DEBUG, "Client reused DNS query port. Disabling DNS optimization.");
             con->dns_id = -1;
         }
     }
@@ -402,128 +448,174 @@ void connection_send (struct SocksUdpClient_connection *con, BAddr remote_addr,
             address_size = sizeof(struct socks_addr_ipv6);
         } break;
         default: {
-          BLog(BLOG_ERROR, "bad address type");
-          return;
-        }
+            BLog(BLOG_ERROR, "Bad address type in outgoing packet.");
+            return;
+        } break;
     }
     
-    // Wrap the payload in a UDP SOCKS header.
-    size_t socks_data_len = sizeof(struct socks_udp_header) + address_size + data_len;
-    if (socks_data_len > compute_mtu(o->udp_mtu)) {
-        BLog(BLOG_ERROR, "Packet is too big: %d > %d", socks_data_len, compute_mtu(o->udp_mtu));
-        return;
-    }
-    uint8_t *socks_data;
-    if (!BufferWriter_StartPacket(&con->send_writer, &socks_data)) {
-        BLog(BLOG_ERROR, "Send buffer is full");
+    // Determine total packet size in the buffer.
+    // This cannot exceed o->socks_mtu because data_len is required to not exceed
+    // o->udp_mtu and o->socks_mtu is calculated to accomodate any UDP packet not
+    // not exceeding o->udp_mtu.
+    size_t total_len = sizeof(struct socks_udp_header) + address_size + data_len;
+    ASSERT(total_len <= o->socks_mtu)
+
+    // Get a pointer to write the packet to.
+    uint8_t *out_data_begin;
+    if (!BufferWriter_StartPacket(&con->send_writer, &out_data_begin)) {
+        BLog(BLOG_ERROR, "Send buffer is full.");
         return;
     }
+    uint8_t *out_data = out_data_begin;
+
     // Write header
-    struct socks_udp_header *header = (struct socks_udp_header *)socks_data;
-    header->rsv = 0;
-    header->frag = 0;
-    header->atyp = atyp;
-    uint8_t *addr_data = socks_data + sizeof(struct socks_udp_header);
+    struct socks_udp_header header;
+    header.rsv = 0;
+    header.frag = 0;
+    header.atyp = atyp;
+    memcpy(out_data, &header, sizeof(header));
+    out_data += sizeof(header);
+
+    // Write address
     switch (atyp) {
         case SOCKS_ATYP_IPV4: {
-            struct socks_addr_ipv4 *addr_ipv4 = (struct socks_addr_ipv4 *)addr_data;
-            addr_ipv4->addr = remote_addr.ipv4.ip;
-            addr_ipv4->port = remote_addr.ipv4.port;
+            struct socks_addr_ipv4 addr_ipv4;
+            addr_ipv4.addr = remote_addr.ipv4.ip;
+            addr_ipv4.port = remote_addr.ipv4.port;
+            memcpy(out_data, &addr_ipv4, sizeof(addr_ipv4));
+            out_data += sizeof(addr_ipv4);
         } break;
         case SOCKS_ATYP_IPV6: {
-            struct socks_addr_ipv6 *addr_ipv6 = (struct socks_addr_ipv6 *)addr_data;
-            memcpy(addr_ipv6->addr, remote_addr.ipv6.ip, sizeof(addr_ipv6->addr));
-            addr_ipv6->port = remote_addr.ipv6.port;
+            struct socks_addr_ipv6 addr_ipv6;
+            memcpy(addr_ipv6.addr, remote_addr.ipv6.ip, sizeof(addr_ipv6.addr));
+            addr_ipv6.port = remote_addr.ipv6.port;
+            memcpy(out_data, &addr_ipv6, sizeof(addr_ipv6));
+            out_data += sizeof(addr_ipv6);
         } break;
     }
-    // write packet to buffer
-    memcpy(addr_data + address_size, data, data_len);
-    BufferWriter_EndPacket(&con->send_writer, socks_data_len);
+
+    // Write payload
+    memcpy(out_data, data, data_len);
+    out_data += data_len;
+
+    ASSERT(out_data - out_data_begin == total_len)
+
+    // Submit packet to buffer
+    BufferWriter_EndPacket(&con->send_writer, total_len);
 }
 
-void first_job_handler(struct SocksUdpClient_connection *con)
+void first_job_handler (struct SocksUdpClient_connection *con)
 {
+    DebugObject_Access(&con->client->d_obj);
+    ASSERT(con->first_data)
+    
+    // Send the first packet.
     connection_send(con, con->first_remote_addr, con->first_data, con->first_data_len);
+
+    // Release the first packet buffer.
     BFree(con->first_data);
     con->first_data = NULL;
     con->first_data_len = 0;
 }
 
-int compute_mtu(int udp_mtu)
+int compute_socks_mtu (int udp_mtu)
 {
-    return udp_mtu + sizeof(struct socks_udp_header) + sizeof(struct socks_addr_ipv6);
+    bsize_t bs = bsize_add(
+        bsize_fromint(udp_mtu),
+        bsize_add(
+            bsize_fromsize(sizeof(struct socks_udp_header)),
+            bsize_fromsize(sizeof(struct socks_addr_ipv6))
+        )
+    );
+    int s;
+    return bsize_toint(bs, &s) ? s : -1;
 }
 
-int get_dns_id(BAddr *remote_addr, const uint8_t *data, int data_len)
+// Get the DNS transaction ID, or -1 if this does not look like a DNS packet.
+int get_dns_id (BAddr *remote_addr, const uint8_t *data, int data_len)
 {
-    if (BAddr_GetPort(remote_addr) == htons(DNS_PORT) && data_len >= 2) {
-        return (data[0] << 8) + data[1];
+    if (ntoh16(BAddr_GetPort(remote_addr)) == DnsPort && data_len >= 2) {
+        return (data[0] << 8) | data[1];
+    } else {
+        return -1;
     }
-    return -1;
 }
 
-void SocksUdpClient_Init (SocksUdpClient *o, int udp_mtu, int max_connections, btime_t keepalive_time,
-                          BAddr server_addr, const struct BSocksClient_auth_info *auth_info, size_t num_auth_info,
-                          BReactor *reactor, void *user,
-                          SocksUdpClient_handler_received handler_received)
+int SocksUdpClient_Init (SocksUdpClient *o, int udp_mtu, int max_connections,
+    int send_buf_size, btime_t keepalive_time, BAddr server_addr,
+    const struct BSocksClient_auth_info *auth_info, size_t num_auth_info,
+    BReactor *reactor, void *user, SocksUdpClient_handler_received handler_received)
 {
     ASSERT(udp_mtu >= 0)
-    ASSERT(compute_mtu(udp_mtu) >= 0)
     ASSERT(max_connections > 0)
+    ASSERT(send_buf_size > 0)
     
-    // init arguments
+    // init simple things
     o->server_addr = server_addr;
     o->auth_info = auth_info;
     o->num_auth_info = num_auth_info;
-    o->udp_mtu = udp_mtu;
-    o->max_connections = max_connections;
     o->num_connections = 0;
+    o->max_connections = max_connections;
+    o->send_buf_size = send_buf_size;
+    o->udp_mtu = udp_mtu;
     o->keepalive_time = keepalive_time;
     o->reactor = reactor;
     o->user = user;
     o->handler_received = handler_received;
-    
-    // limit max connections to number of conid's
-    if (o->max_connections > UINT16_MAX + 1) {
-        o->max_connections = UINT16_MAX + 1;
+
+    // calculate full MTU with SOCKS header
+    o->socks_mtu = compute_socks_mtu(udp_mtu);
+    if (o->socks_mtu < 0) {
+        BLog(BLOG_ERROR, "SocksUdpClient_Init: MTU too large.");
+        goto fail0;
     }
     
     // init connections tree
-    BAVL_Init(&o->connections_tree, OFFSET_DIFF(struct SocksUdpClient_connection, local_addr, connections_tree_node), (BAVL_comparator)addr_comparator, NULL);
+    BAVL_Init(&o->connections_tree,
+        OFFSET_DIFF(struct SocksUdpClient_connection, local_addr, connections_tree_node),
+        (BAVL_comparator)addr_comparator, NULL);
     
     DebugObject_Init(&o->d_obj);
+    return 1;
+
+fail0:
+    return 0;
 }
 
 void SocksUdpClient_Free (SocksUdpClient *o)
 {
+    DebugObject_Free(&o->d_obj);
+
     // free connections
     while (!BAVL_IsEmpty(&o->connections_tree)) {
-        struct SocksUdpClient_connection *con = UPPER_OBJECT(BAVL_GetFirst(&o->connections_tree), struct SocksUdpClient_connection, connections_tree_node);
+        BAVLNode *node = BAVL_GetFirst(&o->connections_tree);
+        struct SocksUdpClient_connection *con =
+            UPPER_OBJECT(node, struct SocksUdpClient_connection, connections_tree_node);
         connection_free(con);
     }
-
-    DebugObject_Free(&o->d_obj);
 }
 
-void SocksUdpClient_SubmitPacket (SocksUdpClient *o, BAddr local_addr, BAddr remote_addr, const uint8_t *data, int data_len)
+void SocksUdpClient_SubmitPacket (SocksUdpClient *o,
+    BAddr local_addr, BAddr remote_addr, const uint8_t *data, int data_len)
 {
     DebugObject_Access(&o->d_obj);
     ASSERT(local_addr.type == BADDR_TYPE_IPV4 || local_addr.type == BADDR_TYPE_IPV6)
     ASSERT(remote_addr.type == BADDR_TYPE_IPV4 || remote_addr.type == BADDR_TYPE_IPV6)
     ASSERT(data_len >= 0)
+    ASSERT(data_len <= o->udp_mtu)
     
     // lookup connection
-    struct SocksUdpClient_connection *con = find_connection_by_addr(o, local_addr);
+    struct SocksUdpClient_connection *con = find_connection(o, local_addr);
     if (!con) {
-        if (o->num_connections == o->max_connections) {
+        if (o->num_connections >= o->max_connections) {
             // Drop the packet.
-            BLog(BLOG_ERROR, "Dropping UDP packet, reached max number of connections.");
+            BLog(BLOG_WARNING, "Dropping UDP packet, reached max number of connections.");
             return;
         }
         // create new connection and enqueue the packet
         connection_init(o, local_addr, remote_addr, data, data_len);
     } else {
-      // send packet
-      connection_send(con, remote_addr, data, data_len);
+        // send packet
+        connection_send(con, remote_addr, data, data_len);
     }
 }

+ 40 - 25
socks_udp_client/SocksUdpClient.h

@@ -1,5 +1,6 @@
 /*
  * Copyright (C) 2018 Jigsaw Operations LLC
+ * Copyright (C) 2019 Ambroz Bizjak (modifications)
  * 
  * Redistribution and use in source and binary forms, with or without
  * modification, are permitted provided that the following conditions are met:
@@ -27,6 +28,7 @@
 #ifndef BADVPN_SOCKS_UDP_CLIENT_SOCKSUDPCLIENT_H
 #define BADVPN_SOCKS_UDP_CLIENT_SOCKSUDPCLIENT_H
 
+#include <stddef.h>
 #include <stdint.h>
 
 #include <base/BPending.h>
@@ -34,9 +36,8 @@
 #include <flow/BufferWriter.h>
 #include <flow/PacketBuffer.h>
 #include <flow/SinglePacketBuffer.h>
+#include <flow/PacketPassInterface.h>
 #include <flowextra/PacketPassInactivityMonitor.h>
-#include <misc/debug.h>
-#include <misc/socks_proto.h>
 #include <socksclient/BSocksClient.h>
 #include <structure/BAVL.h>
 #include <system/BAddr.h>
@@ -44,12 +45,8 @@
 #include <system/BReactor.h>
 #include <system/BTime.h>
 
-// This sets the number of packets to accept while waiting for SOCKS server to authenticate and
-// connect.  A slow or far-away SOCKS server could require 300 ms to connect, and a chatty
-// client (e.g. STUN) could send a packet every 20 ms, so a limit of 16 seems reasonable.
-#define SOCKS_UDP_SEND_BUFFER_PACKETS 16
-
-typedef void (*SocksUdpClient_handler_received) (void *user, BAddr local_addr, BAddr remote_addr, const uint8_t *data, int data_len);
+typedef void (*SocksUdpClient_handler_received) (
+    void *user, BAddr local_addr, BAddr remote_addr, const uint8_t *data, int data_len);
 
 typedef struct {
     BAddr server_addr;
@@ -57,7 +54,9 @@ typedef struct {
     size_t num_auth_info;
     int num_connections;
     int max_connections;
+    int send_buf_size;
     int udp_mtu;
+    int socks_mtu;
     btime_t keepalive_time;
     BReactor *reactor;
     void *user;
@@ -77,8 +76,8 @@ struct SocksUdpClient_connection {
     BDatagram socket;
     PacketPassInterface recv_if;
     SinglePacketBuffer recv_buffer;
-    // The first_* members represent the initial packet, which has to be stored so it can wait for
-    // send_writer to become ready.
+    // The first_* members represent the initial packet, which has to be stored so it can
+    // wait for send_writer to become ready.
     uint8_t *first_data;
     int first_data_len;
     BAddr first_remote_addr;
@@ -92,43 +91,59 @@ struct SocksUdpClient_connection {
 
 /**
  * Initializes the SOCKS5-UDP client object.
- * This function does not perform network access, so it will always succeed if the arguments
- * are valid.
+ * 
+ * This function only initialzies the object and does not perform network access.
  * 
  * Currently, this function only supports connection to a SOCKS5 server that is routable from
- * localhost (i.e. running on the local machine).  It may be possible to add support for remote
- * servers, but SOCKS5 does not support UDP if there is a NAT or firewall between the client
- * and the proxy.
+ * localhost (i.e. running on the local machine). It may be possible to add support for
+ * remote servers, but SOCKS5 does not support UDP if there is a NAT or firewall between the
+ * client and the proxy.
  * 
  * @param o the object
  * @param udp_mtu the maximum size of packets that will be sent through the tunnel
  * @param max_connections how many local ports to track before dropping packets
+ * @param send_buf_size maximum number of buffered outgoing packets per connection
  * @param keepalive_time how long to track an idle local port before forgetting it
  * @param server_addr SOCKS5 server address.  MUST BE ON LOCALHOST.
+ * @param auth_info List of authentication info for BSocksClient. The pointer must remain
+ *        valid while this object exists, the data is not copied.
+ * @param num_auth_info Number of the above.
  * @param reactor reactor we live in
  * @param user value passed to handler
  * @param handler_received handler for incoming UDP packets
+ * @return 1 on success, 0 on failure
+ */
+int SocksUdpClient_Init (SocksUdpClient *o, int udp_mtu, int max_connections,
+    int send_buf_size, btime_t keepalive_time, BAddr server_addr,
+    const struct BSocksClient_auth_info *auth_info, size_t num_auth_info,
+    BReactor *reactor, void *user, SocksUdpClient_handler_received handler_received);
+
+/**
+ * Frees the SOCKS5-UDP client object.
+ *
+ * @param o the object
  */
-void SocksUdpClient_Init (SocksUdpClient *o, int udp_mtu, int max_connections, btime_t keepalive_time,
-                          BAddr server_addr, const struct BSocksClient_auth_info *auth_info, size_t num_auth_info,
-                          BReactor *reactor, void *user, SocksUdpClient_handler_received handler_received);
 void SocksUdpClient_Free (SocksUdpClient *o);
 
 /**
  * Submit a packet to be sent through the proxy.
  *
  * This will reuse an existing connection for packets from local_addr, or create one if
- * there is none.  If the number of live connections exceeds max_connections, or if the number of
- * buffered packets from this port exceeds a limit, packets will be dropped silently.
+ * there is none. If the number of live connections exceeds max_connections, or if the
+ * number of buffered packets from this port exceeds a limit, packets will be dropped
+ * silently.
  * 
- * As a resource optimization, if a connection has only been used to send one DNS query, then
- * the connection will be closed and freed once the reply is received.
+ * As a resource optimization, if a connection has only been used to send one DNS query,
+ * then the connection will be closed and freed once the reply is received.
  * 
  * @param o the object
- * @param local_addr the UDP packet's source address, and the expected destination for replies
+ * @param local_addr the UDP packet's source address, and the expected destination for
+ *        replies
  * @param remote_addr the destination of the packet after it exits the proxy
- * @param data the packet contents.  Caller retains ownership.
+ * @param data the packet contents. Caller retains ownership.
+ * @param data_len number of bytes in the data
  */
-void SocksUdpClient_SubmitPacket (SocksUdpClient *o, BAddr local_addr, BAddr remote_addr, const uint8_t *data, int data_len);
+void SocksUdpClient_SubmitPacket (SocksUdpClient *o,
+    BAddr local_addr, BAddr remote_addr, const uint8_t *data, int data_len);
 
 #endif

+ 20 - 6
socksclient/BSocksClient.c

@@ -375,14 +375,16 @@ void recv_handler_done (BSocksClient *o, int data_len)
             void *addr_buffer = o->buffer + sizeof(struct socks_reply_header);
             switch (o->bind_addr.type) {
                 case BADDR_TYPE_IPV4: {
-                    struct socks_addr_ipv4 *ip4 = addr_buffer;
-                    o->bind_addr.ipv4.ip = ip4->addr;
-                    o->bind_addr.ipv4.port = ip4->port;
+                    struct socks_addr_ipv4 ip4;
+                    memcpy(&ip4, addr_buffer, sizeof(ip4));
+                    o->bind_addr.ipv4.ip = ip4.addr;
+                    o->bind_addr.ipv4.port = ip4.port;
                 } break;
                 case BADDR_TYPE_IPV6: {
-                    struct socks_addr_ipv6 *ip6 = addr_buffer;
-                    memcpy(o->bind_addr.ipv6.ip, ip6->addr, sizeof(ip6->addr));
-                    o->bind_addr.ipv6.port = ip6->port;
+                    struct socks_addr_ipv6 ip6;
+                    memcpy(&ip6, addr_buffer, sizeof(ip6));
+                    memcpy(o->bind_addr.ipv6.ip, ip6.addr, sizeof(ip6.addr));
+                    o->bind_addr.ipv6.port = ip6.port;
                 } break;
                 default: ASSERT(0);
             }
@@ -395,6 +397,8 @@ void recv_handler_done (BSocksClient *o, int data_len)
             free_control_io(o);
             
             // init up I/O
+            // Initializing this is not needed for UDP ASSOCIATE but it doesn't hurt.
+            // We anyway don't allow the user to use these interfaces in that case.
             init_up_io(o);
             
             // set state
@@ -614,9 +618,18 @@ void BSocksClient_Free (BSocksClient *o)
     }
 }
 
+BAddr BSocksClient_GetBindAddr (BSocksClient *o)
+{
+    ASSERT(o->state == STATE_UP)
+    DebugObject_Access(&o->d_obj);
+
+    return o->bind_addr;
+}
+
 StreamPassInterface * BSocksClient_GetSendInterface (BSocksClient *o)
 {
     ASSERT(o->state == STATE_UP)
+    ASSERT(!o->udp)
     DebugObject_Access(&o->d_obj);
     
     return BConnection_SendAsync_GetIf(&o->con);
@@ -625,6 +638,7 @@ StreamPassInterface * BSocksClient_GetSendInterface (BSocksClient *o)
 StreamRecvInterface * BSocksClient_GetRecvInterface (BSocksClient *o)
 {
     ASSERT(o->state == STATE_UP)
+    ASSERT(!o->udp)
     DebugObject_Access(&o->d_obj);
     
     return BConnection_RecvAsync_GetIf(&o->con);

+ 22 - 2
socksclient/BSocksClient.h

@@ -34,6 +34,7 @@
 #ifndef BADVPN_SOCKS_BSOCKSCLIENT_H
 #define BADVPN_SOCKS_BSOCKSCLIENT_H
 
+#include <stddef.h>
 #include <stdint.h>
 #include <stdbool.h>
 
@@ -112,7 +113,14 @@ struct BSocksClient_auth_info BSocksClient_auth_password (const char *username,
  * 
  * @param o the object
  * @param server_addr SOCKS5 server address
+ * @param auth_info List of supported authentication methods and associated parameters.
+ *        Initialize these using functions such as BSocksClient_auth_none() and
+ *        BSocksClient_auth_password(). The pointer must remain valid while this object
+ *        exists, the data is not copied.
+ * @param num_auth_info Number of the above. There should be at least one, otherwise it
+ *        certainly won't work.
  * @param dest_addr remote address
+ * @param udp whether to do UDP ASSOCIATE instead of CONNECT
  * @param handler handler for up and error events
  * @param user value passed to handler
  * @param reactor reactor we live in
@@ -129,9 +137,20 @@ int BSocksClient_Init (BSocksClient *o,
  */
 void BSocksClient_Free (BSocksClient *o);
 
+/**
+ * Return the bind address that the SOCKS server reported.
+ * The object must be in up state. The bind address is needed for UDP ASSOCIATE
+ * because it is the address that the client should send UDP packets to.
+ *
+ * @param o the object
+ * @return The bind address, of type BADDR_TYPE_IPV4 or BADDR_TYPE_IPV6.
+ */
+BAddr BSocksClient_GetBindAddr (BSocksClient *o);
+
 /**
  * Returns the send interface.
- * The object must be in up state.
+ * The object must be in up state. Additionally this must not be called if the
+ * object was initialized in UDP ASSOCIATE mode.
  * 
  * @param o the object
  * @return send interface
@@ -140,7 +159,8 @@ StreamPassInterface * BSocksClient_GetSendInterface (BSocksClient *o);
 
 /**
  * Returns the receive interface.
- * The object must be in up state.
+ * The object must be in up state. Additionally this must not be called if the
+ * object was initialized in UDP ASSOCIATE mode.
  * 
  * @param o the object
  * @return receive interface

+ 18 - 0
system/BAddr.h

@@ -97,6 +97,8 @@ static int BIPAddr_Resolve (BIPAddr *addr, char *str, int noresolve) WARN_UNUSED
 
 static int BIPAddr_Compare (BIPAddr *addr1, BIPAddr *addr2);
 
+static void BIPAddr_InitLocalhost (BIPAddr *addr, int addr_type);
+
 /**
  * Converts an IP address to human readable form.
  *
@@ -805,4 +807,20 @@ int BAddr_CompareOrder (BAddr *addr1, BAddr *addr2)
     }
 }
 
+void BIPAddr_InitLocalhost (BIPAddr *addr, int addr_type)
+{
+    if (addr_type == BADDR_TYPE_IPV4) {
+        addr->type = addr_type;
+        addr->ipv4 = hton32(0x7f000001);
+    }
+    else if (addr_type == BADDR_TYPE_IPV6) {
+        addr->type = addr_type;
+        memset(addr->ipv6, 0, 16);
+        addr->ipv6[15] = 1;
+    }
+    else {
+        addr->type = BADDR_TYPE_NONE;
+    }
+}
+
 #endif

+ 16 - 2
system/BDatagram.h

@@ -124,8 +124,22 @@ void BDatagram_SetSendAddrs (BDatagram *o, BAddr remote_addr, BIPAddr local_addr
 int BDatagram_GetLastReceiveAddrs (BDatagram *o, BAddr *remote_addr, BIPAddr *local_addr);
 
 /**
- * Returns the bound port.
- * Fails if and only if a port is not yet bound.
+ * Determines the local address.
+ * 
+ * This calls getsockname() to determine the local address and returns the result as
+ * BAddr. This function fails if the address cannot be determined or translated to
+ * BAddr (it never succeeds but returns a BADDR_TYPE_NONE address).
+ *
+ * @param o the object
+ * @param local_addr returns the local bound address.
+ * @return 1 on success, 0 on failure
+ */
+int BDatagram_GetLocalAddr (BDatagram *o, BAddr *local_addr);
+
+/**
+ * Returns the local port.
+ *
+ * This is a convenience function implemented based on BDatagram_GetLocalAddr.
  * 
  * @param o the object
  * @param local_port returns the local bound port.

+ 52 - 0
system/BDatagram_common.c

@@ -0,0 +1,52 @@
+/**
+ * @file BDatagram_unix.c
+ * @author Ambroz Bizjak <ambrop7@gmail.com>
+ * 
+ * @section LICENSE
+ * 
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ * 1. Redistributions of source code must retain the above copyright
+ *    notice, this list of conditions and the following disclaimer.
+ * 2. Redistributions in binary form must reproduce the above copyright
+ *    notice, this list of conditions and the following disclaimer in the
+ *    documentation and/or other materials provided with the distribution.
+ * 3. Neither the name of the author nor the
+ *    names of its contributors may be used to endorse or promote products
+ *    derived from this software without specific prior written permission.
+ * 
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+ * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+ * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+ * DISCLAIMED. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
+ * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+ * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+ * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
+ * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ */
+
+#include <base/BLog.h>
+#include <system/BAddr.h>
+
+#include "BDatagram.h"
+
+#include <generated/blog_channel_BDatagram.h>
+
+int BDatagram_GetLocalPort (BDatagram *o, uint16_t *local_port)
+{
+    BAddr addr;
+    if (!BDatagram_GetLocalAddr(o, &addr)) {
+        return 0;
+    }
+
+    if (addr.type != BADDR_TYPE_IPV4 && addr.type != BADDR_TYPE_IPV6) {
+        BLog(BLOG_ERROR,
+            "BDatagram_GetLocalPort: Port not defined for this address type.");
+        return 0;
+    }
+
+    *local_port = BAddr_GetPort(&addr);
+    return 1;
+}

+ 12 - 12
system/BDatagram_unix.c

@@ -734,28 +734,28 @@ int BDatagram_GetLastReceiveAddrs (BDatagram *o, BAddr *remote_addr, BIPAddr *lo
     return 1;
 }
 
-int BDatagram_GetLocalPort (BDatagram *o, uint16_t *local_port)
+int BDatagram_GetLocalAddr (BDatagram *o, BAddr *local_addr)
 {
     DebugObject_Access(&o->d_obj);
     
     struct sys_addr sysaddr;
-    BAddr addr;
     sysaddr.len = sizeof(sysaddr.addr);
     if (getsockname(o->fd, &sysaddr.addr.generic, &sysaddr.len) != 0) {
-        BLog(BLOG_ERROR, "getsockname failed");
+        BLog(BLOG_ERROR, "BDatagram_GetLocalAddr: getsockname failed");
         return 0;
     }
+
+    BAddr addr;
     addr_sys_to_socket(&addr, sysaddr);
-    if (addr.type == BADDR_TYPE_IPV4) {
-        *local_port = addr.ipv4.port;
-        return 1;
-    }
-    if (addr.type == BADDR_TYPE_IPV6) {
-        *local_port = addr.ipv6.port;
-        return 1;
+
+    if (addr.type == BADDR_TYPE_NONE) {
+        BLog(BLOG_ERROR, "BDatagram_GetLocalAddr: Unsupported address family "
+            "from getsockname: %d", (int)sysaddr.addr.generic.sa_family);
+        return 0;
     }
-    BLog(BLOG_ERROR, "Unknown address type from getsockname: %d", addr.type);
-    return 0;
+
+    *local_addr = addr;
+    return 1;
 }
 
 int BDatagram_GetFd (BDatagram *o)

+ 12 - 12
system/BDatagram_win.c

@@ -635,29 +635,29 @@ int BDatagram_GetLastReceiveAddrs (BDatagram *o, BAddr *remote_addr, BIPAddr *lo
     return 1;
 }
 
-int BDatagram_GetLocalPort (BDatagram *o, uint16_t *local_port)
+int BDatagram_GetLocalAddr (BDatagram *o, BAddr *local_addr)
 {
     DebugObject_Access(&o->d_obj);
     
     struct BDatagram_sys_addr sysaddr;
-    BAddr addr;
     socklen_t addr_size = sizeof(sysaddr.addr.generic);
     if (getsockname(o->sock, &sysaddr.addr.generic, &addr_size) != 0) {
-        BLog(BLOG_ERROR, "getsockname failed");
+        BLog(BLOG_ERROR, "BDatagram_GetLocalAddr: getsockname failed");
         return 0;
     }
+    sysaddr.len = addr_size;
     
+    BAddr addr;
     addr_sys_to_socket(&addr, sysaddr);
-    if (addr.type == BADDR_TYPE_IPV4) {
-        *local_port = addr.ipv4.port;
-        return 1;
-    }
-    if (addr.type == BADDR_TYPE_IPV6) {
-        *local_port = addr.ipv6.port;
-        return 1;
+    
+    if (addr.type == BADDR_TYPE_NONE) {
+        BLog(BLOG_ERROR, "BDatagram_GetLocalAddr: Unsupported address family "
+            "from getsockname: %d", int(sysaddr.addr.generic.sa_family));
+        return 0;
     }
-    BLog(BLOG_ERROR, "Unknown address type from getsockname: %d", addr.type);
-    return 0;
+
+    *local_addr = addr;
+    return 1;
 }
 
 int BDatagram_SetReuseAddr (BDatagram *o, int reuse)

+ 1 - 0
system/CMakeLists.txt

@@ -6,6 +6,7 @@ if (NOT EMSCRIPTEN)
         BSignal.c
         BNetwork.c
         BConnection_common.c
+        BDatagram_common.c
     )
 
     if (WIN32)

+ 4 - 1
tun2socks/SocksUdpGwClient.c

@@ -62,7 +62,10 @@ static void try_connect (SocksUdpGwClient *o)
     ASSERT(!BTimer_IsRunning(&o->reconnect_timer))
     
     // init SOCKS client
-    if (!BSocksClient_Init(&o->socks_client, o->socks_server_addr, o->auth_info, o->num_auth_info, o->remote_udpgw_addr, false, (BSocksClient_handler)socks_client_handler, o, o->reactor)) {
+    if (!BSocksClient_Init(&o->socks_client, o->socks_server_addr,
+        o->auth_info, o->num_auth_info, o->remote_udpgw_addr, /*udp=*/false,
+        (BSocksClient_handler)socks_client_handler, o, o->reactor))
+    {
         BLog(BLOG_ERROR, "BSocksClient_Init failed");
         goto fail0;
     }

+ 37 - 20
tun2socks/tun2socks.c

@@ -181,6 +181,10 @@ uint8_t *device_write_buf;
 SinglePacketBuffer device_read_buffer;
 PacketPassInterface device_read_interface;
 
+// UDP support mode
+enum UdpMode {UdpModeNone, UdpModeUdpgw, UdpModeSocks};
+enum UdpMode udp_mode;
+
 // udpgw client
 SocksUdpGwClient udpgw_client;
 int udp_mtu;
@@ -249,7 +253,6 @@ static int client_socks_recv_send_out (struct tcp_client *client);
 static err_t client_sent_func (void *arg, struct tcp_pcb *tpcb, u16_t len);
 static void udp_send_packet_to_device (void *unused, BAddr local_addr, BAddr remote_addr, const uint8_t *data, int data_len);
 
-
 int main (int argc, char **argv)
 {
     if (argc <= 0) {
@@ -357,7 +360,8 @@ int main (int argc, char **argv)
         goto fail4;
     }
     
-    // compute maximum UDP payload size we need to pass through udpgw
+    // Compute the largest possible UDP payload that we can receive from or send to the
+    // TUN device.
     udp_mtu = BTap_GetMTU(&device) - (int)(sizeof(struct ipv4_header) + sizeof(struct udp_header));
     if (options.netif_ip6addr) {
         int udp_ip6_mtu = BTap_GetMTU(&device) - (int)(sizeof(struct ipv6_header) + sizeof(struct udp_header));
@@ -370,6 +374,8 @@ int main (int argc, char **argv)
     }
 
     if (options.udpgw_remote_server_addr) {
+        udp_mode = UdpModeUdpgw;
+
         // make sure our UDP payloads aren't too large for udpgw
         int udpgw_mtu = udpgw_compute_mtu(udp_mtu);
         if (udpgw_mtu < 0 || udpgw_mtu > PACKETPROTO_MAXPAYLOAD) {
@@ -378,17 +384,23 @@ int main (int argc, char **argv)
         }
         
         // init udpgw client
-        if (!SocksUdpGwClient_Init(&udpgw_client, udp_mtu, DEFAULT_UDPGW_MAX_CONNECTIONS, options.udpgw_connection_buffer_size, UDPGW_KEEPALIVE_TIME,
-                                   socks_server_addr, socks_auth_info, socks_num_auth_info,
-                                   udpgw_remote_server_addr, UDPGW_RECONNECT_TIME, &ss, NULL, udp_send_packet_to_device
-        )) {
+        if (!SocksUdpGwClient_Init(&udpgw_client, udp_mtu, DEFAULT_UDPGW_MAX_CONNECTIONS,
+            options.udpgw_connection_buffer_size, UDPGW_KEEPALIVE_TIME, socks_server_addr,
+            socks_auth_info, socks_num_auth_info, udpgw_remote_server_addr,
+            UDPGW_RECONNECT_TIME, &ss, NULL, udp_send_packet_to_device))
+        {
             BLog(BLOG_ERROR, "SocksUdpGwClient_Init failed");
             goto fail4a;
         }
     } else if (options.socks5_udp) {
+        udp_mode = UdpModeSocks;
+
+        // init SOCKS UDP client
         SocksUdpClient_Init(&socks_udp_client, udp_mtu, DEFAULT_UDPGW_MAX_CONNECTIONS,
-                            UDPGW_KEEPALIVE_TIME, socks_server_addr, socks_auth_info,
-                            socks_num_auth_info, &ss, NULL, udp_send_packet_to_device);
+            SOCKS_UDP_SEND_BUFFER_PACKETS, UDPGW_KEEPALIVE_TIME, socks_server_addr,
+            socks_auth_info, socks_num_auth_info, &ss, NULL, udp_send_packet_to_device);
+    } else {
+        udp_mode = UdpModeNone;
     }
     
     // init lwip init job
@@ -448,9 +460,9 @@ int main (int argc, char **argv)
     BFree(device_write_buf);
 fail5:
     BPending_Free(&lwip_init_job);
-    if (options.udpgw_remote_server_addr) {
+    if (udp_mode == UdpModeUdpgw) {
         SocksUdpGwClient_Free(&udpgw_client);
-    } else if (options.socks5_udp) {
+    } else if (udp_mode == UdpModeSocks) {
         SocksUdpClient_Free(&socks_udp_client);
     }
 fail4a:
@@ -1066,8 +1078,8 @@ int process_device_udp_packet (uint8_t *data, int data_len)
 {
     ASSERT(data_len >= 0)
     
-    // do nothing if we don't have udpgw
-    if (!options.udpgw_remote_server_addr && !options.socks5_udp) {
+    // do nothing if we don't use udpgw or SOCKS UDP
+    if (udp_mode == UdpModeNone) {
         goto fail;
     }
     
@@ -1172,11 +1184,11 @@ int process_device_udp_packet (uint8_t *data, int data_len)
         goto fail;
     }
     
-    if (options.udpgw_remote_server_addr) {
-        // submit packet to udpgw
+    // submit packet to udpgw or SOCKS UDP
+    if (udp_mode == UdpModeUdpgw) {
         SocksUdpGwClient_SubmitPacket(&udpgw_client, local_addr, remote_addr,
                                       is_dns, data, data_len);
-    } else if (options.socks5_udp) {
+    } else if (udp_mode == UdpModeSocks) {
         SocksUdpClient_SubmitPacket(&socks_udp_client, local_addr, remote_addr, data, data_len);
     }
     
@@ -1326,8 +1338,10 @@ err_t listener_accept_func (void *arg, struct tcp_pcb *newpcb, err_t err)
     }
     
     // init SOCKS
-    if (!BSocksClient_Init(&client->socks_client, socks_server_addr, socks_auth_info, socks_num_auth_info,
-                           addr, false, (BSocksClient_handler)client_socks_handler, client, &ss)) {
+    if (!BSocksClient_Init(&client->socks_client,
+        socks_server_addr, socks_auth_info, socks_num_auth_info, addr, /*udp=*/false,
+        (BSocksClient_handler)client_socks_handler, client, &ss))
+    {
         BLog(BLOG_ERROR, "listener accept: BSocksClient_Init failed");
         goto fail1;
     }
@@ -1840,15 +1854,18 @@ out:
 
 void udp_send_packet_to_device (void *unused, BAddr local_addr, BAddr remote_addr, const uint8_t *data, int data_len)
 {
+    ASSERT(udp_mode != UdpModeNone)
     ASSERT(local_addr.type == BADDR_TYPE_IPV4 || local_addr.type == BADDR_TYPE_IPV6)
     ASSERT(local_addr.type == remote_addr.type)
     ASSERT(data_len >= 0)
+
+    char const *source_name = (udp_mode == UdpModeUdpgw) ? "udpgw" : "SOCKS UDP";
     
     int packet_length = 0;
     
     switch (local_addr.type) {
         case BADDR_TYPE_IPV4: {
-            BLog(BLOG_INFO, "UDP: from udpgw %d bytes", data_len);
+            BLog(BLOG_INFO, "UDP: from %s %d bytes", source_name, data_len);
             
             if (data_len > UINT16_MAX - (sizeof(struct ipv4_header) + sizeof(struct udp_header)) ||
                 data_len > BTap_GetMTU(&device) - (int)(sizeof(struct ipv4_header) + sizeof(struct udp_header))
@@ -1887,10 +1904,10 @@ void udp_send_packet_to_device (void *unused, BAddr local_addr, BAddr remote_add
         } break;
         
         case BADDR_TYPE_IPV6: {
-            BLog(BLOG_INFO, "UDP/IPv6: from udpgw %d bytes", data_len);
+            BLog(BLOG_INFO, "UDP/IPv6: from %s %d bytes", source_name, data_len);
             
             if (!options.netif_ip6addr) {
-                BLog(BLOG_ERROR, "got IPv6 packet from udpgw but IPv6 is disabled");
+                BLog(BLOG_ERROR, "got IPv6 packet from %s but IPv6 is disabled", source_name);
                 return;
             }
             

+ 6 - 0
tun2socks/tun2socks.h

@@ -44,3 +44,9 @@
 
 // option to override the destination addresses to give the SOCKS server
 //#define OVERRIDE_DEST_ADDR "10.111.0.2:2000"
+
+// Max number of buffered outgoing UDP packets for SOCKS5-UDP. It should be large
+// enough to prevent packet loss while the SOCKS UDP association is being set up. A slow
+// or far-away SOCKS server could require 300 ms to connect, and a chatty client (e.g.
+// STUN) could send a packet every 20 ms, so a default limit of 16 seems reasonable.
+#define SOCKS_UDP_SEND_BUFFER_PACKETS 16