X-Git-Url: https://git.jsancho.org/?p=datasette-connectors.git;a=blobdiff_plain;f=datasette_connectors%2Fmonkey.py;fp=datasette_connectors%2Fmonkey.py;h=6c656e45882f7efea8aa9bdcf1c3db0daa77fd56;hp=e18175f38889f69b8eabc1590f9fc680421b24f9;hb=52416a749fac092a032a8b5239e477dd68180dfa;hpb=1a34f766bbcada99da81fabdc93b802e4ff8fb2a diff --git a/datasette_connectors/monkey.py b/datasette_connectors/monkey.py index e18175f..6c656e4 100644 --- a/datasette_connectors/monkey.py +++ b/datasette_connectors/monkey.py @@ -1,12 +1,11 @@ -import asyncio -import datasette -from datasette.app import connections -from datasette.inspect import inspect_hash -from datasette.utils import Results -from pathlib import Path +import threading import sqlite3 +import datasette.views.base +from datasette.database import Database -from . import connectors +from .connectors import ConnectorList + +connections = threading.local() def patch_datasette(): @@ -14,74 +13,81 @@ def patch_datasette(): Monkey patching for original Datasette """ - def inspect(self): - " Inspect the database and return a dictionary of table metadata " - if self._inspect: - return self._inspect - - _inspect = {} - files = self.files - - for filename in files: - self.files = (filename,) - path = Path(filename) - name = path.stem - if name in _inspect: - raise Exception("Multiple files with the same stem %s" % name) - try: - _inspect[name] = self.original_inspect()[name] - except sqlite3.DatabaseError: - tables, views, dbtype = connectors.inspect(path) - _inspect[name] = { - "hash": inspect_hash(path), - "file": str(path), - "dbtype": dbtype, - "tables": tables, - "views": views, - } - - self.files = files - self._inspect = _inspect - return self._inspect - - datasette.app.Datasette.original_inspect = datasette.app.Datasette.inspect - datasette.app.Datasette.inspect = inspect - - - async def execute(self, db_name, sql, params=None, truncate=False, custom_time_limit=None, page_size=None): - """Executes sql against db_name in a thread""" - page_size = page_size or self.page_size - - def is_sqlite3_conn(): - conn = getattr(connections, db_name, None) - if not conn: - info = self.inspect()[db_name] - return info.get('dbtype', 'sqlite3') == 'sqlite3' - else: - return isinstance(conn, sqlite3.Connection) - - def sql_operation_in_thread(): - conn = getattr(connections, db_name, None) - if not conn: - info = self.inspect()[db_name] - conn = connectors.connect(info['file'], info['dbtype']) - setattr(connections, db_name, conn) - - rows, truncated, description = conn.execute( - sql, - params or {}, - truncate=truncate, - page_size=page_size, - max_returned_rows=self.max_returned_rows, - ) - return Results(rows, truncated, description) - - if is_sqlite3_conn(): - return await self.original_execute(db_name, sql, params=params, truncate=truncate, custom_time_limit=custom_time_limit, page_size=page_size) - else: - return await asyncio.get_event_loop().run_in_executor( - self.executor, sql_operation_in_thread - ) - - datasette.app.Datasette.original_execute = datasette.app.Datasette.execute - datasette.app.Datasette.execute = execute + async def table_names(self): + try: + return await self.original_table_names() + except sqlite3.DatabaseError: + return ConnectorList.table_names(self.path) + + Database.original_table_names = Database.table_names + Database.table_names = table_names + + + async def hidden_table_names(self): + try: + return await self.original_hidden_table_names() + except sqlite3.DatabaseError: + return ConnectorList.hidden_table_names(self.path) + + Database.original_hidden_table_names = Database.hidden_table_names + Database.hidden_table_names = hidden_table_names + + + async def view_names(self): + try: + return await self.original_view_names() + except sqlite3.DatabaseError: + return ConnectorList.view_names(self.path) + + Database.original_view_names = Database.view_names + Database.view_names = view_names + + + async def table_columns(self, table): + try: + return await self.original_table_columns(table) + except sqlite3.DatabaseError: + return ConnectorList.table_columns(self.path, table) + + Database.original_table_columns = Database.table_columns + Database.table_columns = table_columns + + + async def primary_keys(self, table): + try: + return await self.original_primary_keys(table) + except sqlite3.DatabaseError: + return ConnectorList.primary_keys(self.path, table) + + Database.original_primary_keys = Database.primary_keys + Database.primary_keys = primary_keys + + + async def fts_table(self, table): + try: + return await self.original_fts_table(table) + except sqlite3.DatabaseError: + return ConnectorList.fts_table(self.path, table) + + Database.original_fts_table = Database.fts_table + Database.fts_table = fts_table + + + async def get_all_foreign_keys(self): + try: + return await self.original_get_all_foreign_keys() + except sqlite3.DatabaseError: + return ConnectorList.get_all_foreign_keys(self.path) + + Database.original_get_all_foreign_keys = Database.get_all_foreign_keys + Database.get_all_foreign_keys = get_all_foreign_keys + + + async def table_counts(self, *args, **kwargs): + counts = await self.original_table_counts(**kwargs) + # If all tables has None as counts, an error had ocurred + if len(list(filter(lambda table_count: table_count is not None, counts.values()))) == 0: + return ConnectorList.table_counts(self.path, *args, **kwargs) + + Database.original_table_counts = Database.table_counts + Database.table_counts = table_counts