]> git.jsancho.org Git - mojodb.git/blob - cursor.py
Savepoints to protect integrity of documents
[mojodb.git] / cursor.py
1 # -*- coding: utf-8 -*-
2 ##############################################################################
3 #
4 #    mojo, a Python library for implementing document based databases
5 #    Copyright (C) 2013-2014 by Javier Sancho Fernandez <jsf at jsancho dot org>
6 #
7 #    This program is free software: you can redistribute it and/or modify
8 #    it under the terms of the GNU General Public License as published by
9 #    the Free Software Foundation, either version 3 of the License, or
10 #    (at your option) any later version.
11 #
12 #    This program is distributed in the hope that it will be useful,
13 #    but WITHOUT ANY WARRANTY; without even the implied warranty of
14 #    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 #    GNU General Public License for more details.
16 #
17 #    You should have received a copy of the GNU General Public License
18 #    along with this program.  If not, see <http://www.gnu.org/licenses/>.
19 #
20 ##############################################################################
21
22
23 class Cursor(object):
24     def __init__(self, collection, spec=None, fields=None, **kwargs):
25         if spec and not type(spec) is dict:
26             raise Exception("spec must be an instance of dict")
27
28         self.Query = collection.database.connection.Query
29         self.Field = collection.database.connection.Field
30         self.Table = collection.database.connection.Table
31         self.Constraint = collection.database.connection.Constraint
32         self.Literal = collection.database.connection.Literal
33
34         self.serializer = collection.database.connection.serializer
35
36         self.collection = collection
37         self.spec = spec
38         if self.collection.exists():
39             self.fields = self._get_fields(fields)
40             self.cursor = self._get_cursor()
41         else:
42             self.fields = None
43             self.cursor = None
44
45     def __iter__(self):
46         return self
47
48     def _get_fields(self, fields):
49         set_all_fields = set(self.collection._get_fields())
50         if fields is None:
51             res_fields = list(set_all_fields)
52         elif type(fields) is dict:
53             fields_without_id = filter(lambda x: x[0] != '_id', fields.iteritems())
54             if fields_without_id[0][1]:
55                 first = True
56                 res_fields = set()
57             else:
58                 first = False
59                 res_fields = set(set_all_fields)
60             for f in fields_without_id:
61                 if f[1] and f[0] in set_all_fields:
62                     if first:
63                         res_fields.add(f[0])
64                     else:
65                         raise Exception("You cannot currently mix including and excluding fields. Contact us if this is an issue.")
66                 elif not f[1]:
67                     if not first:
68                         res_fields.discard(f[0])
69                     else:
70                         raise Exception("You cannot currently mix including and excluding fields. Contact us if this is an issue.")
71             if '_id' in fields and not fields['_id']:
72                 res_fields.discard('_id')
73             else:
74                 res_fields.add('_id')
75             res_fields = list(res_fields)
76         else:
77             set_fields = set(list(fields))
78             set_fields.add('_id')
79             res_fields = list(set_all_fields.intersection(set_fields))
80
81         return res_fields
82
83     def _get_cursor(self):
84         table_id = self.Table(self.collection.database.db_name, '%s$_id' % self.collection.table_name)
85
86         fields = [self.Field(table_id, 'value')]
87         for f in filter(lambda x: x != '_id', self.fields):
88             fields.append(self._get_cursor_field(f))
89
90         tables = [table_id]
91
92         constraints = [self.Constraint('=', self.Field(table_id, 'name'), self.Literal('_id'))]
93         if self.spec:
94             for k, v in self.spec.iteritems():
95                 constraints.append(self._get_cursor_constraint(k, v))
96
97         query = self.Query(fields, tables, constraints)
98         return self.collection.database.connection._get_cursor(query)
99
100     def _get_cursor_field(self, field_name):
101         table_id = self.Table(self.collection.database.db_name, '%s$_id' % self.collection.table_name)
102         table_field = self.Table(self.collection.database.db_name, '%s$%s' % (self.collection.table_name, field_name.split(".")[0]))
103
104         fields = [self.Field(table_field, 'value')]
105         tables = [table_field]
106         constraints = [
107             self.Constraint('=', self.Field(table_field, 'id'), self.Field(table_id, 'id')),
108             self.Constraint('=', self.Field(table_field, 'name'), self.Literal(field_name)),
109             ]
110         return self.Query(fields, tables, constraints)
111
112     def _get_cursor_constraint(self, field_name, field_value):
113         table_id = self.Table(self.collection.database.db_name, '%s$_id' % self.collection.table_name)
114         table_field = self.Table(self.collection.database.db_name, '%s$%s' % (self.collection.table_name, field_name.split(".")[0]))
115
116         if type(field_value) in (int, float):
117             field_type = 'number'
118         else:
119             field_type = 'value'
120             field_value = self.serializer.dumps(field_value)
121
122         fields = [self.Field(table_field, 'id')]
123         tables = [table_field]
124         constraints = [
125             self.Constraint('or', self.Constraint('=', self.Field(table_field, 'name'), self.Literal(field_name)),
126                                   self.Constraint('starts', self.Field(table_field, 'name'), self.Literal('%s..' % field_name))),
127             self.Constraint('=', self.Field(table_field, field_type), self.Literal(field_value)),
128             ]
129
130         return self.Constraint('in', self.Field(table_id, 'id'), self.Query(fields, tables, constraints))
131
132     def next(self):
133         if self.cursor is None:
134             raise StopIteration
135
136         if self.cursor:
137             res = self.collection.database.connection._next(self.cursor)
138             if res is None:
139                 raise StopIteration
140             else:
141                 document = {}
142                 if '_id' in self.fields:
143                     document['_id'] = self.serializer.loads(res[0])
144                 fields_without_id = filter(lambda x: x != '_id', self.fields)
145                 for i in xrange(len(fields_without_id)):
146                     if not res[i + 1] is None:
147                         document[fields_without_id[i]] = self.serializer.loads(res[i + 1])
148                 return document
149         else:
150             return None