-from moz_sql_parser import parse
-import re
-
import tables
import datasette_connectors as dc
+from .utils import parse_sql
class PyTablesConnection(dc.Connection):
connector_type = 'pytables'
connection_class = PyTablesConnection
+ operators = {
+ 'eq': '==',
+ 'neq': '!=',
+ 'gt': '>',
+ 'gte': '>=',
+ 'lt': '<',
+ 'lte': '<=',
+ 'and': '&',
+ 'or': '|',
+ }
+
def table_names(self):
return [
node._v_pathname
def foreign_keys(self, table_name):
return []
-
-def inspect(path):
- "Open file and return tables info"
- h5tables = {}
- views = []
- h5file = tables.open_file(path)
-
- for table in filter(lambda node: not(isinstance(node, tables.group.Group)), h5file):
- colnames = ['value']
- if isinstance(table, tables.table.Table):
- colnames = table.colnames
-
- h5tables[table._v_pathname] = {
- 'name': table._v_pathname,
- 'columns': colnames,
- 'primary_keys': [],
- 'count': int(table.nrows),
- 'label_column': None,
- 'hidden': False,
- 'fts_table': None,
- 'foreign_keys': {'incoming': [], 'outgoing': []},
- }
-
- h5file.close()
- return h5tables, views, _connector_type
-
-def _parse_sql(sql, params):
- # Table name
- sql = re.sub(r'(?i)from \[(.*)]', r'from "\g<1>"', sql)
- # Params
- for param in params:
- sql = sql.replace(":" + param, param)
-
- try:
- parsed = parse(sql)
- except:
- # Propably it's a PyTables expression
- for token in ['group by', 'order by', 'limit', '']:
- res = re.search('(?i)where (.*)' + token, sql)
- if res:
- modified_sql = re.sub('(?i)where (.*)(' + token + ')', r'\g<2>', sql)
- parsed = parse(modified_sql)
- parsed['where'] = res.group(1).strip()
- break
-
- # Always a list of fields
- if type(parsed['select']) is not list:
- parsed['select'] = [parsed['select']]
-
- return parsed
-
-_operators = {
- 'eq': '==',
- 'neq': '!=',
- 'gt': '>',
- 'gte': '>=',
- 'lt': '<',
- 'lte': '<=',
- 'and': '&',
- 'or': '|',
-}
-
-class Connection:
- def __init__(self, path):
- self.path = path
- self.h5file = tables.open_file(path)
-
- def execute(self, sql, params=None, truncate=False, page_size=None, max_returned_rows=None):
- if params is None:
- params = {}
- rows = []
+ def execute(
+ self,
+ sql,
+ params=None,
+ truncate=False,
+ custom_time_limit=None,
+ page_size=None,
+ log_sql_errors=True,
+ ):
+ results = []
truncated = False
- description = []
-
- parsed_sql = _parse_sql(sql, params)
+ description = ()
- if parsed_sql['from'] == 'sqlite_master':
- rows = self._execute_datasette_query(sql, params)
- description = (('value',),)
- return rows, truncated, description
+ parsed_sql = parse_sql(sql, params)
- table = self.h5file.get_node(parsed_sql['from'])
+ table = self.conn.h5file.get_node(parsed_sql['from'])
table_rows = []
fields = parsed_sql['select']
colnames = ['value']
subexpr = [_translate_where(e) for e in where[operator]]
subexpr = filter(lambda e: e, subexpr)
subexpr = ["({})".format(e) for e in subexpr]
- expr = " {} ".format(_operators[operator]).join(subexpr)
+ expr = " {} ".format(self.operators[operator]).join(subexpr)
elif operator == 'exists':
pass
elif where == {'eq': ['rowid', 'p0']}:
elif right in params:
_cast_param(left, right)
- expr = "{left} {operator} {right}".format(left=left, operator=_operators.get(operator, operator), right=right)
+ expr = "{left} {operator} {right}".format(left=left, operator=self.operators.get(operator, operator), right=right)
return expr
if len(fields) == 1 and type(fields[0]['value']) is dict and \
fields[0]['value'].get('count') == '*':
- rows.append(Row({'count(*)': int(table.nrows)}))
+ results.append({'count(*)': int(table.nrows)})
else:
get_rowid = make_get_rowid()
get_row_value = make_get_row_value()
if truncate and max_returned_rows and count > max_returned_rows:
truncated = True
break
- row = Row()
+ row = {}
for field in fields:
field_name = field['value']
if type(field_name) is dict and 'distinct' in field_name:
row[col] = normalize_field_value(get_row_value(table_row, col))
else:
row[field_name] = normalize_field_value(get_row_value(table_row, field_name))
- rows.append(row)
+ results.append(row)
# Prepare query description
for field in [f['value'] for f in fields]:
if field == '*':
for col in colnames:
- description.append((col,))
+ description += ((col,),)
else:
- description.append((field,))
-
- # Return the rows
- return rows, truncated, tuple(description)
-
- def _execute_datasette_query(self, sql, params):
- "Datasette special queries for getting tables info"
- if sql == 'select sql from sqlite_master where name = :n and type=:t':
- if params['t'] == 'view':
- return []
- else:
- try:
- table = self.h5file.get_node(params['n'])
- colnames = ['value']
- if type(table) is tables.table.Table:
- colnames = table.colnames
- row = Row()
- row['sql'] = 'CREATE TABLE {} ({})'.format(params['n'], ", ".join(colnames))
- return [row]
- except:
- return []
- else:
- raise Exception("SQLite queries cannot be executed with this connector: %s, %s" % (sql, params))
-
-
-class Row(list):
- def __init__(self, values=None):
- self.labels = []
- self.values = []
- if values:
- for idx in values:
- self.__setitem__(idx, values[idx])
-
- def __setitem__(self, idx, value):
- if type(idx) is str:
- if idx in self.labels:
- self.values[self.labels.index(idx)] = value
- else:
- self.labels.append(idx)
- self.values.append(value)
- else:
- self.values[idx] = value
-
- def __getitem__(self, idx):
- if type(idx) is str:
- return self.values[self.labels.index(idx)]
- else:
- return self.values[idx]
+ description += ((field,),)
- def __iter__(self):
- return self.values.__iter__()
+ return results, truncated, description