X-Git-Url: https://git.jsancho.org/?p=mojodb.git;a=blobdiff_plain;f=collection.py;h=804135379d089cf4ebc5f0e6f0f8b3a1a8981f42;hp=7af5caaa7817e872bb856741945436f617162506;hb=HEAD;hpb=4a34db8a057d135225e70ede89b767e89f827c8f diff --git a/collection.py b/collection.py index 7af5caa..8041353 100644 --- a/collection.py +++ b/collection.py @@ -19,14 +19,13 @@ # ############################################################################## -import cPickle from cursor import Cursor -import uuid +from objectid import ObjectId class Collection(object): def __init__(self, database, table_name): self.database = database - self.table_name = unicode(table_name) + self.table_name = str(table_name) def __repr__(self): return "Collection(%r, %r)" % (self.database, self.table_name) @@ -34,15 +33,10 @@ class Collection(object): def exists(self): return (self.database.exists() and self.table_name in self.database.collection_names()) - def _create_table(self): - fields = [ - {'name': 'id', 'type': 'char', 'size': 512, 'primary': True}, - ] - return self.database.connection._create_table(self.database.db_name, '%s$_id' % self.table_name, fields) - def _create_field(self, field_name): fields = [ {'name': 'id', 'type': 'char', 'size': 512, 'primary': True}, + {'name': 'name', 'type': 'char', 'size': 64, 'primary': True}, {'name': 'value', 'type': 'text', 'null': False}, {'name': 'number', 'type': 'float'}, ] @@ -50,7 +44,7 @@ class Collection(object): def _get_fields(self): tables = self.database.connection._get_tables(self.database.db_name) - return [unicode(x[x.find('$')+1:]) for x in filter(lambda x: x.startswith('%s$' % self.table_name), tables)] + return [str(x[x.find('$')+1:]) for x in filter(lambda x: x.startswith('%s$' % self.table_name), tables)] def count(self): return self.database.connection._count(self.database.db_name, self.table_name) @@ -62,37 +56,51 @@ class Collection(object): if not self.database.db_name in self.database.connection.database_names(): self.database._create_database() if not self.table_name in self.database.collection_names(): - self._create_table() + self._create_field('_id') if not type(doc_or_docs) in (list, tuple): docs = [doc_or_docs] else: docs = doc_or_docs for doc in docs: - if not u'_id' in doc: - doc[u'_id'] = uuid.uuid4().hex - self._insert_document(doc) + doc_id = str(ObjectId()) + if not '_id' in doc: + doc['_id'] = doc_id + self._insert_document(doc_id, doc) if type(doc_or_docs) in (list, tuple): - return [d[u'_id'] for d in docs] + return [d['_id'] for d in docs] else: - return docs[0][u'_id'] + return docs[0]['_id'] - def _insert_document(self, doc): - table_id = '%s$_id' % self.table_name + def _insert_document(self, doc_id, doc): fields = self._get_fields() - coded_id = cPickle.dumps(doc['_id']) - self.database.connection._insert(self.database.db_name, table_id, {'id': coded_id}) - for f in doc: - if f == '_id': - continue - if not f in fields: - self._create_field(f) - table_f = '%s$%s' % (self.table_name, f) - values = { - 'id': coded_id, - 'value': cPickle.dumps(doc[f]), - } - if type(doc[f]) in (int, float): - values['number'] = doc[f] - self.database.connection._insert(self.database.db_name, table_f, values) + self.database.connection.savepoint("insert_document") + try: + for f in doc: + if not f in fields: + self._create_field(f) + table_f = '%s$%s' % (self.table_name, f) + self._insert_field(doc_id, table_f, f, doc[f]) + self.database.connection.commit_savepoint("insert_document") + except: + self.database.connection.rollback_savepoint("insert_document") + raise + + def _insert_field(self, doc_id, field_table, field_name, field_value): + values = { + 'id': doc_id, + 'name': field_name, + 'value': self.database.connection.serializer.dumps(field_value), + } + if type(field_value) in (int, float): + values['number'] = field_value + + self.database.connection._insert(self.database.db_name, field_table, values) + + if type(field_value) in (list, tuple) and not '.' in field_name: + for i in xrange(len(field_value)): + self._insert_field(doc_id, field_table, "%s..%s" % (field_name, i), field_value[i]) + elif type(field_value) is dict: + for k, v in field_value.iteritems(): + self._insert_field(doc_id, field_table, "%s.%s" % (field_name, k), v)