]> git.jsancho.org Git - datasette-pytables.git/blobdiff - datasette_pytables/__init__.py
Check limit to return the appropiate number of rows
[datasette-pytables.git] / datasette_pytables / __init__.py
index 0bbb25f9cd63eb55c4ab37212984d0aaea1a221a..b857ee179bb744a1e64dca0fa09834fb8ee35b53 100644 (file)
@@ -97,26 +97,39 @@ class Connection:
         table_rows = []
         fields = parsed_sql['select'].split(',')
 
+        query = ''
+        start = 0
+        end = table.nrows
+
         # Use 'where' statement or get all the rows
         if 'where' in parsed_sql:
-            query = []
-            start = 0
-            end = table.nrows
-            for condition in parsed_sql['where'].get_sublists():
-                if str(condition) == '"rowid"=:p0':
-                    start = int(params['p0'])
-                    end = start + 1
-                else:
-                    translated, params = _translate_condition(table, condition, params)
-                    query.append(translated)
-            if query:
-                query = ') & ('.join(query)
-                query = '(' + query + ')'
-                table_rows = table.where(query, params, start, end)
-            else:
-                table_rows = table.iterrows(start, end)
+            try:
+                conditions = []
+                for condition in parsed_sql['where'].get_sublists():
+                    if str(condition) == '"rowid"=:p0':
+                        start = int(params['p0'])
+                        end = start + 1
+                    else:
+                        translated, params = _translate_condition(table, condition, params)
+                        conditions.append(translated)
+                if conditions:
+                    query = ') & ('.join(conditions)
+                    query = '(' + query + ')'
+            except:
+                # Probably it's a PyTables query
+                query = str(parsed_sql['where'])[6:]    # without where keyword
+
+        # Limit number of rows
+        if 'limit' in parsed_sql:
+            max_rows = int(parsed_sql['limit'])
+            if end - start > max_rows:
+                end = start + max_rows
+
+        # Execute query
+        if query:
+            table_rows = table.where(query, params, start, end)
         else:
-            table_rows = table.iterrows()
+            table_rows = table.iterrows(start, end)
 
         # Prepare rows
         if len(fields) == 1 and fields[0] == 'count(*)':