]> git.jsancho.org Git - mojodb.git/blobdiff - cursor.py
Separate classes and functionality into various files
[mojodb.git] / cursor.py
diff --git a/cursor.py b/cursor.py
new file mode 100644 (file)
index 0000000..73a0c13
--- /dev/null
+++ b/cursor.py
@@ -0,0 +1,122 @@
+# -*- coding: utf-8 -*-
+##############################################################################
+#
+#    mojo, a Python library for implementing document based databases
+#    Copyright (C) 2013-2014 by Javier Sancho Fernandez <jsf at jsancho dot org>
+#
+#    This program is free software: you can redistribute it and/or modify
+#    it under the terms of the GNU General Public License as published by
+#    the Free Software Foundation, either version 3 of the License, or
+#    (at your option) any later version.
+#
+#    This program is distributed in the hope that it will be useful,
+#    but WITHOUT ANY WARRANTY; without even the implied warranty of
+#    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+#    GNU General Public License for more details.
+#
+#    You should have received a copy of the GNU General Public License
+#    along with this program.  If not, see <http://www.gnu.org/licenses/>.
+#
+##############################################################################
+
+import cPickle
+
+class Cursor(object):
+    def __init__(self, collection, spec=None, fields=None, **kwargs):
+        if spec and not type(spec) is dict:
+            raise Exception("spec must be an instance of dict")
+
+        self.collection = collection
+        self.spec = spec
+        if self.collection.exists():
+            self.fields = self._get_fields(fields)
+            self.cursor = self._get_cursor()
+        else:
+            self.fields = None
+            self.cursor = None
+
+    def __iter__(self):
+        return self
+
+    def _get_fields(self, fields):
+        set_all_fields = set(self.collection._get_fields())
+        if fields is None:
+            res_fields = list(set_all_fields)
+        elif type(fields) is dict:
+            fields_without_id = filter(lambda x: x[0] != '_id', fields.iteritems())
+            if fields_without_id[0][1]:
+                first = True
+                res_fields = set()
+            else:
+                first = False
+                res_fields = set(set_all_fields)
+            for f in fields_without_id:
+                if f[1] and f[0] in set_all_fields:
+                    if first:
+                        res_fields.add(f[0])
+                    else:
+                        raise Exception("You cannot currently mix including and excluding fields. Contact us if this is an issue.")
+                elif not f[1]:
+                    if not first:
+                        res_fields.discard(f[0])
+                    else:
+                        raise Exception("You cannot currently mix including and excluding fields. Contact us if this is an issue.")
+            if '_id' in fields and not fields['_id']:
+                res_fields.discard('_id')
+            else:
+                res_fields.add('_id')
+            res_fields = list(res_fields)
+        else:
+            set_fields = set(list(fields))
+            set_fields.add('_id')
+            res_fields = list(set_all_fields.intersection(set_fields))
+
+        return res_fields
+
+    def _get_cursor(self):
+        query = {}
+        table_id = '%s$_id' % self.collection.table_name
+
+        query['select'] = [(table_id, 'id')]
+        for f in filter(lambda x: x != '_id', self.fields):
+            table_f = '%s$%s' % (self.collection.table_name, f)
+            q = self._get_cursor_field(table_id, table_f)
+            query['select'].append(q)
+
+        query['from'] = [table_id]
+
+        if self.spec:
+            query['where'] = []
+            for k, v in self.spec.iteritems():
+                table_f = '%s$%s' % (self.collection.table_name, k)
+                field_q = self._get_cursor_field(table_id, table_f)
+                query['where'].append((field_q, '=', v))
+
+        return self.collection.database.connection._get_cursor(self.collection.database.db_name, query)
+
+    def _get_cursor_field(self, table_id, table_field):
+        return {
+            'select': [(table_field, 'value')],
+            'from': [table_field],
+            'where': [((table_field, 'id'), '=', (table_id, 'id'))],
+            }
+
+    def next(self):
+        if self.cursor is None:
+            raise StopIteration
+
+        if self.cursor:
+            res = self.collection.database.connection._next(self.cursor)
+            if res is None:
+                raise StopIteration
+            else:
+                document = {}
+                if '_id' in self.fields:
+                    document['_id'] = res[0]
+                fields_without_id = filter(lambda x: x != '_id', self.fields)
+                for i in xrange(len(fields_without_id)):
+                    if not res[i + 1] is None:
+                        document[fields_without_id[i]] = cPickle.loads(res[i + 1])
+                return document
+        else:
+            return None