]> git.jsancho.org Git - datasette-pytables.git/blobdiff - datasette_pytables/__init__.py
Fix number of rows returned when limit is present or pagination is needed
[datasette-pytables.git] / datasette_pytables / __init__.py
index d9530ae9a7edb969a64628c018d2b2c83d29c051..86f879c0a1b423019ccb88de9616e8ecfee0ce1f 100644 (file)
@@ -1,4 +1,3 @@
-from collections import OrderedDict
 from moz_sql_parser import parse
 import re
 import tables
@@ -71,7 +70,7 @@ class Connection:
         self.path = path
         self.h5file = tables.open_file(path)
 
-    def execute(self, sql, params=None, truncate=False, page_size=None):
+    def execute(self, sql, params=None, truncate=False, page_size=None, max_returned_rows=None):
         if params is None:
             params = {}
         rows = []
@@ -81,7 +80,9 @@ class Connection:
         parsed_sql = _parse_sql(sql, params)
 
         if parsed_sql['from'] == 'sqlite_master':
-            return self._execute_datasette_query(sql, params)
+            rows = self._execute_datasette_query(sql, params)
+            description = (('value',))
+            return rows, truncated, description
 
         table = self.h5file.get_node(parsed_sql['from'])
         table_rows = []
@@ -94,7 +95,10 @@ class Connection:
         # Use 'where' statement or get all the rows
         def _cast_param(field, pname):
             # Cast value to the column type
-            coltype = table.coltypes[field]
+            if type(table) is tables.table.Table:
+                coltype = table.coltypes[field]
+            else:
+                coltype = table.dtype.name
             fcast = None
             if coltype == 'string':
                 fcast = str
@@ -107,6 +111,7 @@ class Connection:
 
         def _translate_where(where):
             # Translate SQL to PyTables expression
+            nonlocal start, end
             expr = ''
             operator = list(where)[0]
 
@@ -118,9 +123,10 @@ class Connection:
             elif operator == 'exists':
                 pass
             elif where == {'eq': ['rowid', 'p0']}:
-                nonlocal start, end
                 start = int(params['p0'])
                 end = start + 1
+            elif where == {'gt': ['rowid', 'p0']}:
+                start = int(params['p0']) + 1
             else:
                 left, right = where[operator]
                 if left in params:
@@ -139,16 +145,14 @@ class Connection:
                 query = parsed_sql['where']
 
         # Limit number of rows
+        limit = None
         if 'limit' in parsed_sql:
-            max_rows = int(parsed_sql['limit'])
-            if end - start > max_rows:
-                end = start + max_rows
+            limit = int(parsed_sql['limit'])
 
         # Truncate if needed
-        if page_size and truncate:
-            if end - start > page_size:
-                end = start + page_size
-                truncated = True
+        if page_size and max_returned_rows and truncate:
+            if max_returned_rows == page_size:
+                max_returned_rows += 1
 
         # Execute query
         if query:
@@ -162,7 +166,14 @@ class Connection:
             rows.append(Row({'count(*)': int(table.nrows)}))
         else:
             if type(table) is tables.table.Table:
+                count = 0
                 for table_row in table_rows:
+                    count += 1
+                    if limit and count > limit:
+                        break
+                    if truncate and max_returned_rows and count > max_returned_rows:
+                        truncated = True
+                        break
                     row = Row()
                     for field in fields:
                         field_name = field['value']
@@ -182,7 +193,14 @@ class Connection:
             else:
                 # Any kind of array
                 rowid = start - 1
+                count = 0
                 for table_row in table_rows:
+                    count += 1
+                    if limit and count > limit:
+                        break
+                    if truncate and max_returned_rows and count > max_returned_rows:
+                        truncated = True
+                        break
                     row = Row()
                     rowid += 1
                     for field in fields:
@@ -210,10 +228,7 @@ class Connection:
                 description.append((field,))
 
         # Return the rows
-        if truncate:
-            return rows, truncated, tuple(description)
-        else:
-            return rows
+        return rows, truncated, tuple(description)
 
     def _execute_datasette_query(self, sql, params):
         "Datasette special queries for getting tables info"