]> git.jsancho.org Git - datasette-pytables.git/blobdiff - datasette_pytables/__init__.py
Support queries in PyTables style
[datasette-pytables.git] / datasette_pytables / __init__.py
index 217214f6f4d7ece9b9566e8018f85ee85c5ee47f..50628e7848ae704859bc41300b3f4c71f5531821 100644 (file)
@@ -19,7 +19,7 @@ def inspect(path):
             'name': table._v_pathname,
             'columns': colnames,
             'primary_keys': [],
-            'count': table.nrows,
+            'count': int(table.nrows),
             'label_column': None,
             'hidden': False,
             'fts_table': None,
@@ -99,19 +99,26 @@ class Connection:
 
         # Use 'where' statement or get all the rows
         if 'where' in parsed_sql:
-            query = []
+            query = ''
             start = 0
             end = table.nrows
-            for condition in parsed_sql['where'].get_sublists():
-                if str(condition) == '"rowid"=:p0':
-                    start = int(params['p0'])
-                    end = start + 1
-                else:
-                    translated, params = _translate_condition(table, condition, params)
-                    query.append(translated)
+            try:
+                conditions = []
+                for condition in parsed_sql['where'].get_sublists():
+                    if str(condition) == '"rowid"=:p0':
+                        start = int(params['p0'])
+                        end = start + 1
+                    else:
+                        translated, params = _translate_condition(table, condition, params)
+                        conditions.append(translated)
+                if conditions:
+                    query = ') & ('.join(conditions)
+                    query = '(' + query + ')'
+            except:
+                # Probably it's a PyTables query
+                query = str(parsed_sql['where'])[6:]    # without where keyword
+
             if query:
-                query = ') & ('.join(query)
-                query = '(' + query + ')'
                 table_rows = table.where(query, params, start, end)
             else:
                 table_rows = table.iterrows(start, end)
@@ -120,13 +127,13 @@ class Connection:
 
         # Prepare rows
         if len(fields) == 1 and fields[0] == 'count(*)':
-            rows.append(Row({fields[0]: table.nrows}))
+            rows.append(Row({fields[0]: int(table.nrows)}))
         else:
             for table_row in table_rows:
                 row = Row()
                 for field in fields:
                     if field == 'rowid':
-                        row[field] = table_row.nrow
+                        row[field] = int(table_row.nrow)
                     elif field == '*':
                         for col in table.colnames:
                             value = table_row[col]