Ver código fonte

structure: BAVL: add count keeping, ifdef'd out

ambrop7 14 anos atrás
pai
commit
803f79fc63
1 arquivos alterados com 121 adições e 0 exclusões
  1. 121 0
      structure/BAVL.h

+ 121 - 0
structure/BAVL.h

@@ -69,6 +69,9 @@ typedef struct BAVLNode {
     struct BAVLNode *parent;
     struct BAVLNode *link[2];
     int8_t balance;
+#ifdef BAVL_COUNT
+    uint64_t count;
+#endif
 } BAVLNode;
 
 /**
@@ -175,6 +178,12 @@ static BAVLNode * BAVL_GetNext (const BAVL *o, BAVLNode *n);
  */
 static BAVLNode * BAVL_GetPrev (const BAVL *o, BAVLNode *n);
 
+#ifdef BAVL_COUNT
+static uint64_t BAVL_Count (const BAVL *o);
+static uint64_t BAVL_IndexOf (const BAVL *o, BAVLNode *n);
+static BAVLNode * BAVL_GetAt (const BAVL *o, uint64_t index);
+#endif
+
 #define BAVL_MAX(_a, _b) ((_a) > (_b) ? (_a) : (_b))
 #define BAVL_OPTNEG(_a, _neg) ((_neg) ? -(_a) : (_a))
 
@@ -211,6 +220,10 @@ static int _BAVL_assert_recurser (BAVL *o, BAVLNode *n)
     
     int height_left = 0;
     int height_right = 0;
+#ifdef BAVL_COUNT
+    uint64_t count_left = 0;
+    uint64_t count_right = 0;
+#endif
     
     // check left subtree
     if (n->link[0]) {
@@ -220,6 +233,9 @@ static int _BAVL_assert_recurser (BAVL *o, BAVLNode *n)
         ASSERT(_BAVL_compare_nodes(o, n->link[0], n) == -1)
         // recursively calculate height
         height_left = _BAVL_assert_recurser(o, n->link[0]);
+#ifdef BAVL_COUNT
+        count_left = n->link[0]->count;
+#endif
     }
     
     // check right subtree
@@ -230,11 +246,19 @@ static int _BAVL_assert_recurser (BAVL *o, BAVLNode *n)
         ASSERT(_BAVL_compare_nodes(o, n->link[1], n) == 1)
         // recursively calculate height
         height_right = _BAVL_assert_recurser(o, n->link[1]);
+#ifdef BAVL_COUNT
+        count_right = n->link[1]->count;
+#endif
     }
     
     // check balance factor
     ASSERT(n->balance == height_right - height_left)
     
+#ifdef BAVL_COUNT
+    // check count
+    ASSERT(n->count == 1 + count_left + count_right)
+#endif
+    
     return (BAVL_MAX(height_left, height_right) + 1);
 }
 
@@ -248,6 +272,13 @@ static void _BAVL_assert (BAVL *o)
 
 #endif
 
+#ifdef BAVL_COUNT
+static void _BAVL_update_count_from_children (BAVLNode *n)
+{
+    n->count = 1 + (n->link[0] ? n->link[0]->count : 0) + (n->link[1] ? n->link[1]->count : 0);
+}
+#endif
+
 static void _BAVL_rotate (BAVL *tree, BAVLNode *r, uint8_t dir)
 {
     BAVLNode *nr = r->link[!dir];
@@ -264,6 +295,12 @@ static void _BAVL_rotate (BAVL *tree, BAVLNode *r, uint8_t dir)
         tree->root = nr;
     }
     r->parent = nr;
+    
+#ifdef BAVL_COUNT
+    // update counts
+    _BAVL_update_count_from_children(r); // first r!
+    _BAVL_update_count_from_children(nr);
+#endif
 }
 
 static BAVLNode * _BAVL_subtree_max (BAVLNode *n)
@@ -287,6 +324,18 @@ static void _BAVL_replace_subtree (BAVL *tree, BAVLNode *dest, BAVLNode *n)
     if (n) {
         n->parent = dest->parent;
     }
+    
+#ifdef BAVL_COUNT
+    // update counts
+    for (BAVLNode *c = dest->parent; c; c = c->parent) {
+        ASSERT(c->count >= dest->count)
+        c->count -= dest->count;
+        if (n) {
+            ASSERT(n->count <= UINT64_MAX - c->count)
+            c->count += n->count;
+        }
+    }
+#endif
 }
 
 static void _BAVL_swap_nodes (BAVL *tree, BAVLNode *n1, BAVLNode *n2)
@@ -360,6 +409,13 @@ static void _BAVL_swap_nodes (BAVL *tree, BAVLNode *n1, BAVLNode *n2)
     int8_t b = n1->balance;
     n1->balance = n2->balance;
     n2->balance = b;
+    
+#ifdef BAVL_COUNT
+    // swap counts
+    uint64_t c = n1->count;
+    n1->count = n2->count;
+    n2->count = c;
+#endif
 }
 
 static void _BAVL_rebalance (BAVL *o, BAVLNode *node, uint8_t side, int8_t deltac)
@@ -462,6 +518,9 @@ int BAVL_Insert (BAVL *o, BAVLNode *node, BAVLNode **ref)
         node->link[0] = NULL;
         node->link[1] = NULL;
         node->balance = 0;
+#ifdef BAVL_COUNT
+        node->count = 1;
+#endif
         
         BAVL_ASSERT(o)
         
@@ -502,6 +561,17 @@ int BAVL_Insert (BAVL *o, BAVLNode *node, BAVLNode **ref)
     node->link[0] = NULL;
     node->link[1] = NULL;
     node->balance = 0;
+#ifdef BAVL_COUNT
+    node->count = 1;
+#endif
+    
+#ifdef BAVL_COUNT
+    // update counts
+    for (BAVLNode *p = c; p; p = p->parent) {
+        ASSERT(p->count < UINT64_MAX)
+        p->count++;
+    }
+#endif
     
     // rebalance
     _BAVL_rebalance(o, c, side, 1);
@@ -667,4 +737,55 @@ BAVLNode * BAVL_GetPrev (const BAVL *o, BAVLNode *n)
     return n;
 }
 
+#ifdef BAVL_COUNT
+
+static uint64_t BAVL_Count (const BAVL *o)
+{
+    return (o->root ? o->root->count : 0);
+}
+
+static uint64_t BAVL_IndexOf (const BAVL *o, BAVLNode *n)
+{
+    uint64_t index = (n->link[0] ? n->link[0]->count : 0);
+    
+    for (BAVLNode *c = n; c->parent; c = c->parent) {
+        if (c == c->parent->link[1]) {
+            ASSERT(c->parent->count > c->count)
+            ASSERT(c->parent->count - c->count <= UINT64_MAX - index)
+            index += c->parent->count - c->count;
+        }
+    }
+    
+    return index;
+}
+
+static BAVLNode * BAVL_GetAt (const BAVL *o, uint64_t index)
+{
+    if (index >= BAVL_Count(o)) {
+        return NULL;
+    }
+    
+    BAVLNode *c = o->root;
+    
+    while (1) {
+        ASSERT(c)
+        ASSERT(index < c->count)
+        
+        uint64_t left_count = (c->link[0] ? c->link[0]->count : 0);
+        
+        if (index == left_count) {
+            return c;
+        }
+        
+        if (index < left_count) {
+            c = c->link[0];
+        } else {
+            c = c->link[1];
+            index -= left_count + 1;
+        }
+    }
+}
+
+#endif
+
 #endif