]> git.jsancho.org Git - mojodb.git/blob - MySQL.py
e983a670fa957f2d082f21b6c41a729512a05716
[mojodb.git] / MySQL.py
1 # -*- coding: utf-8 -*-
2 ##############################################################################
3 #
4 #    mojo, a Python library for implementing document based databases
5 #    Copyright (C) 2013-2014 by Javier Sancho Fernandez <jsf at jsancho dot org>
6 #
7 #    This program is free software: you can redistribute it and/or modify
8 #    it under the terms of the GNU General Public License as published by
9 #    the Free Software Foundation, either version 3 of the License, or
10 #    (at your option) any later version.
11 #
12 #    This program is distributed in the hope that it will be useful,
13 #    but WITHOUT ANY WARRANTY; without even the implied warranty of
14 #    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 #    GNU General Public License for more details.
16 #
17 #    You should have received a copy of the GNU General Public License
18 #    along with this program.  If not, see <http://www.gnu.org/licenses/>.
19 #
20 ##############################################################################
21
22 import dbutils
23 import connection
24 import MySQLdb
25
26 SQL_FIELD_TYPES = {
27     'char': 'VARCHAR',
28     'text': 'LONGTEXT',
29     'float': 'DOUBLE',
30     }
31
32
33 class Query(dbutils.Query):
34     def sql(self):
35         res = "SELECT "
36         res += ",".join(["(%s)" % f.sql() for f in self.fields])
37
38         res += " FROM "
39         res += ",".join([t.sql() for t in self.tables])
40
41         if self.constraints:
42             res += " WHERE "
43             res += " AND ".join(["(%s)" % c.sql() for c in self.constraints])
44
45         return res
46
47
48 class Field(dbutils.Field):
49     def sql(self):
50         return "%s.`%s`" % (self.table.sql(), self.field_name)
51
52
53 class Table(dbutils.Table):
54     def sql(self):
55         return "`%s`.`%s`" % (self.db_name, self.table_name)
56
57
58 class Constraint(dbutils.Constraint):
59     def sql(self):
60         operator = self.operator.strip().lower()
61         if operator == "starts":
62             return "(%s) LIKE (%s)" % (self.args[0].sql(), self.args[1].sql()[:-1] + "%'")
63         elif operator == "in":
64             return "(%s) IN (%s)" % (self.args[0].sql(), ",".join(["(%s)" % a.sql() for a in self.args[1:]]))
65         elif operator == "=":
66             return "(%s) = (%s)" % (self.args[0].sql(), self.args[1].sql())
67         else:
68             token = " %s " % operator.upper()
69             return token.join(["(%s)" % a.sql() for a in self.args])
70
71
72 class Literal(dbutils.Literal):
73     def sql(self):
74         if type(self.value) in (int, float):
75             return "%s" % self.value
76         else:
77             return "'%s'" % str(self.value).replace("'", "''")
78
79
80 class Connection(connection.Connection):
81     Query = Query
82     Field = Field
83     Table = Table
84     Constraint = Constraint
85     Literal = Literal
86
87     def __init__(self, *args, **kwargs):
88         self._db_con = MySQLdb.connect(*args, **kwargs)
89         self._db_con_autocommit = MySQLdb.connect(*args, **kwargs)
90
91     def query(self, sql, db=None):
92         if db is None:
93             db = self._db_con
94         cur = db.cursor()
95         cur.execute(sql)
96         res = cur.fetchall()
97         cur.close()
98         cur = None
99         return res
100
101     def execute(self, sql, db=None):
102         if db is None:
103             db = self._db_con
104         cur = db.cursor()
105         res = cur.execute(sql)
106         cur.close()
107         cur = None
108         return res
109
110     def _get_databases(self):
111         return [x[0] for x in self.query("SHOW DATABASES")]
112
113     def _get_tables(self, db_name):
114         return [x[0] for x in self.query("SHOW TABLES FROM `%s`" % db_name)]
115
116     def _count_rows(self, db_name, table_name):
117         return self.query("SELECT COUNT(*) FROM `%s`.`%s`" % (db_name, table_name))[0][0]
118
119     def _create_database(self, db_name):
120         return (self.execute("CREATE DATABASE `%s`" % db_name, db=self._db_con_autocommit) or False) and True
121
122     def _get_sql_field_type(self, field_type):
123         return SQL_FIELD_TYPES.get(field_type, "UNKNOW")
124
125     def _create_table(self, db_name, table_name, fields):
126         primary = []
127         sql = "CREATE TABLE `%s`.`%s` (" % (db_name, table_name)
128
129         sql_fields = []
130         for f in fields:
131             sql_field = "%s %s" % (f['name'], self._get_sql_field_type(f['type']))
132             if f.get('size'):
133                 sql_field += "(%s)" % f['size']
134             if f.get('primary'):
135                 primary.append(f['name'])
136             if 'null' in f and not f['null']:
137                 sql_field += " NOT NULL"
138             sql_fields.append(sql_field)
139         sql += ",".join(sql_fields)
140
141         if primary:
142             sql += ", PRIMARY KEY(%s)" % ",".join(primary)
143
144         sql += ")"
145
146         return (self.execute(sql, db=self._db_con_autocommit) or False) and True
147
148     def _get_cursor(self, query):
149         cur = self._db_con.cursor()
150         cur.execute(query.sql())
151         return cur
152
153     def _next(self, cur):
154         return cur.fetchone()
155
156     def _insert(self, db_name, table_name, values):
157         keys = []
158         vals = []
159         for k, v in values.iteritems():
160             keys.append(k)
161             if type(v) is str:
162                 vals.append("'%s'" % v.replace("'", "''"))
163             else:
164                 vals.append(str(v))
165         sql = "INSERT INTO `%s`.`%s`(%s) VALUES (%s)" % (db_name, table_name, ",".join(keys), ",".join(vals))
166         return self.execute(sql)
167
168     def commit(self):
169         self._db_con.commit()
170
171     def rollback(self):
172         self._db_con.rollback()