]> git.jsancho.org Git - mojodb.git/blobdiff - MySQL.py
msgpack instead cPickle (for multiple platforms) and str instead unicode (thinking...
[mojodb.git] / MySQL.py
index 76f7312ef1dc8e74db64e6ebdc8f86e84751e6cc..f5f888828b83109b43b41d7b8d109ee066b5742b 100644 (file)
--- a/MySQL.py
+++ b/MySQL.py
 #
 ##############################################################################
 
-import mojo
+import connection
 import MySQLdb
 
-class Connection(mojo.Connection):
+SQL_FIELD_TYPES = {
+    'char': 'VARCHAR',
+    'text': 'LONGTEXT',
+    'float': 'DOUBLE',
+    }
+
+class Connection(connection.Connection):
     def __init__(self, *args, **kwargs):
         self._db_con = MySQLdb.connect(*args, **kwargs)
+        self._db_con_autocommit = MySQLdb.connect(*args, **kwargs)
 
-    def query(self, sql):
-        cur = self._db_con.cursor()
+    def query(self, sql, db=None):
+        if db is None:
+            db = self._db_con
+        cur = db.cursor()
         cur.execute(sql)
         res = cur.fetchall()
         cur.close()
         cur = None
         return res
 
+    def execute(self, sql, db=None):
+        if db is None:
+            db = self._db_con
+        cur = db.cursor()
+        res = cur.execute(sql)
+        cur.close()
+        cur = None
+        return res
+
     def _get_databases(self):
         return [x[0] for x in self.query("SHOW DATABASES")]
 
@@ -43,13 +61,38 @@ class Connection(mojo.Connection):
     def _count_rows(self, db_name, table_name):
         return self.query("SELECT COUNT(*) FROM `%s`.`%s`" % (db_name, table_name))[0][0]
 
+    def _create_database(self, db_name):
+        return (self.execute("CREATE DATABASE `%s`" % db_name, db=self._db_con_autocommit) or False) and True
+
+    def _get_sql_field_type(self, field_type):
+        return SQL_FIELD_TYPES.get(field_type, "UNKNOW")
+
+    def _create_table(self, db_name, table_name, fields):
+        sql = "CREATE TABLE `%s`.`%s` (" % (db_name, table_name)
+
+        sql_fields = []
+        for f in fields:
+            sql_field = "%s %s" % (f['name'], self._get_sql_field_type(f['type']))
+            if f.get('size'):
+                sql_field += "(%s)" % f['size']
+            if f.get('primary'):
+                sql_field += " PRIMARY KEY"
+            if 'null' in f and not f['null']:
+                sql_field += " NOT NULL"
+            sql_fields.append(sql_field)
+        sql += ",".join(sql_fields)
+
+        sql += ")"
+
+        return (self.execute(sql, db=self._db_con_autocommit) or False) and True
+
     def _get_sql_field(self, db_name, field):
         if type(field) is tuple:
             return "`%s`.`%s`.`%s`" % (db_name, field[0], field[1])
         elif type(field) is dict:
             return "(%s)" % self._get_sql_query(db_name, field)
         else:
-            return str(field)
+            return "'%s'" % str(field).replace("'", "''")
         
     def _get_sql_query(self, db_name, query):
         sql = "SELECT "
@@ -69,8 +112,27 @@ class Connection(mojo.Connection):
 
     def _get_cursor(self, db_name, query):
         cur = self._db_con.cursor()
+        cur.execute("USE `%s`" % db_name)
         cur.execute(self._get_sql_query(db_name, query))
         return cur
 
     def _next(self, cur):
         return cur.fetchone()
+
+    def _insert(self, db_name, table_name, values):
+        keys = []
+        vals = []
+        for k, v in values.iteritems():
+            keys.append(k)
+            if type(v) is str:
+                vals.append("'%s'" % v.replace("'", "''"))
+            else:
+                vals.append(str(v))
+        sql = "INSERT INTO `%s`.`%s`(%s) VALUES (%s)" % (db_name, table_name, ",".join(keys), ",".join(vals))
+        return self.execute(sql)
+
+    def commit(self):
+        self._db_con.commit()
+
+    def rollback(self):
+        self._db_con.rollback()