tables in json format
[datasette-pytables.git] / datasette_pytables / __init__.py
index 9dc2a099bc32666af22122d94f19d39d4da532a5..c4087730d24699c52968cb34b1de0955635b24a0 100644 (file)
@@ -67,6 +67,27 @@ class PyTablesConnector(dc.Connector):
     def foreign_keys(self, table_name):
         return []
 
+    def table_exists(self, table_name):
+        try:
+            self.conn.h5file.get_node(table_name)
+            return True
+        except:
+            return False
+
+    def table_definition(self, table_type, table_name):
+        table = self.conn.h5file.get_node(table_name)
+        colnames = ['value']
+        if isinstance(table, tables.table.Table):
+            colnames = table.colnames
+
+        return 'CREATE TABLE {} ({})'.format(
+            table_name,
+            ', '.join(colnames),
+        )
+
+    def indices_definition(self, table_name):
+        return []
+
     def execute(
         self,
         sql,
@@ -82,6 +103,10 @@ class PyTablesConnector(dc.Connector):
 
         parsed_sql = parse_sql(sql, params)
 
+        while isinstance(parsed_sql['from'], dict):
+            # Pytables does not support subqueries
+            parsed_sql['from'] = parsed_sql['from']['value']['from']
+
         table = self.conn.h5file.get_node(parsed_sql['from'])
         table_rows = []
         fields = parsed_sql['select']
@@ -176,7 +201,8 @@ class PyTablesConnector(dc.Connector):
 
         # Execute query
         if query:
-            table_rows = table.where(query, params, start, end)
+            if not ' glob ' in query:
+                table_rows = table.where(query, params, start, end)
         elif orderby:
             table_rows = table.itersorted(orderby, start=start, stop=end)
         else:
@@ -212,36 +238,43 @@ class PyTablesConnector(dc.Connector):
                     return row
             return get_row_value
 
-        if len(fields) == 1 and type(fields[0]['value']) is dict and \
-           fields[0]['value'].get('count') == '*':
-            results.append({'count(*)': int(table.nrows)})
-        else:
-            get_rowid = make_get_rowid()
-            get_row_value = make_get_row_value()
-            count = 0
-            for table_row in table_rows:
-                count += 1
-                if limit and count > limit:
-                    break
-                if truncate and max_returned_rows and count > max_returned_rows:
-                    truncated = True
-                    break
-                row = {}
-                for field in fields:
+        # Get results
+        get_rowid = make_get_rowid()
+        get_row_value = make_get_row_value()
+        count = 0
+        for table_row in table_rows:
+            count += 1
+            if limit is not None and count > limit:
+                break
+            if truncate and max_returned_rows and count > max_returned_rows:
+                truncated = True
+                break
+            row = {}
+            for field in fields:
+                field_name = field
+                if isinstance(field, dict):
                     field_name = field['value']
-                    if type(field_name) is dict and 'distinct' in field_name:
-                        field_name = field_name['distinct']
-                    if field_name == 'rowid':
-                        row['rowid'] = get_rowid(table_row)
-                    elif field_name == '*':
-                        for col in colnames:
-                            row[col] = normalize_field_value(get_row_value(table_row, col))
+                if isinstance(field_name, dict) and 'distinct' in field_name:
+                    field_name = field_name['distinct']
+                if field_name == 'rowid':
+                    row['rowid'] = get_rowid(table_row)
+                elif field_name == '*':
+                    for col in colnames:
+                        row[col] = normalize_field_value(get_row_value(table_row, col))
+                elif isinstance(field_name, dict):
+                    if field_name.get('count') == '*':
+                        row['count(*)'] = int(table.nrows)
+                    elif field_name.get('json_type'):
+                        field_name = field_name.get('json_type')
+                        row['json_type(' + field_name + ')'] = table.coltypes[field_name]
                     else:
-                        row[field_name] = normalize_field_value(get_row_value(table_row, field_name))
-                results.append(row)
+                        raise Exception("Function not recognized")
+                else:
+                    row[field_name] = normalize_field_value(get_row_value(table_row, field_name))
+            results.append(row)
 
         # Prepare query description
-        for field in [f['value'] for f in fields]:
+        for field in [f['value'] if isinstance(f, dict) else f for f in fields]:
             if field == '*':
                 for col in colnames:
                     description += ((col,),)