]> git.jsancho.org Git - datasette-connectors.git/blob - datasette_connectors/monkey.py
5fd04703bf676b047a04c6a493341ab16ade2abf
[datasette-connectors.git] / datasette_connectors / monkey.py
1 import asyncio
2 import threading
3 import sqlite3
4
5 import datasette.views.base
6 from datasette.tracer import trace
7 from datasette.database import Database
8 from datasette.database import Results
9
10 from .connectors import ConnectorList
11
12 connections = threading.local()
13
14
15 def patch_datasette():
16     """
17     Monkey patching for original Datasette
18     """
19
20     async def table_columns(self, table):
21         try:
22             return await self.original_table_columns(table)
23         except sqlite3.DatabaseError:
24             return ConnectorList.table_columns(self.path, table)
25
26     Database.original_table_columns = Database.table_columns
27     Database.table_columns = table_columns
28
29
30     async def primary_keys(self, table):
31         try:
32             return await self.original_primary_keys(table)
33         except sqlite3.DatabaseError:
34             return ConnectorList.primary_keys(self.path, table)
35
36     Database.original_primary_keys = Database.primary_keys
37     Database.primary_keys = primary_keys
38
39
40     async def fts_table(self, table):
41         try:
42             return await self.original_fts_table(table)
43         except sqlite3.DatabaseError:
44             return ConnectorList.fts_table(self.path, table)
45
46     Database.original_fts_table = Database.fts_table
47     Database.fts_table = fts_table
48
49
50     def connect(self, write=False):
51         try:
52             # Check if it's a sqlite database
53             conn = self.original_connect(write=write)
54             conn.execute("select name from sqlite_master where type='table'")
55             return conn
56         except sqlite3.DatabaseError:
57             conn = ConnectorList.connect(self.path)
58             return conn
59
60     Database.original_connect = Database.connect
61     Database.connect = connect
62
63
64     async def execute_fn(self, fn):
65         def in_thread():
66             conn = getattr(connections, self.name, None)
67             if not conn:
68                 conn = self.connect()
69                 if isinstance(conn, sqlite3.Connection):
70                     self.ds._prepare_connection(conn, self.name)
71                 setattr(connections, self.name, conn)
72             return fn(conn)
73
74         return await asyncio.get_event_loop().run_in_executor(
75             self.ds.executor, in_thread
76         )
77
78     Database.original_execute_fn = Database.execute_fn
79     Database.execute_fn = execute_fn