]> git.jsancho.org Git - mojodb.git/commitdiff
Basic insertion of documents, using cPickle for fields codification
authorJavier Sancho <jsf@jsancho.org>
Tue, 28 Jan 2014 15:14:19 +0000 (16:14 +0100)
committerJavier Sancho <jsf@jsancho.org>
Tue, 28 Jan 2014 15:14:19 +0000 (16:14 +0100)
MySQL.py
mojo.py

index a3046e6456a5225abd865f52f897c215d4cce8a8..39c34354cae256b496811befd649e0042e04f90c 100644 (file)
--- a/MySQL.py
+++ b/MySQL.py
 import mojo
 import MySQLdb
 
+SQL_FIELD_TYPES = {
+    'char': 'VARCHAR',
+    'text': 'LONGTEXT',
+    'float': 'DOUBLE',
+    }
+
 class Connection(mojo.Connection):
     def __init__(self, *args, **kwargs):
         self._db_con = MySQLdb.connect(*args, **kwargs)
+        self._db_con_autocommit = MySQLdb.connect(*args, **kwargs)
 
-    def query(self, sql):
-        cur = self._db_con.cursor()
+    def query(self, sql, db=None):
+        if db is None:
+            db = self._db_con
+        cur = db.cursor()
         cur.execute(sql)
         res = cur.fetchall()
         cur.close()
         cur = None
         return res
 
+    def execute(self, sql, db=None):
+        if db is None:
+            db = self._db_con
+        cur = db.cursor()
+        res = cur.execute(sql)
+        cur.close()
+        cur = None
+        return res
+
     def _get_databases(self):
         return [x[0] for x in self.query("SHOW DATABASES")]
 
@@ -43,6 +61,31 @@ class Connection(mojo.Connection):
     def _count_rows(self, db_name, table_name):
         return self.query("SELECT COUNT(*) FROM `%s`.`%s`" % (db_name, table_name))[0][0]
 
+    def _create_database(self, db_name):
+        return (self.execute("CREATE DATABASE `%s`" % db_name, db=self._db_con_autocommit) or False) and True
+
+    def _get_sql_field_type(self, field_type):
+        return SQL_FIELD_TYPES.get(field_type, "UNKNOW")
+
+    def _create_table(self, db_name, table_name, fields):
+        sql = "CREATE TABLE `%s`.`%s` (" % (db_name, table_name)
+
+        sql_fields = []
+        for f in fields:
+            sql_field = "%s %s" % (f['name'], self._get_sql_field_type(f['type']))
+            if f.get('size'):
+                sql_field += "(%s)" % f['size']
+            if f.get('primary'):
+                sql_field += " PRIMARY KEY"
+            if 'null' in f and not f['null']:
+                sql_field += " NOT NULL"
+            sql_fields.append(sql_field)
+        sql += ",".join(sql_fields)
+
+        sql += ")"
+
+        return (self.execute(sql, db=self._db_con_autocommit) or False) and True
+
     def _get_sql_field(self, db_name, field):
         if type(field) is tuple:
             return "`%s`.`%s`.`%s`" % (db_name, field[0], field[1])
@@ -69,8 +112,27 @@ class Connection(mojo.Connection):
 
     def _get_cursor(self, db_name, query):
         cur = self._db_con.cursor()
+        cur.execute("USE `%s`" % db_name)
         cur.execute(self._get_sql_query(db_name, query))
         return cur
 
     def _next(self, cur):
         return cur.fetchone()
+
+    def _insert(self, db_name, table_name, values):
+        keys = []
+        vals = []
+        for k, v in values.iteritems():
+            keys.append(k)
+            if type(v) in (str, unicode):
+                vals.append("'%s'" % v)
+            else:
+                vals.append(str(v))
+        sql = "INSERT INTO `%s`.`%s`(%s) VALUES (%s)" % (db_name, table_name, ",".join(keys), ",".join(vals))
+        return self.execute(sql)
+
+    def commit(self):
+        self._db_con.commit()
+
+    def rollback(self):
+        self._db_con.rollback()
diff --git a/mojo.py b/mojo.py
index 16adc442ea6acdea1bc7bfb4976b7fb51ea1a978..f3aa74456056a7d0c84eb8cb076bb502dc05e444 100644 (file)
--- a/mojo.py
+++ b/mojo.py
@@ -19,6 +19,9 @@
 #
 ##############################################################################
 
+import cPickle
+import uuid
+
 
 class Connection(object):
     def __init__(self, *args, **kwargs):
@@ -37,13 +40,19 @@ class Connection(object):
         return []
 
     def database_names(self):
-        return [unicode(x) for x in self._get_databases()]
+        try:
+            return [unicode(x) for x in self._get_databases()]
+        except:
+            return []
 
     def _get_tables(self, db_name):
         return []
 
     def collection_names(self, db_name):
-        return list(set([unicode(x.split('$')[0]) for x in filter(lambda x: '$' in x, self._get_tables(db_name))]))
+        try:
+            return list(set([unicode(x.split('$')[0]) for x in filter(lambda x: '$' in x, self._get_tables(db_name))]))
+        except:
+            return []
 
     def _count_rows(self, db_name, table_name):
         return 0
