]> git.jsancho.org Git - mojodb.git/blobdiff - collection.py
Some tests
[mojodb.git] / collection.py
index 664961a040680aebadf96e4e3724e1e765713b9c..e0f00026ba5dea8c411cfc4cd84782bd78df1582 100644 (file)
 #
 ##############################################################################
 
-import cPickle
 from cursor import Cursor
 import uuid
 
 class Collection(object):
     def __init__(self, database, table_name):
         self.database = database
-        self.table_name = unicode(table_name)
+        self.table_name = str(table_name)
 
     def __repr__(self):
         return "Collection(%r, %r)" % (self.database, self.table_name)
@@ -34,15 +33,10 @@ class Collection(object):
     def exists(self):
         return (self.database.exists() and self.table_name in self.database.collection_names())
 
-    def _create_table(self):
-        fields = [
-            {'name': 'id', 'type': 'char', 'size': 32, 'primary': True},
-            ]
-        return self.database.connection._create_table(self.database.db_name, '%s$_id' % self.table_name, fields)
-
     def _create_field(self, field_name):
         fields = [
-            {'name': 'id', 'type': 'char', 'size': 32, 'primary': True},
+            {'name': 'id', 'type': 'char', 'size': 512, 'primary': True},
+            {'name': 'name', 'type': 'char', 'size': 64, 'primary': True},
             {'name': 'value', 'type': 'text', 'null': False},
             {'name': 'number', 'type': 'float'},
             ]
@@ -50,7 +44,7 @@ class Collection(object):
 
     def _get_fields(self):
         tables = self.database.connection._get_tables(self.database.db_name)
-        return [unicode(x[x.find('$')+1:]) for x in filter(lambda x: x.startswith('%s$' % self.table_name), tables)]
+        return [str(x[x.find('$')+1:]) for x in filter(lambda x: x.startswith('%s$' % self.table_name), tables)]
 
     def count(self):
         return self.database.connection._count(self.database.db_name, self.table_name)
@@ -62,36 +56,45 @@ class Collection(object):
         if not self.database.db_name in self.database.connection.database_names():
             self.database._create_database()
         if not self.table_name in self.database.collection_names():
-            self._create_table()
+            self._create_field('_id')
 
         if not type(doc_or_docs) in (list, tuple):
             docs = [doc_or_docs]
         else:
             docs = doc_or_docs
         for doc in docs:
+            doc_id = uuid.uuid4().hex
             if not '_id' in doc:
-                doc['_id'] = uuid.uuid4().hex
-            self._insert_document(doc)
+                doc['_id'] = doc_id
+            self._insert_document(doc_id, doc)
 
         if type(doc_or_docs) in (list, tuple):
             return [d['_id'] for d in docs]
         else:
             return docs[0]['_id']
 
-    def _insert_document(self, doc):
-        table_id = '%s$_id' % self.table_name
+    def _insert_document(self, doc_id, doc):
         fields = self._get_fields()
-        self.database.connection._insert(self.database.db_name, table_id, {'id': doc['_id']})
         for f in doc:
-            if f == '_id':
-                continue
             if not f in fields:
                 self._create_field(f)
             table_f = '%s$%s' % (self.table_name, f)
-            values = {
-                'id': doc['_id'],
-                'value': cPickle.dumps(doc[f]),
-                }
-            if type(doc[f]) in (int, float):
-                values['number'] = doc[f]
-            self.database.connection._insert(self.database.db_name, table_f, values)
+            self._insert_field(doc_id, table_f, f, doc[f])
+
+    def _insert_field(self, doc_id, field_table, field_name, field_value):
+        values = {
+            'id': doc_id,
+            'name': field_name,
+            'value': self.database.connection.serializer.dumps(field_value),
+            }
+        if type(field_value) in (int, float):
+            values['number'] = field_value
+
+        self.database.connection._insert(self.database.db_name, field_table, values)
+
+        if type(field_value) in (list, tuple) and not '.' in field_name:
+            for i in xrange(len(field_value)):
+                self._insert_field(doc_id, field_table, "%s..%s" % (field_name, i), field_value[i])
+        elif type(field_value) is dict:
+            for k, v in field_value.iteritems():
+                self._insert_field(doc_id, field_table, "%s.%s" % (field_name, k), v)