X-Git-Url: https://git.jsancho.org/?p=mojodb.git;a=blobdiff_plain;f=MySQL.py;fp=MySQL.py;h=39c34354cae256b496811befd649e0042e04f90c;hp=a3046e6456a5225abd865f52f897c215d4cce8a8;hb=f63673b9ed08886c9ac1582f94f0e63103985497;hpb=78dffed58cc217f429a98f16cec881636f2cc444 diff --git a/MySQL.py b/MySQL.py index a3046e6..39c3435 100644 --- a/MySQL.py +++ b/MySQL.py @@ -22,18 +22,36 @@ import mojo import MySQLdb +SQL_FIELD_TYPES = { + 'char': 'VARCHAR', + 'text': 'LONGTEXT', + 'float': 'DOUBLE', + } + class Connection(mojo.Connection): def __init__(self, *args, **kwargs): self._db_con = MySQLdb.connect(*args, **kwargs) + self._db_con_autocommit = MySQLdb.connect(*args, **kwargs) - def query(self, sql): - cur = self._db_con.cursor() + def query(self, sql, db=None): + if db is None: + db = self._db_con + cur = db.cursor() cur.execute(sql) res = cur.fetchall() cur.close() cur = None return res + def execute(self, sql, db=None): + if db is None: + db = self._db_con + cur = db.cursor() + res = cur.execute(sql) + cur.close() + cur = None + return res + def _get_databases(self): return [x[0] for x in self.query("SHOW DATABASES")] @@ -43,6 +61,31 @@ class Connection(mojo.Connection): def _count_rows(self, db_name, table_name): return self.query("SELECT COUNT(*) FROM `%s`.`%s`" % (db_name, table_name))[0][0] + def _create_database(self, db_name): + return (self.execute("CREATE DATABASE `%s`" % db_name, db=self._db_con_autocommit) or False) and True + + def _get_sql_field_type(self, field_type): + return SQL_FIELD_TYPES.get(field_type, "UNKNOW") + + def _create_table(self, db_name, table_name, fields): + sql = "CREATE TABLE `%s`.`%s` (" % (db_name, table_name) + + sql_fields = [] + for f in fields: + sql_field = "%s %s" % (f['name'], self._get_sql_field_type(f['type'])) + if f.get('size'): + sql_field += "(%s)" % f['size'] + if f.get('primary'): + sql_field += " PRIMARY KEY" + if 'null' in f and not f['null']: + sql_field += " NOT NULL" + sql_fields.append(sql_field) + sql += ",".join(sql_fields) + + sql += ")" + + return (self.execute(sql, db=self._db_con_autocommit) or False) and True + def _get_sql_field(self, db_name, field): if type(field) is tuple: return "`%s`.`%s`.`%s`" % (db_name, field[0], field[1]) @@ -69,8 +112,27 @@ class Connection(mojo.Connection): def _get_cursor(self, db_name, query): cur = self._db_con.cursor() + cur.execute("USE `%s`" % db_name) cur.execute(self._get_sql_query(db_name, query)) return cur def _next(self, cur): return cur.fetchone() + + def _insert(self, db_name, table_name, values): + keys = [] + vals = [] + for k, v in values.iteritems(): + keys.append(k) + if type(v) in (str, unicode): + vals.append("'%s'" % v) + else: + vals.append(str(v)) + sql = "INSERT INTO `%s`.`%s`(%s) VALUES (%s)" % (db_name, table_name, ",".join(keys), ",".join(vals)) + return self.execute(sql) + + def commit(self): + self._db_con.commit() + + def rollback(self): + self._db_con.rollback()