瀏覽代碼

ncd: NCDStringIndex: support null bytes in strings

ambrop7 13 年之前
父節點
當前提交
c946119f9d
共有 7 個文件被更改,包括 43 次插入24 次删除
  1. 5 9
      ncd/NCDInterpProcess.c
  2. 3 0
      ncd/NCDModule.c
  3. 21 5
      ncd/NCDStringIndex.c
  4. 3 0
      ncd/NCDStringIndex.h
  5. 2 2
      ncd/NCDStringIndex_hash.h
  6. 6 5
      ncd/NCDVal.c
  7. 3 3
      ncd/modules/value.c

+ 5 - 9
ncd/NCDInterpProcess.c

@@ -78,17 +78,13 @@ static int convert_value_recurser (NCDPlaceholderDb *pdb, NCDStringIndex *string
             const char *str = NCDValue_StringValue(value);
             size_t len = NCDValue_StringLength(value);
             
-            if (strlen(str) == len) {
-                NCD_string_id_t string_id = NCDStringIndex_Get(string_index, str);
-                if (string_id < 0) {
-                    BLog(BLOG_ERROR, "NCDStringIndex_Get failed");
-                    goto fail;
-                }
-                *out = NCDVal_NewIdString(mem, string_id, string_index);
-            } else {
-                *out = NCDVal_NewStringBin(mem, (const uint8_t *)str, len);
+            NCD_string_id_t string_id = NCDStringIndex_GetBin(string_index, str, len);
+            if (string_id < 0) {
+                BLog(BLOG_ERROR, "NCDStringIndex_GetBin failed");
+                goto fail;
             }
             
+            *out = NCDVal_NewIdString(mem, string_id, string_index);
             if (NCDVal_IsInvalid(*out)) {
                 goto fail;
             }

+ 3 - 0
ncd/NCDModule.c

@@ -210,6 +210,9 @@ static int object_func_getvar (const NCDObject *obj, NCD_string_id_t name, NCDVa
     if (n->m->func_getvar2) {
         res = n->m->func_getvar2(n->mem, name, mem, out_value);
     } else {
+        if (NCDStringIndex_HasNulls(n->params->iparams->string_index, name)) {
+            return 0;
+        }
         const char *name_str = NCDStringIndex_Value(n->params->iparams->string_index, name);
         res = n->m->func_getvar(n->mem, name_str, mem, out_value);
     }

+ 21 - 5
ncd/NCDStringIndex.c

@@ -92,11 +92,6 @@ static NCD_string_id_t do_get (NCDStringIndex *o, const char *str, size_t str_le
         return ref.link;
     }
     
-    if (memchr(str, '\0', str_len)) {
-        BLog(BLOG_ERROR, "cannot store strings with nulls");
-        return -1;
-    }
-    
     if (o->entries_size == o->entries_capacity) {
         if (!Array_DoubleUp(o)) {
             BLog(BLOG_ERROR, "Array_DoubleUp failed");
@@ -118,6 +113,7 @@ static NCD_string_id_t do_get (NCDStringIndex *o, const char *str, size_t str_le
         return -1;
     }
     entry->str_len = str_len;
+    entry->has_nulls = !!memchr(str, '\0', str_len);
     
     NCDStringIndex__HashRef newref = {entry, o->entries_size};
     int res = NCDStringIndex__Hash_Insert(&o->hash, o->entries, newref, NULL);
@@ -220,6 +216,26 @@ const char * NCDStringIndex_Value (NCDStringIndex *o, NCD_string_id_t id)
     return o->entries[id].str;
 }
 
+size_t NCDStringIndex_Length (NCDStringIndex *o, NCD_string_id_t id)
+{
+    DebugObject_Access(&o->d_obj);
+    ASSERT(id >= 0)
+    ASSERT(id < o->entries_size)
+    ASSERT(o->entries[id].str)
+    
+    return o->entries[id].str_len;
+}
+
+int NCDStringIndex_HasNulls (NCDStringIndex *o, NCD_string_id_t id)
+{
+    DebugObject_Access(&o->d_obj);
+    ASSERT(id >= 0)
+    ASSERT(id < o->entries_size)
+    ASSERT(o->entries[id].str)
+    
+    return o->entries[id].has_nulls;
+}
+
 int NCDStringIndex_GetRequests (NCDStringIndex *o, struct NCD_string_request *requests)
 {
     DebugObject_Access(&o->d_obj);

+ 3 - 0
ncd/NCDStringIndex.h

@@ -46,6 +46,7 @@ typedef int NCD_string_id_t;
 struct NCDStringIndex__entry {
     char *str;
     size_t str_len;
+    int has_nulls;
     NCD_string_id_t hash_next;
 };
 
@@ -75,6 +76,8 @@ NCD_string_id_t NCDStringIndex_LookupBin (NCDStringIndex *o, const char *str, si
 NCD_string_id_t NCDStringIndex_Get (NCDStringIndex *o, const char *str);
 NCD_string_id_t NCDStringIndex_GetBin (NCDStringIndex *o, const char *str, size_t str_len);
 const char * NCDStringIndex_Value (NCDStringIndex *o, NCD_string_id_t id);
+size_t NCDStringIndex_Length (NCDStringIndex *o, NCD_string_id_t id);
+int NCDStringIndex_HasNulls (NCDStringIndex *o, NCD_string_id_t id);
 int NCDStringIndex_GetRequests (NCDStringIndex *o, struct NCD_string_request *requests) WARN_UNUSED;
 
 #endif

+ 2 - 2
ncd/NCDStringIndex_hash.h

@@ -5,9 +5,9 @@
 #define CHASH_PARAM_ARG NCDStringIndex_hash_arg
 #define CHASH_PARAM_NULL ((NCD_string_id_t)-1)
 #define CHASH_PARAM_DEREF(arg, link) (&(arg)[(link)])
-#define CHASH_PARAM_ENTRYHASH(arg, entry) badvpn_djb2_hash((const uint8_t *)(entry).ptr->str)
+#define CHASH_PARAM_ENTRYHASH(arg, entry) badvpn_djb2_hash_bin((const uint8_t *)(entry).ptr->str, (entry).ptr->str_len)
 #define CHASH_PARAM_KEYHASH(arg, key) badvpn_djb2_hash_bin((const uint8_t *)(key).str, (key).len)
 #define CHASH_PARAM_ENTRYHASH_IS_CHEAP 0
-#define CHASH_PARAM_COMPARE_ENTRIES(arg, entry1, entry2) (!strcmp((entry1).ptr->str, (entry2).ptr->str))
+#define CHASH_PARAM_COMPARE_ENTRIES(arg, entry1, entry2) ((entry1).ptr->str_len == (entry2).ptr->str_len && !memcmp((entry1).ptr->str, (entry2).ptr->str, (entry1).ptr->str_len))
 #define CHASH_PARAM_COMPARE_KEY_ENTRY(arg, key1, entry2) ((key1).len == (entry2).ptr->str_len && !memcmp((key1).str, (entry2).ptr->str, (key1).len))
 #define CHASH_PARAM_ENTRY_NEXT hash_next

+ 6 - 5
ncd/NCDVal.c

@@ -786,8 +786,7 @@ size_t NCDVal_StringLength (NCDValRef string)
     switch (*(int *)ptr) {
         case IDSTRING_TYPE: {
             struct NCDVal__idstring *ids_e = ptr;
-            const char *value = NCDStringIndex_Value(ids_e->string_index, ids_e->string_id);
-            return strlen(value);
+            return NCDStringIndex_Length(ids_e->string_index, ids_e->string_id);
         } break;
         
         case EXTERNALSTRING_TYPE: {
@@ -926,13 +925,15 @@ int NCDVal_StringEqualsId (NCDValRef string, NCD_string_id_t string_id,
         case EXTERNALSTRING_TYPE: {
             struct NCDVal__externalstring *exs_e = ptr;
             const char *string_data = NCDStringIndex_Value(string_index, string_id);
-            return strlen(string_data) == exs_e->length && !memcmp(string_data, exs_e->data, exs_e->length);
+            size_t string_length = NCDStringIndex_Length(string_index, string_id);
+            return (string_length == exs_e->length) && !memcmp(string_data, exs_e->data, string_length);
         } break;
     }
     
-    const char *string_data = NCDStringIndex_Value(string_index, string_id);
     struct NCDVal__string *str_e = ptr;
-    return !strcmp(str_e->data, string_data) && str_e->length == strlen(string_data);
+    const char *string_data = NCDStringIndex_Value(string_index, string_id);
+    size_t string_length = NCDStringIndex_Length(string_index, string_id);
+    return (string_length == str_e->length) && !memcmp(string_data, str_e->data, string_length);
 }
 
 int NCDVal_IsList (NCDValRef val)

+ 3 - 3
ncd/modules/value.c

@@ -1011,7 +1011,7 @@ static int value_append (NCDModuleInst *i, struct value *v, NCDValRef data)
             size_t append_length = NCDVal_StringLength(data);
             
             const char *string = NCDStringIndex_Value(v->idstring.string_index, v->idstring.id);
-            size_t length = strlen(string);
+            size_t length = NCDStringIndex_Length(v->idstring.string_index, v->idstring.id);
             
             if (append_length > SIZE_MAX - length) {
                 ModuleLog(i, BLOG_ERROR, "too much data to append");
@@ -1144,7 +1144,7 @@ static int func_getvar2 (void *vo, NCD_string_id_t name, NCDValMem *mem, NCDValR
                 len = v->string.length;
                 break;
             case IDSTRING_TYPE:
-                len = strlen(NCDStringIndex_Value(v->idstring.string_index, v->idstring.id));
+                len = NCDStringIndex_Length(v->idstring.string_index, v->idstring.id);
                 break;
             default:
                 ASSERT(0);
@@ -1617,7 +1617,7 @@ static void func_new_substr (void *vo, NCDModuleInst *i, const struct NCDModuleI
     size_t string_len;
     if (mov->type == IDSTRING_TYPE) {
         string_data = NCDStringIndex_Value(mov->idstring.string_index, mov->idstring.id);
-        string_len = strlen(string_data);
+        string_len = NCDStringIndex_Length(mov->idstring.string_index, mov->idstring.id);
     } else {
         string_data = mov->string.string;
         string_len = mov->string.length;