]> git.jsancho.org Git - datasette-connectors.git/commitdiff
Overwriting Connector class is enough to operate with
authorJavier Sancho <jsf@jsancho.org>
Thu, 17 Sep 2020 08:55:30 +0000 (10:55 +0200)
committerJavier Sancho <jsf@jsancho.org>
Thu, 17 Sep 2020 08:55:30 +0000 (10:55 +0200)
datasette-connectors

datasette_connectors/__init__.py
datasette_connectors/connectors.py
datasette_connectors/monkey.py
setup.py
tests/dummy.py
tests/fixtures.py
tests/test_api.py

index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..21cf96b66f2dfa9918c082f14beb4d61a71f0a04 100644 (file)
@@ -0,0 +1 @@
+from .connectors import Connector, OperationalError
index aa8f824cf4d0a51234107c7801e0fd2d3fa8d7a4..f09727f93bc07f59e60343ea3fcf78867e6ac1da 100644 (file)
@@ -1,5 +1,10 @@
 import pkg_resources
 import functools
+import re
+import sqlite3
+
+from .row import Row
+
 
 db_connectors = {}
 
@@ -26,76 +31,216 @@ class ConnectorList:
     def add_connector(name, connector):
         db_connectors[name] = connector
 
-    @staticmethod
-    @for_each_connector
-    def table_names(connector, path):
-        return connector.table_names(path)
-
-    @staticmethod
-    @for_each_connector
-    def hidden_table_names(connector, path):
-        return connector.hidden_table_names(path)
+    class DatabaseNotSupported(Exception):
+        pass
 
     @staticmethod
-    @for_each_connector
-    def view_names(connector, path):
-        return connector.view_names(path)
+    def connect(path):
+        for connector in db_connectors.values():
+            try:
+                return connector.connect(path)
+            except:
+                pass
+        else:
+            raise ConnectorList.DatabaseNotSupported
+
+
+class Connection:
+    def __init__(self, path, connector):
+        self.path = path
+        self.connector = connector
+
+    def execute(self, *args, **kwargs):
+        cursor = Cursor(self)
+        cursor.execute(*args, **kwargs)
+        return cursor
+
+    def cursor(self):
+        return Cursor(self)
+
+    def set_progress_handler(self, handler, n):
+        pass
+
+
+class OperationalError(Exception):
+    pass
+
+
+class Cursor:
+    class QueryNotSupported(Exception):
+        pass
+
+    def __init__(self, conn):
+        self.conn = conn
+        self.connector = conn.connector
+        self.rows = []
+        self.description = ()
+
+    def execute(
+        self,
+        sql,
+        params=None,
+        truncate=False,
+        custom_time_limit=None,
+        page_size=None,
+        log_sql_errors=True,
+    ):
+        if params is None:
+            params = {}
+        results = []
+        truncated = False
+        description = ()
+
+        # Normalize sql
+        sql = sql.strip()
+        sql = ' '.join(sql.split())
+
+        if sql == "select name from sqlite_master where type='table'" or \
+           sql == "select name from sqlite_master where type=\"table\"":
+            results = [{'name': name} for name in self.connector.table_names()]
+        elif sql == "select name from sqlite_master where rootpage = 0 and sql like '%VIRTUAL TABLE%USING FTS%'":
+            results = [{'name': name} for name in self.connector.hidden_table_names()]
+        elif sql == 'select 1 from sqlite_master where tbl_name = "geometry_columns"':
+            if self.connector.detect_spatialite():
+                results = [{'1': '1'}]
+        elif sql == "select name from sqlite_master where type='view'":
+            results = [{'name': name} for name in self.connector.view_names()]
+        elif sql.startswith("select count(*) from ["):
+            match = re.search(r'select count\(\*\) from \[(.*)\]', sql)
+            results = [{'count(*)': self.connector.table_count(match.group(1))}]
+        elif sql.startswith("select count(*) from "):
+            match = re.search(r'select count\(\*\) from (.*)', sql)
+            results = [{'count(*)': self.connector.table_count(match.group(1))}]
+        elif sql.startswith("PRAGMA table_info("):
+            match = re.search(r'PRAGMA table_info\((.*)\)', sql)
+            results = self.connector.table_info(match.group(1))
+        elif sql.startswith("select name from sqlite_master where rootpage = 0 and ( sql like \'%VIRTUAL TABLE%USING FTS%content="):
+            match = re.search(r'select name from sqlite_master where rootpage = 0 and \( sql like \'%VIRTUAL TABLE%USING FTS%content="(.*)"', sql)
+            if self.connector.detect_fts(match.group(1)):
+                results = [{'name': match.group(1)}]
+        elif sql.startswith("PRAGMA foreign_key_list(["):
+            match = re.search(r'PRAGMA foreign_key_list\(\[(.*)\]\)', sql)
+            results = self.connector.foreign_keys(match.group(1))
+        elif sql == "select 1 from sqlite_master where type='table' and name=?":
+            if self.connector.table_exists(params[0]):
+                results = [{'1': '1'}]
+        elif sql == "select sql from sqlite_master where name = :n and type=:t":
+            results = [{'sql': self.connector.table_definition(params['t'], params['n'])}]
+        elif sql == "select sql from sqlite_master where tbl_name = :n and type='index' and sql is not null":
+            results = [{'sql': sql} for sql in self.connector.indices_definition(params['n'])]
+        else:
+            try:
+                results, truncated, description = \
+                    self.connector.execute(
+                        sql,
+                        params=params,
+                        truncate=truncate,
+                        custom_time_limit=custom_time_limit,
+                        page_size=page_size,
+                        log_sql_errors=log_sql_errors,
+                    )
+            except OperationalError as ex:
+                raise sqlite3.OperationalError(*ex.args)
 
