]> git.jsancho.org Git - mojodb.git/blob - MySQL.py
Custom serializer in connection object; default is msgpack
[mojodb.git] / MySQL.py
1 # -*- coding: utf-8 -*-
2 ##############################################################################
3 #
4 #    mojo, a Python library for implementing document based databases
5 #    Copyright (C) 2013-2014 by Javier Sancho Fernandez <jsf at jsancho dot org>
6 #
7 #    This program is free software: you can redistribute it and/or modify
8 #    it under the terms of the GNU General Public License as published by
9 #    the Free Software Foundation, either version 3 of the License, or
10 #    (at your option) any later version.
11 #
12 #    This program is distributed in the hope that it will be useful,
13 #    but WITHOUT ANY WARRANTY; without even the implied warranty of
14 #    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 #    GNU General Public License for more details.
16 #
17 #    You should have received a copy of the GNU General Public License
18 #    along with this program.  If not, see <http://www.gnu.org/licenses/>.
19 #
20 ##############################################################################
21
22 import dbutils
23 import connection
24 import MySQLdb
25
26 SQL_FIELD_TYPES = {
27     'char': 'VARCHAR',
28     'text': 'LONGTEXT',
29     'float': 'DOUBLE',
30     }
31
32
33 class Query(dbutils.Query):
34     def sql(self):
35         res = "SELECT "
36         res += ",".join(["(%s)" % f.sql() for f in self.fields])
37
38         res += " FROM "
39         res += ",".join([t.sql() for t in self.tables])
40
41         if self.constraints:
42             res += " WHERE "
43             res += " AND ".join(["(%s)" % c.sql() for c in self.constraints])
44
45         return res
46
47
48 class Field(dbutils.Field):
49     def sql(self):
50         return "%s.`%s`" % (self.table.sql(), self.field_name)
51
52
53 class Table(dbutils.Table):
54     def sql(self):
55         return "`%s`.`%s`" % (self.db_name, self.table_name)
56
57
58 class Constraint(dbutils.Constraint):
59     def sql(self):
60         operator = self.operator.strip().lower()
61         if operator == "starts":
62             return "(%s) LIKE (%s)" % (self.args[0].sql(), self.args[1].sql()[:-1] + "%'")
63         elif operator == "in":
64             return "(%s) IN (%s)" % (self.args[0].sql(), ",".join(["(%s)" % a.sql() for a in self.args[1:]]))
65         elif operator == "=":
66             return "(%s) = (%s)" % (self.args[0].sql(), self.args[1].sql())
67         else:
68             token = " %s " % operator.upper()
69             return token.join(["(%s)" % a.sql() for a in self.args])
70
71
72 class Literal(dbutils.Literal):
73     def sql(self):
74         if type(self.value) in (int, float):
75             return "%s" % self.value
76         else:
77             return "'%s'" % str(self.value).replace("'", "''")
78
79
80 class Connection(connection.Connection):
81     Query = Query
82     Field = Field
83     Table = Table
84     Constraint = Constraint
85     Literal = Literal
86
87     def __init__(self, host="localhost", user=None, passwd=None, *args, **kwargs):
88         self._db_con = MySQLdb.connect(host=host, user=user, passwd=passwd)
89         self._db_con_autocommit = MySQLdb.connect(host=host, user=user, passwd=passwd)
90         super(Connection, self).__init__(*args, **kwargs)
91
92     def query(self, sql, db=None):
93         if db is None:
94             db = self._db_con
95         cur = db.cursor()
96         cur.execute(sql)
97         res = cur.fetchall()
98         cur.close()
99         cur = None
100         return res
101
102     def execute(self, sql, db=None):
103         if db is None:
104             db = self._db_con
105         cur = db.cursor()
106         res = cur.execute(sql)
107         cur.close()
108         cur = None
109         return res
110
111     def _get_databases(self):
112         return [x[0] for x in self.query("SHOW DATABASES")]
113
114     def _get_tables(self, db_name):
115         return [x[0] for x in self.query("SHOW TABLES FROM `%s`" % db_name)]
116
117     def _count_rows(self, db_name, table_name):
118         return self.query("SELECT COUNT(*) FROM `%s`.`%s`" % (db_name, table_name))[0][0]
119
120     def _create_database(self, db_name):
121         return (self.execute("CREATE DATABASE `%s`" % db_name, db=self._db_con_autocommit) or False) and True
122
123     def _get_sql_field_type(self, field_type):
124         return SQL_FIELD_TYPES.get(field_type, "UNKNOW")
125
126     def _create_table(self, db_name, table_name, fields):
127         primary = []
128         sql = "CREATE TABLE `%s`.`%s` (" % (db_name, table_name)
129
130         sql_fields = []
131         for f in fields:
132             sql_field = "%s %s" % (f['name'], self._get_sql_field_type(f['type']))
133             if f.get('size'):
134                 sql_field += "(%s)" % f['size']
135             if f.get('primary'):
136                 primary.append(f['name'])
137             if 'null' in f and not f['null']:
138                 sql_field += " NOT NULL"
139             sql_fields.append(sql_field)
140         sql += ",".join(sql_fields)
141
142         if primary:
143             sql += ", PRIMARY KEY(%s)" % ",".join(primary)
144
145         sql += ")"
146
147         return (self.execute(sql, db=self._db_con_autocommit) or False) and True
148
149     def _get_cursor(self, query):
150         cur = self._db_con.cursor()
151         cur.execute(query.sql())
152         return cur
153
154     def _next(self, cur):
155         return cur.fetchone()
156
157     def _insert(self, db_name, table_name, values):
158         keys = []
159         vals = []
160         for k, v in values.iteritems():
161             keys.append(k)
162             if type(v) is str:
163                 vals.append("'%s'" % v.replace("'", "''"))
164             else:
165                 vals.append(str(v))
166         sql = "INSERT INTO `%s`.`%s`(%s) VALUES (%s)" % (db_name, table_name, ",".join(keys), ",".join(vals))
167         return self.execute(sql)
168
169     def commit(self):
170         self._db_con.commit()
171
172     def rollback(self):
173         self._db_con.rollback()