1 from moz_sql_parser import parse
5 import datasette_connectors as dc
8 class PyTablesConnection(dc.Connection):
9 def __init__(self, path, connector):
10 super().__init__(path, connector)
11 self.h5file = tables.open_file(path)
14 class PyTablesConnector(dc.Connector):
15 connector_type = 'pytables'
16 connection_class = PyTablesConnection
18 def table_names(self):
21 for node in self.conn.h5file
22 if not(isinstance(node, tables.group.Group))
25 def table_count(self, table_name):
26 table = self.conn.h5file.get_node(table_name)
27 return int(table.nrows)
29 def table_info(self, table_name):
30 table = self.conn.h5file.get_node(table_name)
32 if isinstance(table, tables.table.Table):
33 colnames = table.colnames
41 for idx, colname in enumerate(colnames)
44 def hidden_table_names(self):
47 def detect_spatialite(self):
53 def detect_fts(self, table_name):
56 def foreign_keys(self, table_name):
61 "Open file and return tables info"
64 h5file = tables.open_file(path)
66 for table in filter(lambda node: not(isinstance(node, tables.group.Group)), h5file):
68 if isinstance(table, tables.table.Table):
69 colnames = table.colnames
71 h5tables[table._v_pathname] = {
72 'name': table._v_pathname,
75 'count': int(table.nrows),
79 'foreign_keys': {'incoming': [], 'outgoing': []},
83 return h5tables, views, _connector_type
85 def _parse_sql(sql, params):
87 sql = re.sub(r'(?i)from \[(.*)]', r'from "\g<1>"', sql)
90 sql = sql.replace(":" + param, param)
95 # Propably it's a PyTables expression
96 for token in ['group by', 'order by', 'limit', '']:
97 res = re.search('(?i)where (.*)' + token, sql)
99 modified_sql = re.sub('(?i)where (.*)(' + token + ')', r'\g<2>', sql)
100 parsed = parse(modified_sql)
101 parsed['where'] = res.group(1).strip()
104 # Always a list of fields
105 if type(parsed['select']) is not list:
106 parsed['select'] = [parsed['select']]
122 def __init__(self, path):
124 self.h5file = tables.open_file(path)
126 def execute(self, sql, params=None, truncate=False, page_size=None, max_returned_rows=None):
133 parsed_sql = _parse_sql(sql, params)
135 if parsed_sql['from'] == 'sqlite_master':
136 rows = self._execute_datasette_query(sql, params)
137 description = (('value',),)
138 return rows, truncated, description
140 table = self.h5file.get_node(parsed_sql['from'])
142 fields = parsed_sql['select']
144 if type(table) is tables.table.Table:
145 colnames = table.colnames
151 # Use 'where' statement or get all the rows
152 def _cast_param(field, pname):
153 # Cast value to the column type
154 coltype = table.dtype.name
155 if type(table) is tables.table.Table:
156 coltype = table.coltypes[field]
158 if coltype == 'string':
160 elif coltype.startswith('int'):
162 elif coltype.startswith('float'):
165 params[pname] = fcast(params[pname])
167 def _translate_where(where):
168 # Translate SQL to PyTables expression
171 operator = list(where)[0]
173 if operator in ['and', 'or']:
174 subexpr = [_translate_where(e) for e in where[operator]]
175 subexpr = filter(lambda e: e, subexpr)
176 subexpr = ["({})".format(e) for e in subexpr]
177 expr = " {} ".format(_operators[operator]).join(subexpr)
178 elif operator == 'exists':
180 elif where == {'eq': ['rowid', 'p0']}:
181 start = int(params['p0'])
183 elif where == {'gt': ['rowid', 'p0']}:
184 start = int(params['p0']) + 1
186 left, right = where[operator]
188 _cast_param(right, left)
189 elif right in params:
190 _cast_param(left, right)
192 expr = "{left} {operator} {right}".format(left=left, operator=_operators.get(operator, operator), right=right)
196 if 'where' in parsed_sql:
197 if type(parsed_sql['where']) is dict:
198 query = _translate_where(parsed_sql['where'])
200 query = parsed_sql['where']
204 if 'orderby' in parsed_sql:
205 orderby = parsed_sql['orderby']
206 if type(orderby) is list:
208 orderby = orderby['value']
209 if orderby == 'rowid':
212 # Limit number of rows
214 if 'limit' in parsed_sql:
215 limit = int(parsed_sql['limit'])
218 if page_size and max_returned_rows and truncate:
219 if max_returned_rows == page_size:
220 max_returned_rows += 1
224 table_rows = table.where(query, params, start, end)
226 table_rows = table.itersorted(orderby, start=start, stop=end)
228 table_rows = table.iterrows(start, end)
231 def normalize_field_value(value):
232 if type(value) is bytes:
233 return value.decode('utf-8')
234 elif not type(value) in (int, float, complex):
239 def make_get_rowid():
240 if type(table) is tables.table.Table:
251 def make_get_row_value():
252 if type(table) is tables.table.Table:
253 def get_row_value(row, field):
256 def get_row_value(row, field):
260 if len(fields) == 1 and type(fields[0]['value']) is dict and \
261 fields[0]['value'].get('count') == '*':
262 rows.append(Row({'count(*)': int(table.nrows)}))
264 get_rowid = make_get_rowid()
265 get_row_value = make_get_row_value()
267 for table_row in table_rows:
269 if limit and count > limit:
271 if truncate and max_returned_rows and count > max_returned_rows:
276 field_name = field['value']
277 if type(field_name) is dict and 'distinct' in field_name:
278 field_name = field_name['distinct']
279 if field_name == 'rowid':
280 row['rowid'] = get_rowid(table_row)
281 elif field_name == '*':
283 row[col] = normalize_field_value(get_row_value(table_row, col))
285 row[field_name] = normalize_field_value(get_row_value(table_row, field_name))
288 # Prepare query description
289 for field in [f['value'] for f in fields]:
292 description.append((col,))
294 description.append((field,))
297 return rows, truncated, tuple(description)
299 def _execute_datasette_query(self, sql, params):
300 "Datasette special queries for getting tables info"
301 if sql == 'select sql from sqlite_master where name = :n and type=:t':
302 if params['t'] == 'view':
306 table = self.h5file.get_node(params['n'])
308 if type(table) is tables.table.Table:
309 colnames = table.colnames
311 row['sql'] = 'CREATE TABLE {} ({})'.format(params['n'], ", ".join(colnames))
316 raise Exception("SQLite queries cannot be executed with this connector: %s, %s" % (sql, params))
320 def __init__(self, values=None):
325 self.__setitem__(idx, values[idx])
327 def __setitem__(self, idx, value):
329 if idx in self.labels:
330 self.values[self.labels.index(idx)] = value
332 self.labels.append(idx)
333 self.values.append(value)
335 self.values[idx] = value
337 def __getitem__(self, idx):
339 return self.values[self.labels.index(idx)]
341 return self.values[idx]
344 return self.values.__iter__()