]> git.jsancho.org Git - mojodb.git/blob - mojo.py
f3aa74456056a7d0c84eb8cb076bb502dc05e444
[mojodb.git] / mojo.py
1 # -*- coding: utf-8 -*-
2 ##############################################################################
3 #
4 #    mojo, a Python library for implementing document based databases
5 #    Copyright (C) 2013-2014 by Javier Sancho Fernandez <jsf at jsancho dot org>
6 #
7 #    This program is free software: you can redistribute it and/or modify
8 #    it under the terms of the GNU General Public License as published by
9 #    the Free Software Foundation, either version 3 of the License, or
10 #    (at your option) any later version.
11 #
12 #    This program is distributed in the hope that it will be useful,
13 #    but WITHOUT ANY WARRANTY; without even the implied warranty of
14 #    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 #    GNU General Public License for more details.
16 #
17 #    You should have received a copy of the GNU General Public License
18 #    along with this program.  If not, see <http://www.gnu.org/licenses/>.
19 #
20 ##############################################################################
21
22 import cPickle
23 import uuid
24
25
26 class Connection(object):
27     def __init__(self, *args, **kwargs):
28         self._db_con = None
29
30     def __getattr__(self, db_name):
31         return Database(self, db_name)
32
33     def __getitem__(self, *args, **kwargs):
34         return self.__getattr__(*args, **kwargs)
35
36     def __repr__(self):
37         return "Connection(%s)" % self._db_con
38
39     def _get_databases(self):
40         return []
41
42     def database_names(self):
43         try:
44             return [unicode(x) for x in self._get_databases()]
45         except:
46             return []
47
48     def _get_tables(self, db_name):
49         return []
50
51     def collection_names(self, db_name):
52         try:
53             return list(set([unicode(x.split('$')[0]) for x in filter(lambda x: '$' in x, self._get_tables(db_name))]))
54         except:
55             return []
56
57     def _count_rows(self, db_name, table_name):
58         return 0
59
60     def _count(self, db_name, table_name):
61         try:
62             return self._count_rows(db_name, table_name + '$_id')
63         except:
64             return 0
65
66     def _create_database(self, db_name):
67         return None
68
69     def _create_table(self, db_name, table_name, fields):
70         # [{'name': 'id', 'type': 'char', 'size': 20, 'primary': True}]
71         return None
72
73     def _get_cursor(self, db_name, query):
74         # {'select': [('t1$_id', 'id'), {'select': [('t1$c1', 'value')], 'from': ['t1$c1'], 'where': [(('t1$c1', 'id'), '=', ('t1$_id', 'id'))]}], 'from': ['t1$_id']}
75         return None
76
77     def _next(self, cursor):
78         return None
79
80     def _insert(self, db_name, table_name, values):
81         return None
82
83     def commit(self):
84         pass
85
86     def rollback(self):
87         pass
88
89
90 class Database(object):
91     def __init__(self, connection, db_name):
92         self.connection = connection
93         self.db_name = unicode(db_name)
94
95     def __getattr__(self, table_name):
96         return Collection(self, table_name)
97
98     def __getitem__(self, *args, **kwargs):
99         return self.__getattr__(*args, **kwargs)
100
101     def __repr__(self):
102         return "Database(%r, %r)" % (self.connection, self.db_name)
103
104     def _create_database(self):
105         return self.connection._create_database(self.db_name)
106
107     def exists(self):
108         return (self.db_name in self.connection.database_names())
109
110     def collection_names(self):
111         return self.connection.collection_names(self.db_name)
112
113
114 class Collection(object):
115     def __init__(self, database, table_name):
116         self.database = database
117         self.table_name = unicode(table_name)
118
119     def __repr__(self):
120         return "Collection(%r, %r)" % (self.database, self.table_name)
121
122     def exists(self):
123         return (self.database.exists() and self.table_name in self.database.collection_names())
124
125     def _create_table(self):
126         fields = [
127             {'name': 'id', 'type': 'char', 'size': 32, 'primary': True},
128             ]
129         return self.database.connection._create_table(self.database.db_name, '%s$_id' % self.table_name, fields)
130
131     def _create_field(self, field_name):
132         fields = [
133             {'name': 'id', 'type': 'char', 'size': 32, 'primary': True},
134             {'name': 'value', 'type': 'text', 'null': False},
135             {'name': 'number', 'type': 'float'},
136             ]
137         return self.database.connection._create_table(self.database.db_name, '%s$%s' % (self.table_name, field_name), fields)
138
139     def _get_fields(self):
140         tables = self.database.connection._get_tables(self.database.db_name)
141         return [unicode(x[x.find('$')+1:]) for x in filter(lambda x: x.startswith('%s$' % self.table_name), tables)]
142
143     def count(self):
144         return self.database.connection._count(self.database.db_name, self.table_name)
145
146     def find(self, *args, **kwargs):
147         return Cursor(self, *args, **kwargs)
148
149     def insert(self, doc_or_docs):
150         if not self.database.db_name in self.database.connection.database_names():
151             self.database._create_database()
152         if not self.table_name in self.database.collection_names():
153             self._create_table()
154
155         if not type(doc_or_docs) in (list, tuple):
156             docs = [doc_or_docs]
157         else:
158             docs = doc_or_docs
159         for doc in docs:
160             if not '_id' in doc:
161                 doc['_id'] = uuid.uuid4().hex
162             self._insert_document(doc)
163
164         if type(doc_or_docs) in (list, tuple):
165             return [d['_id'] for d in docs]
166         else:
167             return docs[0]['_id']
168
169     def _insert_document(self, doc):
170         table_id = '%s$_id' % self.table_name
171         fields = self._get_fields()
172         self.database.connection._insert(self.database.db_name, table_id, {'id': doc['_id']})
173         for f in doc:
174             if f == '_id':
175                 continue
176             if not f in fields:
177                 self._create_field(f)
178             table_f = '%s$%s' % (self.table_name, f)
179             values = {
180                 'id': doc['_id'],
181                 'value': cPickle.dumps(doc[f]),
182                 }
183             if type(doc[f]) in (int, float):
184                 values['number'] = doc[f]
185             self.database.connection._insert(self.database.db_name, table_f, values)
186
187
188 class Cursor(object):
189     def __init__(self, collection, spec=None, fields=None, **kwargs):
190         if spec and not type(spec) is dict:
191             raise Exception("spec must be an instance of dict")
192
193         self.collection = collection
194         self.spec = spec
195         if self.collection.exists():
196             self.fields = self._get_fields(fields)
197             self.cursor = self._get_cursor()
198         else:
199             self.fields = None
200             self.cursor = None
201
202     def __iter__(self):
203         return self
204
205     def _get_fields(self, fields):
206         set_all_fields = set(self.collection._get_fields())
207         if fields is None:
208             res_fields = list(set_all_fields)
209         elif type(fields) is dict:
210             fields_without_id = filter(lambda x: x[0] != '_id', fields.iteritems())
211             if fields_without_id[0][1]:
212                 first = True
213                 res_fields = set()
214             else:
215                 first = False
216                 res_fields = set(set_all_fields)
217             for f in fields_without_id:
218                 if f[1] and f[0] in set_all_fields:
219                     if first:
220                         res_fields.add(f[0])
221                     else:
222                         raise Exception("You cannot currently mix including and excluding fields. Contact us if this is an issue.")
223                 elif not f[1]:
224                     if not first:
225                         res_fields.discard(f[0])
226                     else:
227                         raise Exception("You cannot currently mix including and excluding fields. Contact us if this is an issue.")
228             if '_id' in fields and not fields['_id']:
229                 res_fields.discard('_id')
230             else:
231                 res_fields.add('_id')
232             res_fields = list(res_fields)
233         else:
234             set_fields = set(list(fields))
235             set_fields.add('_id')
236             res_fields = list(set_all_fields.intersection(set_fields))
237
238         return res_fields
239
240     def _get_cursor(self):
241         query = {}
242         table_id = '%s$_id' % self.collection.table_name
243
244         query['select'] = [(table_id, 'id')]
245         for f in filter(lambda x: x != '_id', self.fields):
246             table_f = '%s$%s' % (self.collection.table_name, f)
247             q = self._get_cursor_field(table_id, table_f)
248             query['select'].append(q)
249
250         query['from'] = [table_id]
251
252         if self.spec:
253             query['where'] = []
254             for k, v in self.spec.iteritems():
255                 table_f = '%s$%s' % (self.collection.table_name, k)
256                 field_q = self._get_cursor_field(table_id, table_f)
257                 query['where'].append((field_q, '=', v))
258
259         return self.collection.database.connection._get_cursor(self.collection.database.db_name, query)
260
261     def _get_cursor_field(self, table_id, table_field):
262         return {
263             'select': [(table_field, 'value')],
264             'from': [table_field],
265             'where': [((table_field, 'id'), '=', (table_id, 'id'))],
266             }
267
268     def next(self):
269         if self.cursor is None:
270             raise StopIteration
271
272         if self.cursor:
273             res = self.collection.database.connection._next(self.cursor)
274             if res is None:
275                 raise StopIteration
276             else:
277                 document = {}
278                 if '_id' in self.fields:
279                     document['_id'] = res[0]
280                 fields_without_id = filter(lambda x: x != '_id', self.fields)
281                 for i in xrange(len(fields_without_id)):
282                     if not res[i + 1] is None:
283                         document[fields_without_id[i]] = cPickle.loads(res[i + 1])
284                 return document
285         else:
286             return None