]> git.jsancho.org Git - mojodb.git/blob - cursor.py
202c6c78aabb6429a8b8cd012b616e9f921a6169
[mojodb.git] / cursor.py
1 # -*- coding: utf-8 -*-
2 ##############################################################################
3 #
4 #    mojo, a Python library for implementing document based databases
5 #    Copyright (C) 2013-2014 by Javier Sancho Fernandez <jsf at jsancho dot org>
6 #
7 #    This program is free software: you can redistribute it and/or modify
8 #    it under the terms of the GNU General Public License as published by
9 #    the Free Software Foundation, either version 3 of the License, or
10 #    (at your option) any later version.
11 #
12 #    This program is distributed in the hope that it will be useful,
13 #    but WITHOUT ANY WARRANTY; without even the implied warranty of
14 #    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 #    GNU General Public License for more details.
16 #
17 #    You should have received a copy of the GNU General Public License
18 #    along with this program.  If not, see <http://www.gnu.org/licenses/>.
19 #
20 ##############################################################################
21
22 import cPickle
23
24 class Cursor(object):
25     def __init__(self, collection, spec=None, fields=None, **kwargs):
26         if spec and not type(spec) is dict:
27             raise Exception("spec must be an instance of dict")
28
29         self.collection = collection
30         self.spec = spec
31         if self.collection.exists():
32             self.fields = self._get_fields(fields)
33             self.cursor = self._get_cursor()
34         else:
35             self.fields = None
36             self.cursor = None
37
38     def __iter__(self):
39         return self
40
41     def _get_fields(self, fields):
42         set_all_fields = set(self.collection._get_fields())
43         if fields is None:
44             res_fields = list(set_all_fields)
45         elif type(fields) is dict:
46             fields_without_id = filter(lambda x: x[0] != '_id', fields.iteritems())
47             if fields_without_id[0][1]:
48                 first = True
49                 res_fields = set()
50             else:
51                 first = False
52                 res_fields = set(set_all_fields)
53             for f in fields_without_id:
54                 if f[1] and f[0] in set_all_fields:
55                     if first:
56                         res_fields.add(f[0])
57                     else:
58                         raise Exception("You cannot currently mix including and excluding fields. Contact us if this is an issue.")
59                 elif not f[1]:
60                     if not first:
61                         res_fields.discard(f[0])
62                     else:
63                         raise Exception("You cannot currently mix including and excluding fields. Contact us if this is an issue.")
64             if '_id' in fields and not fields['_id']:
65                 res_fields.discard('_id')
66             else:
67                 res_fields.add('_id')
68             res_fields = list(res_fields)
69         else:
70             set_fields = set(list(fields))
71             set_fields.add('_id')
72             res_fields = list(set_all_fields.intersection(set_fields))
73
74         return res_fields
75
76     def _get_cursor(self):
77         query = {}
78         table_id = '%s$_id' % self.collection.table_name
79
80         query['select'] = [(table_id, 'id')]
81         for f in filter(lambda x: x != '_id', self.fields):
82             table_f = '%s$%s' % (self.collection.table_name, f)
83             q = self._get_cursor_field(table_id, table_f)
84             query['select'].append(q)
85
86         query['from'] = [table_id]
87
88         if self.spec:
89             query['where'] = []
90             for k, v in self.spec.iteritems():
91                 table_f = '%s$%s' % (self.collection.table_name, k)
92                 if type(v) in (int, float):
93                     field_q = self._get_cursor_field(table_id, table_f, field_name='number')
94                     query['where'].append((field_q, '=', v))
95                 else:
96                     field_q = self._get_cursor_field(table_id, table_f)
97                     query['where'].append((field_q, '=', cPickle.dumps(v)))
98
99         return self.collection.database.connection._get_cursor(self.collection.database.db_name, query)
100
101     def _get_cursor_field(self, table_id, table_field, field_name='value'):
102         return {
103             'select': [(table_field, field_name)],
104             'from': [table_field],
105             'where': [((table_field, 'id'), '=', (table_id, 'id'))],
106             }
107
108     def next(self):
109         if self.cursor is None:
110             raise StopIteration
111
112         if self.cursor:
113             res = self.collection.database.connection._next(self.cursor)
114             if res is None:
115                 raise StopIteration
116             else:
117                 document = {}
118                 if '_id' in self.fields:
119                     document['_id'] = res[0]
120                 fields_without_id = filter(lambda x: x != '_id', self.fields)
121                 for i in xrange(len(fields_without_id)):
122                     if not res[i + 1] is None:
123                         document[fields_without_id[i]] = cPickle.loads(res[i + 1])
124                 return document
125         else:
126             return None