]> git.jsancho.org Git - mojodb.git/blob - mojo.py
16adc442ea6acdea1bc7bfb4976b7fb51ea1a978
[mojodb.git] / mojo.py
1 # -*- coding: utf-8 -*-
2 ##############################################################################
3 #
4 #    mojo, a Python library for implementing document based databases
5 #    Copyright (C) 2013-2014 by Javier Sancho Fernandez <jsf at jsancho dot org>
6 #
7 #    This program is free software: you can redistribute it and/or modify
8 #    it under the terms of the GNU General Public License as published by
9 #    the Free Software Foundation, either version 3 of the License, or
10 #    (at your option) any later version.
11 #
12 #    This program is distributed in the hope that it will be useful,
13 #    but WITHOUT ANY WARRANTY; without even the implied warranty of
14 #    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 #    GNU General Public License for more details.
16 #
17 #    You should have received a copy of the GNU General Public License
18 #    along with this program.  If not, see <http://www.gnu.org/licenses/>.
19 #
20 ##############################################################################
21
22
23 class Connection(object):
24     def __init__(self, *args, **kwargs):
25         self._db_con = None
26
27     def __getattr__(self, db_name):
28         return Database(self, db_name)
29
30     def __getitem__(self, *args, **kwargs):
31         return self.__getattr__(*args, **kwargs)
32
33     def __repr__(self):
34         return "Connection(%s)" % self._db_con
35
36     def _get_databases(self):
37         return []
38
39     def database_names(self):
40         return [unicode(x) for x in self._get_databases()]
41
42     def _get_tables(self, db_name):
43         return []
44
45     def collection_names(self, db_name):
46         return list(set([unicode(x.split('$')[0]) for x in filter(lambda x: '$' in x, self._get_tables(db_name))]))
47
48     def _count_rows(self, db_name, table_name):
49         return 0
50
51     def _count(self, db_name, table_name):
52         try:
53             return self._count_rows(db_name, table_name + '$_id')
54         except:
55             return 0
56
57     def _get_cursor(self, db_name, query):
58         # {'select': [('t1$_id', 'id'), {'select': [('t1$c1', 'value')], 'from': ['t1$c1'], 'where': [(('t1$c1', 'id'), '=', ('t1$_id', 'id'))]}], 'from': ['t1$_id']}
59         return None
60
61     def _next(self, cursor):
62         return None
63
64
65 class Database(object):
66     def __init__(self, connection, db_name):
67         self.connection = connection
68         self.db_name = unicode(db_name)
69
70     def __getattr__(self, table_name):
71         return Collection(self, table_name)
72
73     def __getitem__(self, *args, **kwargs):
74         return self.__getattr__(*args, **kwargs)
75
76     def __repr__(self):
77         return "Database(%r, %r)" % (self.connection, self.db_name)
78
79     def collection_names(self):
80         return self.connection.collection_names(self.db_name)
81
82
83 class Collection(object):
84     def __init__(self, database, table_name):
85         self.database = database
86         self.table_name = unicode(table_name)
87
88     def __repr__(self):
89         return "Collection(%r, %r)" % (self.database, self.table_name)
90
91     def _get_fields(self):
92         tables = self.database.connection._get_tables(self.database.db_name)
93         return [unicode(x[x.find('$')+1:]) for x in filter(lambda x: x.startswith('%s$' % self.table_name), tables)]
94
95     def count(self):
96         return self.database.connection._count(self.database.db_name, self.table_name)
97
98     def find(self, *args, **kwargs):
99         return Cursor(self, *args, **kwargs)
100
101
102 class Cursor(object):
103     def __init__(self, collection, spec=None, fields=None, **kwargs):
104         if spec and not type(spec) is dict:
105             raise Exception("spec must be an instance of dict")
106
107         self.collection = collection
108         self.spec = spec
109         self.fields = self._get_fields(fields)
110         self.cursor = self._get_cursor()
111
112     def __iter__(self):
113         return self
114
115     def _get_fields(self, fields):
116         set_all_fields = set(self.collection._get_fields())
117         if fields is None:
118             res_fields = list(set_all_fields)
119         elif type(fields) is dict:
120             fields_without_id = filter(lambda x: x[0] != '_id', fields.iteritems())
121             if fields_without_id[0][1]:
122                 first = True
123                 res_fields = set()
124             else:
125                 first = False
126                 res_fields = set(set_all_fields)
127             for f in fields_without_id:
128                 if f[1] and f[0] in set_all_fields:
129                     if first:
130                         res_fields.add(f[0])
131                     else:
132                         raise Exception("You cannot currently mix including and excluding fields. Contact us if this is an issue.")
133                 elif not f[1]:
134                     if not first:
135                         res_fields.discard(f[0])
136                     else:
137                         raise Exception("You cannot currently mix including and excluding fields. Contact us if this is an issue.")
138             if '_id' in fields and not fields['_id']:
139                 res_fields.discard('_id')
140             else:
141                 res_fields.add('_id')
142             res_fields = list(res_fields)
143         else:
144             set_fields = set(list(fields))
145             set_fields.add('_id')
146             res_fields = list(set_all_fields.intersection(set_fields))
147
148         return res_fields
149
150     def _get_cursor(self):
151         query = {}
152         table_id = '%s$_id' % self.collection.table_name
153
154         query['select'] = [(table_id, 'id')]
155         for f in filter(lambda x: x != '_id', self.fields):
156             table_f = '%s$%s' % (self.collection.table_name, f)
157             q = self._get_cursor_field(table_id, table_f)
158             query['select'].append(q)
159
160         query['from'] = [table_id]
161
162         if self.spec:
163             query['where'] = []
164             for k, v in self.spec.iteritems():
165                 table_f = '%s$%s' % (self.collection.table_name, k)
166                 field_q = self._get_cursor_field(table_id, table_f)
167                 query['where'].append((field_q, '=', v))
168
169         return self.collection.database.connection._get_cursor(self.collection.database.db_name, query)
170
171     def _get_cursor_field(self, table_id, table_field):
172         return {
173             'select': [(table_field, 'value')],
174             'from': [table_field],
175             'where': [((table_field, 'id'), '=', (table_id, 'id'))],
176             }
177
178     def next(self):
179         if self.cursor:
180             res = self.collection.database.connection._next(self.cursor)
181             if res is None:
182                 raise StopIteration
183             else:
184                 document = {}
185                 if '_id' in self.fields:
186                     document['_id'] = res[0]
187                 fields_without_id = filter(lambda x: x != '_id', self.fields)
188                 for i in xrange(len(fields_without_id)):
189                     if not res[i + 1] is None:
190                         document[fields_without_id[i]] = res[i + 1]
191                 return document
192         else:
193             return None