]> git.jsancho.org Git - datasette-connectors.git/blobdiff - datasette_connectors/monkey.py
Overwriting Connector class is enough to operate with
[datasette-connectors.git] / datasette_connectors / monkey.py
index 6c656e45882f7efea8aa9bdcf1c3db0daa77fd56..5fd04703bf676b047a04c6a493341ab16ade2abf 100644 (file)
@@ -1,7 +1,11 @@
+import asyncio
 import threading
 import sqlite3
+
 import datasette.views.base
+from datasette.tracer import trace
 from datasette.database import Database
+from datasette.database import Results
 
 from .connectors import ConnectorList
 
@@ -13,36 +17,6 @@ def patch_datasette():
     Monkey patching for original Datasette
     """
 
-    async def table_names(self):
-        try:
-            return await self.original_table_names()
-        except sqlite3.DatabaseError:
-            return ConnectorList.table_names(self.path)
-
-    Database.original_table_names = Database.table_names
-    Database.table_names = table_names
-
-
-    async def hidden_table_names(self):
-        try:
-            return await self.original_hidden_table_names()
-        except sqlite3.DatabaseError:
-            return ConnectorList.hidden_table_names(self.path)
-
-    Database.original_hidden_table_names = Database.hidden_table_names
-    Database.hidden_table_names = hidden_table_names
-
-
-    async def view_names(self):
-        try:
-            return await self.original_view_names()
-        except sqlite3.DatabaseError:
-            return ConnectorList.view_names(self.path)
-
-    Database.original_view_names = Database.view_names
-    Database.view_names = view_names
-
-
     async def table_columns(self, table):
         try:
             return await self.original_table_columns(table)
@@ -73,21 +47,33 @@ def patch_datasette():
     Database.fts_table = fts_table
 
 
-    async def get_all_foreign_keys(self):
+    def connect(self, write=False):
         try:
-            return await self.original_get_all_foreign_keys()
+            # Check if it's a sqlite database
+            conn = self.original_connect(write=write)
+            conn.execute("select name from sqlite_master where type='table'")
+            return conn
         except sqlite3.DatabaseError:
-            return ConnectorList.get_all_foreign_keys(self.path)
+            conn = ConnectorList.connect(self.path)
+            return conn
+
+    Database.original_connect = Database.connect
+    Database.connect = connect
 
-    Database.original_get_all_foreign_keys = Database.get_all_foreign_keys
-    Database.get_all_foreign_keys = get_all_foreign_keys
 
+    async def execute_fn(self, fn):
+        def in_thread():
+            conn = getattr(connections, self.name, None)
+            if not conn:
+                conn = self.connect()
+                if isinstance(conn, sqlite3.Connection):
+                    self.ds._prepare_connection(conn, self.name)
+                setattr(connections, self.name, conn)
+            return fn(conn)
 
-    async def table_counts(self, *args, **kwargs):
-        counts = await self.original_table_counts(**kwargs)
-        # If all tables has None as counts, an error had ocurred
-        if len(list(filter(lambda table_count: table_count is not None, counts.values()))) == 0:
-            return ConnectorList.table_counts(self.path, *args, **kwargs)
+        return await asyncio.get_event_loop().run_in_executor(
+            self.ds.executor, in_thread
+        )
 
-    Database.original_table_counts = Database.table_counts
-    Database.table_counts = table_counts
+    Database.original_execute_fn = Database.execute_fn
+    Database.execute_fn = execute_fn