]> git.jsancho.org Git - datasette-connectors.git/blobdiff - datasette_connectors/monkey.py
Overwriting Connector class is enough to operate with
[datasette-connectors.git] / datasette_connectors / monkey.py
index e18175f38889f69b8eabc1590f9fc680421b24f9..5fd04703bf676b047a04c6a493341ab16ade2abf 100644 (file)
@@ -1,12 +1,15 @@
 import asyncio
-import datasette
-from datasette.app import connections
-from datasette.inspect import inspect_hash
-from datasette.utils import Results
-from pathlib import Path
+import threading
 import sqlite3
 
-from . import connectors
+import datasette.views.base
+from datasette.tracer import trace
+from datasette.database import Database
+from datasette.database import Results
+
+from .connectors import ConnectorList
+
+connections = threading.local()
 
 
 def patch_datasette():
@@ -14,74 +17,63 @@ def patch_datasette():
     Monkey patching for original Datasette
     """
 
-    def inspect(self):
-        " Inspect the database and return a dictionary of table metadata "
-        if self._inspect:
-            return self._inspect
-
-        _inspect = {}
-        files = self.files
-
-        for filename in files:
-            self.files = (filename,)
-            path = Path(filename)
-            name = path.stem
-            if name in _inspect:
-                raise Exception("Multiple files with the same stem %s" % name)
-            try:
-                _inspect[name] = self.original_inspect()[name]
-            except sqlite3.DatabaseError:
-                tables, views, dbtype = connectors.inspect(path)
-                _inspect[name] = {
-                    "hash": inspect_hash(path),
-                    "file": str(path),
-                    "dbtype": dbtype,
-                    "tables": tables,
-                    "views": views,
-                }
-
-        self.files = files
-        self._inspect = _inspect
-        return self._inspect
-
-    datasette.app.Datasette.original_inspect = datasette.app.Datasette.inspect
-    datasette.app.Datasette.inspect = inspect
-
-
-    async def execute(self, db_name, sql, params=None, truncate=False, custom_time_limit=None, page_size=None):
-        """Executes sql against db_name in a thread"""
-        page_size = page_size or self.page_size
-
-        def is_sqlite3_conn():
-            conn = getattr(connections, db_name, None)
-            if not conn:
-                info = self.inspect()[db_name]
-                return info.get('dbtype', 'sqlite3') == 'sqlite3'
-            else:
-                return isinstance(conn, sqlite3.Connection)
+    async def table_columns(self, table):
+        try:
+            return await self.original_table_columns(table)
+        except sqlite3.DatabaseError:
+            return ConnectorList.table_columns(self.path, table)
+
+    Database.original_table_columns = Database.table_columns
+    Database.table_columns = table_columns
+
+
+    async def primary_keys(self, table):
+        try:
+            return await self.original_primary_keys(table)
+        except sqlite3.DatabaseError:
+            return ConnectorList.primary_keys(self.path, table)
+
+    Database.original_primary_keys = Database.primary_keys
+    Database.primary_keys = primary_keys
+
 
-        def sql_operation_in_thread():
-            conn = getattr(connections, db_name, None)
+    async def fts_table(self, table):
+        try:
+            return await self.original_fts_table(table)
+        except sqlite3.DatabaseError:
+            return ConnectorList.fts_table(self.path, table)
+
+    Database.original_fts_table = Database.fts_table
+    Database.fts_table = fts_table
+
+
+    def connect(self, write=False):
+        try:
+            # Check if it's a sqlite database
+            conn = self.original_connect(write=write)
+            conn.execute("select name from sqlite_master where type='table'")
+            return conn
+        except sqlite3.DatabaseError:
+            conn = ConnectorList.connect(self.path)
+            return conn
+
+    Database.original_connect = Database.connect
+    Database.connect = connect
+
+
+    async def execute_fn(self, fn):
+        def in_thread():
+            conn = getattr(connections, self.name, None)
             if not conn:
-                info = self.inspect()[db_name]
-                conn = connectors.connect(info['file'], info['dbtype'])
-                setattr(connections, db_name, conn)
-
-            rows, truncated, description = conn.execute(
-                sql,
-                params or {},
-                truncate=truncate,
-                page_size=page_size,
-                max_returned_rows=self.max_returned_rows,
-            )
-            return Results(rows, truncated, description)
-
-        if is_sqlite3_conn():
-            return await self.original_execute(db_name, sql, params=params, truncate=truncate, custom_time_limit=custom_time_limit, page_size=page_size)
-        else:
-            return await asyncio.get_event_loop().run_in_executor(
-                self.executor, sql_operation_in_thread
-            )
-
-    datasette.app.Datasette.original_execute = datasette.app.Datasette.execute
-    datasette.app.Datasette.execute = execute
+                conn = self.connect()
+                if isinstance(conn, sqlite3.Connection):
+                    self.ds._prepare_connection(conn, self.name)
+                setattr(connections, self.name, conn)
+            return fn(conn)
+
+        return await asyncio.get_event_loop().run_in_executor(
+            self.ds.executor, in_thread
+        )
+
+    Database.original_execute_fn = Database.execute_fn
+    Database.execute_fn = execute_fn