Savepoints to protect integrity of documents
[mojodb.git] / collection.py
1 # -*- coding: utf-8 -*-
2 ##############################################################################
3 #
4 #    mojo, a Python library for implementing document based databases
5 #    Copyright (C) 2013-2014 by Javier Sancho Fernandez <jsf at jsancho dot org>
6 #
7 #    This program is free software: you can redistribute it and/or modify
8 #    it under the terms of the GNU General Public License as published by
9 #    the Free Software Foundation, either version 3 of the License, or
10 #    (at your option) any later version.
11 #
12 #    This program is distributed in the hope that it will be useful,
13 #    but WITHOUT ANY WARRANTY; without even the implied warranty of
14 #    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 #    GNU General Public License for more details.
16 #
17 #    You should have received a copy of the GNU General Public License
18 #    along with this program.  If not, see <http://www.gnu.org/licenses/>.
19 #
20 ##############################################################################
21
22 from cursor import Cursor
23 from objectid import ObjectId
24
25 class Collection(object):
26     def __init__(self, database, table_name):
27         self.database = database
28         self.table_name = str(table_name)
29
30     def __repr__(self):
31         return "Collection(%r, %r)" % (self.database, self.table_name)
32
33     def exists(self):
34         return (self.database.exists() and self.table_name in self.database.collection_names())
35
36     def _create_field(self, field_name):
37         fields = [
38             {'name': 'id', 'type': 'char', 'size': 512, 'primary': True},
39             {'name': 'name', 'type': 'char', 'size': 64, 'primary': True},
40             {'name': 'value', 'type': 'text', 'null': False},
41             {'name': 'number', 'type': 'float'},
42             ]
43         return self.database.connection._create_table(self.database.db_name, '%s$%s' % (self.table_name, field_name), fields)
44
45     def _get_fields(self):
46         tables = self.database.connection._get_tables(self.database.db_name)
47         return [str(x[x.find('$')+1:]) for x in filter(lambda x: x.startswith('%s$' % self.table_name), tables)]
48
49     def count(self):
50         return self.database.connection._count(self.database.db_name, self.table_name)
51
52     def find(self, *args, **kwargs):
53         return Cursor(self, *args, **kwargs)
54
55     def insert(self, doc_or_docs):
56         if not self.database.db_name in self.database.connection.database_names():
57             self.database._create_database()
58         if not self.table_name in self.database.collection_names():
59             self._create_field('_id')
60
61         if not type(doc_or_docs) in (list, tuple):
62             docs = [doc_or_docs]
63         else:
64             docs = doc_or_docs
65         for doc in docs:
66             doc_id = str(ObjectId())
67             if not '_id' in doc:
68                 doc['_id'] = doc_id
69             self._insert_document(doc_id, doc)
70
71         if type(doc_or_docs) in (list, tuple):
72             return [d['_id'] for d in docs]
73         else:
74             return docs[0]['_id']
75
76     def _insert_document(self, doc_id, doc):
77         fields = self._get_fields()
78         self.database.connection.savepoint("insert_document")
79         try:
80             for f in doc:
81                 if not f in fields:
82                     self._create_field(f)
83                 table_f = '%s$%s' % (self.table_name, f)
84                 self._insert_field(doc_id, table_f, f, doc[f])
85             self.database.connection.commit_savepoint("insert_document")
86         except:
87             self.database.connection.rollback_savepoint("insert_document")
88             raise
89
90     def _insert_field(self, doc_id, field_table, field_name, field_value):
91         values = {
92             'id': doc_id,
93             'name': field_name,
94             'value': self.database.connection.serializer.dumps(field_value),
95             }
96         if type(field_value) in (int, float):
97             values['number'] = field_value
98
99         self.database.connection._insert(self.database.db_name, field_table, values)
100
101         if type(field_value) in (list, tuple) and not '.' in field_name:
102             for i in xrange(len(field_value)):
103                 self._insert_field(doc_id, field_table, "%s..%s" % (field_name, i), field_value[i])
104         elif type(field_value) is dict:
105             for k, v in field_value.iteritems():
106                 self._insert_field(doc_id, field_table, "%s.%s" % (field_name, k), v)