X-Git-Url: https://git.jsancho.org/?p=datasette-connectors.git;a=blobdiff_plain;f=datasette_connectors%2Fmonkey.py;h=5fd04703bf676b047a04c6a493341ab16ade2abf;hp=6c656e45882f7efea8aa9bdcf1c3db0daa77fd56;hb=5c00383b9044ca27de9c51a511962ffad65ed5f3;hpb=52416a749fac092a032a8b5239e477dd68180dfa diff --git a/datasette_connectors/monkey.py b/datasette_connectors/monkey.py index 6c656e4..5fd0470 100644 --- a/datasette_connectors/monkey.py +++ b/datasette_connectors/monkey.py @@ -1,7 +1,11 @@ +import asyncio import threading import sqlite3 + import datasette.views.base +from datasette.tracer import trace from datasette.database import Database +from datasette.database import Results from .connectors import ConnectorList @@ -13,36 +17,6 @@ def patch_datasette(): Monkey patching for original Datasette """ - async def table_names(self): - try: - return await self.original_table_names() - except sqlite3.DatabaseError: - return ConnectorList.table_names(self.path) - - Database.original_table_names = Database.table_names - Database.table_names = table_names - - - async def hidden_table_names(self): - try: - return await self.original_hidden_table_names() - except sqlite3.DatabaseError: - return ConnectorList.hidden_table_names(self.path) - - Database.original_hidden_table_names = Database.hidden_table_names - Database.hidden_table_names = hidden_table_names - - - async def view_names(self): - try: - return await self.original_view_names() - except sqlite3.DatabaseError: - return ConnectorList.view_names(self.path) - - Database.original_view_names = Database.view_names - Database.view_names = view_names - - async def table_columns(self, table): try: return await self.original_table_columns(table) @@ -73,21 +47,33 @@ def patch_datasette(): Database.fts_table = fts_table - async def get_all_foreign_keys(self): + def connect(self, write=False): try: - return await self.original_get_all_foreign_keys() + # Check if it's a sqlite database + conn = self.original_connect(write=write) + conn.execute("select name from sqlite_master where type='table'") + return conn except sqlite3.DatabaseError: - return ConnectorList.get_all_foreign_keys(self.path) + conn = ConnectorList.connect(self.path) + return conn + + Database.original_connect = Database.connect + Database.connect = connect - Database.original_get_all_foreign_keys = Database.get_all_foreign_keys - Database.get_all_foreign_keys = get_all_foreign_keys + async def execute_fn(self, fn): + def in_thread(): + conn = getattr(connections, self.name, None) + if not conn: + conn = self.connect() + if isinstance(conn, sqlite3.Connection): + self.ds._prepare_connection(conn, self.name) + setattr(connections, self.name, conn) + return fn(conn) - async def table_counts(self, *args, **kwargs): - counts = await self.original_table_counts(**kwargs) - # If all tables has None as counts, an error had ocurred - if len(list(filter(lambda table_count: table_count is not None, counts.values()))) == 0: - return ConnectorList.table_counts(self.path, *args, **kwargs) + return await asyncio.get_event_loop().run_in_executor( + self.ds.executor, in_thread + ) - Database.original_table_counts = Database.table_counts - Database.table_counts = table_counts + Database.original_execute_fn = Database.execute_fn + Database.execute_fn = execute_fn