From: Javier Sancho Date: Tue, 28 Jan 2014 15:14:19 +0000 (+0100) Subject: Basic insertion of documents, using cPickle for fields codification X-Git-Url: https://git.jsancho.org/?p=mojodb.git;a=commitdiff_plain;h=f63673b9ed08886c9ac1582f94f0e63103985497 Basic insertion of documents, using cPickle for fields codification --- diff --git a/MySQL.py b/MySQL.py index a3046e6..39c3435 100644 --- a/MySQL.py +++ b/MySQL.py @@ -22,18 +22,36 @@ import mojo import MySQLdb +SQL_FIELD_TYPES = { + 'char': 'VARCHAR', + 'text': 'LONGTEXT', + 'float': 'DOUBLE', + } + class Connection(mojo.Connection): def __init__(self, *args, **kwargs): self._db_con = MySQLdb.connect(*args, **kwargs) + self._db_con_autocommit = MySQLdb.connect(*args, **kwargs) - def query(self, sql): - cur = self._db_con.cursor() + def query(self, sql, db=None): + if db is None: + db = self._db_con + cur = db.cursor() cur.execute(sql) res = cur.fetchall() cur.close() cur = None return res + def execute(self, sql, db=None): + if db is None: + db = self._db_con + cur = db.cursor() + res = cur.execute(sql) + cur.close() + cur = None + return res + def _get_databases(self): return [x[0] for x in self.query("SHOW DATABASES")] @@ -43,6 +61,31 @@ class Connection(mojo.Connection): def _count_rows(self, db_name, table_name): return self.query("SELECT COUNT(*) FROM `%s`.`%s`" % (db_name, table_name))[0][0] + def _create_database(self, db_name): + return (self.execute("CREATE DATABASE `%s`" % db_name, db=self._db_con_autocommit) or False) and True + + def _get_sql_field_type(self, field_type): + return SQL_FIELD_TYPES.get(field_type, "UNKNOW") + + def _create_table(self, db_name, table_name, fields): + sql = "CREATE TABLE `%s`.`%s` (" % (db_name, table_name) + + sql_fields = [] + for f in fields: + sql_field = "%s %s" % (f['name'], self._get_sql_field_type(f['type'])) + if f.get('size'): + sql_field += "(%s)" % f['size'] + if f.get('primary'): + sql_field += " PRIMARY KEY" + if 'null' in f and not f['null']: + sql_field += " NOT NULL" + sql_fields.append(sql_field) + sql += ",".join(sql_fields) + + sql += ")" + + return (self.execute(sql, db=self._db_con_autocommit) or False) and True + def _get_sql_field(self, db_name, field): if type(field) is tuple: return "`%s`.`%s`.`%s`" % (db_name, field[0], field[1]) @@ -69,8 +112,27 @@ class Connection(mojo.Connection): def _get_cursor(self, db_name, query): cur = self._db_con.cursor() + cur.execute("USE `%s`" % db_name) cur.execute(self._get_sql_query(db_name, query)) return cur def _next(self, cur): return cur.fetchone() + + def _insert(self, db_name, table_name, values): + keys = [] + vals = [] + for k, v in values.iteritems(): + keys.append(k) + if type(v) in (str, unicode): + vals.append("'%s'" % v) + else: + vals.append(str(v)) + sql = "INSERT INTO `%s`.`%s`(%s) VALUES (%s)" % (db_name, table_name, ",".join(keys), ",".join(vals)) + return self.execute(sql) + + def commit(self): + self._db_con.commit() + + def rollback(self): + self._db_con.rollback() diff --git a/mojo.py b/mojo.py index 16adc44..f3aa744 100644 --- a/mojo.py +++ b/mojo.py @@ -19,6 +19,9 @@ # ############################################################################## +import cPickle +import uuid + class Connection(object): def __init__(self, *args, **kwargs): @@ -37,13 +40,19 @@ class Connection(object): return [] def database_names(self): - return [unicode(x) for x in self._get_databases()] + try: + return [unicode(x) for x in self._get_databases()] + except: + return [] def _get_tables(self, db_name): return [] def collection_names(self, db_name): - return list(set([unicode(x.split('$')[0]) for x in filter(lambda x: '$' in x, self._get_tables(db_name))])) + try: + return list(set([unicode(x.split('$')[0]) for x in filter(lambda x: '$' in x, self._get_tables(db_name))])) + except: + return [] def _count_rows(self, db_name, table_name): return 0 @@ -54,6 +63,13 @@ class Connection(object): except: return 0 + def _create_database(self, db_name): + return None + + def _create_table(self, db_name, table_name, fields): + # [{'name': 'id', 'type': 'char', 'size': 20, 'primary': True}] + return None + def _get_cursor(self, db_name, query): # {'select': [('t1$_id', 'id'), {'select': [('t1$c1', 'value')], 'from': ['t1$c1'], 'where': [(('t1$c1', 'id'), '=', ('t1$_id', 'id'))]}], 'from': ['t1$_id']} return None @@ -61,6 +77,15 @@ class Connection(object): def _next(self, cursor): return None + def _insert(self, db_name, table_name, values): + return None + + def commit(self): + pass + + def rollback(self): + pass + class Database(object): def __init__(self, connection, db_name): @@ -76,6 +101,12 @@ class Database(object): def __repr__(self): return "Database(%r, %r)" % (self.connection, self.db_name) + def _create_database(self): + return self.connection._create_database(self.db_name) + + def exists(self): + return (self.db_name in self.connection.database_names()) + def collection_names(self): return self.connection.collection_names(self.db_name) @@ -88,6 +119,23 @@ class Collection(object): def __repr__(self): return "Collection(%r, %r)" % (self.database, self.table_name) + def exists(self): + return (self.database.exists() and self.table_name in self.database.collection_names()) + + def _create_table(self): + fields = [ + {'name': 'id', 'type': 'char', 'size': 32, 'primary': True}, + ] + return self.database.connection._create_table(self.database.db_name, '%s$_id' % self.table_name, fields) + + def _create_field(self, field_name): + fields = [ + {'name': 'id', 'type': 'char', 'size': 32, 'primary': True}, + {'name': 'value', 'type': 'text', 'null': False}, + {'name': 'number', 'type': 'float'}, + ] + return self.database.connection._create_table(self.database.db_name, '%s$%s' % (self.table_name, field_name), fields) + def _get_fields(self): tables = self.database.connection._get_tables(self.database.db_name) return [unicode(x[x.find('$')+1:]) for x in filter(lambda x: x.startswith('%s$' % self.table_name), tables)] @@ -98,6 +146,44 @@ class Collection(object): def find(self, *args, **kwargs): return Cursor(self, *args, **kwargs) + def insert(self, doc_or_docs): + if not self.database.db_name in self.database.connection.database_names(): + self.database._create_database() + if not self.table_name in self.database.collection_names(): + self._create_table() + + if not type(doc_or_docs) in (list, tuple): + docs = [doc_or_docs] + else: + docs = doc_or_docs + for doc in docs: + if not '_id' in doc: + doc['_id'] = uuid.uuid4().hex + self._insert_document(doc) + + if type(doc_or_docs) in (list, tuple): + return [d['_id'] for d in docs] + else: + return docs[0]['_id'] + + def _insert_document(self, doc): + table_id = '%s$_id' % self.table_name + fields = self._get_fields() + self.database.connection._insert(self.database.db_name, table_id, {'id': doc['_id']}) + for f in doc: + if f == '_id': + continue + if not f in fields: + self._create_field(f) + table_f = '%s$%s' % (self.table_name, f) + values = { + 'id': doc['_id'], + 'value': cPickle.dumps(doc[f]), + } + if type(doc[f]) in (int, float): + values['number'] = doc[f] + self.database.connection._insert(self.database.db_name, table_f, values) + class Cursor(object): def __init__(self, collection, spec=None, fields=None, **kwargs): @@ -106,8 +192,12 @@ class Cursor(object): self.collection = collection self.spec = spec - self.fields = self._get_fields(fields) - self.cursor = self._get_cursor() + if self.collection.exists(): + self.fields = self._get_fields(fields) + self.cursor = self._get_cursor() + else: + self.fields = None + self.cursor = None def __iter__(self): return self @@ -176,6 +266,9 @@ class Cursor(object): } def next(self): + if self.cursor is None: + raise StopIteration + if self.cursor: res = self.collection.database.connection._next(self.cursor) if res is None: @@ -187,7 +280,7 @@ class Cursor(object): fields_without_id = filter(lambda x: x != '_id', self.fields) for i in xrange(len(fields_without_id)): if not res[i + 1] is None: - document[fields_without_id[i]] = res[i + 1] + document[fields_without_id[i]] = cPickle.loads(res[i + 1]) return document else: return None