]> git.jsancho.org Git - datasette-connectors.git/commitdiff
Adapting project to last datasette version (WIP)
authorJavier Sancho <jsf@jsancho.org>
Sun, 23 Aug 2020 11:51:58 +0000 (13:51 +0200)
committerJavier Sancho <jsf@jsancho.org>
Sun, 23 Aug 2020 11:51:58 +0000 (13:51 +0200)
datasette_connectors/cli.py
datasette_connectors/connectors.py
datasette_connectors/monkey.py
setup.py
tests/dummy.py
tests/fixtures.py
tests/test_api.py

index 74375e017bc7cd65acea23d5ae826cdcc8ab6087..0fe48dcccbac8b1ce649b9f329e8d14f9839ae21 100644 (file)
@@ -1,3 +1,3 @@
 from .monkey import patch_datasette; patch_datasette()
-from .connectors import load; load()
+from .connectors import ConnectorList; ConnectorList.load()
 from datasette.cli import cli
index e3d10ba4dacd0fd5da758558377a38c6b9924c13..aa8f824cf4d0a51234107c7801e0fd2d3fa8d7a4 100644 (file)
 import pkg_resources
+import functools
 
 db_connectors = {}
 
-def load():
-    for entry_point in pkg_resources.iter_entry_points('datasette.connectors'):
-        db_connectors[entry_point.name] = entry_point.load()
-
-def inspect(path):
-    for connector in db_connectors.values():
-        try:
-            return connector.inspect(path)
-        except:
-            pass
-    else:
-        raise Exception("No database connector found for %s" % path)
-
-def connect(path, dbtype):
-    try:
-        return db_connectors[dbtype].Connection(path)
-    except:
-        raise Exception("No database connector found for %s" % path)
+def for_each_connector(func):
+    @functools.wraps(func)
+    def wrapper_for_each_connector(*args, **kwargs):
+        for connector in db_connectors.values():
+            try:
+                return func(connector, *args, **kwargs)
+            except:
+                pass
+        else:
+            raise Exception("No database connector found!!")
+    return wrapper_for_each_connector
+
+
+class ConnectorList:
+    @staticmethod
+    def load():
+        for entry_point in pkg_resources.iter_entry_points('datasette.connectors'):
+            db_connectors[entry_point.name] = entry_point.load()
+
+    @staticmethod
+    def add_connector(name, connector):
+        db_connectors[name] = connector
+
+    @staticmethod
+    @for_each_connector
+    def table_names(connector, path):
+        return connector.table_names(path)
+
+    @staticmethod
+    @for_each_connector
+    def hidden_table_names(connector, path):
+        return connector.hidden_table_names(path)
+
+    @staticmethod
+    @for_each_connector
+    def view_names(connector, path):
+        return connector.view_names(path)
+
+    @staticmethod
+    @for_each_connector
+    def table_columns(connector, path, table):
+        return connector.table_columns(path, table)
+
+    @staticmethod
+    @for_each_connector
+    def primary_keys(connector, path, table):
+        return connector.primary_keys(path, table)
+
+    @staticmethod
+    @for_each_connector
+    def fts_table(connector, path, table):
+        return connector.fts_table(path, table)
+
+    @staticmethod
+    @for_each_connector
+    def get_all_foreign_keys(connector, path):
+        return connector.get_all_foreign_keys(path)
+
+    @staticmethod
+    @for_each_connector
+    def table_counts(connector, path, *args, **kwargs):
+        return connector.table_counts(path, *args, **kwargs)
+
+
+class Connector:
+    @staticmethod
+    def table_names(path):
+        return []
+
+    @staticmethod
+    def hidden_table_names(path):
+        return []
+
+    @staticmethod
+    def view_names(path):
+        return []
+
+    @staticmethod
+    def table_columns(path, table):
+        return []
+
+    @staticmethod
+    def primary_keys(path, table):
+        return []
+
+    @staticmethod
+    def fts_table(path, table):
+        return None
+
+    @staticmethod
+    def get_all_foreign_keys(path):
+        return {}
+
+    @staticmethod
+    def table_counts(path, *args, **kwargs):
+        return {}
index e18175f38889f69b8eabc1590f9fc680421b24f9..6c656e45882f7efea8aa9bdcf1c3db0daa77fd56 100644 (file)
@@ -1,12 +1,11 @@
-import asyncio
-import datasette
-from datasette.app import connections
-from datasette.inspect import inspect_hash
-from datasette.utils import Results
-from pathlib import Path
+import threading
 import sqlite3
+import datasette.views.base
+from datasette.database import Database
 
-from . import connectors
+from .connectors import ConnectorList
+
+connections = threading.local()
 
 
 def patch_datasette():