@@ -54,6 +63,13 @@ class Connection(object):
         except:
             return 0
 
+    def _create_database(self, db_name):
+        return None
+
+    def _create_table(self, db_name, table_name, fields):
+        # [{'name': 'id', 'type': 'char', 'size': 20, 'primary': True}]
+        return None
+
     def _get_cursor(self, db_name, query):
         # {'select': [('t1$_id', 'id'), {'select': [('t1$c1', 'value')], 'from': ['t1$c1'], 'where': [(('t1$c1', 'id'), '=', ('t1$_id', 'id'))]}], 'from': ['t1$_id']}
         return None
@@ -61,6 +77,15 @@ class Connection(object):
     def _next(self, cursor):
         return None
 
+    def _insert(self, db_name, table_name, values):
+        return None
+
+    def commit(self):
+        pass
+
+    def rollback(self):
+        pass
+
 
 class Database(object):
     def __init__(self, connection, db_name):
@@ -76,6 +101,12 @@ class Database(object):
     def __repr__(self):
         return "Database(%r, %r)" % (self.connection, self.db_name)
 
+    def _create_database(self):
+        return self.connection._create_database(self.db_name)
+
+    def exists(self):
+        return (self.db_name in self.connection.database_names())
+
     def collection_names(self):
         return self.connection.collection_names(self.db_name)
 
@@ -88,6 +119,23 @@ class Collection(object):
     def __repr__(self):
         return "Collection(%r, %r)" % (self.database, self.table_name)
 
+    def exists(self):
+        return (self.database.exists() and self.table_name in self.database.collection_names())
+
+    def _create_table(self):
+        fields = [
+            {'name': 'id', 'type': 'char', 'size': 32, 'primary': True},
+            ]
+        return self.database.connection._create_table(self.database.db_name, '%s$_id' % self.table_name, fields)
+
+    def _create_field(self, field_name):
+        fields = [
+            {'name': 'id', 'type': 'char', 'size': 32, 'primary': True},
+            {'name': 'value', 'type': 'text', 'null': False},
+            {'name': 'number', 'type': 'float'},
+            ]
+        return self.database.connection._create_table(self.database.db_name, '%s$%s' % (self.table_name, field_name), fields)
+
     def _get_fields(self):
         tables = self.database.connection._get_tables(self.database.db_name)
         return [unicode(x[x.find('$')+1:]) for x in filter(lambda x: x.startswith('%s$' % self.table_name), tables)]
@@ -98,6 +146,44 @@ class Collection(object):
     def find(self, *args, **kwargs):
         return Cursor(self, *args, **kwargs)
 
+    def insert(self, doc_or_docs):
+        if not self.database.db_name in self.database.connection.database_names():
+            self.database._create_database()
+        if not self.table_name in self.database.collection_names():
+            self._create_table()
+
+        if not type(doc_or_docs) in (list, tuple):
+            docs = [doc_or_docs]
+        else:
+            docs = doc_or_docs
+        for doc in docs:
+            if not '_id' in doc:
+                doc['_id'] = uuid.uuid4().hex
+            self._insert_document(doc)
+
+        if type(doc_or_docs) in (list, tuple):
+            return [d['_id'] for d in docs]
+        else:
+            return docs[0]['_id']
+
+    def _insert_document(self, doc):
+        table_id = '%s$_id' % self.table_name
+        fields = self._get_fields()
+        self.database.connection._insert(self.database.db_name, table_id, {'id': doc['_id']})
+        for f in doc:
+            if f == '_id':
+                continue
+            if not f in fields:
+                self._create_field(f)
+            table_f = '%s$%s' % (self.table_name, f)
+            values = {
+                'id': doc['_id'],
+                'value': cPickle.dumps(doc[f]),
+                }
+            if type(doc[f]) in (int, float):
+                values['number'] = doc[f]
+            self.database.connection._insert(self.database.db_name, table_f, values)
+
 
 class Cursor(object):
     def __init__(self, collection, spec=None, fields=None, **kwargs):
@@ -106,8 +192,12 @@ class Cursor(object):
 
         self.collection = collection
         self.spec = spec
-        self.fields = self._get_fields(fields)
-        self.cursor = self._get_cursor()
+        if self.collection.exists():
+            self.fields = self._get_fields(fields)
+            self.cursor = self._get_cursor()
+        else:
+            self.fields = None
+            self.cursor = None
 
     def __iter__(self):
         return self
@@ -176,6 +266,9 @@ class Cursor(object):
             }
 
     def next(self):
+        if self.cursor is None:
+            raise StopIteration
+
         if self.cursor:
             res = self.collection.database.connection._next(self.cursor)
             if res is None:
@@ -187,7 +280,7 @@ class Cursor(object):
                 fields_without_id = filter(lambda x: x != '_id', self.fields)
                 for i in xrange(len(fields_without_id)):
                     if not res[i + 1] is None:
-                        document[fields_without_id[i]] = res[i + 1]
+                        document[fields_without_id[i]] = cPickle.loads(res[i + 1])
                 return document
         else:
             return None