from .monkey import patch_datasette; patch_datasette()
-from .connectors import load; load()
+from .connectors import ConnectorList; ConnectorList.load()
from datasette.cli import cli
import pkg_resources
+import functools
db_connectors = {}
-def load():
- for entry_point in pkg_resources.iter_entry_points('datasette.connectors'):
- db_connectors[entry_point.name] = entry_point.load()
-
-def inspect(path):
- for connector in db_connectors.values():
- try:
- return connector.inspect(path)
- except:
- pass
- else:
- raise Exception("No database connector found for %s" % path)
-
-def connect(path, dbtype):
- try:
- return db_connectors[dbtype].Connection(path)
- except:
- raise Exception("No database connector found for %s" % path)
+def for_each_connector(func):
+ @functools.wraps(func)
+ def wrapper_for_each_connector(*args, **kwargs):
+ for connector in db_connectors.values():
+ try:
+ return func(connector, *args, **kwargs)
+ except:
+ pass
+ else:
+ raise Exception("No database connector found!!")
+ return wrapper_for_each_connector
+
+
+class ConnectorList:
+ @staticmethod
+ def load():
+ for entry_point in pkg_resources.iter_entry_points('datasette.connectors'):
+ db_connectors[entry_point.name] = entry_point.load()
+
+ @staticmethod
+ def add_connector(name, connector):
+ db_connectors[name] = connector
+
+ @staticmethod
+ @for_each_connector
+ def table_names(connector, path):
+ return connector.table_names(path)
+
+ @staticmethod
+ @for_each_connector
+ def hidden_table_names(connector, path):
+ return connector.hidden_table_names(path)
+
+ @staticmethod
+ @for_each_connector
+ def view_names(connector, path):
+ return connector.view_names(path)
+
+ @staticmethod
+ @for_each_connector
+ def table_columns(connector, path, table):
+ return connector.table_columns(path, table)
+
+ @staticmethod
+ @for_each_connector
+ def primary_keys(connector, path, table):
+ return connector.primary_keys(path, table)
+
+ @staticmethod
+ @for_each_connector
+ def fts_table(connector, path, table):
+ return connector.fts_table(path, table)
+
+ @staticmethod
+ @for_each_connector
+ def get_all_foreign_keys(connector, path):
+ return connector.get_all_foreign_keys(path)
+
+ @staticmethod
+ @for_each_connector
+ def table_counts(connector, path, *args, **kwargs):
+ return connector.table_counts(path, *args, **kwargs)
+
+
+class Connector:
+ @staticmethod
+ def table_names(path):
+ return []
+
+ @staticmethod
+ def hidden_table_names(path):
+ return []
+
+ @staticmethod
+ def view_names(path):
+ return []
+
+ @staticmethod
+ def table_columns(path, table):
+ return []
+
+ @staticmethod
+ def primary_keys(path, table):
+ return []
+
+ @staticmethod
+ def fts_table(path, table):
+ return None
+
+ @staticmethod
+ def get_all_foreign_keys(path):
+ return {}
+
+ @staticmethod
+ def table_counts(path, *args, **kwargs):
+ return {}
-import asyncio
-import datasette
-from datasette.app import connections
-from datasette.inspect import inspect_hash
-from datasette.utils import Results
-from pathlib import Path
+import threading
import sqlite3
+import datasette.views.base
+from datasette.database import Database
-from . import connectors
+from .connectors import ConnectorList
+
+connections = threading.local()
def patch_datasette():
Monkey patching for original Datasette
"""
- def inspect(self):
- " Inspect the database and return a dictionary of table metadata "
- if self._inspect:
- return self._inspect
-
- _inspect = {}
- files = self.files
-
- for filename in files:
- self.files = (filename,)
- path = Path(filename)
- name = path.stem
- if name in _inspect:
- raise Exception("Multiple files with the same stem %s" % name)
- try:
- _inspect[name] = self.original_inspect()[name]
- except sqlite3.DatabaseError:
- tables, views, dbtype = connectors.inspect(path)
- _inspect[name] = {
- "hash": inspect_hash(path),
- "file": str(path),
- "dbtype": dbtype,
- "tables": tables,
- "views": views,
- }
-
- self.files = files
- self._inspect = _inspect
- return self._inspect
-
- datasette.app.Datasette.original_inspect = datasette.app.Datasette.inspect
- datasette.app.Datasette.inspect = inspect
-
-
- async def execute(self, db_name, sql, params=None, truncate=False, custom_time_limit=None, page_size=None):
- """Executes sql against db_name in a thread"""
- page_size = page_size or self.page_size
-
- def is_sqlite3_conn():
- conn = getattr(connections, db_name, None)
- if not conn:
- info = self.inspect()[db_name]
- return info.get('dbtype', 'sqlite3') == 'sqlite3'
- else:
- return isinstance(conn, sqlite3.Connection)
-
- def sql_operation_in_thread():
- conn = getattr(connections, db_name, None)
- if not conn:
- info = self.inspect()[db_name]
- conn = connectors.connect(info['file'], info['dbtype'])
- setattr(connections, db_name, conn)
-
- rows, truncated, description = conn.execute(
- sql,
- params or {},
- truncate=truncate,
- page_size=page_size,
- max_returned_rows=self.max_returned_rows,
- )
- return Results(rows, truncated, description)
-
- if is_sqlite3_conn():
- return await self.original_execute(db_name, sql, params=params, truncate=truncate, custom_time_limit=custom_time_limit, page_size=page_size)
- else:
- return await asyncio.get_event_loop().run_in_executor(
- self.executor, sql_operation_in_thread
- )
-
- datasette.app.Datasette.original_execute = datasette.app.Datasette.execute
- datasette.app.Datasette.execute = execute
+ async def table_names(self):
+ try:
+ return await self.original_table_names()
+ except sqlite3.DatabaseError:
+ return ConnectorList.table_names(self.path)
+
+ Database.original_table_names = Database.table_names
+ Database.table_names = table_names
+
+
+ async def hidden_table_names(self):
+ try:
+ return await self.original_hidden_table_names()
+ except sqlite3.DatabaseError:
+ return ConnectorList.hidden_table_names(self.path)
+
+ Database.original_hidden_table_names = Database.hidden_table_names
+ Database.hidden_table_names = hidden_table_names
+
+
+ async def view_names(self):
+ try:
+ return await self.original_view_names()
+ except sqlite3.DatabaseError:
+ return ConnectorList.view_names(self.path)
+
+ Database.original_view_names = Database.view_names
+ Database.view_names = view_names
+
+
+ async def table_columns(self, table):
+ try:
+ return await self.original_table_columns(table)
+ except sqlite3.DatabaseError:
+ return ConnectorList.table_columns(self.path, table)
+
+ Database.original_table_columns = Database.table_columns
+ Database.table_columns = table_columns
+
+
+ async def primary_keys(self, table):
+ try:
+ return await self.original_primary_keys(table)
+ except sqlite3.DatabaseError:
+ return ConnectorList.primary_keys(self.path, table)
+
+ Database.original_primary_keys = Database.primary_keys
+ Database.primary_keys = primary_keys
+
+
+ async def fts_table(self, table):
+ try:
+ return await self.original_fts_table(table)
+ except sqlite3.DatabaseError:
+ return ConnectorList.fts_table(self.path, table)
+
+ Database.original_fts_table = Database.fts_table
+ Database.fts_table = fts_table
+
+
+ async def get_all_foreign_keys(self):
+ try:
+ return await self.original_get_all_foreign_keys()
+ except sqlite3.DatabaseError:
+ return ConnectorList.get_all_foreign_keys(self.path)
+
+ Database.original_get_all_foreign_keys = Database.get_all_foreign_keys
+ Database.get_all_foreign_keys = get_all_foreign_keys
+
+
+ async def table_counts(self, *args, **kwargs):
+ counts = await self.original_table_counts(**kwargs)
+ # If all tables has None as counts, an error had ocurred
+ if len(list(filter(lambda table_count: table_count is not None, counts.values()))) == 0:
+ return ConnectorList.table_counts(self.path, *args, **kwargs)
+
+ Database.original_table_counts = Database.table_counts
+ Database.table_counts = table_counts
url='https://github.com/pytables/datasette-connectors',
license='Apache License, Version 2.0',
packages=['datasette_connectors'],
- install_requires=['datasette==0.46'],
- tests_require=['pytest', 'aiohttp'],
+ install_requires=['datasette==0.48'],
+ tests_require=[
+ 'pytest',
+ 'aiohttp',
+ 'asgiref',
+ ],
entry_points='''
[console_scripts]
datasette=datasette_connectors.cli:cli
from datasette_connectors.row import Row
+from datasette_connectors.connectors import Connector
-_connector_type = 'dummy'
+class DummyConnector(Connector):
+ _connector_type = 'dummy'
+
+ @staticmethod
+ def table_names(path):
+ return ['table1', 'table2']
+
+ @staticmethod
+ def table_columns(path, table):
+ return ['c1', 'c2', 'c3']
+
+ @staticmethod
+ def get_all_foreign_keys(path):
+ return {
+ 'table1': {'incoming': [], 'outgoing': []},
+ 'table2': {'incoming': [], 'outgoing': []},
+ }
+
+ @staticmethod
+ def table_counts(path, *args, **kwargs):
+ return {
+ 'table1': 2,
+ 'table2': 2,
+ }
+
def inspect(path):
tables = {}
from datasette_connectors import monkey; monkey.patch_datasette()
-from datasette_connectors import connectors
-from . import dummy
-connectors.db_connectors['dummy'] = dummy
+from datasette_connectors.connectors import ConnectorList
+from .dummy import DummyConnector
+ConnectorList.add_connector('dummy', DummyConnector)
from datasette.app import Datasette
+from datasette.utils.testing import TestClient
import os
import pytest
import tempfile
+
@pytest.fixture(scope='session')
def app_client(max_returned_rows=None):
with tempfile.TemporaryDirectory() as tmpdir:
'max_returned_rows': max_returned_rows or 1000,
}
)
- client = ds.app().test_client
+ client = TestClient(ds.app())
client.ds = ds
yield client
from urllib.parse import urlencode
def test_homepage(app_client):
- _, response = app_client.get('/.json')
+ response = app_client.get('/.json')
assert response.status == 200
assert response.json.keys() == {'dummy_tables': 0}.keys()
d = response.json['dummy_tables']
assert d['tables_count'] == 2
def test_database_page(app_client):
- response = app_client.get('/dummy_tables.json', gather_request=False)
+ response = app_client.get('/dummy_tables.json')
data = response.json
assert 'dummy_tables' == data['database']
- assert [{
- 'name': 'table1',
- 'columns': ['c1', 'c2', 'c3'],
- 'primary_keys': [],
- 'count': 2,
- 'label_column': None,
- 'hidden': False,
- 'fts_table': None,
- 'foreign_keys': {'incoming': [], 'outgoing': []}
- }, {
- 'name': 'table2',
- 'columns': ['c1', 'c2', 'c3'],
- 'primary_keys': [],
- 'count': 2,
- 'label_column': None,
- 'hidden': False,
- 'fts_table': None,
- 'foreign_keys': {'incoming': [], 'outgoing': []}
- }] == data['tables']
+ assert len(data['tables']) == 2
+ assert data['tables'][0]['count'] == 2
+ assert data['tables'][0]['columns'] == ['c1', 'c2', 'c3']
def test_custom_sql(app_client):
response = app_client.get(
'sql': 'select c1 from table1',
'_shape': 'objects'
}),
- gather_request=False
)
data = response.json
assert {