custom pytables sql
[datasette-pytables.git] / datasette_pytables / __init__.py
1 import tables
2 import datasette_connectors as dc
3 from .utils import parse_sql
4
5
6 class PyTablesConnection(dc.Connection):
7     def __init__(self, path, connector):
8         super().__init__(path, connector)
9         self.h5file = tables.open_file(path)
10
11
12 class PyTablesConnector(dc.Connector):
13     connector_type = 'pytables'
14     connection_class = PyTablesConnection
15
16     operators = {
17         'eq': '==',
18         'neq': '!=',
19         'gt': '>',
20         'gte': '>=',
21         'lt': '<',
22         'lte': '<=',
23         'and': '&',
24         'or': '|',
25         'binary_and': '&',
26         'binary_or': '|',
27     }
28
29     def table_names(self):
30         return [
31             node._v_pathname
32             for node in self.conn.h5file
33             if not(isinstance(node, tables.group.Group))
34         ]
35
36     def table_count(self, table_name):
37         table = self.conn.h5file.get_node(table_name)
38         return int(table.nrows)
39
40     def table_info(self, table_name):
41         table = self.conn.h5file.get_node(table_name)
42         colnames = ['value']
43         if isinstance(table, tables.table.Table):
44             colnames = table.colnames
45
46         return [
47             {
48                 'idx': idx,
49                 'name': colname,
50                 'primary_key': False,
51             }
52             for idx, colname in enumerate(colnames)
53         ]
54
55     def hidden_table_names(self):
56         return []
57
58     def detect_spatialite(self):
59         return False
60
61     def view_names(self):
62         return []
63
64     def detect_fts(self, table_name):
65         return False
66
67     def foreign_keys(self, table_name):
68         return []
69
70     def execute(
71         self,
72         sql,
73         params=None,
74         truncate=False,
75         custom_time_limit=None,
76         page_size=None,
77         log_sql_errors=True,
78     ):
79         results = []
80         truncated = False
81         description = ()
82
83         parsed_sql = parse_sql(sql, params)
84
85         table = self.conn.h5file.get_node(parsed_sql['from'])
86         table_rows = []
87         fields = parsed_sql['select']
88         colnames = ['value']
89         if type(table) is tables.table.Table:
90             colnames = table.colnames
91
92         query = ''
93         start = 0
94         end = table.nrows
95
96         # Use 'where' statement or get all the rows
97         def _cast_param(field, pname):
98             # Cast value to the column type
99             coltype = table.dtype.name
100             if type(table) is tables.table.Table:
101                 coltype = table.coltypes[field]
102             fcast = None
103             if coltype == 'string':
104                 fcast = str
105             elif coltype.startswith('int'):
106                 fcast = int
107             elif coltype.startswith('float'):
108                 fcast = float
109             if fcast:
110                 params[pname] = fcast(params[pname])
111
112         def _translate_where(where):
113             # Translate SQL to PyTables expression
114             nonlocal start, end
115             expr = ''
116             operator = list(where)[0]
117
118             if operator in ['and', 'or']:
119                 subexpr = [_translate_where(e) for e in where[operator]]
120                 subexpr = filter(lambda e: e, subexpr)
121                 subexpr = ["({})".format(e) for e in subexpr]
122                 expr = " {} ".format(self.operators[operator]).join(subexpr)
123             elif operator == 'exists':
124                 pass
125             elif where == {'eq': ['rowid', 'p0']}:
126                 start = int(params['p0'])
127                 end = start + 1
128             elif where == {'gt': ['rowid', 'p0']}:
129                 start = int(params['p0']) + 1
130             else:
131                 left, right = where[operator]
132
133                 if isinstance(left, dict):
134                     left = "(" + _translate_where(left) + ")"
135                 elif left in params:
136                     _cast_param(right, left)
137
138                 if isinstance(right, dict):
139                     right = "(" + _translate_where(right) + ")"
140                 elif right in params:
141                     _cast_param(left, right)
142
143                 expr = "{left} {operator} {right}".format(
144                     left=left,
145                     operator=self.operators.get(operator, operator),
146                     right=right,
147                 )
148
149             return expr
150
151         if 'where' in parsed_sql:
152             if type(parsed_sql['where']) is dict:
153                 query = _translate_where(parsed_sql['where'])
154             else:
155                 query = parsed_sql['where']
156
157         # Sort by column
158         orderby = ''
159         if 'orderby' in parsed_sql:
160             orderby = parsed_sql['orderby']
161             if type(orderby) is list:
162                 orderby = orderby[0]
163             orderby = orderby['value']
164             if orderby == 'rowid':
165                 orderby = ''
166
167         # Limit number of rows
168         limit = None
169         if 'limit' in parsed_sql:
170             limit = int(parsed_sql['limit'])
171
172         # Truncate if needed
173         if page_size and max_returned_rows and truncate:
174             if max_returned_rows == page_size:
175                 max_returned_rows += 1
176
177         # Execute query
178         if query:
179             table_rows = table.where(query, params, start, end)
180         elif orderby:
181             table_rows = table.itersorted(orderby, start=start, stop=end)
182         else:
183             table_rows = table.iterrows(start, end)
184
185         # Prepare rows
186         def normalize_field_value(value):
187             if type(value) is bytes:
188                 return value.decode('utf-8')
189             elif not type(value) in (int, float, complex):
190                 return str(value)
191             else:
192                 return value
193
194         def make_get_rowid():
195             if type(table) is tables.table.Table:
196                 def get_rowid(row):
197                     return int(row.nrow)
198             else:
199                 rowid = start - 1
200                 def get_rowid(row):
201                     nonlocal rowid
202                     rowid += 1
203                     return rowid
204             return get_rowid
205
206         def make_get_row_value():
207             if type(table) is tables.table.Table:
208                 def get_row_value(row, field):
209                     return row[field]
210             else:
211                 def get_row_value(row, field):
212                     return row
213             return get_row_value
214
215         if len(fields) == 1 and type(fields[0]['value']) is dict and \
216            fields[0]['value'].get('count') == '*':
217             results.append({'count(*)': int(table.nrows)})
218         else:
219             get_rowid = make_get_rowid()
220             get_row_value = make_get_row_value()
221             count = 0
222             for table_row in table_rows:
223                 count += 1
224                 if limit and count > limit:
225                     break
226                 if truncate and max_returned_rows and count > max_returned_rows:
227                     truncated = True
228                     break
229                 row = {}
230                 for field in fields:
231                     field_name = field['value']
232                     if type(field_name) is dict and 'distinct' in field_name:
233                         field_name = field_name['distinct']
234                     if field_name == 'rowid':
235                         row['rowid'] = get_rowid(table_row)
236                     elif field_name == '*':
237                         for col in colnames:
238                             row[col] = normalize_field_value(get_row_value(table_row, col))
239                     else:
240                         row[field_name] = normalize_field_value(get_row_value(table_row, field_name))
241                 results.append(row)
242
243         # Prepare query description
244         for field in [f['value'] for f in fields]:
245             if field == '*':
246                 for col in colnames:
247                     description += ((col,),)
248             else:
249                 description += ((field,),)
250
251         return results, truncated, description