]> git.jsancho.org Git - datasette-connectors.git/commitdiff
Monkey patching for original Datasette
authorJavier Sancho <jsf@jsancho.org>
Thu, 4 Oct 2018 09:16:04 +0000 (11:16 +0200)
committerJavier Sancho <jsf@jsancho.org>
Thu, 4 Oct 2018 09:16:04 +0000 (11:16 +0200)
datasette_connectors/__init__.py

index 9cecd042fef40b949f8ecd1852e8ac94b79eddf5..43f11016001e7a53889cf9a2f1fed2b139f16dee 100644 (file)
@@ -1,14 +1,85 @@
+import asyncio
 import datasette
+from datasette.app import connections
 from datasette.cli import cli
+from datasette.inspect import inspect_hash
+from datasette.utils import Results
+from pathlib import Path
+import sqlite3
 
 
 # Monkey patching for original Datasette
-def init(self, *args, **kwargs):
-    print("Test")
-    self.original_init(*args, **kwargs)
+def inspect(self):
+    " Inspect the database and return a dictionary of table metadata "
+    if self._inspect:
+        return self._inspect
 
-datasette.app.Datasette.original_init = datasette.app.Datasette.__init__
-datasette.app.Datasette.__init__ = init
+    _inspect = {}
+    files = self.files
+
+    for filename in files:
+        self.files = (filename,)
+        path = Path(filename)
+        name = path.stem
+        if name in _inspect:
+            raise Exception("Multiple files with the same stem %s" % name)
+        try:
+            _inspect[name] = self.original_inspect()[name]
+        except sqlite3.DatabaseError:
+            tables, views, dbtype = connectors.inspect(path)
+            _inspect[name] = {
+                "hash": inspect_hash(path),
+                "file": str(path),
+                "dbtype": dbtype,
+                "tables": tables,
+                "views": views,
+            }
+
+    self.files = files
+    self._inspect = _inspect
+    return self._inspect
+
+datasette.app.Datasette.original_inspect = datasette.app.Datasette.inspect
+datasette.app.Datasette.inspect = inspect
+
+
+async def execute(self, db_name, sql, params=None, truncate=False, custom_time_limit=None, page_size=None):
+    """Executes sql against db_name in a thread"""
+    page_size = page_size or self.page_size
+
+    def is_sqlite3_conn():
+        conn = getattr(connections, db_name, None)
+        if not conn:
+            info = self.inspect()[db_name]
+            return info.get('dbtype', 'sqlite3') == 'sqlite3'
+        else:
+            return isinstance(conn, sqlite3.Connection)
+
+    def sql_operation_in_thread():
+        conn = getattr(connections, db_name, None)
+        if not conn:
+            info = self.inspect()[db_name]
+            conn = connectors.connect(info['file'], info['dbtype'])
+            setattr(connections, db_name, conn)
+
+        rows, truncated, description = conn.execute(
+            sql,
+            params or {},
+            truncate=truncate,
+            page_size=page_size,
+            max_returned_rows=self.max_returned_rows,
+        )
+        return Results(rows, truncated, description)
+
+    if is_sqlite3_conn():
+        return await self.original_execute(db_name, sql, params=params, truncate=truncate, custom_time_limit=custom_time_limit, page_size=page_size)
+    else:
+        return await asyncio.get_event_loop().run_in_executor(
+            self.executor, sql_operation_in_thread
+        )
+
+datasette.app.Datasette.original_execute = datasette.app.Datasette.execute
+datasette.app.Datasette.execute = execute
 
 
 # Read external database connectors