+import asyncio
import threading
import sqlite3
+
import datasette.views.base
+from datasette.tracer import trace
from datasette.database import Database
+from datasette.database import Results
from .connectors import ConnectorList
Monkey patching for original Datasette
"""
- async def table_names(self):
- try:
- return await self.original_table_names()
- except sqlite3.DatabaseError:
- return ConnectorList.table_names(self.path)
-
- Database.original_table_names = Database.table_names
- Database.table_names = table_names
-
-
- async def hidden_table_names(self):
- try:
- return await self.original_hidden_table_names()
- except sqlite3.DatabaseError:
- return ConnectorList.hidden_table_names(self.path)
-
- Database.original_hidden_table_names = Database.hidden_table_names
- Database.hidden_table_names = hidden_table_names
-
-
- async def view_names(self):
- try:
- return await self.original_view_names()
- except sqlite3.DatabaseError:
- return ConnectorList.view_names(self.path)
-
- Database.original_view_names = Database.view_names
- Database.view_names = view_names
-
-
async def table_columns(self, table):
try:
return await self.original_table_columns(table)
Database.fts_table = fts_table
- async def get_all_foreign_keys(self):
+ def connect(self, write=False):
try:
- return await self.original_get_all_foreign_keys()
+ # Check if it's a sqlite database
+ conn = self.original_connect(write=write)
+ conn.execute("select name from sqlite_master where type='table'")
+ return conn
except sqlite3.DatabaseError:
- return ConnectorList.get_all_foreign_keys(self.path)
+ conn = ConnectorList.connect(self.path)
+ return conn
+
+ Database.original_connect = Database.connect
+ Database.connect = connect
- Database.original_get_all_foreign_keys = Database.get_all_foreign_keys
- Database.get_all_foreign_keys = get_all_foreign_keys
+ async def execute_fn(self, fn):
+ def in_thread():
+ conn = getattr(connections, self.name, None)
+ if not conn:
+ conn = self.connect()
+ if isinstance(conn, sqlite3.Connection):
+ self.ds._prepare_connection(conn, self.name)
+ setattr(connections, self.name, conn)
+ return fn(conn)
- async def table_counts(self, *args, **kwargs):
- counts = await self.original_table_counts(**kwargs)
- # If all tables has None as counts, an error had ocurred
- if len(list(filter(lambda table_count: table_count is not None, counts.values()))) == 0:
- return ConnectorList.table_counts(self.path, *args, **kwargs)
+ return await asyncio.get_event_loop().run_in_executor(
+ self.ds.executor, in_thread
+ )
- Database.original_table_counts = Database.table_counts
- Database.table_counts = table_counts
+ Database.original_execute_fn = Database.execute_fn
+ Database.execute_fn = execute_fn