#
##############################################################################
-import mojo
+import connection
import MySQLdb
-class Connection(mojo.Connection):
+SQL_FIELD_TYPES = {
+ 'char': 'VARCHAR',
+ 'text': 'LONGTEXT',
+ 'float': 'DOUBLE',
+ }
+
+class Connection(connection.Connection):
def __init__(self, *args, **kwargs):
self._db_con = MySQLdb.connect(*args, **kwargs)
+ self._db_con_autocommit = MySQLdb.connect(*args, **kwargs)
- def query(self, sql):
- cur = self._db_con.cursor()
+ def query(self, sql, db=None):
+ if db is None:
+ db = self._db_con
+ cur = db.cursor()
cur.execute(sql)
res = cur.fetchall()
cur.close()
cur = None
return res
+ def execute(self, sql, db=None):
+ if db is None:
+ db = self._db_con
+ cur = db.cursor()
+ res = cur.execute(sql)
+ cur.close()
+ cur = None
+ return res
+
def _get_databases(self):
return [x[0] for x in self.query("SHOW DATABASES")]
def _count_rows(self, db_name, table_name):
return self.query("SELECT COUNT(*) FROM `%s`.`%s`" % (db_name, table_name))[0][0]
+ def _create_database(self, db_name):
+ return (self.execute("CREATE DATABASE `%s`" % db_name, db=self._db_con_autocommit) or False) and True
+
+ def _get_sql_field_type(self, field_type):
+ return SQL_FIELD_TYPES.get(field_type, "UNKNOW")
+
+ def _create_table(self, db_name, table_name, fields):
+ sql = "CREATE TABLE `%s`.`%s` (" % (db_name, table_name)
+
+ sql_fields = []
+ for f in fields:
+ sql_field = "%s %s" % (f['name'], self._get_sql_field_type(f['type']))
+ if f.get('size'):
+ sql_field += "(%s)" % f['size']
+ if f.get('primary'):
+ sql_field += " PRIMARY KEY"
+ if 'null' in f and not f['null']:
+ sql_field += " NOT NULL"
+ sql_fields.append(sql_field)
+ sql += ",".join(sql_fields)
+
+ sql += ")"
+
+ return (self.execute(sql, db=self._db_con_autocommit) or False) and True
+
def _get_sql_field(self, db_name, field):
if type(field) is tuple:
return "`%s`.`%s`.`%s`" % (db_name, field[0], field[1])
elif type(field) is dict:
return "(%s)" % self._get_sql_query(db_name, field)
else:
- return str(field)
+ return "'%s'" % str(field).replace("'", "''")
def _get_sql_query(self, db_name, query):
sql = "SELECT "
def _get_cursor(self, db_name, query):
cur = self._db_con.cursor()
+ cur.execute("USE `%s`" % db_name)
cur.execute(self._get_sql_query(db_name, query))
return cur
def _next(self, cur):
return cur.fetchone()
+
+ def _insert(self, db_name, table_name, values):
+ keys = []
+ vals = []
+ for k, v in values.iteritems():
+ keys.append(k)
+ if type(v) is str:
+ vals.append("'%s'" % v.replace("'", "''"))
+ else:
+ vals.append(str(v))
+ sql = "INSERT INTO `%s`.`%s`(%s) VALUES (%s)" % (db_name, table_name, ",".join(keys), ",".join(vals))
+ return self.execute(sql)
+
+ def commit(self):
+ self._db_con.commit()
+
+ def rollback(self):
+ self._db_con.rollback()