-    @staticmethod
-    @for_each_connector
-    def table_columns(connector, path, table):
-        return connector.table_columns(path, table)
+        self.rows = [Row(result) for result in results]
+        self.description = description
 
-    @staticmethod
-    @for_each_connector
-    def primary_keys(connector, path, table):
-        return connector.primary_keys(path, table)
+    def fetchall(self):
+        return self.rows
 
-    @staticmethod
-    @for_each_connector
-    def fts_table(connector, path, table):
-        return connector.fts_table(path, table)
+    def fetchmany(self, max):
+        return self.rows[:max]
 
-    @staticmethod
-    @for_each_connector
-    def get_all_foreign_keys(connector, path):
-        return connector.get_all_foreign_keys(path)
-
-    @staticmethod
-    @for_each_connector
-    def table_counts(connector, path, *args, **kwargs):
-        return connector.table_counts(path, *args, **kwargs)
+    def __getitem__(self, index):
+        return self.rows[index]
 
 
 class Connector:
-    @staticmethod
-    def table_names(path):
-        return []
-
-    @staticmethod
-    def hidden_table_names(path):
-        return []
-
-    @staticmethod
-    def view_names(path):
-        return []
-
-    @staticmethod
-    def table_columns(path, table):
-        return []
-
-    @staticmethod
-    def primary_keys(path, table):
-        return []
-
-    @staticmethod
-    def fts_table(path, table):
-        return None
-
-    @staticmethod
-    def get_all_foreign_keys(path):
-        return {}
-
-    @staticmethod
-    def table_counts(path, *args, **kwargs):
-        return {}
+    connector_type = None
+    connection_class = Connection
+
+    def connect(self, path):
+        return self.connection_class(path, self)
+
+    def table_names(self):
+        """
+        Return a list of table names
+        """
+        raise NotImplementedError
+
+    def hidden_table_names(self):
+        raise NotImplementedError
+
+    def detect_spatialite(self):
+        """
+        Return boolean indicating if geometry_columns exists
+        """
+        raise NotImplementedError
+
+    def view_names(self):
+        """
+        Return a list of view names
+        """
+        raise NotImplementedError
+
+    def table_count(self, table_name):
+        """
+        Return an integer with the rows count of the table
+        """
+        raise NotImplementedError
+
+    def table_info(self, table_name):
+        """
+        Return a list of dictionaries with columns description, with format:
+        [
+            {
+                'idx': 0,
+                'name': 'column1',
+                'primary_key': False,
+            },
+            ...
+        ]
+        """
+        raise NotImplementedError
+
+    def detect_fts(self, table_name):
+        """
+        Return boolean indicating if table has a corresponding FTS virtual table
+        """
+        raise NotImplementedError
+
+    def foreign_keys(self, table_name):
+        """
+        Return a list of dictionaries with foreign keys description
+        id, seq, table_name, from_, to_, on_update, on_delete, match
+        """
+        raise NotImplementedError
+
+    def table_exists(self, table_name):
+        """
+        Return boolean indicating if table exists in the database
+        """
+        raise NotImplementedError
+
+    def table_definition(self, table_type, table_name):
+        """
+        Return string with a 'CREATE TABLE' sql definition
+        """
+        raise NotImplementedError
+
+    def indices_definition(self, table_name):
+        """
+        Return a list of strings with 'CREATE INDEX' sql definitions
+        """
+        raise NotImplementedError
+
+    def execute(
+        self,
+        sql,
+        params=None,
+        truncate=False,
+        custom_time_limit=None,
+        page_size=None,
+        log_sql_errors=True,
+    ):
+        raise NotImplementedError
index 6c656e45882f7efea8aa9bdcf1c3db0daa77fd56..5fd04703bf676b047a04c6a493341ab16ade2abf 100644 (file)
@@ -1,7 +1,11 @@
+import asyncio
 import threading
 import sqlite3
