]> git.jsancho.org Git - datasette-connectors.git/blob - datasette_connectors/connectors.py
f09727f93bc07f59e60343ea3fcf78867e6ac1da
[datasette-connectors.git] / datasette_connectors / connectors.py
1 import pkg_resources
2 import functools
3 import re
4 import sqlite3
5
6 from .row import Row
7
8
9 db_connectors = {}
10
11 def for_each_connector(func):
12     @functools.wraps(func)
13     def wrapper_for_each_connector(*args, **kwargs):
14         for connector in db_connectors.values():
15             try:
16                 return func(connector, *args, **kwargs)
17             except:
18                 pass
19         else:
20             raise Exception("No database connector found!!")
21     return wrapper_for_each_connector
22
23
24 class ConnectorList:
25     @staticmethod
26     def load():
27         for entry_point in pkg_resources.iter_entry_points('datasette.connectors'):
28             db_connectors[entry_point.name] = entry_point.load()
29
30     @staticmethod
31     def add_connector(name, connector):
32         db_connectors[name] = connector
33
34     class DatabaseNotSupported(Exception):
35         pass
36
37     @staticmethod
38     def connect(path):
39         for connector in db_connectors.values():
40             try:
41                 return connector.connect(path)
42             except:
43                 pass
44         else:
45             raise ConnectorList.DatabaseNotSupported
46
47
48 class Connection:
49     def __init__(self, path, connector):
50         self.path = path
51         self.connector = connector
52
53     def execute(self, *args, **kwargs):
54         cursor = Cursor(self)
55         cursor.execute(*args, **kwargs)
56         return cursor
57
58     def cursor(self):
59         return Cursor(self)
60
61     def set_progress_handler(self, handler, n):
62         pass
63
64
65 class OperationalError(Exception):
66     pass
67
68
69 class Cursor:
70     class QueryNotSupported(Exception):
71         pass
72
73     def __init__(self, conn):
74         self.conn = conn
75         self.connector = conn.connector
76         self.rows = []
77         self.description = ()
78
79     def execute(
80         self,
81         sql,
82         params=None,
83         truncate=False,
84         custom_time_limit=None,
85         page_size=None,
86         log_sql_errors=True,
87     ):
88         if params is None:
89             params = {}
90         results = []
91         truncated = False
92         description = ()
93
94         # Normalize sql
95         sql = sql.strip()
96         sql = ' '.join(sql.split())
97
98         if sql == "select name from sqlite_master where type='table'" or \
99            sql == "select name from sqlite_master where type=\"table\"":
100             results = [{'name': name} for name in self.connector.table_names()]
101         elif sql == "select name from sqlite_master where rootpage = 0 and sql like '%VIRTUAL TABLE%USING FTS%'":
102             results = [{'name': name} for name in self.connector.hidden_table_names()]
103         elif sql == 'select 1 from sqlite_master where tbl_name = "geometry_columns"':
104             if self.connector.detect_spatialite():
105                 results = [{'1': '1'}]
106         elif sql == "select name from sqlite_master where type='view'":
107             results = [{'name': name} for name in self.connector.view_names()]
108         elif sql.startswith("select count(*) from ["):
109             match = re.search(r'select count\(\*\) from \[(.*)\]', sql)
110             results = [{'count(*)': self.connector.table_count(match.group(1))}]
111         elif sql.startswith("select count(*) from "):
112             match = re.search(r'select count\(\*\) from (.*)', sql)
113             results = [{'count(*)': self.connector.table_count(match.group(1))}]
114         elif sql.startswith("PRAGMA table_info("):
115             match = re.search(r'PRAGMA table_info\((.*)\)', sql)
116             results = self.connector.table_info(match.group(1))
117         elif sql.startswith("select name from sqlite_master where rootpage = 0 and ( sql like \'%VIRTUAL TABLE%USING FTS%content="):
118             match = re.search(r'select name from sqlite_master where rootpage = 0 and \( sql like \'%VIRTUAL TABLE%USING FTS%content="(.*)"', sql)
119             if self.connector.detect_fts(match.group(1)):
120                 results = [{'name': match.group(1)}]
121         elif sql.startswith("PRAGMA foreign_key_list(["):
122             match = re.search(r'PRAGMA foreign_key_list\(\[(.*)\]\)', sql)
123             results = self.connector.foreign_keys(match.group(1))
124         elif sql == "select 1 from sqlite_master where type='table' and name=?":
125             if self.connector.table_exists(params[0]):
126                 results = [{'1': '1'}]
127         elif sql == "select sql from sqlite_master where name = :n and type=:t":
128             results = [{'sql': self.connector.table_definition(params['t'], params['n'])}]
129         elif sql == "select sql from sqlite_master where tbl_name = :n and type='index' and sql is not null":
130             results = [{'sql': sql} for sql in self.connector.indices_definition(params['n'])]
131         else:
132             try:
133                 results, truncated, description = \
134                     self.connector.execute(
135                         sql,
136                         params=params,
137                         truncate=truncate,
138                         custom_time_limit=custom_time_limit,
139                         page_size=page_size,
140                         log_sql_errors=log_sql_errors,
141                     )
142             except OperationalError as ex:
143                 raise sqlite3.OperationalError(*ex.args)
144
145         self.rows = [Row(result) for result in results]
146         self.description = description
147
148     def fetchall(self):
149         return self.rows
150
151     def fetchmany(self, max):
152         return self.rows[:max]
153
154     def __getitem__(self, index):
155         return self.rows[index]
156
157
158 class Connector:
159     connector_type = None
160     connection_class = Connection
161
162     def connect(self, path):
163         return self.connection_class(path, self)
164
165     def table_names(self):
166         """
167         Return a list of table names
168         """
169         raise NotImplementedError
170
171     def hidden_table_names(self):
172         raise NotImplementedError
173
174     def detect_spatialite(self):
175         """
176         Return boolean indicating if geometry_columns exists
177         """
178         raise NotImplementedError
179
180     def view_names(self):
181         """
182         Return a list of view names
183         """
184         raise NotImplementedError
185
186     def table_count(self, table_name):
187         """
188         Return an integer with the rows count of the table
189         """
190         raise NotImplementedError
191
192     def table_info(self, table_name):
193         """
194         Return a list of dictionaries with columns description, with format:
195         [
196             {
197                 'idx': 0,
198                 'name': 'column1',
199                 'primary_key': False,
200             },
201             ...
202         ]
203         """
204         raise NotImplementedError
205
206     def detect_fts(self, table_name):
207         """
208         Return boolean indicating if table has a corresponding FTS virtual table
209         """
210         raise NotImplementedError
211
212     def foreign_keys(self, table_name):
213         """
214         Return a list of dictionaries with foreign keys description
215         id, seq, table_name, from_, to_, on_update, on_delete, match
216         """
217         raise NotImplementedError
218
219     def table_exists(self, table_name):
220         """
221         Return boolean indicating if table exists in the database
222         """
223         raise NotImplementedError
224
225     def table_definition(self, table_type, table_name):
226         """
227         Return string with a 'CREATE TABLE' sql definition
228         """
229         raise NotImplementedError
230
231     def indices_definition(self, table_name):
232         """
233         Return a list of strings with 'CREATE INDEX' sql definitions
234         """
235         raise NotImplementedError
236
237     def execute(
238         self,
239         sql,
240         params=None,
241         truncate=False,
242         custom_time_limit=None,
243         page_size=None,
244         log_sql_errors=True,
245     ):
246         raise NotImplementedError