Use dataset-connectors 2.0 api (wip)
[datasette-pytables.git] / datasette_pytables / __init__.py
1 from moz_sql_parser import parse
2 import re
3
4 import tables
5 import datasette_connectors as dc
6
7
8 class PyTablesConnection(dc.Connection):
9     def __init__(self, path, connector):
10         super().__init__(path, connector)
11         self.h5file = tables.open_file(path)
12
13
14 class PyTablesConnector(dc.Connector):
15     connector_type = 'pytables'
16     connection_class = PyTablesConnection
17
18     def table_names(self):
19         return [
20             node._v_pathname
21             for node in self.conn.h5file
22             if not(isinstance(node, tables.group.Group))
23         ]
24
25     def table_count(self, table_name):
26         table = self.conn.h5file.get_node(table_name)
27         return int(table.nrows)
28
29     def table_info(self, table_name):
30         table = self.conn.h5file.get_node(table_name)
31         colnames = ['value']
32         if isinstance(table, tables.table.Table):
33             colnames = table.colnames
34
35         return [
36             {
37                 'idx': idx,
38                 'name': colname,
39                 'primary_key': False,
40             }
41             for idx, colname in enumerate(colnames)
42         ]
43
44     def hidden_table_names(self):
45         return []
46
47     def detect_spatialite(self):
48         return False
49
50     def view_names(self):
51         return []
52
53     def detect_fts(self, table_name):
54         return False
55
56
57 def inspect(path):
58     "Open file and return tables info"
59     h5tables = {}
60     views = []
61     h5file = tables.open_file(path)
62
63     for table in filter(lambda node: not(isinstance(node, tables.group.Group)), h5file):
64         colnames = ['value']
65         if isinstance(table, tables.table.Table):
66             colnames = table.colnames
67
68         h5tables[table._v_pathname] = {
69             'name': table._v_pathname,
70             'columns': colnames,
71             'primary_keys': [],
72             'count': int(table.nrows),
73             'label_column': None,
74             'hidden': False,
75             'fts_table': None,
76             'foreign_keys': {'incoming': [], 'outgoing': []},
77         }
78
79     h5file.close()
80     return h5tables, views, _connector_type
81
82 def _parse_sql(sql, params):
83     # Table name
84     sql = re.sub(r'(?i)from \[(.*)]', r'from "\g<1>"', sql)
85     # Params
86     for param in params:
87         sql = sql.replace(":" + param, param)
88
89     try:
90         parsed = parse(sql)
91     except:
92         # Propably it's a PyTables expression
93         for token in ['group by', 'order by', 'limit', '']:
94             res = re.search('(?i)where (.*)' + token, sql)
95             if res:
96                 modified_sql = re.sub('(?i)where (.*)(' + token + ')', r'\g<2>', sql)
97                 parsed = parse(modified_sql)
98                 parsed['where'] = res.group(1).strip()
99                 break
100
101     # Always a list of fields
102     if type(parsed['select']) is not list:
103         parsed['select'] = [parsed['select']]
104
105     return parsed
106
107 _operators = {
108     'eq': '==',
109     'neq': '!=',
110     'gt': '>',
111     'gte': '>=',
112     'lt': '<',
113     'lte': '<=',
114     'and': '&',
115     'or': '|',
116 }
117
118 class Connection:
119     def __init__(self, path):
120         self.path = path
121         self.h5file = tables.open_file(path)
122
123     def execute(self, sql, params=None, truncate=False, page_size=None, max_returned_rows=None):
124         if params is None:
125             params = {}
126         rows = []
127         truncated = False
128         description = []
129
130         parsed_sql = _parse_sql(sql, params)
131
132         if parsed_sql['from'] == 'sqlite_master':
133             rows = self._execute_datasette_query(sql, params)
134             description = (('value',),)
135             return rows, truncated, description
136
137         table = self.h5file.get_node(parsed_sql['from'])
138         table_rows = []
139         fields = parsed_sql['select']
140         colnames = ['value']
141         if type(table) is tables.table.Table:
142             colnames = table.colnames
143
144         query = ''
145         start = 0
146         end = table.nrows
147
148         # Use 'where' statement or get all the rows
149         def _cast_param(field, pname):
150             # Cast value to the column type
151             coltype = table.dtype.name
152             if type(table) is tables.table.Table:
153                 coltype = table.coltypes[field]
154             fcast = None
155             if coltype == 'string':
156                 fcast = str
157             elif coltype.startswith('int'):
158                 fcast = int
159             elif coltype.startswith('float'):
160                 fcast = float
161             if fcast:
162                 params[pname] = fcast(params[pname])
163
164         def _translate_where(where):
165             # Translate SQL to PyTables expression
166             nonlocal start, end
167             expr = ''
168             operator = list(where)[0]
169
170             if operator in ['and', 'or']:
171                 subexpr = [_translate_where(e) for e in where[operator]]
172                 subexpr = filter(lambda e: e, subexpr)
173                 subexpr = ["({})".format(e) for e in subexpr]
174                 expr = " {} ".format(_operators[operator]).join(subexpr)
175             elif operator == 'exists':
176                 pass
177             elif where == {'eq': ['rowid', 'p0']}:
178                 start = int(params['p0'])
179                 end = start + 1
180             elif where == {'gt': ['rowid', 'p0']}:
181                 start = int(params['p0']) + 1
182             else:
183                 left, right = where[operator]
184                 if left in params:
185                     _cast_param(right, left)
186                 elif right in params:
187                     _cast_param(left, right)
188
189                 expr = "{left} {operator} {right}".format(left=left, operator=_operators.get(operator, operator), right=right)
190
191             return expr
192
193         if 'where' in parsed_sql:
194             if type(parsed_sql['where']) is dict:
195                 query = _translate_where(parsed_sql['where'])
196             else:
197                 query = parsed_sql['where']
198
199         # Sort by column
200         orderby = ''
201         if 'orderby' in parsed_sql:
202             orderby = parsed_sql['orderby']
203             if type(orderby) is list:
204                 orderby = orderby[0]
205             orderby = orderby['value']
206             if orderby == 'rowid':
207                 orderby = ''
208
209         # Limit number of rows
210         limit = None
211         if 'limit' in parsed_sql:
212             limit = int(parsed_sql['limit'])
213
214         # Truncate if needed
215         if page_size and max_returned_rows and truncate:
216             if max_returned_rows == page_size:
217                 max_returned_rows += 1
218
219         # Execute query
220         if query:
221             table_rows = table.where(query, params, start, end)
222         elif orderby:
223             table_rows = table.itersorted(orderby, start=start, stop=end)
224         else:
225             table_rows = table.iterrows(start, end)
226
227         # Prepare rows
228         def normalize_field_value(value):
229             if type(value) is bytes:
230                 return value.decode('utf-8')
231             elif not type(value) in (int, float, complex):
232                 return str(value)
233             else:
234                 return value
235
236         def make_get_rowid():
237             if type(table) is tables.table.Table:
238                 def get_rowid(row):
239                     return int(row.nrow)
240             else:
241                 rowid = start - 1
242                 def get_rowid(row):
243                     nonlocal rowid
244                     rowid += 1
245                     return rowid
246             return get_rowid
247
248         def make_get_row_value():
249             if type(table) is tables.table.Table:
250                 def get_row_value(row, field):
251                     return row[field]
252             else:
253                 def get_row_value(row, field):
254                     return row
255             return get_row_value
256
257         if len(fields) == 1 and type(fields[0]['value']) is dict and \
258            fields[0]['value'].get('count') == '*':
259             rows.append(Row({'count(*)': int(table.nrows)}))
260         else:
261             get_rowid = make_get_rowid()
262             get_row_value = make_get_row_value()
263             count = 0
264             for table_row in table_rows:
265                 count += 1
266                 if limit and count > limit:
267                     break
268                 if truncate and max_returned_rows and count > max_returned_rows:
269                     truncated = True
270                     break
271                 row = Row()
272                 for field in fields:
273                     field_name = field['value']
274                     if type(field_name) is dict and 'distinct' in field_name:
275                         field_name = field_name['distinct']
276                     if field_name == 'rowid':
277                         row['rowid'] = get_rowid(table_row)
278                     elif field_name == '*':
279                         for col in colnames:
280                             row[col] = normalize_field_value(get_row_value(table_row, col))
281                     else:
282                         row[field_name] = normalize_field_value(get_row_value(table_row, field_name))
283                 rows.append(row)
284
285         # Prepare query description
286         for field in [f['value'] for f in fields]:
287             if field == '*':
288                 for col in colnames:
289                     description.append((col,))
290             else:
291                 description.append((field,))
292
293         # Return the rows
294         return rows, truncated, tuple(description)
295
296     def _execute_datasette_query(self, sql, params):
297         "Datasette special queries for getting tables info"
298         if sql == 'select sql from sqlite_master where name = :n and type=:t':
299             if params['t'] == 'view':
300                 return []
301             else:
302                 try:
303                     table = self.h5file.get_node(params['n'])
304                     colnames = ['value']
305                     if type(table) is tables.table.Table:
306                         colnames = table.colnames
307                     row = Row()
308                     row['sql'] = 'CREATE TABLE {} ({})'.format(params['n'], ", ".join(colnames))
309                     return [row]
310                 except:
311                     return []
312         else:
313             raise Exception("SQLite queries cannot be executed with this connector: %s, %s" % (sql, params))
314
315
316 class Row(list):
317     def __init__(self, values=None):
318         self.labels = []
319         self.values = []
320         if values:
321             for idx in values:
322                 self.__setitem__(idx, values[idx])
323
324     def __setitem__(self, idx, value):
325         if type(idx) is str:
326             if idx in self.labels:
327                 self.values[self.labels.index(idx)] = value
328             else:
329                 self.labels.append(idx)
330                 self.values.append(value)
331         else:
332             self.values[idx] = value
333
334     def __getitem__(self, idx):
335         if type(idx) is str:
336             return self.values[self.labels.index(idx)]
337         else:
338             return self.values[idx]
339
340     def __iter__(self):
341         return self.values.__iter__()