+
 import datasette.views.base
+from datasette.tracer import trace
 from datasette.database import Database
+from datasette.database import Results
 
 from .connectors import ConnectorList
 
@@ -13,36 +17,6 @@ def patch_datasette():
     Monkey patching for original Datasette
     """
 
-    async def table_names(self):
-        try:
-            return await self.original_table_names()
-        except sqlite3.DatabaseError:
-            return ConnectorList.table_names(self.path)
-
-    Database.original_table_names = Database.table_names
-    Database.table_names = table_names
-
-
-    async def hidden_table_names(self):
-        try:
-            return await self.original_hidden_table_names()
-        except sqlite3.DatabaseError:
-            return ConnectorList.hidden_table_names(self.path)
-
-    Database.original_hidden_table_names = Database.hidden_table_names
-    Database.hidden_table_names = hidden_table_names
-
-
-    async def view_names(self):
-        try:
-            return await self.original_view_names()
-        except sqlite3.DatabaseError:
-            return ConnectorList.view_names(self.path)
-
-    Database.original_view_names = Database.view_names
-    Database.view_names = view_names
-
-
     async def table_columns(self, table):
         try:
             return await self.original_table_columns(table)
@@ -73,21 +47,33 @@ def patch_datasette():
     Database.fts_table = fts_table
 
 
-    async def get_all_foreign_keys(self):
+    def connect(self, write=False):
         try:
-            return await self.original_get_all_foreign_keys()
+            # Check if it's a sqlite database
+            conn = self.original_connect(write=write)
+            conn.execute("select name from sqlite_master where type='table'")
+            return conn
         except sqlite3.DatabaseError:
-            return ConnectorList.get_all_foreign_keys(self.path)
+            conn = ConnectorList.connect(self.path)
+            return conn
+
+    Database.original_connect = Database.connect
+    Database.connect = connect
 
-    Database.original_get_all_foreign_keys = Database.get_all_foreign_keys
-    Database.get_all_foreign_keys = get_all_foreign_keys
 
+    async def execute_fn(self, fn):
+        def in_thread():
+            conn = getattr(connections, self.name, None)
+            if not conn:
+                conn = self.connect()
+                if isinstance(conn, sqlite3.Connection):
+                    self.ds._prepare_connection(conn, self.name)
+                setattr(connections, self.name, conn)
+            return fn(conn)
 
-    async def table_counts(self, *args, **kwargs):
-        counts = await self.original_table_counts(**kwargs)
-        # If all tables has None as counts, an error had ocurred
-        if len(list(filter(lambda table_count: table_count is not None, counts.values()))) == 0:
-            return ConnectorList.table_counts(self.path, *args, **kwargs)
+        return await asyncio.get_event_loop().run_in_executor(
+            self.ds.executor, in_thread
+        )
 
-    Database.original_table_counts = Database.table_counts
-    Database.table_counts = table_counts
+    Database.original_execute_fn = Database.execute_fn
+    Database.execute_fn = execute_fn
index 144666284ce9f58c4f819f51bad7490c2c1748dc..8d52909194b98ea299471d051a9ad894e415d524 100644 (file)
--- a/setup.py
+++ b/setup.py
@@ -22,7 +22,9 @@ setup(
     url='https://github.com/pytables/datasette-connectors',
     license='Apache License, Version 2.0',
     packages=['datasette_connectors'],
-    install_requires=['datasette==0.48'],
+    install_requires=[
+        'datasette==0.48',
+    ],
     tests_require=[
         'pytest',
         'aiohttp',
index feadf0a2e1ed2e81c1e703769b5225149a391650..04c6686ad459a81eba587adbed51c984c526eccb 100644 (file)
-from datasette_connectors.row import Row
-from datasette_connectors.connectors import Connector
+import datasette_connectors as dc
 
 
-class DummyConnector(Connector):
-    _connector_type = 'dummy'
+class DummyConnector(dc.Connector):
+    connector_type = 'dummy'
 
-    @staticmethod
-    def table_names(path):
+    def table_names(self):
         return ['table1', 'table2']
 
-    @staticmethod
-    def table_columns(path, table):
-        return ['c1', 'c2', 'c3']
-
-    @staticmethod
-    def get_all_foreign_keys(path):
-        return {
-            'table1': {'incoming': [], 'outgoing': []},
-            'table2': {'incoming': [], 'outgoing': []},
-        }
-
-    @staticmethod
-    def table_counts(path, *args, **kwargs):
-        return {
-            'table1': 2,
-            'table2': 2,
-        }
-
-
-def inspect(path):
-    tables = {}
-    views = []
-
-    for table in ['table1', 'table2']:
-        tables[table] = {
-            'name': table,
-            'columns': ['c1', 'c2', 'c3'],
-            'primary_keys': [],
-            'count': 2,
-            'label_column': None,
-            'hidden': False,
-            'fts_table': None,
-            'foreign_keys': {'incoming': [], 'outgoing': []},
-        }
-
-    return tables, views, _connector_type
-
-
-class Connection:
-    def __init__(self, path):
-        self.path = path
-
-    def execute(self, sql, params=None, truncate=False, page_size=None, max_returned_rows=None):
-        sql = sql.strip()
-
-        rows = []
+    def hidden_table_names(self):
+        return []
+
+    def detect_spatialite(self):
+        return False
+
+    def view_names(self):
+        return []
+
+    def table_count(self, table_name):
+        return 2
+
+    def table_info(self, table_name):
+        return [
+            {
+                'idx': 0,
+                'name': 'c1',
+                'primary_key': False,
+            },
+            {
+                'idx': 0,
+                'name': 'c2',
+                'primary_key': False,
+            },
+            {
+                'idx': 0,
+                'name': 'c3',
+                'primary_key': False,
+            },
+        ]
+
+    def detect_fts(self, table_name):
+        return False
+
+    def foreign_keys(self, table_name):
+        return []
+
+    def table_exists(self, table_name):
+        return table_name in ['table1', 'table2']
+
+    def table_definition(self, table_type, table_name):
+        return 'CREATE TABLE ' + table_name + ' (c1, c2, c3)'
+
+    def indices_definition(self, table_name):
+        return []
+
+    def execute(
+        self,
+        sql,
+        params=None,
+        truncate=False,
+        custom_time_limit=None,
+        page_size=None,
+        log_sql_errors=True,
+    ):
+        results = []
         truncated = False
-        description = []
+        description = ()
 
         if sql == 'select c1 from table1':
-            rows = [
-                Row({'c1': 10}),
-                Row({'c1': 20})
+            results = [
+                {'c1': 10},
+                {'c1': 20},
             ]
             description = (('c1',),)
-        elif sql == 'select rowid, * from table2 order by rowid limit 51':
-            rows = [
-                Row({'rowid': 1, 'c1': 100, 'c2': 120, 'c3': 130}),
-                Row({'rowid': 2, 'c1': 200, 'c2': 220, 'c3': 230})
-            ]
-            description = (('rowid',), ('c1',), ('c2',), ('c3',))
-        elif sql == 'select count(*) from table2':
-            rows = [Row({'count(*)': 2})]
-            description = (('count(*)',),)
-        elif sql == """select distinct rowid from table2 
-                        where rowid is not null
-                        limit 31""":
-            rows = [
-                Row({'rowid': 1}),
-                Row({'rowid': 2})
+        elif sql == 'select c1, c2, c3 from table2 limit 51':
+            results = [
+                {'c1': 100, 'c2': 120, 'c3': 130},
+                {'c1': 200, 'c2': 220, 'c3': 230},
             ]
-            description = (('rowid',),)
-        elif sql == """select distinct c1 from table2 
-                        where c1 is not null
-                        limit 31""":
-            rows = [
-                Row({'c1': 100}),
-                Row({'c1': 200})
+            description = (('c1',), ('c2',), ('c3',))
+        elif sql == "select * from (select c1, c2, c3 from table2 ) limit 0":
+            pass
+        elif sql == "select c1, count(*) as n from ( select c1, c2, c3 from table2 ) where c1 is not null group by c1 limit 31":
+            results = [
+                {'c1': 100, 'n': 1},
+                {'c1': 200, 'n': 1},
             ]
-            description = (('c1',),)
-        elif sql == """select distinct c2 from table2 
-                        where c2 is not null
-                        limit 31""":
-            rows = [
-                Row({'c2': 120}),
-                Row({'c2': 220})
+            description = (('c1',), ('n',))
+        elif sql == "select c2, count(*) as n from ( select c1, c2, c3 from table2 ) where c2 is not null group by c2 limit 31":
+            results = [
+                {'c2': 120, 'n': 1},
+                {'c2': 220, 'n': 1},
             ]
-            description = (('c2',),)
-        elif sql == """select distinct c3 from table2 
-                        where c3 is not null
-                        limit 31""":
-            rows = [
-                Row({'c3': 130}),
-                Row({'c3': 230})
+            description = (('c2',), ('n',))
+        elif sql == "select c3, count(*) as n from ( select c1, c2, c3 from table2 ) where c3 is not null group by c3 limit 31":
+            results = [
+                {'c3': 130, 'n': 1},
+                {'c3': 230, 'n': 1},
             ]
-            description = (('c3',),)
-        elif sql == 'select sql from sqlite_master where name = :n and type=:t':
-            if params['t'] != 'view':
-                rows = [Row({'sql': 'CREATE TABLE ' + params['n'] + ' (c1, c2, c3)'})]
-                description = (('sql',),)
+            description = (('c3',), ('n',))
+        elif sql == 'select date(c1) from ( select c1, c2, c3 from table2 ) where c1 glob "????-??-*" limit 100;':
+            pass
+        elif sql == "select c1, c2, c3 from blah limit 51":
+            raise dc.OperationalError("no such table: blah")
         else:
-            raise Exception("Unexpected query: %s" % sql)
+            raise Exception("Unexpected query:", sql)
 
-        return rows, truncated, description
+        return results, truncated, description
index 70a59d841c772eaf95c79d884e403c7d2f19d543..4cb60cef297aaf5fe12e21ff3591e94f4c9ea613 100644 (file)
@@ -1,7 +1,7 @@
 from datasette_connectors import monkey; monkey.patch_datasette()
 from datasette_connectors.connectors import ConnectorList
 from .dummy import DummyConnector
-ConnectorList.add_connector('dummy', DummyConnector)
+ConnectorList.add_connector('dummy', DummyConnector())
 
 from datasette.app import Datasette
 from datasette.utils.testing import TestClient
index 2d74c95f15fe2edd656d11c2226bc8e583d71beb..25bd29757da230af08fece32442bbcd7398dc487 100644 (file)
@@ -39,93 +39,75 @@ def test_custom_sql(app_client):
     assert not data['truncated']
 
 def test_invalid_custom_sql(app_client):
-    response = app_client.get(
-        '/dummy_tables.json?sql=.schema',
-        gather_request=False
-    )
+    response = app_client.get('/dummy_tables.json?sql=.schema')
     assert response.status == 400
     assert response.json['ok'] is False
     assert 'Statement must be a SELECT' == response.json['error']
 
 def test_table_json(app_client):
-    response = app_client.get(
-        '/dummy_tables/table2.json?_shape=objects',
-        gather_request=False
-    )
+    response = app_client.get('/dummy_tables/table2.json?_shape=objects')
     assert response.status == 200
     data = response.json
-    assert data['query']['sql'] == 'select rowid, * from table2 order by rowid limit 51'
-    assert data['rows'] == [{
-        'rowid': 1,
-        'c1': 100,
-        'c2': 120,
-        'c3': 130
-    }, {
-        'rowid': 2,
-        'c1': 200,
-        'c2': 220,
-        'c3': 230
-    }]
+    assert data['query']['sql'] == 'select c1, c2, c3 from table2 limit 51'
+    assert data['rows'] == [
+        {
+            'c1': 100,
+            'c2': 120,
+            'c3': 130,
+        },
+        {
+            'c1': 200,
+            'c2': 220,
+            'c3': 230,
+        }]
 
 def test_table_not_exists_json(app_client):
     assert {
         'ok': False,
-        'error': 'Table not found: blah',
-        'status': 404,
-        'title': None,
-    } == app_client.get(
-        '/dummy_tables/blah.json', gather_request=False
-    ).json
+        'title': 'Invalid SQL',
+        'error': 'no such table: blah',
+        'status': 400,
+    } == app_client.get('/dummy_tables/blah.json').json
 
 def test_table_shape_arrays(app_client):
-    response = app_client.get(
-        '/dummy_tables/table2.json?_shape=arrays',
-        gather_request=False
-    )
+    response = app_client.get('/dummy_tables/table2.json?_shape=arrays')
     assert [
-        [1, 100, 120, 130],
-        [2, 200, 220, 230],
+        [100, 120, 130],
+        [200, 220, 230],
     ] == response.json['rows']
 
 def test_table_shape_objects(app_client):
-    response = app_client.get(
-        '/dummy_tables/table2.json?_shape=objects',
-        gather_request=False
-    )
-    assert [{
-        'rowid': 1,
-        'c1': 100,
-        'c2': 120,
-        'c3': 130,
-    }, {
-        'rowid': 2,
-        'c1': 200,
-        'c2': 220,
-        'c3': 230,
-    }] == response.json['rows']
+    response = app_client.get('/dummy_tables/table2.json?_shape=objects')
+    assert [
+        {
+            'c1': 100,
+            'c2': 120,
+            'c3': 130,
+        },
+        {
+            'c1': 200,
+            'c2': 220,
+            'c3': 230,
+        },
+    ] == response.json['rows']
 
 def test_table_shape_array(app_client):
-    response = app_client.get(
-        '/dummy_tables/table2.json?_shape=array',
-        gather_request=False
-    )
-    assert [{
-        'rowid': 1,
-        'c1': 100,
-        'c2': 120,
-        'c3': 130,
-    }, {
-        'rowid': 2,
-        'c1': 200,
-        'c2': 220,
-        'c3': 230,
-    }] == response.json
+    response = app_client.get('/dummy_tables/table2.json?_shape=array')
+    assert [
+        {
+            'c1': 100,
+            'c2': 120,
+            'c3': 130,
+        },
+        {
+            'c1': 200,
+            'c2': 220,
+            'c3': 230,
+        },
+    ] == response.json
 
 def test_table_shape_invalid(app_client):
-    response = app_client.get(
-        '/dummy_tables/table2.json?_shape=invalid',
-        gather_request=False
-    )
+    response = app_client.get('/dummy_tables/table2.json?_shape=invalid')
     assert {
         'ok': False,
         'error': 'Invalid _shape: invalid',