database page
[datasette-pytables.git] / datasette_pytables / __init__.py
1 from moz_sql_parser import parse
2 import re
3
4 import tables
5 import datasette_connectors as dc
6
7
8 class PyTablesConnection(dc.Connection):
9     def __init__(self, path, connector):
10         super().__init__(path, connector)
11         self.h5file = tables.open_file(path)
12
13
14 class PyTablesConnector(dc.Connector):
15     connector_type = 'pytables'
16     connection_class = PyTablesConnection
17
18     def table_names(self):
19         return [
20             node._v_pathname
21             for node in self.conn.h5file
22             if not(isinstance(node, tables.group.Group))
23         ]
24
25     def table_count(self, table_name):
26         table = self.conn.h5file.get_node(table_name)
27         return int(table.nrows)
28
29     def table_info(self, table_name):
30         table = self.conn.h5file.get_node(table_name)
31         colnames = ['value']
32         if isinstance(table, tables.table.Table):
33             colnames = table.colnames
34
35         return [
36             {
37                 'idx': idx,
38                 'name': colname,
39                 'primary_key': False,
40             }
41             for idx, colname in enumerate(colnames)
42         ]
43
44     def hidden_table_names(self):
45         return []
46
47     def detect_spatialite(self):
48         return False
49
50     def view_names(self):
51         return []
52
53     def detect_fts(self, table_name):
54         return False
55
56     def foreign_keys(self, table_name):
57         return []
58
59
60 def inspect(path):
61     "Open file and return tables info"
62     h5tables = {}
63     views = []
64     h5file = tables.open_file(path)
65
66     for table in filter(lambda node: not(isinstance(node, tables.group.Group)), h5file):
67         colnames = ['value']
68         if isinstance(table, tables.table.Table):
69             colnames = table.colnames
70
71         h5tables[table._v_pathname] = {
72             'name': table._v_pathname,
73             'columns': colnames,
74             'primary_keys': [],
75             'count': int(table.nrows),
76             'label_column': None,
77             'hidden': False,
78             'fts_table': None,
79             'foreign_keys': {'incoming': [], 'outgoing': []},
80         }
81
82     h5file.close()
83     return h5tables, views, _connector_type
84
85 def _parse_sql(sql, params):
86     # Table name
87     sql = re.sub(r'(?i)from \[(.*)]', r'from "\g<1>"', sql)
88     # Params
89     for param in params:
90         sql = sql.replace(":" + param, param)
91
92     try:
93         parsed = parse(sql)
94     except:
95         # Propably it's a PyTables expression
96         for token in ['group by', 'order by', 'limit', '']:
97             res = re.search('(?i)where (.*)' + token, sql)
98             if res:
99                 modified_sql = re.sub('(?i)where (.*)(' + token + ')', r'\g<2>', sql)
100                 parsed = parse(modified_sql)
101                 parsed['where'] = res.group(1).strip()
102                 break
103
104     # Always a list of fields
105     if type(parsed['select']) is not list:
106         parsed['select'] = [parsed['select']]
107
108     return parsed
109
110 _operators = {
111     'eq': '==',
112     'neq': '!=',
113     'gt': '>',
114     'gte': '>=',
115     'lt': '<',
116     'lte': '<=',
117     'and': '&',
118     'or': '|',
119 }
120
121 class Connection:
122     def __init__(self, path):
123         self.path = path
124         self.h5file = tables.open_file(path)
125
126     def execute(self, sql, params=None, truncate=False, page_size=None, max_returned_rows=None):
127         if params is None:
128             params = {}
129         rows = []
130         truncated = False
131         description = []
132
133         parsed_sql = _parse_sql(sql, params)
134
135         if parsed_sql['from'] == 'sqlite_master':
136             rows = self._execute_datasette_query(sql, params)
137             description = (('value',),)
138             return rows, truncated, description
139
140         table = self.h5file.get_node(parsed_sql['from'])
141         table_rows = []
142         fields = parsed_sql['select']
143         colnames = ['value']
144         if type(table) is tables.table.Table:
145             colnames = table.colnames
146
147         query = ''
148         start = 0
149         end = table.nrows
150
151         # Use 'where' statement or get all the rows
152         def _cast_param(field, pname):
153             # Cast value to the column type
154             coltype = table.dtype.name
155             if type(table) is tables.table.Table:
156                 coltype = table.coltypes[field]
157             fcast = None
158             if coltype == 'string':
159                 fcast = str
160             elif coltype.startswith('int'):
161                 fcast = int
162             elif coltype.startswith('float'):
163                 fcast = float
164             if fcast:
165                 params[pname] = fcast(params[pname])
166
167         def _translate_where(where):
168             # Translate SQL to PyTables expression
169             nonlocal start, end
170             expr = ''
171             operator = list(where)[0]
172
173             if operator in ['and', 'or']:
174                 subexpr = [_translate_where(e) for e in where[operator]]
175                 subexpr = filter(lambda e: e, subexpr)
176                 subexpr = ["({})".format(e) for e in subexpr]
177                 expr = " {} ".format(_operators[operator]).join(subexpr)
178             elif operator == 'exists':
179                 pass
180             elif where == {'eq': ['rowid', 'p0']}:
181                 start = int(params['p0'])
182                 end = start + 1
183             elif where == {'gt': ['rowid', 'p0']}:
184                 start = int(params['p0']) + 1
185             else:
186                 left, right = where[operator]
187                 if left in params:
188                     _cast_param(right, left)
189                 elif right in params:
190                     _cast_param(left, right)
191
192                 expr = "{left} {operator} {right}".format(left=left, operator=_operators.get(operator, operator), right=right)
193
194             return expr
195
196         if 'where' in parsed_sql:
197             if type(parsed_sql['where']) is dict:
198                 query = _translate_where(parsed_sql['where'])
199             else:
200                 query = parsed_sql['where']
201
202         # Sort by column
203         orderby = ''
204         if 'orderby' in parsed_sql:
205             orderby = parsed_sql['orderby']
206             if type(orderby) is list:
207                 orderby = orderby[0]
208             orderby = orderby['value']
209             if orderby == 'rowid':
210                 orderby = ''
211
212         # Limit number of rows
213         limit = None
214         if 'limit' in parsed_sql:
215             limit = int(parsed_sql['limit'])
216
217         # Truncate if needed
218         if page_size and max_returned_rows and truncate:
219             if max_returned_rows == page_size:
220                 max_returned_rows += 1
221
222         # Execute query
223         if query:
224             table_rows = table.where(query, params, start, end)
225         elif orderby:
226             table_rows = table.itersorted(orderby, start=start, stop=end)
227         else:
228             table_rows = table.iterrows(start, end)
229
230         # Prepare rows
231         def normalize_field_value(value):
232             if type(value) is bytes:
233                 return value.decode('utf-8')
234             elif not type(value) in (int, float, complex):
235                 return str(value)
236             else:
237                 return value
238
239         def make_get_rowid():
240             if type(table) is tables.table.Table:
241                 def get_rowid(row):
242                     return int(row.nrow)
243             else:
244                 rowid = start - 1
245                 def get_rowid(row):
246                     nonlocal rowid
247                     rowid += 1
248                     return rowid
249             return get_rowid
250
251         def make_get_row_value():
252             if type(table) is tables.table.Table:
253                 def get_row_value(row, field):
254                     return row[field]
255             else:
256                 def get_row_value(row, field):
257                     return row
258             return get_row_value
259
260         if len(fields) == 1 and type(fields[0]['value']) is dict and \
261            fields[0]['value'].get('count') == '*':
262             rows.append(Row({'count(*)': int(table.nrows)}))
263         else:
264             get_rowid = make_get_rowid()
265             get_row_value = make_get_row_value()
266             count = 0
267             for table_row in table_rows:
268                 count += 1
269                 if limit and count > limit:
270                     break
271                 if truncate and max_returned_rows and count > max_returned_rows:
272                     truncated = True
273                     break
274                 row = Row()
275                 for field in fields:
276                     field_name = field['value']
277                     if type(field_name) is dict and 'distinct' in field_name:
278                         field_name = field_name['distinct']
279                     if field_name == 'rowid':
280                         row['rowid'] = get_rowid(table_row)
281                     elif field_name == '*':
282                         for col in colnames:
283                             row[col] = normalize_field_value(get_row_value(table_row, col))
284                     else:
285                         row[field_name] = normalize_field_value(get_row_value(table_row, field_name))
286                 rows.append(row)
287
288         # Prepare query description
289         for field in [f['value'] for f in fields]:
290             if field == '*':
291                 for col in colnames:
292                     description.append((col,))
293             else:
294                 description.append((field,))
295
296         # Return the rows
297         return rows, truncated, tuple(description)
298
299     def _execute_datasette_query(self, sql, params):
300         "Datasette special queries for getting tables info"
301         if sql == 'select sql from sqlite_master where name = :n and type=:t':
302             if params['t'] == 'view':
303                 return []
304             else:
305                 try:
306                     table = self.h5file.get_node(params['n'])
307                     colnames = ['value']
308                     if type(table) is tables.table.Table:
309                         colnames = table.colnames
310                     row = Row()
311                     row['sql'] = 'CREATE TABLE {} ({})'.format(params['n'], ", ".join(colnames))
312                     return [row]
313                 except:
314                     return []
315         else:
316             raise Exception("SQLite queries cannot be executed with this connector: %s, %s" % (sql, params))
317
318
319 class Row(list):
320     def __init__(self, values=None):
321         self.labels = []
322         self.values = []
323         if values:
324             for idx in values:
325                 self.__setitem__(idx, values[idx])
326
327     def __setitem__(self, idx, value):
328         if type(idx) is str:
329             if idx in self.labels:
330                 self.values[self.labels.index(idx)] = value
331             else:
332                 self.labels.append(idx)
333                 self.values.append(value)
334         else:
335             self.values[idx] = value
336
337     def __getitem__(self, idx):
338         if type(idx) is str:
339             return self.values[self.labels.index(idx)]
340         else:
341             return self.values[idx]
342
343     def __iter__(self):
344         return self.values.__iter__()