X-Git-Url: https://git.jsancho.org/?p=datasette-connectors.git;a=blobdiff_plain;f=datasette_connectors%2Fmonkey.py;h=0fb4e1c7b8d583d4e5e353952bb325122c679c52;hp=6c656e45882f7efea8aa9bdcf1c3db0daa77fd56;hb=3cc49f23a9f3c0e8cb2b7eb707382c6ae708c1f4;hpb=52416a749fac092a032a8b5239e477dd68180dfa diff --git a/datasette_connectors/monkey.py b/datasette_connectors/monkey.py index 6c656e4..0fb4e1c 100644 --- a/datasette_connectors/monkey.py +++ b/datasette_connectors/monkey.py @@ -1,7 +1,11 @@ +import asyncio import threading import sqlite3 + import datasette.views.base +from datasette.tracer import trace from datasette.database import Database +from datasette.database import Results from .connectors import ConnectorList @@ -13,81 +17,33 @@ def patch_datasette(): Monkey patching for original Datasette """ - async def table_names(self): - try: - return await self.original_table_names() - except sqlite3.DatabaseError: - return ConnectorList.table_names(self.path) - - Database.original_table_names = Database.table_names - Database.table_names = table_names - - - async def hidden_table_names(self): - try: - return await self.original_hidden_table_names() - except sqlite3.DatabaseError: - return ConnectorList.hidden_table_names(self.path) - - Database.original_hidden_table_names = Database.hidden_table_names - Database.hidden_table_names = hidden_table_names - - - async def view_names(self): - try: - return await self.original_view_names() - except sqlite3.DatabaseError: - return ConnectorList.view_names(self.path) - - Database.original_view_names = Database.view_names - Database.view_names = view_names - - - async def table_columns(self, table): - try: - return await self.original_table_columns(table) - except sqlite3.DatabaseError: - return ConnectorList.table_columns(self.path, table) - - Database.original_table_columns = Database.table_columns - Database.table_columns = table_columns - - - async def primary_keys(self, table): + def connect(self, write=False): try: - return await self.original_primary_keys(table) + # Check if it's a sqlite database + conn = self.original_connect(write=write) + conn.execute("select name from sqlite_master where type='table'") + return conn except sqlite3.DatabaseError: - return ConnectorList.primary_keys(self.path, table) - - Database.original_primary_keys = Database.primary_keys - Database.primary_keys = primary_keys - + conn = ConnectorList.connect(self.path) + return conn - async def fts_table(self, table): - try: - return await self.original_fts_table(table) - except sqlite3.DatabaseError: - return ConnectorList.fts_table(self.path, table) - - Database.original_fts_table = Database.fts_table - Database.fts_table = fts_table - - - async def get_all_foreign_keys(self): - try: - return await self.original_get_all_foreign_keys() - except sqlite3.DatabaseError: - return ConnectorList.get_all_foreign_keys(self.path) + Database.original_connect = Database.connect + Database.connect = connect - Database.original_get_all_foreign_keys = Database.get_all_foreign_keys - Database.get_all_foreign_keys = get_all_foreign_keys + async def execute_fn(self, fn): + def in_thread(): + conn = getattr(connections, self.name, None) + if not conn: + conn = self.connect() + if isinstance(conn, sqlite3.Connection): + self.ds._prepare_connection(conn, self.name) + setattr(connections, self.name, conn) + return fn(conn) - async def table_counts(self, *args, **kwargs): - counts = await self.original_table_counts(**kwargs) - # If all tables has None as counts, an error had ocurred - if len(list(filter(lambda table_count: table_count is not None, counts.values()))) == 0: - return ConnectorList.table_counts(self.path, *args, **kwargs) + return await asyncio.get_event_loop().run_in_executor( + self.ds.executor, in_thread + ) - Database.original_table_counts = Database.table_counts - Database.table_counts = table_counts + Database.original_execute_fn = Database.execute_fn + Database.execute_fn = execute_fn