From 52416a749fac092a032a8b5239e477dd68180dfa Mon Sep 17 00:00:00 2001 From: Javier Sancho Date: Sun, 23 Aug 2020 13:51:58 +0200 Subject: [PATCH] Adapting project to last datasette version (WIP) --- datasette_connectors/cli.py | 2 +- datasette_connectors/connectors.py | 115 ++++++++++++++++---- datasette_connectors/monkey.py | 162 +++++++++++++++-------------- setup.py | 8 +- tests/dummy.py | 27 ++++- tests/fixtures.py | 10 +- tests/test_api.py | 27 +---- 7 files changed, 225 insertions(+), 126 deletions(-) diff --git a/datasette_connectors/cli.py b/datasette_connectors/cli.py index 74375e0..0fe48dc 100644 --- a/datasette_connectors/cli.py +++ b/datasette_connectors/cli.py @@ -1,3 +1,3 @@ from .monkey import patch_datasette; patch_datasette() -from .connectors import load; load() +from .connectors import ConnectorList; ConnectorList.load() from datasette.cli import cli diff --git a/datasette_connectors/connectors.py b/datasette_connectors/connectors.py index e3d10ba..aa8f824 100644 --- a/datasette_connectors/connectors.py +++ b/datasette_connectors/connectors.py @@ -1,22 +1,101 @@ import pkg_resources +import functools db_connectors = {} -def load(): - for entry_point in pkg_resources.iter_entry_points('datasette.connectors'): - db_connectors[entry_point.name] = entry_point.load() - -def inspect(path): - for connector in db_connectors.values(): - try: - return connector.inspect(path) - except: - pass - else: - raise Exception("No database connector found for %s" % path) - -def connect(path, dbtype): - try: - return db_connectors[dbtype].Connection(path) - except: - raise Exception("No database connector found for %s" % path) +def for_each_connector(func): + @functools.wraps(func) + def wrapper_for_each_connector(*args, **kwargs): + for connector in db_connectors.values(): + try: + return func(connector, *args, **kwargs) + except: + pass + else: + raise Exception("No database connector found!!") + return wrapper_for_each_connector + + +class ConnectorList: + @staticmethod + def load(): + for entry_point in pkg_resources.iter_entry_points('datasette.connectors'): + db_connectors[entry_point.name] = entry_point.load() + + @staticmethod + def add_connector(name, connector): + db_connectors[name] = connector + + @staticmethod + @for_each_connector + def table_names(connector, path): + return connector.table_names(path) + + @staticmethod + @for_each_connector + def hidden_table_names(connector, path): + return connector.hidden_table_names(path) + + @staticmethod + @for_each_connector + def view_names(connector, path): + return connector.view_names(path) + + @staticmethod + @for_each_connector + def table_columns(connector, path, table): + return connector.table_columns(path, table) + + @staticmethod + @for_each_connector + def primary_keys(connector, path, table): + return connector.primary_keys(path, table) + + @staticmethod + @for_each_connector + def fts_table(connector, path, table): + return connector.fts_table(path, table) + + @staticmethod + @for_each_connector + def get_all_foreign_keys(connector, path): + return connector.get_all_foreign_keys(path) + + @staticmethod + @for_each_connector + def table_counts(connector, path, *args, **kwargs): + return connector.table_counts(path, *args, **kwargs) + + +class Connector: + @staticmethod + def table_names(path): + return [] + + @staticmethod + def hidden_table_names(path): + return [] + + @staticmethod + def view_names(path): + return [] + + @staticmethod + def table_columns(path, table): + return [] + + @staticmethod + def primary_keys(path, table): + return [] + + @staticmethod + def fts_table(path, table): + return None + + @staticmethod + def get_all_foreign_keys(path): + return {} + + @staticmethod + def table_counts(path, *args, **kwargs): + return {} diff --git a/datasette_connectors/monkey.py b/datasette_connectors/monkey.py index e18175f..6c656e4 100644 --- a/datasette_connectors/monkey.py +++ b/datasette_connectors/monkey.py @@ -1,12 +1,11 @@ -import asyncio -import datasette -from datasette.app import connections -from datasette.inspect import inspect_hash -from datasette.utils import Results -from pathlib import Path +import threading import sqlite3 +import datasette.views.base +from datasette.database import Database -from . import connectors +from .connectors import ConnectorList + +connections = threading.local() def patch_datasette(): @@ -14,74 +13,81 @@ def patch_datasette(): Monkey patching for original Datasette """ - def inspect(self): - " Inspect the database and return a dictionary of table metadata " - if self._inspect: - return self._inspect - - _inspect = {} - files = self.files - - for filename in files: - self.files = (filename,) - path = Path(filename) - name = path.stem - if name in _inspect: - raise Exception("Multiple files with the same stem %s" % name) - try: - _inspect[name] = self.original_inspect()[name] - except sqlite3.DatabaseError: - tables, views, dbtype = connectors.inspect(path) - _inspect[name] = { - "hash": inspect_hash(path), - "file": str(path), - "dbtype": dbtype, - "tables": tables, - "views": views, - } - - self.files = files - self._inspect = _inspect - return self._inspect - - datasette.app.Datasette.original_inspect = datasette.app.Datasette.inspect - datasette.app.Datasette.inspect = inspect - - - async def execute(self, db_name, sql, params=None, truncate=False, custom_time_limit=None, page_size=None): - """Executes sql against db_name in a thread""" - page_size = page_size or self.page_size - - def is_sqlite3_conn(): - conn = getattr(connections, db_name, None) - if not conn: - info = self.inspect()[db_name] - return info.get('dbtype', 'sqlite3') == 'sqlite3' - else: - return isinstance(conn, sqlite3.Connection) - - def sql_operation_in_thread(): - conn = getattr(connections, db_name, None) - if not conn: - info = self.inspect()[db_name] - conn = connectors.connect(info['file'], info['dbtype']) - setattr(connections, db_name, conn) - - rows, truncated, description = conn.execute( - sql, - params or {}, - truncate=truncate, - page_size=page_size, - max_returned_rows=self.max_returned_rows, - ) - return Results(rows, truncated, description) - - if is_sqlite3_conn(): - return await self.original_execute(db_name, sql, params=params, truncate=truncate, custom_time_limit=custom_time_limit, page_size=page_size) - else: - return await asyncio.get_event_loop().run_in_executor( - self.executor, sql_operation_in_thread - ) - - datasette.app.Datasette.original_execute = datasette.app.Datasette.execute - datasette.app.Datasette.execute = execute + async def table_names(self): + try: + return await self.original_table_names() + except sqlite3.DatabaseError: + return ConnectorList.table_names(self.path) + + Database.original_table_names = Database.table_names + Database.table_names = table_names + + + async def hidden_table_names(self): + try: + return await self.original_hidden_table_names() + except sqlite3.DatabaseError: + return ConnectorList.hidden_table_names(self.path) + + Database.original_hidden_table_names = Database.hidden_table_names + Database.hidden_table_names = hidden_table_names + + + async def view_names(self): + try: + return await self.original_view_names() + except sqlite3.DatabaseError: + return ConnectorList.view_names(self.path) + + Database.original_view_names = Database.view_names + Database.view_names = view_names + + + async def table_columns(self, table): + try: + return await self.original_table_columns(table) + except sqlite3.DatabaseError: + return ConnectorList.table_columns(self.path, table) + + Database.original_table_columns = Database.table_columns + Database.table_columns = table_columns + + + async def primary_keys(self, table): + try: + return await self.original_primary_keys(table) + except sqlite3.DatabaseError: + return ConnectorList.primary_keys(self.path, table) + + Database.original_primary_keys = Database.primary_keys + Database.primary_keys = primary_keys + + + async def fts_table(self, table): + try: + return await self.original_fts_table(table) + except sqlite3.DatabaseError: + return ConnectorList.fts_table(self.path, table) + + Database.original_fts_table = Database.fts_table + Database.fts_table = fts_table + + + async def get_all_foreign_keys(self): + try: + return await self.original_get_all_foreign_keys() + except sqlite3.DatabaseError: + return ConnectorList.get_all_foreign_keys(self.path) + + Database.original_get_all_foreign_keys = Database.get_all_foreign_keys + Database.get_all_foreign_keys = get_all_foreign_keys + + + async def table_counts(self, *args, **kwargs): + counts = await self.original_table_counts(**kwargs) + # If all tables has None as counts, an error had ocurred + if len(list(filter(lambda table_count: table_count is not None, counts.values()))) == 0: + return ConnectorList.table_counts(self.path, *args, **kwargs) + + Database.original_table_counts = Database.table_counts + Database.table_counts = table_counts diff --git a/setup.py b/setup.py index 916ac86..1446662 100644 --- a/setup.py +++ b/setup.py @@ -22,8 +22,12 @@ setup( url='https://github.com/pytables/datasette-connectors', license='Apache License, Version 2.0', packages=['datasette_connectors'], - install_requires=['datasette==0.46'], - tests_require=['pytest', 'aiohttp'], + install_requires=['datasette==0.48'], + tests_require=[ + 'pytest', + 'aiohttp', + 'asgiref', + ], entry_points=''' [console_scripts] datasette=datasette_connectors.cli:cli diff --git a/tests/dummy.py b/tests/dummy.py index b4ae1c0..feadf0a 100644 --- a/tests/dummy.py +++ b/tests/dummy.py @@ -1,7 +1,32 @@ from datasette_connectors.row import Row +from datasette_connectors.connectors import Connector -_connector_type = 'dummy' +class DummyConnector(Connector): + _connector_type = 'dummy' + + @staticmethod + def table_names(path): + return ['table1', 'table2'] + + @staticmethod + def table_columns(path, table): + return ['c1', 'c2', 'c3'] + + @staticmethod + def get_all_foreign_keys(path): + return { + 'table1': {'incoming': [], 'outgoing': []}, + 'table2': {'incoming': [], 'outgoing': []}, + } + + @staticmethod + def table_counts(path, *args, **kwargs): + return { + 'table1': 2, + 'table2': 2, + } + def inspect(path): tables = {} diff --git a/tests/fixtures.py b/tests/fixtures.py index 6b772c6..70a59d8 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -1,13 +1,15 @@ from datasette_connectors import monkey; monkey.patch_datasette() -from datasette_connectors import connectors -from . import dummy -connectors.db_connectors['dummy'] = dummy +from datasette_connectors.connectors import ConnectorList +from .dummy import DummyConnector +ConnectorList.add_connector('dummy', DummyConnector) from datasette.app import Datasette +from datasette.utils.testing import TestClient import os import pytest import tempfile + @pytest.fixture(scope='session') def app_client(max_returned_rows=None): with tempfile.TemporaryDirectory() as tmpdir: @@ -20,7 +22,7 @@ def app_client(max_returned_rows=None): 'max_returned_rows': max_returned_rows or 1000, } ) - client = ds.app().test_client + client = TestClient(ds.app()) client.ds = ds yield client diff --git a/tests/test_api.py b/tests/test_api.py index 63555cd..2d74c95 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -2,7 +2,7 @@ from .fixtures import app_client from urllib.parse import urlencode def test_homepage(app_client): - _, response = app_client.get('/.json') + response = app_client.get('/.json') assert response.status == 200 assert response.json.keys() == {'dummy_tables': 0}.keys() d = response.json['dummy_tables'] @@ -10,28 +10,12 @@ def test_homepage(app_client): assert d['tables_count'] == 2 def test_database_page(app_client): - response = app_client.get('/dummy_tables.json', gather_request=False) + response = app_client.get('/dummy_tables.json') data = response.json assert 'dummy_tables' == data['database'] - assert [{ - 'name': 'table1', - 'columns': ['c1', 'c2', 'c3'], - 'primary_keys': [], - 'count': 2, - 'label_column': None, - 'hidden': False, - 'fts_table': None, - 'foreign_keys': {'incoming': [], 'outgoing': []} - }, { - 'name': 'table2', - 'columns': ['c1', 'c2', 'c3'], - 'primary_keys': [], - 'count': 2, - 'label_column': None, - 'hidden': False, - 'fts_table': None, - 'foreign_keys': {'incoming': [], 'outgoing': []} - }] == data['tables'] + assert len(data['tables']) == 2 + assert data['tables'][0]['count'] == 2 + assert data['tables'][0]['columns'] == ['c1', 'c2', 'c3'] def test_custom_sql(app_client): response = app_client.get( @@ -39,7 +23,6 @@ def test_custom_sql(app_client): 'sql': 'select c1 from table1', '_shape': 'objects' }), - gather_request=False ) data = response.json assert { -- 2.39.5