@@ -14,74 +13,81 @@ def patch_datasette():
     Monkey patching for original Datasette
     """
 
-    def inspect(self):
-        " Inspect the database and return a dictionary of table metadata "
-        if self._inspect:
-            return self._inspect
-
-        _inspect = {}
-        files = self.files
-
-        for filename in files:
-            self.files = (filename,)
-            path = Path(filename)
-            name = path.stem
-            if name in _inspect:
-                raise Exception("Multiple files with the same stem %s" % name)
-            try:
-                _inspect[name] = self.original_inspect()[name]
-            except sqlite3.DatabaseError:
-                tables, views, dbtype = connectors.inspect(path)
-                _inspect[name] = {
-                    "hash": inspect_hash(path),
-                    "file": str(path),
-                    "dbtype": dbtype,
-                    "tables": tables,
-                    "views": views,
-                }
-
-        self.files = files
-        self._inspect = _inspect
-        return self._inspect
-
-    datasette.app.Datasette.original_inspect = datasette.app.Datasette.inspect
-    datasette.app.Datasette.inspect = inspect
-
-
-    async def execute(self, db_name, sql, params=None, truncate=False, custom_time_limit=None, page_size=None):
-        """Executes sql against db_name in a thread"""
-        page_size = page_size or self.page_size
-
-        def is_sqlite3_conn():
-            conn = getattr(connections, db_name, None)
-            if not conn:
-                info = self.inspect()[db_name]
-                return info.get('dbtype', 'sqlite3') == 'sqlite3'
-            else:
-                return isinstance(conn, sqlite3.Connection)
-
-        def sql_operation_in_thread():
-            conn = getattr(connections, db_name, None)
-            if not conn:
-                info = self.inspect()[db_name]
-                conn = connectors.connect(info['file'], info['dbtype'])
-                setattr(connections, db_name, conn)
-
-            rows, truncated, description = conn.execute(
-                sql,
-                params or {},
-                truncate=truncate,
-                page_size=page_size,
-                max_returned_rows=self.max_returned_rows,
-            )
-            return Results(rows, truncated, description)
-
-        if is_sqlite3_conn():
-            return await self.original_execute(db_name, sql, params=params, truncate=truncate, custom_time_limit=custom_time_limit, page_size=page_size)
-        else:
-            return await asyncio.get_event_loop().run_in_executor(
-                self.executor, sql_operation_in_thread
-            )
-
-    datasette.app.Datasette.original_execute = datasette.app.Datasette.execute
-    datasette.app.Datasette.execute = execute
+    async def table_names(self):
+        try:
+            return await self.original_table_names()
+        except sqlite3.DatabaseError:
+            return ConnectorList.table_names(self.path)
+
+    Database.original_table_names = Database.table_names
+    Database.table_names = table_names
+
+
+    async def hidden_table_names(self):
+        try:
+            return await self.original_hidden_table_names()
+        except sqlite3.DatabaseError:
+            return ConnectorList.hidden_table_names(self.path)
+
+    Database.original_hidden_table_names = Database.hidden_table_names
+    Database.hidden_table_names = hidden_table_names
+
+
+    async def view_names(self):
+        try:
+            return await self.original_view_names()
+        except sqlite3.DatabaseError:
+            return ConnectorList.view_names(self.path)
+
+    Database.original_view_names = Database.view_names
+    Database.view_names = view_names
+
+
+    async def table_columns(self, table):
+        try:
+            return await self.original_table_columns(table)
+        except sqlite3.DatabaseError:
+            return ConnectorList.table_columns(self.path, table)
+
+    Database.original_table_columns = Database.table_columns
+    Database.table_columns = table_columns
+
+
+    async def primary_keys(self, table):
+        try:
+            return await self.original_primary_keys(table)
+        except sqlite3.DatabaseError:
+            return ConnectorList.primary_keys(self.path, table)
+
+    Database.original_primary_keys = Database.primary_keys
+    Database.primary_keys = primary_keys
+
+
+    async def fts_table(self, table):
+        try:
+            return await self.original_fts_table(table)
+        except sqlite3.DatabaseError:
+            return ConnectorList.fts_table(self.path, table)
+
+    Database.original_fts_table = Database.fts_table
+    Database.fts_table = fts_table
+
+
+    async def get_all_foreign_keys(self):
+        try:
+            return await self.original_get_all_foreign_keys()
+        except sqlite3.DatabaseError:
+            return ConnectorList.get_all_foreign_keys(self.path)
+
+    Database.original_get_all_foreign_keys = Database.get_all_foreign_keys
+    Database.get_all_foreign_keys = get_all_foreign_keys
+
+
+    async def table_counts(self, *args, **kwargs):
+        counts = await self.original_table_counts(**kwargs)
+        # If all tables has None as counts, an error had ocurred
+        if len(list(filter(lambda table_count: table_count is not None, counts.values()))) == 0:
+            return ConnectorList.table_counts(self.path, *args, **kwargs)
+
+    Database.original_table_counts = Database.table_counts
+    Database.table_counts = table_counts
index 916ac86d5b66666104a48908a2ffddde9a82cfb1..144666284ce9f58c4f819f51bad7490c2c1748dc 100644 (file)
--- a/setup.py
+++ b/setup.py
@@ -22,8 +22,12 @@ setup(
     url='https://github.com/pytables/datasette-connectors',
     license='Apache License, Version 2.0',
     packages=['datasette_connectors'],
-    install_requires=['datasette==0.46'],
-    tests_require=['pytest', 'aiohttp'],
+    install_requires=['datasette==0.48'],
+    tests_require=[
+        'pytest',
+        'aiohttp',
+        'asgiref',
+    ],
     entry_points='''
         [console_scripts]
         datasette=datasette_connectors.cli:cli
index b4ae1c097da4c9e200f59cea33d316eb0703b968..feadf0a2e1ed2e81c1e703769b5225149a391650 100644 (file)
@@ -1,7 +1,32 @@
 from datasette_connectors.row import Row
+from datasette_connectors.connectors import Connector
 
 
-_connector_type = 'dummy'
+class DummyConnector(Connector):
+    _connector_type = 'dummy'
+
+    @staticmethod
+    def table_names(path):
+        return ['table1', 'table2']
+
+    @staticmethod
+    def table_columns(path, table):
+        return ['c1', 'c2', 'c3']
+
+    @staticmethod
+    def get_all_foreign_keys(path):
+        return {
+            'table1': {'incoming': [], 'outgoing': []},
+            'table2': {'incoming': [], 'outgoing': []},
+        }
+
+    @staticmethod
+    def table_counts(path, *args, **kwargs):
+        return {
+            'table1': 2,
+            'table2': 2,
+        }
+
 
 def inspect(path):
     tables = {}
index 6b772c67ba258cc75f36b4bda7c8d963245dc4ff..70a59d841c772eaf95c79d884e403c7d2f19d543 100644 (file)
@@ -1,13 +1,15 @@
 from datasette_connectors import monkey; monkey.patch_datasette()
-from datasette_connectors import connectors
-from . import dummy
-connectors.db_connectors['dummy'] = dummy
+from datasette_connectors.connectors import ConnectorList
+from .dummy import DummyConnector
+ConnectorList.add_connector('dummy', DummyConnector)
 
 from datasette.app import Datasette
+from datasette.utils.testing import TestClient
 import os
 import pytest
 import tempfile
 
+
 @pytest.fixture(scope='session')
 def app_client(max_returned_rows=None):
     with tempfile.TemporaryDirectory() as tmpdir:
@@ -20,7 +22,7 @@ def app_client(max_returned_rows=None):
                 'max_returned_rows': max_returned_rows or 1000,
             }
         )
-        client = ds.app().test_client
+        client = TestClient(ds.app())
         client.ds = ds
         yield client
 
index 63555cddf7dc2e62b519708939c4e03ce30773d7..2d74c95f15fe2edd656d11c2226bc8e583d71beb 100644 (file)
@@ -2,7 +2,7 @@ from .fixtures import app_client
 from urllib.parse import urlencode
 
 def test_homepage(app_client):
-    _, response = app_client.get('/.json')
+    response = app_client.get('/.json')
     assert response.status == 200
     assert response.json.keys() == {'dummy_tables': 0}.keys()
     d = response.json['dummy_tables']
@@ -10,28 +10,12 @@ def test_homepage(app_client):
     assert d['tables_count'] == 2
 
 def test_database_page(app_client):
-    response = app_client.get('/dummy_tables.json', gather_request=False)
+    response = app_client.get('/dummy_tables.json')
     data = response.json
     assert 'dummy_tables' == data['database']
-    assert [{
-        'name': 'table1',
-        'columns': ['c1', 'c2', 'c3'],
-        'primary_keys': [],
-        'count': 2,
-        'label_column': None,
-        'hidden': False,
-        'fts_table': None,
-        'foreign_keys': {'incoming': [], 'outgoing': []}
-    }, {
-        'name': 'table2',
-        'columns': ['c1', 'c2', 'c3'],
-        'primary_keys': [],
-        'count': 2,
-        'label_column': None,
-        'hidden': False,
-        'fts_table': None,
-        'foreign_keys': {'incoming': [], 'outgoing': []}
-    }] == data['tables']
+    assert len(data['tables']) == 2
+    assert data['tables'][0]['count'] == 2
+    assert data['tables'][0]['columns'] == ['c1', 'c2', 'c3']
 
 def test_custom_sql(app_client):
     response = app_client.get(
@@ -39,7 +23,6 @@ def test_custom_sql(app_client):
             'sql': 'select c1 from table1',
             '_shape': 'objects'
         }),
-        gather_request=False
     )
     data = response.json
     assert {