d4604d3cf23dfa59866cca9c97d9f9063071d32c
[mojodb.git] / MySQL.py
1 # -*- coding: utf-8 -*-
2 ##############################################################################
3 #
4 #    mojo, a Python library for implementing document based databases
5 #    Copyright (C) 2013-2014 by Javier Sancho Fernandez <jsf at jsancho dot org>
6 #
7 #    This program is free software: you can redistribute it and/or modify
8 #    it under the terms of the GNU General Public License as published by
9 #    the Free Software Foundation, either version 3 of the License, or
10 #    (at your option) any later version.
11 #
12 #    This program is distributed in the hope that it will be useful,
13 #    but WITHOUT ANY WARRANTY; without even the implied warranty of
14 #    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 #    GNU General Public License for more details.
16 #
17 #    You should have received a copy of the GNU General Public License
18 #    along with this program.  If not, see <http://www.gnu.org/licenses/>.
19 #
20 ##############################################################################
21
22 import connection
23 import MySQLdb
24
25 SQL_FIELD_TYPES = {
26     'char': 'VARCHAR',
27     'text': 'LONGTEXT',
28     'float': 'DOUBLE',
29     }
30
31 class Connection(connection.Connection):
32     def __init__(self, *args, **kwargs):
33         self._db_con = MySQLdb.connect(*args, **kwargs)
34         self._db_con_autocommit = MySQLdb.connect(*args, **kwargs)
35
36     def query(self, sql, db=None):
37         if db is None:
38             db = self._db_con
39         cur = db.cursor()
40         cur.execute(sql)
41         res = cur.fetchall()
42         cur.close()
43         cur = None
44         return res
45
46     def execute(self, sql, db=None):
47         if db is None:
48             db = self._db_con
49         cur = db.cursor()
50         res = cur.execute(sql)
51         cur.close()
52         cur = None
53         return res
54
55     def _get_databases(self):
56         return [x[0] for x in self.query("SHOW DATABASES")]
57
58     def _get_tables(self, db_name):
59         return [x[0] for x in self.query("SHOW TABLES FROM `%s`" % db_name)]
60
61     def _count_rows(self, db_name, table_name):
62         return self.query("SELECT COUNT(*) FROM `%s`.`%s`" % (db_name, table_name))[0][0]
63
64     def _create_database(self, db_name):
65         return (self.execute("CREATE DATABASE `%s`" % db_name, db=self._db_con_autocommit) or False) and True
66
67     def _get_sql_field_type(self, field_type):
68         return SQL_FIELD_TYPES.get(field_type, "UNKNOW")
69
70     def _create_table(self, db_name, table_name, fields):
71         sql = "CREATE TABLE `%s`.`%s` (" % (db_name, table_name)
72
73         sql_fields = []
74         for f in fields:
75             sql_field = "%s %s" % (f['name'], self._get_sql_field_type(f['type']))
76             if f.get('size'):
77                 sql_field += "(%s)" % f['size']
78             if f.get('primary'):
79                 sql_field += " PRIMARY KEY"
80             if 'null' in f and not f['null']:
81                 sql_field += " NOT NULL"
82             sql_fields.append(sql_field)
83         sql += ",".join(sql_fields)
84
85         sql += ")"
86
87         return (self.execute(sql, db=self._db_con_autocommit) or False) and True
88
89     def _get_sql_field(self, db_name, field):
90         if type(field) is tuple:
91             return "`%s`.`%s`.`%s`" % (db_name, field[0], field[1])
92         elif type(field) is dict:
93             return "(%s)" % self._get_sql_query(db_name, field)
94         else:
95             return "'%s'" % str(field)
96         
97     def _get_sql_query(self, db_name, query):
98         sql = "SELECT "
99         sql += ",".join([self._get_sql_field(db_name, x) for x in query['select']])
100
101         sql += " FROM "
102         sql += ",".join(query['from'])
103
104         if query.get('where'):
105             sql += " WHERE "
106             where = []
107             for cond in query['where']:
108                 where.append("%s %s %s" % (self._get_sql_field(db_name, cond[0]), cond[1], self._get_sql_field(db_name, cond[2])))
109             sql += " AND ".join(where)
110
111         return sql
112
113     def _get_cursor(self, db_name, query):
114         cur = self._db_con.cursor()
115         cur.execute("USE `%s`" % db_name)
116         cur.execute(self._get_sql_query(db_name, query))
117         return cur
118
119     def _next(self, cur):
120         return cur.fetchone()
121
122     def _insert(self, db_name, table_name, values):
123         keys = []
124         vals = []
125         for k, v in values.iteritems():
126             keys.append(k)
127             if type(v) in (str, unicode):
128                 vals.append("'%s'" % v)
129             else:
130                 vals.append(str(v))
131         sql = "INSERT INTO `%s`.`%s`(%s) VALUES (%s)" % (db_name, table_name, ",".join(keys), ",".join(vals))
132         return self.execute(sql)
133
134     def commit(self):
135         self._db_con.commit()
136
137     def rollback(self):
138         self._db_con.rollback()