Просмотр исходного кода

BSocket: interpret setsockopt return value correctly, fallback to WSASendTo/WSARecvFrom if WSASendMsg/WSARecvMsg isn't available

ambrop7 15 лет назад
Родитель
Сommit
d4fbebc2ff
1 измененных файлов с 147 добавлено и 110 удалено
  1. 147 110
      system/BSocket.c

+ 147 - 110
system/BSocket.c

@@ -456,12 +456,12 @@ static int setup_pktinfo (int socket, int type, int domain)
     if (type == BSOCKET_TYPE_DGRAM) {
         switch (domain) {
             case BADDR_TYPE_IPV4:
-                if (set_pktinfo(socket) == 0) {
+                if (set_pktinfo(socket)) {
                     return 0;
                 }
                 break;
             case BADDR_TYPE_IPV6:
-                if (set_pktinfo6(socket) == 0) {
+                if (set_pktinfo6(socket)) {
                     return 0;
                 }
                 break;
@@ -471,27 +471,29 @@ static int setup_pktinfo (int socket, int type, int domain)
     return 1;
 }
 
-static int setup_winsock_exts (int socket, BSocket *bs)
+static void setup_winsock_exts (int socket, int type, BSocket *bs)
 {
     #ifdef BADVPN_USE_WINAPI
     
-    DWORD out_bytes;
-    
-    // obtain WSASendMsg
-    GUID guid_send = WSAID_WSASENDMSG;
-    if (WSAIoctl(socket, SIO_GET_EXTENSION_FUNCTION_POINTER, &guid_send, sizeof(guid_send), &bs->WSASendMsg, sizeof(bs->WSASendMsg), &out_bytes, NULL, NULL) != 0) {
-        return 0;
-    }
-    
-    // obtain WSARecvMsg
-    GUID guid_recv = WSAID_WSARECVMSG;
-    if (WSAIoctl(socket, SIO_GET_EXTENSION_FUNCTION_POINTER, &guid_recv, sizeof(guid_recv), &bs->WSARecvMsg, sizeof(bs->WSARecvMsg), &out_bytes, NULL, NULL) != 0) {
-        return 0;
+    if (type == BSOCKET_TYPE_DGRAM) {
+        DWORD out_bytes;
+        
+        // obtain WSARecvMsg
+        GUID guid_recv = WSAID_WSARECVMSG;
+        if (WSAIoctl(socket, SIO_GET_EXTENSION_FUNCTION_POINTER, &guid_recv, sizeof(guid_recv), &bs->WSARecvMsg, sizeof(bs->WSARecvMsg), &out_bytes, NULL, NULL) != 0) {
+            DEBUG("WSAIoctl(SIO_GET_EXTENSION_FUNCTION_POINTER WSAID_WSARECVMSG) failed (%u)", WSAGetLastError());
+            bs->WSARecvMsg = NULL;
+        }
+        
+        // obtain WSASendMsg
+        GUID guid_send = WSAID_WSASENDMSG;
+        if (WSAIoctl(socket, SIO_GET_EXTENSION_FUNCTION_POINTER, &guid_send, sizeof(guid_send), &bs->WSASendMsg, sizeof(bs->WSASendMsg), &out_bytes, NULL, NULL) != 0) {
+            DEBUG("WSAIoctl(SIO_GET_EXTENSION_FUNCTION_POINTER WSAID_WSASENDMSG) failed (%u)", WSAGetLastError());
+            bs->WSASendMsg = NULL;
+        }
     }
     
     #endif
-    
-    return 1;
 }
 
 int BSocket_GlobalInit (void)
@@ -576,10 +578,7 @@ int BSocket_Init (BSocket *bs, BReactor *bsys, int domain, int type)
     }
     
     // setup winsock exts
-    if (!setup_winsock_exts(fd, bs)) {
-        DEBUG("setup_winsock_exts failed");
-        goto fail1;
-    }
+    setup_winsock_exts(fd, type, bs);
     
     DEAD_INIT(bs->dead);
     bs->bsys = bsys;
@@ -905,10 +904,7 @@ int BSocket_Accept (BSocket *bs, BSocket *newsock, BAddr *addr)
         }
         
         // setup winsock exts
-        if (!setup_winsock_exts(fd, newsock)) {
-            DEBUG("setup_winsock_exts failed");
-            goto fail0;
-        }
+        setup_winsock_exts(fd, bs->type, newsock);
         
         DEAD_INIT(newsock->dead);
         newsock->bsys = bs->bsys;
@@ -948,6 +944,7 @@ fail0:
 int BSocket_Send (BSocket *bs, uint8_t *data, int len)
 {
     ASSERT(len >= 0)
+    ASSERT(bs->type == BSOCKET_TYPE_STREAM)
     
     #ifdef BADVPN_USE_WINAPI
     int flags = 0;
@@ -986,6 +983,7 @@ int BSocket_Send (BSocket *bs, uint8_t *data, int len)
 int BSocket_Recv (BSocket *bs, uint8_t *data, int len)
 {
     ASSERT(len >= 0)
+    ASSERT(bs->type == BSOCKET_TYPE_STREAM)
     
     if (limit_recv(bs)) {
         bs->error = BSOCKET_ERROR_LATER;
@@ -1026,6 +1024,7 @@ int BSocket_SendToFrom (BSocket *bs, uint8_t *data, int len, BAddr *addr, BIPAdd
     ASSERT(addr)
     ASSERT(!BAddr_IsInvalid(addr))
     ASSERT(local_addr)
+    ASSERT(bs->type == BSOCKET_TYPE_DGRAM)
     
     struct sys_addr remote_sysaddr;
     addr_socket_to_sys(&remote_sysaddr, addr);
@@ -1036,62 +1035,77 @@ int BSocket_SendToFrom (BSocket *bs, uint8_t *data, int len, BAddr *addr, BIPAdd
     buf.len = len;
     buf.buf = data;
     
-    union {
-        char in[WSA_CMSG_SPACE(sizeof(struct in_pktinfo))];
-        char in6[WSA_CMSG_SPACE(sizeof(struct in6_pktinfo))];
-    } cdata;
-    
-    WSAMSG msg;
-    memset(&msg, 0, sizeof(msg));
-    msg.name = &remote_sysaddr.addr.generic;
-    msg.namelen = remote_sysaddr.len;
-    msg.lpBuffers = &buf;
-    msg.dwBufferCount = 1;
-    msg.Control.buf = (char *)&cdata;
-    msg.Control.len = sizeof(cdata);
-    
-    int sum = 0;
-    
-    WSACMSGHDR *cmsg = WSA_CMSG_FIRSTHDR(&msg);
-    
-    switch (local_addr->type) {
-        case BADDR_TYPE_NONE:
-            break;
-        case BADDR_TYPE_IPV4: {
-            memset(cmsg, 0, WSA_CMSG_SPACE(sizeof(struct in_pktinfo)));
-            cmsg->cmsg_level = IPPROTO_IP;
-            cmsg->cmsg_type = IP_PKTINFO;
-            cmsg->cmsg_len = WSA_CMSG_LEN(sizeof(struct in_pktinfo));
-            struct in_pktinfo *pktinfo = (struct in_pktinfo *)WSA_CMSG_DATA(cmsg);
-            pktinfo->ipi_addr.s_addr = local_addr->ipv4;
-            sum += WSA_CMSG_SPACE(sizeof(struct in_pktinfo));
-        } break;
-        case BADDR_TYPE_IPV6: {
-            memset(cmsg, 0, WSA_CMSG_SPACE(sizeof(struct in6_pktinfo)));
-            cmsg->cmsg_level = IPPROTO_IPV6;
-            cmsg->cmsg_type = IPV6_PKTINFO;
-            cmsg->cmsg_len = WSA_CMSG_LEN(sizeof(struct in6_pktinfo));
-            struct in6_pktinfo *pktinfo = (struct in6_pktinfo *)WSA_CMSG_DATA(cmsg);
-            memcpy(pktinfo->ipi6_addr.s6_addr, local_addr->ipv6, 16);
-            sum += WSA_CMSG_SPACE(sizeof(struct in6_pktinfo));
-        } break;
-        default:
-            ASSERT(0);
-    }
-    
-    msg.Control.len = sum;
-    
     DWORD bytes;
-    if (bs->WSASendMsg(bs->socket, &msg, 0, &bytes, NULL, NULL) != 0) {
-        int error;
-        switch ((error = WSAGetLastError())) {
-            case WSAEWOULDBLOCK:
-                bs->error = BSOCKET_ERROR_LATER;
-                return -1;
+    
+    if (!bs->WSASendMsg) {
+        if (WSASendTo(bs->socket, &buf, 1, &bytes, 0, &remote_sysaddr.addr.generic, remote_sysaddr.len, NULL, NULL) != 0) {
+            int error;
+            switch ((error = WSAGetLastError())) {
+                case WSAEWOULDBLOCK:
+                    bs->error = BSOCKET_ERROR_LATER;
+                    return -1;
+            }
+            
+            bs->error = translate_error(error);
+            return -1;
         }
+    } else {
+        union {
+            char in[WSA_CMSG_SPACE(sizeof(struct in_pktinfo))];
+            char in6[WSA_CMSG_SPACE(sizeof(struct in6_pktinfo))];
+        } cdata;
         
-        bs->error = translate_error(error);
-        return -1;
+        WSAMSG msg;
+        memset(&msg, 0, sizeof(msg));
+        msg.name = &remote_sysaddr.addr.generic;
+        msg.namelen = remote_sysaddr.len;
+        msg.lpBuffers = &buf;
+        msg.dwBufferCount = 1;
+        msg.Control.buf = (char *)&cdata;
+        msg.Control.len = sizeof(cdata);
+        
+        int sum = 0;
+        
+        WSACMSGHDR *cmsg = WSA_CMSG_FIRSTHDR(&msg);
+        
+        switch (local_addr->type) {
+            case BADDR_TYPE_NONE:
+                break;
+            case BADDR_TYPE_IPV4: {
+                memset(cmsg, 0, WSA_CMSG_SPACE(sizeof(struct in_pktinfo)));
+                cmsg->cmsg_level = IPPROTO_IP;
+                cmsg->cmsg_type = IP_PKTINFO;
+                cmsg->cmsg_len = WSA_CMSG_LEN(sizeof(struct in_pktinfo));
+                struct in_pktinfo *pktinfo = (struct in_pktinfo *)WSA_CMSG_DATA(cmsg);
+                pktinfo->ipi_addr.s_addr = local_addr->ipv4;
+                sum += WSA_CMSG_SPACE(sizeof(struct in_pktinfo));
+            } break;
+            case BADDR_TYPE_IPV6: {
+                memset(cmsg, 0, WSA_CMSG_SPACE(sizeof(struct in6_pktinfo)));
+                cmsg->cmsg_level = IPPROTO_IPV6;
+                cmsg->cmsg_type = IPV6_PKTINFO;
+                cmsg->cmsg_len = WSA_CMSG_LEN(sizeof(struct in6_pktinfo));
+                struct in6_pktinfo *pktinfo = (struct in6_pktinfo *)WSA_CMSG_DATA(cmsg);
+                memcpy(pktinfo->ipi6_addr.s6_addr, local_addr->ipv6, 16);
+                sum += WSA_CMSG_SPACE(sizeof(struct in6_pktinfo));
+            } break;
+            default:
+                ASSERT(0);
+        }
+        
+        msg.Control.len = sum;
+        
+        if (bs->WSASendMsg(bs->socket, &msg, 0, &bytes, NULL, NULL) != 0) {
+            int error;
+            switch ((error = WSAGetLastError())) {
+                case WSAEWOULDBLOCK:
+                    bs->error = BSOCKET_ERROR_LATER;
+                    return -1;
+            }
+            
+            bs->error = translate_error(error);
+            return -1;
+        }
     }
     
     #else
@@ -1172,6 +1186,7 @@ int BSocket_RecvFromTo (BSocket *bs, uint8_t *data, int len, BAddr *addr, BIPAdd
     ASSERT(len >= 0)
     ASSERT(addr)
     ASSERT(local_addr)
+    ASSERT(bs->type == BSOCKET_TYPE_DGRAM)
     
     if (limit_recv(bs)) {
         bs->error = BSOCKET_ERROR_LATER;
@@ -1187,37 +1202,57 @@ int BSocket_RecvFromTo (BSocket *bs, uint8_t *data, int len, BAddr *addr, BIPAdd
     buf.len = len;
     buf.buf = data;
     
-    union {
-        char in[WSA_CMSG_SPACE(sizeof(struct in_pktinfo))];
-        char in6[WSA_CMSG_SPACE(sizeof(struct in6_pktinfo))];
-    } cdata;
+    DWORD bytes;
     
     WSAMSG msg;
-    memset(&msg, 0, sizeof(msg));
-    msg.name = &remote_sysaddr.addr.generic;
-    msg.namelen = remote_sysaddr.len;
-    msg.lpBuffers = &buf;
-    msg.dwBufferCount = 1;
-    msg.Control.buf = (char *)&cdata;
-    msg.Control.len = sizeof(cdata);
     
-    DWORD bytes;
-    if (bs->WSARecvMsg(bs->socket, &msg, &bytes, NULL, NULL) != 0) {
-        int error;
-        switch ((error = WSAGetLastError())) {
-            case WSAEWOULDBLOCK:
-                bs->error = BSOCKET_ERROR_LATER;
-                return -1;
+    if (!bs->WSARecvMsg) {
+        DWORD flags = 0;
+        INT fromlen = remote_sysaddr.len;
+        if (WSARecvFrom(bs->socket, &buf, 1, &bytes, &flags, &remote_sysaddr.addr.generic, &fromlen, NULL, NULL) != 0) {
+            int error;
+            switch ((error = WSAGetLastError())) {
+                case WSAEWOULDBLOCK:
+                    bs->error = BSOCKET_ERROR_LATER;
+                    return -1;
+            }
+            
+            bs->error = translate_error(error);
+            return -1;
         }
         
-        bs->error = translate_error(error);
-        return -1;
+        remote_sysaddr.len = fromlen;
+    } else {
+        union {
+            char in[WSA_CMSG_SPACE(sizeof(struct in_pktinfo))];
+            char in6[WSA_CMSG_SPACE(sizeof(struct in6_pktinfo))];
+        } cdata;
+        
+        memset(&msg, 0, sizeof(msg));
+        msg.name = &remote_sysaddr.addr.generic;
+        msg.namelen = remote_sysaddr.len;
+        msg.lpBuffers = &buf;
+        msg.dwBufferCount = 1;
+        msg.Control.buf = (char *)&cdata;
+        msg.Control.len = sizeof(cdata);
+        
+        if (bs->WSARecvMsg(bs->socket, &msg, &bytes, NULL, NULL) != 0) {
+            int error;
+            switch ((error = WSAGetLastError())) {
+                case WSAEWOULDBLOCK:
+                    bs->error = BSOCKET_ERROR_LATER;
+                    return -1;
+            }
+            
+            bs->error = translate_error(error);
+            return -1;
+        }
+        
+        remote_sysaddr.len = msg.namelen;
     }
     
-    remote_sysaddr.len = msg.namelen;
-    
     #else
-    
+        
     struct iovec iov;
     iov.iov_base = data;
     iov.iov_len = len;
@@ -1261,15 +1296,17 @@ int BSocket_RecvFromTo (BSocket *bs, uint8_t *data, int len, BAddr *addr, BIPAdd
     
     #ifdef BADVPN_USE_WINAPI
     
-    WSACMSGHDR *cmsg;
-    for (cmsg = WSA_CMSG_FIRSTHDR(&msg); cmsg; cmsg = WSA_CMSG_NXTHDR(&msg, cmsg)) {
-        if (cmsg->cmsg_level == IPPROTO_IP && cmsg->cmsg_type == IP_PKTINFO) {
-            struct in_pktinfo *pktinfo = (struct in_pktinfo *)WSA_CMSG_DATA(cmsg);
-            BIPAddr_InitIPv4(local_addr, pktinfo->ipi_addr.s_addr);
-        }
-        else if (cmsg->cmsg_level == IPPROTO_IPV6 && cmsg->cmsg_type == IPV6_PKTINFO) {
-            struct in6_pktinfo *pktinfo = (struct in6_pktinfo *)WSA_CMSG_DATA(cmsg);
-            BIPAddr_InitIPv6(local_addr, pktinfo->ipi6_addr.s6_addr);
+    if (bs->WSARecvMsg) {
+        WSACMSGHDR *cmsg;
+        for (cmsg = WSA_CMSG_FIRSTHDR(&msg); cmsg; cmsg = WSA_CMSG_NXTHDR(&msg, cmsg)) {
+            if (cmsg->cmsg_level == IPPROTO_IP && cmsg->cmsg_type == IP_PKTINFO) {
+                struct in_pktinfo *pktinfo = (struct in_pktinfo *)WSA_CMSG_DATA(cmsg);
+                BIPAddr_InitIPv4(local_addr, pktinfo->ipi_addr.s_addr);
+            }
+            else if (cmsg->cmsg_level == IPPROTO_IPV6 && cmsg->cmsg_type == IPV6_PKTINFO) {
+                struct in6_pktinfo *pktinfo = (struct in6_pktinfo *)WSA_CMSG_DATA(cmsg);
+                BIPAddr_InitIPv6(local_addr, pktinfo->ipi6_addr.s6_addr);
+            }
         }
     }