ambrop7 15 жил өмнө
parent
commit
d604b6ae51

+ 51 - 31
client/StreamPeerIO.c

@@ -22,10 +22,6 @@
 
 #include <stdlib.h>
 
-#include <openssl/rand.h>
-
-#include <prio.h>
-
 #include <ssl.h>
 #include <sslerr.h>
 
@@ -57,6 +53,10 @@
 #define LISTEN_STATE_GOTCLIENT 1
 #define LISTEN_STATE_FINISHED 2
 
+#define COMPONENT_SOURCE 1
+#define COMPONENT_SINK 2
+#define COMPONENT_DECODER 3
+
 static int init_persistent_io (StreamPeerIO *pio, PacketPassInterface *user_recv_if);
 static void free_persistent_io (StreamPeerIO *pio);
 static void connecting_connect_handler (StreamPeerIO *pio, int event);
@@ -74,10 +74,6 @@ static void reset_state (StreamPeerIO *pio);
 static void cleanup_socket (sslsocket *sock, int ssl);
 static void reset_and_report_error (StreamPeerIO *pio);
 
-#define COMPONENT_SOURCE 1
-#define COMPONENT_SINK 2
-#define COMPONENT_DECODER 3
-
 void connecting_connect_handler (StreamPeerIO *pio, int event)
 {
     ASSERT(event == BSOCKET_CONNECT)
@@ -164,28 +160,31 @@ SECStatus client_auth_certificate_callback (StreamPeerIO *pio, PRFileDesc *fd, P
     // don't have domain names. We byte-compare the certificate to the one reported
     // by the server anyway.
     
+    SECStatus ret = SECFailure;
+    
     CERTCertificate *server_cert = SSL_PeerCertificate(pio->connect.sock.ssl_prfd);
     if (!server_cert) {
         BLog(BLOG_ERROR, "SSL_PeerCertificate failed");
         PORT_SetError(SSL_ERROR_BAD_CERTIFICATE);
-        return SECFailure;
+        goto fail1;
     }
     
-    SECStatus verify_result = CERT_VerifyCertNow(CERT_GetDefaultCertDB(), server_cert, PR_TRUE, certUsageSSLServer, SSL_RevealPinArg(pio->connect.sock.ssl_prfd));
-    if (verify_result == SECFailure) {
-        goto out;
+    if (CERT_VerifyCertNow(CERT_GetDefaultCertDB(), server_cert, PR_TRUE, certUsageSSLServer, SSL_RevealPinArg(pio->connect.sock.ssl_prfd)) != SECSuccess) {
+        goto fail2;
     }
     
     // compare to certificate provided by the server
     if (!compare_certificate(pio, server_cert)) {
-        verify_result = SECFailure;
         PORT_SetError(SSL_ERROR_BAD_CERTIFICATE);
-        goto out;
+        goto fail2;
     }
     
-out:
+    ret = SECSuccess;
+    
+fail2:
     CERT_DestroyCertificate(server_cert);
-    return verify_result;
+fail1:
+    return ret;
 }
 
 SECStatus client_client_auth_data_callback (StreamPeerIO *pio, PRFileDesc *fd, CERTDistNames *caNames, CERTCertificate **pRetCert, SECKEYPrivateKey **pRetKey)
@@ -195,16 +194,26 @@ SECStatus client_client_auth_data_callback (StreamPeerIO *pio, PRFileDesc *fd, C
     ASSERT(pio->connect.state == CONNECT_STATE_HANDSHAKE)
     DebugObject_Access(&pio->d_obj);
     
-    if (!(*pRetCert = CERT_DupCertificate(pio->connect.ssl_cert))) {
-        return SECFailure;
+    CERTCertificate *cert = CERT_DupCertificate(pio->connect.ssl_cert);
+    if (!cert) {
+        BLog(BLOG_ERROR, "CERT_DupCertificate failed");
+        goto fail0;
     }
     
-    if (!(*pRetKey = SECKEY_CopyPrivateKey(pio->connect.ssl_key))) {
-        CERT_DestroyCertificate(*pRetCert);
-        return SECFailure;
+    SECKEYPrivateKey *key = SECKEY_CopyPrivateKey(pio->connect.ssl_key);
+    if (!key) {
+        BLog(BLOG_ERROR, "SECKEY_CopyPrivateKey failed");
+        goto fail1;
     }
     
+    *pRetCert = cert;
+    *pRetKey = key;
     return SECSuccess;
+    
+fail1:
+    CERT_DestroyCertificate(cert);
+fail0:
+    return SECFailure;
 }
 
 void connecting_try_handshake (StreamPeerIO *pio)
@@ -254,6 +263,9 @@ fail0:
 
 void connecting_handshake_read_handler (StreamPeerIO *pio, PRInt16 event)
 {
+    ASSERT(pio->ssl)
+    ASSERT(pio->mode == MODE_CONNECT)
+    ASSERT(pio->connect.state == CONNECT_STATE_HANDSHAKE)
     DebugObject_Access(&pio->d_obj);
     
     connecting_try_handshake(pio);
@@ -268,7 +280,7 @@ static void connecting_pwsender_handler (StreamPeerIO *pio, int is_error)
     
     if (is_error) {
         BLog(BLOG_NOTICE, "error sending password");
-        BLog(BLOG_NOTICE,"BSocket error %d", BSocket_GetError(&pio->connect.sock.sock));
+        BLog(BLOG_NOTICE, "BSocket error %d", BSocket_GetError(&pio->connect.sock.sock));
         if (pio->ssl) {
             BLog(BLOG_NOTICE, "NSPR error %d", (int)PR_GetError());
         }
@@ -305,7 +317,7 @@ void error_handler (StreamPeerIO *pio, int component, const void *data)
     switch (component) {
         case COMPONENT_SOURCE:
         case COMPONENT_SINK:
-            BLog(BLOG_NOTICE,"BSocket error %d", BSocket_GetError(&pio->sock->sock));
+            BLog(BLOG_NOTICE, "BSocket error %d", BSocket_GetError(&pio->sock->sock));
             if (pio->ssl) {
                 BLog(BLOG_NOTICE, "NSPR error %d", (int)PR_GetError());
             }
@@ -489,10 +501,13 @@ int compare_certificate (StreamPeerIO *pio, CERTCertificate *cert)
 {
     ASSERT(pio->ssl)
     
+    int ret = 0;
+    
+    // alloc arena
     PRArenaPool *arena = PORT_NewArena(DER_DEFAULT_CHUNKSIZE);
     if (!arena) {
-        BLog(BLOG_ERROR, "WARNING: PORT_NewArena failed");
-        return 0;
+        BLog(BLOG_ERROR, "PORT_NewArena failed");
+        goto fail0;
     }
     
     // encode server certificate
@@ -501,19 +516,21 @@ int compare_certificate (StreamPeerIO *pio, CERTCertificate *cert)
     der.data = NULL;
     if (!SEC_ASN1EncodeItem(arena, &der, cert, SEC_ASN1_GET(CERT_CertificateTemplate))) {
         BLog(BLOG_ERROR, "SEC_ASN1EncodeItem failed");
-        PORT_FreeArena(arena, PR_FALSE);
-        return 0;
+        goto fail1;
     }
     
     // byte compare
     if (der.len != pio->ssl_peer_cert_len || memcmp(der.data, pio->ssl_peer_cert, der.len)) {
         BLog(BLOG_NOTICE, "Client certificate doesn't match");
-        PORT_FreeArena(arena, PR_FALSE);
-        return 0;
+        goto fail1;
     }
     
+    ret = 1;
+    
+fail1:
     PORT_FreeArena(arena, PR_FALSE);
-    return 1;
+fail0:
+    return ret;
 }
 
 void reset_state (StreamPeerIO *pio)
@@ -617,7 +634,7 @@ int StreamPeerIO_Init (
     
     // init persistent I/O modules
     if (!init_persistent_io(pio, user_recv_if)) {
-        return 0;
+        goto fail0;
     }
     
     // set mode none
@@ -629,6 +646,9 @@ int StreamPeerIO_Init (
     DebugObject_Init(&pio->d_obj);
     
     return 1;
+    
+fail0:
+    return 0;
 }
 
 void StreamPeerIO_Free (StreamPeerIO *pio)

+ 29 - 58
flow/FragmentProtoDisassembler.c

@@ -21,44 +21,30 @@
  */
 
 #include <stdint.h>
-#include <stdlib.h>
+#include <stddef.h>
 #include <string.h>
 
 #include <misc/debug.h>
 #include <misc/byteorder.h>
+#include <misc/minmax.h>
 
 #include <flow/FragmentProtoDisassembler.h>
 
 static void write_chunks (FragmentProtoDisassembler *o)
 {
+    #define IN_AVAIL (o->in_len - o->in_used)
+    #define OUT_AVAIL ((o->output_mtu - o->out_used) - (int)sizeof(struct fragmentproto_chunk_header))
+    
     ASSERT(o->in_len >= 0)
     ASSERT(o->out)
-    ASSERT(o->output_mtu - o->out_used >= sizeof(struct fragmentproto_chunk_header))
-    
-    int in_avail = o->in_len - o->in_used;
-    int out_avail = (o->output_mtu - o->out_used) - sizeof(struct fragmentproto_chunk_header);
+    ASSERT(OUT_AVAIL > 0)
     
     // write chunks to output packet
     do {
-        ASSERT(in_avail >= 0)
-        ASSERT(!(in_avail == 0) || out_avail >= 0)
-        
-        // check if we have space in the output packet
-        // (if this is a zero input packet, only one chunk is written, which
-        // is always possible)
-        if (in_avail > 0 && out_avail <= 0) {
-            break;
-        }
-        
         // calculate chunk length
-        int chunk_len = in_avail;
-        if (chunk_len > out_avail) {
-            chunk_len = out_avail;
-        }
+        int chunk_len = BMIN(IN_AVAIL, OUT_AVAIL);
         if (o->chunk_mtu > 0) {
-            if (chunk_len > o->chunk_mtu) {
-                chunk_len = o->chunk_mtu;
-            }
+            chunk_len = BMIN(chunk_len, o->chunk_mtu);
         }
         
         // write chunk header
@@ -66,7 +52,7 @@ static void write_chunks (FragmentProtoDisassembler *o)
         header->frame_id = htol16(o->frame_id);
         header->chunk_start = htol16(o->in_used);
         header->chunk_len = htol16(chunk_len);
-        header->is_last = (chunk_len == in_avail);
+        header->is_last = (chunk_len == IN_AVAIL);
         
         // write chunk data
         memcpy(o->out + o->out_used + sizeof(struct fragmentproto_chunk_header), o->in + o->in_used, chunk_len);
@@ -74,56 +60,38 @@ static void write_chunks (FragmentProtoDisassembler *o)
         // increment pointers
         o->in_used += chunk_len;
         o->out_used += sizeof(struct fragmentproto_chunk_header) + chunk_len;
-        
-        in_avail = o->in_len - o->in_used;
-        out_avail = (o->output_mtu - o->out_used) - sizeof(struct fragmentproto_chunk_header);
-    } while (in_avail > 0);
+    } while (IN_AVAIL > 0 && OUT_AVAIL > 0);
     
     // have we finished the input packet?
-    if (in_avail == 0) {
+    if (IN_AVAIL == 0) {
+        // set no input packet
         o->in_len = -1;
+        
+        // increment frame ID
         o->frame_id++;
+        
+        // finish input
+        PacketPassInterface_Done(&o->input);
     }
     
     // should we finish the output packet?
-    if (
-        out_avail < 0 ||
-        (in_avail > 0 && out_avail <= 0) ||
-        o->latency < 0
-    ) {
-        // finish output packet
+    if (OUT_AVAIL <= 0 || o->latency < 0) {
+        // set no output packet
         o->out = NULL;
+        
         // stop timer (if it's running)
         if (o->latency >= 0) {
             BReactor_RemoveTimer(o->reactor, &o->timer);
         }
+        
+        // finish output
+        PacketRecvInterface_Done(&o->output, o->out_used);
     } else {
         // start timer if we have output and it's not running (output was empty before)
         if (!BTimer_IsRunning(&o->timer)) {
             BReactor_SetTimer(o->reactor, &o->timer);
         }
     }
-    
-    ASSERT(o->in_len < 0 || !o->out)
-}
-
-static void work_chunks (FragmentProtoDisassembler *o)
-{
-    ASSERT(o->in_len >= 0)
-    ASSERT(o->out)
-    
-    // write input to output
-    write_chunks(o);
-    
-    // finish input packet if needed
-    if (o->in_len == -1) {
-        PacketPassInterface_Done(&o->input);
-    }
-    
-    // finish output packet if needed
-    if (!o->out) {
-        PacketRecvInterface_Done(&o->output, o->out_used);
-    }
 }
 
 static void input_handler_send (FragmentProtoDisassembler *o, uint8_t *data, int data_len)
@@ -142,7 +110,7 @@ static void input_handler_send (FragmentProtoDisassembler *o, uint8_t *data, int
         return;
     }
     
-    work_chunks(o);
+    write_chunks(o);
 }
 
 static void input_handler_cancel (FragmentProtoDisassembler *o)
@@ -150,6 +118,7 @@ static void input_handler_cancel (FragmentProtoDisassembler *o)
     ASSERT(o->in_len >= 0)
     ASSERT(!o->out)
     
+    // set no input packet
     o->in_len = -1;
 }
 
@@ -167,7 +136,7 @@ static void output_handler_recv (FragmentProtoDisassembler *o, uint8_t *data)
         return;
     }
     
-    work_chunks(o);
+    write_chunks(o);
 }
 
 static void timer_handler (FragmentProtoDisassembler *o)
@@ -176,8 +145,10 @@ static void timer_handler (FragmentProtoDisassembler *o)
     ASSERT(o->out)
     ASSERT(o->in_len = -1)
     
-    // finish output packet
+    // set no output packet
     o->out = NULL;
+    
+    // finish output
     PacketRecvInterface_Done(&o->output, o->out_used);
 }