]> git.jsancho.org Git - datasette-pytables.git/commitdiff
Fix number of rows returned when limit is present or pagination is needed
authorJavier Sancho <jsf@jsancho.org>
Tue, 29 May 2018 11:03:09 +0000 (13:03 +0200)
committerJavier Sancho <jsf@jsancho.org>
Tue, 29 May 2018 11:03:09 +0000 (13:03 +0200)
datasette_pytables/__init__.py
tests/test_api.py

index b3d5b9a7d0ab3ec31eacf6bfe410169130fe094a..86f879c0a1b423019ccb88de9616e8ecfee0ce1f 100644 (file)
@@ -145,18 +145,14 @@ class Connection:
                 query = parsed_sql['where']
 
         # Limit number of rows
                 query = parsed_sql['where']
 
         # Limit number of rows
+        limit = None
         if 'limit' in parsed_sql:
         if 'limit' in parsed_sql:
-            max_rows = int(parsed_sql['limit'])
-            if end - start > max_rows:
-                end = start + max_rows
+            limit = int(parsed_sql['limit'])
 
         # Truncate if needed
         if page_size and max_returned_rows and truncate:
             if max_returned_rows == page_size:
                 max_returned_rows += 1
 
         # Truncate if needed
         if page_size and max_returned_rows and truncate:
             if max_returned_rows == page_size:
                 max_returned_rows += 1
-            if end - start > max_returned_rows:
-                end = start + max_returned_rows
-                truncated = True
 
         # Execute query
         if query:
 
         # Execute query
         if query:
@@ -170,7 +166,14 @@ class Connection:
             rows.append(Row({'count(*)': int(table.nrows)}))
         else:
             if type(table) is tables.table.Table:
             rows.append(Row({'count(*)': int(table.nrows)}))
         else:
             if type(table) is tables.table.Table:
+                count = 0
                 for table_row in table_rows:
                 for table_row in table_rows:
+                    count += 1
+                    if limit and count > limit:
+                        break
+                    if truncate and max_returned_rows and count > max_returned_rows:
+                        truncated = True
+                        break
                     row = Row()
                     for field in fields:
                         field_name = field['value']
                     row = Row()
                     for field in fields:
                         field_name = field['value']
@@ -190,7 +193,14 @@ class Connection:
             else:
                 # Any kind of array
                 rowid = start - 1
             else:
                 # Any kind of array
                 rowid = start - 1
+                count = 0
                 for table_row in table_rows:
                 for table_row in table_rows:
+                    count += 1
+                    if limit and count > limit:
+                        break
+                    if truncate and max_returned_rows and count > max_returned_rows:
+                        truncated = True
+                        break
                     row = Row()
                     rowid += 1
                     for field in fields:
                     row = Row()
                     rowid += 1
                     for field in fields:
index 39c41228019dfbebb53c84803dc7ffa05d29ebb9..97e07dc560cb4202f927c63b33b8990753c2ca0e 100644 (file)
@@ -55,6 +55,30 @@ def test_database_page(app_client):
     }] == data['tables']
 
 def test_custom_sql(app_client):
     }] == data['tables']
 
 def test_custom_sql(app_client):
+    response = app_client.get(
+        '/test_tables.json?' + urlencode({
+            'sql': 'select identity from [/group1/table1]',
+            '_shape': 'objects'
+        }),
+        gather_request=False
+    )
+    data = response.json
+    assert {
+        'sql': 'select identity from [/group1/table1]',
+        'params': {}
+    } == data['query']
+    assert 1000 == len(data['rows'])
+    assert [
+        {'identity': 'This is particle:  0'},
+        {'identity': 'This is particle:  1'},
+        {'identity': 'This is particle:  2'},
+        {'identity': 'This is particle:  3'}
+    ] == data['rows'][:4]
+    assert ['identity'] == data['columns']
+    assert 'test_tables' == data['database']
+    assert data['truncated']
+
+def test_custom_complex_sql(app_client):
     response = app_client.get(
         '/test_tables.json?' + urlencode({
             'sql': 'select identity from [/group1/table1] where speed > 100 and idnumber < 55',
     response = app_client.get(
         '/test_tables.json?' + urlencode({
             'sql': 'select identity from [/group1/table1] where speed > 100 and idnumber < 55',
@@ -63,7 +87,6 @@ def test_custom_sql(app_client):
         gather_request=False
     )
     data = response.json
         gather_request=False
     )
     data = response.json
-    print("*************************", data)
     assert {
         'sql': 'select identity from [/group1/table1] where speed > 100 and idnumber < 55',
         'params': {}
     assert {
         'sql': 'select identity from [/group1/table1] where speed > 100 and idnumber < 55',
         'params': {}
@@ -77,7 +100,7 @@ def test_custom_sql(app_client):
     ] == data['rows']
     assert ['identity'] == data['columns']
     assert 'test_tables' == data['database']
     ] == data['rows']
     assert ['identity'] == data['columns']
     assert 'test_tables' == data['database']
-    assert False == data['truncated']
+    assert not data['truncated']
 
 def test_custom_pytables_sql(app_client):
     response = app_client.get(
 
 def test_custom_pytables_sql(app_client):
     response = app_client.get(
@@ -100,7 +123,7 @@ def test_custom_pytables_sql(app_client):
     ] == data['rows'][:3]
     assert ['identity'] == data['columns']
     assert 'test_tables' == data['database']
     ] == data['rows'][:3]
     assert ['identity'] == data['columns']
     assert 'test_tables' == data['database']
-    assert data['truncated']
+    assert not data['truncated']
 
 def test_invalid_custom_sql(app_client):
     response = app_client.get(
 
 def test_invalid_custom_sql(app_client):
     response = app_client.get(