X-Git-Url: https://git.jsancho.org/?p=datasette-pytables.git;a=blobdiff_plain;f=datasette_pytables%2F__init__.py;h=a1f11d91e3fb2e8b142068703fd4a8181e49ac33;hp=e8d0a6ccb4519f72d2e0680e0c3b4b8f44dc2574;hb=24c552a5f4d88eaf4eb30a8a9fb06b0d7ec75bc8;hpb=2cf7538022a4290be1cdbe6f8cf13e44f8190bee diff --git a/datasette_pytables/__init__.py b/datasette_pytables/__init__.py index e8d0a6c..a1f11d9 100644 --- a/datasette_pytables/__init__.py +++ b/datasette_pytables/__init__.py @@ -37,7 +37,19 @@ def _parse_sql(sql, params): for param in params: sql = sql.replace(":" + param, param) - parsed = parse(sql) + try: + parsed = parse(sql) + except: + # Propably it's a PyTables expression + for token in ['group by', 'order by', 'limit']: + res = re.search('(?i)where (.*)' + token, sql) + if res: + modified_sql = re.sub('(?i)where (.*)(' + token + ')', '\g<2>', sql) + parsed = parse(modified_sql) + parsed['where'] = res.group(1) + break + + # Always a list of fields if type(parsed['select']) is not list: parsed['select'] = [parsed['select']] @@ -95,8 +107,12 @@ class Connection: operator = list(where)[0] if operator in ['and', 'or']: - subexpr = ["({})".format(_translate_where(q)) for q in where[operator]] + subexpr = [_translate_where(e) for e in where[operator]] + subexpr = filter(lambda e: e, subexpr) + subexpr = ["({})".format(e) for e in subexpr] expr = " {} ".format(_operators[operator]).join(subexpr) + elif operator == 'exists': + pass elif where == {'eq': ['rowid', 'p0']}: nonlocal start, end start = int(params['p0']) @@ -113,11 +129,10 @@ class Connection: return expr if 'where' in parsed_sql: - try: + if type(parsed_sql['where']) is dict: query = _translate_where(parsed_sql['where']) - except: - # Probably it's a PyTables query - query = str(parsed_sql['where'])[6:] # without where keyword + else: + query = parsed_sql['where'] # Limit number of rows if 'limit' in parsed_sql: @@ -139,16 +154,19 @@ class Connection: for table_row in table_rows: row = Row() for field in fields: - if field['value'] == 'rowid': + field_name = field['value'] + if type(field_name) is dict and 'distinct' in field_name: + field_name = field_name['distinct'] + if field_name == 'rowid': row['rowid'] = int(table_row.nrow) - elif field['value'] == '*': + elif field_name == '*': for col in table.colnames: value = table_row[col] if type(value) is bytes: value = value.decode('utf-8') row[col] = value else: - row[field['value']] = table_row[field['value']] + row[field_name] = table_row[field_name] rows.append(row) # Prepare query description