]> git.jsancho.org Git - mojodb.git/blobdiff - MySQL.py
New database scheme, storing key name inside tables for improving searching
[mojodb.git] / MySQL.py
index f5f888828b83109b43b41d7b8d109ee066b5742b..e983a670fa957f2d082f21b6c41a729512a05716 100644 (file)
--- a/MySQL.py
+++ b/MySQL.py
@@ -19,6 +19,7 @@
 #
 ##############################################################################
 
+import dbutils
 import connection
 import MySQLdb
 
@@ -28,7 +29,61 @@ SQL_FIELD_TYPES = {
     'float': 'DOUBLE',
     }
 
+
+class Query(dbutils.Query):
+    def sql(self):
+        res = "SELECT "
+        res += ",".join(["(%s)" % f.sql() for f in self.fields])
+
+        res += " FROM "
+        res += ",".join([t.sql() for t in self.tables])
+
+        if self.constraints:
+            res += " WHERE "
+            res += " AND ".join(["(%s)" % c.sql() for c in self.constraints])
+
+        return res
+
+
+class Field(dbutils.Field):
+    def sql(self):
+        return "%s.`%s`" % (self.table.sql(), self.field_name)
+
+
+class Table(dbutils.Table):
+    def sql(self):
+        return "`%s`.`%s`" % (self.db_name, self.table_name)
+
+
+class Constraint(dbutils.Constraint):
+    def sql(self):
+        operator = self.operator.strip().lower()
+        if operator == "starts":
+            return "(%s) LIKE (%s)" % (self.args[0].sql(), self.args[1].sql()[:-1] + "%'")
+        elif operator == "in":
+            return "(%s) IN (%s)" % (self.args[0].sql(), ",".join(["(%s)" % a.sql() for a in self.args[1:]]))
+        elif operator == "=":
+            return "(%s) = (%s)" % (self.args[0].sql(), self.args[1].sql())
+        else:
+            token = " %s " % operator.upper()
+            return token.join(["(%s)" % a.sql() for a in self.args])
+
+
+class Literal(dbutils.Literal):
+    def sql(self):
+        if type(self.value) in (int, float):
+            return "%s" % self.value
+        else:
+            return "'%s'" % str(self.value).replace("'", "''")
+
+
 class Connection(connection.Connection):
+    Query = Query
+    Field = Field
+    Table = Table
+    Constraint = Constraint
+    Literal = Literal
+
     def __init__(self, *args, **kwargs):
         self._db_con = MySQLdb.connect(*args, **kwargs)
         self._db_con_autocommit = MySQLdb.connect(*args, **kwargs)
@@ -68,6 +123,7 @@ class Connection(connection.Connection):
         return SQL_FIELD_TYPES.get(field_type, "UNKNOW")
 
     def _create_table(self, db_name, table_name, fields):
+        primary = []
         sql = "CREATE TABLE `%s`.`%s` (" % (db_name, table_name)
 
         sql_fields = []
@@ -76,44 +132,22 @@ class Connection(connection.Connection):
             if f.get('size'):
                 sql_field += "(%s)" % f['size']
             if f.get('primary'):
-                sql_field += " PRIMARY KEY"
+                primary.append(f['name'])
             if 'null' in f and not f['null']:
                 sql_field += " NOT NULL"
             sql_fields.append(sql_field)
         sql += ",".join(sql_fields)
 
+        if primary:
+            sql += ", PRIMARY KEY(%s)" % ",".join(primary)
+
         sql += ")"
 
         return (self.execute(sql, db=self._db_con_autocommit) or False) and True
 
-    def _get_sql_field(self, db_name, field):
-        if type(field) is tuple:
-            return "`%s`.`%s`.`%s`" % (db_name, field[0], field[1])
-        elif type(field) is dict:
-            return "(%s)" % self._get_sql_query(db_name, field)
-        else:
-            return "'%s'" % str(field).replace("'", "''")
-        
-    def _get_sql_query(self, db_name, query):
-        sql = "SELECT "
-        sql += ",".join([self._get_sql_field(db_name, x) for x in query['select']])
-
-        sql += " FROM "
-        sql += ",".join(query['from'])
-
-        if query.get('where'):
-            sql += " WHERE "
-            where = []
-            for cond in query['where']:
-                where.append("%s %s %s" % (self._get_sql_field(db_name, cond[0]), cond[1], self._get_sql_field(db_name, cond[2])))
-            sql += " AND ".join(where)
-
-        return sql
-
-    def _get_cursor(self, db_name, query):
+    def _get_cursor(self, query):
         cur = self._db_con.cursor()
-        cur.execute("USE `%s`" % db_name)
-        cur.execute(self._get_sql_query(db_name, query))
+        cur.execute(query.sql())
         return cur
 
     def _next(self, cur):