Clean code
authorJavier Sancho <jsf@jsancho.org>
Fri, 18 Sep 2020 09:10:19 +0000 (11:10 +0200)
committerJavier Sancho <jsf@jsancho.org>
Fri, 18 Sep 2020 09:10:19 +0000 (11:10 +0200)
datasette_connectors/__init__.py
datasette_connectors/connection.py [new file with mode: 0644]
datasette_connectors/connectors.py
datasette_connectors/cursor.py [new file with mode: 0644]
datasette_connectors/monkey.py

index 21cf96b66f2dfa9918c082f14beb4d61a71f0a04..834e398f662cac1022aee137d479de7e8a10b098 100644 (file)
@@ -1 +1,2 @@
-from .connectors import Connector, OperationalError
+from .connectors import Connector
+from .cursor import OperationalError
diff --git a/datasette_connectors/connection.py b/datasette_connectors/connection.py
new file mode 100644 (file)
index 0000000..7761506
--- /dev/null
@@ -0,0 +1,18 @@
+from .cursor import Cursor
+
+
+class Connection:
+    def __init__(self, path, connector):
+        self.path = path
+        self.connector = connector
+
+    def execute(self, *args, **kwargs):
+        cursor = Cursor(self)
+        cursor.execute(*args, **kwargs)
+        return cursor
+
+    def cursor(self):
+        return Cursor(self)
+
+    def set_progress_handler(self, handler, n):
+        pass
index f09727f93bc07f59e60343ea3fcf78867e6ac1da..bdba2158c3fb958c6b4aadbfb3c2aaa255ce16f5 100644 (file)
@@ -1,25 +1,8 @@
-import pkg_resources
-import functools
-import re
-import sqlite3
-
-from .row import Row
+from .connection import Connection
 
 
 db_connectors = {}
 
-def for_each_connector(func):
-    @functools.wraps(func)
-    def wrapper_for_each_connector(*args, **kwargs):
-        for connector in db_connectors.values():
-            try:
-                return func(connector, *args, **kwargs)
-            except:
-                pass
-        else:
-            raise Exception("No database connector found!!")
-    return wrapper_for_each_connector
-
 
 class ConnectorList:
     @staticmethod
@@ -45,116 +28,6 @@ class ConnectorList:
             raise ConnectorList.DatabaseNotSupported
 
 
-class Connection:
-    def __init__(self, path, connector):
-        self.path = path
-        self.connector = connector
-
-    def execute(self, *args, **kwargs):
-        cursor = Cursor(self)
-        cursor.execute(*args, **kwargs)
-        return cursor
-
-    def cursor(self):
-        return Cursor(self)
-
-    def set_progress_handler(self, handler, n):
-        pass
-
-
-class OperationalError(Exception):
-    pass
-
-
-class Cursor:
-    class QueryNotSupported(Exception):
-        pass
-
-    def __init__(self, conn):
-        self.conn = conn
-        self.connector = conn.connector
-        self.rows = []
-        self.description = ()
-
-    def execute(
-        self,
-        sql,
-        params=None,
-        truncate=False,
-        custom_time_limit=None,
-        page_size=None,
-        log_sql_errors=True,
-    ):
-        if params is None:
-            params = {}
-        results = []
-        truncated = False
-        description = ()
-
-        # Normalize sql
-        sql = sql.strip()
-        sql = ' '.join(sql.split())
-
-        if sql == "select name from sqlite_master where type='table'" or \
-           sql == "select name from sqlite_master where type=\"table\"":
-            results = [{'name': name} for name in self.connector.table_names()]
-        elif sql == "select name from sqlite_master where rootpage = 0 and sql like '%VIRTUAL TABLE%USING FTS%'":
-            results = [{'name': name} for name in self.connector.hidden_table_names()]
-        elif sql == 'select 1 from sqlite_master where tbl_name = "geometry_columns"':
-            if self.connector.detect_spatialite():
-                results = [{'1': '1'}]
-        elif sql == "select name from sqlite_master where type='view'":
-            results = [{'name': name} for name in self.connector.view_names()]
-        elif sql.startswith("select count(*) from ["):
-            match = re.search(r'select count\(\*\) from \[(.*)\]', sql)
-            results = [{'count(*)': self.connector.table_count(match.group(1))}]
-        elif sql.startswith("select count(*) from "):
-            match = re.search(r'select count\(\*\) from (.*)', sql)
-            results = [{'count(*)': self.connector.table_count(match.group(1))}]
-        elif sql.startswith("PRAGMA table_info("):
-            match = re.search(r'PRAGMA table_info\((.*)\)', sql)
-            results = self.connector.table_info(match.group(1))
-        elif sql.startswith("select name from sqlite_master where rootpage = 0 and ( sql like \'%VIRTUAL TABLE%USING FTS%content="):
-            match = re.search(r'select name from sqlite_master where rootpage = 0 and \( sql like \'%VIRTUAL TABLE%USING FTS%content="(.*)"', sql)
-            if self.connector.detect_fts(match.group(1)):
-                results = [{'name': match.group(1)}]
-        elif sql.startswith("PRAGMA foreign_key_list(["):
-            match = re.search(r'PRAGMA foreign_key_list\(\[(.*)\]\)', sql)
-            results = self.connector.foreign_keys(match.group(1))
-        elif sql == "select 1 from sqlite_master where type='table' and name=?":
-            if self.connector.table_exists(params[0]):
-                results = [{'1': '1'}]
-        elif sql == "select sql from sqlite_master where name = :n and type=:t":
-            results = [{'sql': self.connector.table_definition(params['t'], params['n'])}]
-        elif sql == "select sql from sqlite_master where tbl_name = :n and type='index' and sql is not null":
-            results = [{'sql': sql} for sql in self.connector.indices_definition(params['n'])]
-        else:
-            try:
-                results, truncated, description = \
-                    self.connector.execute(
-                        sql,
-                        params=params,
-                        truncate=truncate,
-                        custom_time_limit=custom_time_limit,
-                        page_size=page_size,
-                        log_sql_errors=log_sql_errors,
-                    )
-            except OperationalError as ex:
-                raise sqlite3.OperationalError(*ex.args)
-
-        self.rows = [Row(result) for result in results]
-        self.description = description
-
-    def fetchall(self):
-        return self.rows
-
-    def fetchmany(self, max):
-        return self.rows[:max]
-
-    def __getitem__(self, index):
-        return self.rows[index]
-
-
 class Connector:
     connector_type = None
     connection_class = Connection
diff --git a/datasette_connectors/cursor.py b/datasette_connectors/cursor.py
new file mode 100644 (file)
index 0000000..6facf54
--- /dev/null
@@ -0,0 +1,97 @@
+import re
+import sqlite3
+
+from .row import Row
+
+
+class OperationalError(Exception):
+    pass
+
+
+class Cursor:
+    class QueryNotSupported(Exception):
+        pass
+
+    def __init__(self, conn):
+        self.conn = conn
+        self.connector = conn.connector
+        self.rows = []
+        self.description = ()
+
+    def execute(
+        self,
+        sql,
+        params=None,
+        truncate=False,
+        custom_time_limit=None,
+        page_size=None,
+        log_sql_errors=True,
+    ):
+        if params is None:
+            params = {}
+        results = []
+        truncated = False
+        description = ()
+
+        # Normalize sql
+        sql = sql.strip()
+        sql = ' '.join(sql.split())
+
+        if sql == "select name from sqlite_master where type='table'" or \
+           sql == "select name from sqlite_master where type=\"table\"":
+            results = [{'name': name} for name in self.connector.table_names()]
+        elif sql == "select name from sqlite_master where rootpage = 0 and sql like '%VIRTUAL TABLE%USING FTS%'":
+            results = [{'name': name} for name in self.connector.hidden_table_names()]
+        elif sql == 'select 1 from sqlite_master where tbl_name = "geometry_columns"':
+            if self.connector.detect_spatialite():
+                results = [{'1': '1'}]
+        elif sql == "select name from sqlite_master where type='view'":
+            results = [{'name': name} for name in self.connector.view_names()]
+        elif sql.startswith("select count(*) from ["):
+            match = re.search(r'select count\(\*\) from \[(.*)\]', sql)
+            results = [{'count(*)': self.connector.table_count(match.group(1))}]
+        elif sql.startswith("select count(*) from "):
+            match = re.search(r'select count\(\*\) from (.*)', sql)
+            results = [{'count(*)': self.connector.table_count(match.group(1))}]
+        elif sql.startswith("PRAGMA table_info("):
+            match = re.search(r'PRAGMA table_info\((.*)\)', sql)
+            results = self.connector.table_info(match.group(1))
+        elif sql.startswith("select name from sqlite_master where rootpage = 0 and ( sql like \'%VIRTUAL TABLE%USING FTS%content="):
+            match = re.search(r'select name from sqlite_master where rootpage = 0 and \( sql like \'%VIRTUAL TABLE%USING FTS%content="(.*)"', sql)
+            if self.connector.detect_fts(match.group(1)):
+                results = [{'name': match.group(1)}]
+        elif sql.startswith("PRAGMA foreign_key_list(["):
+            match = re.search(r'PRAGMA foreign_key_list\(\[(.*)\]\)', sql)
+            results = self.connector.foreign_keys(match.group(1))
+        elif sql == "select 1 from sqlite_master where type='table' and name=?":
+            if self.connector.table_exists(params[0]):
+                results = [{'1': '1'}]
+        elif sql == "select sql from sqlite_master where name = :n and type=:t":
+            results = [{'sql': self.connector.table_definition(params['t'], params['n'])}]
+        elif sql == "select sql from sqlite_master where tbl_name = :n and type='index' and sql is not null":
+            results = [{'sql': sql} for sql in self.connector.indices_definition(params['n'])]
+        else:
+            try:
+                results, truncated, description = \
+                    self.connector.execute(
+                        sql,
+                        params=params,
+                        truncate=truncate,
+                        custom_time_limit=custom_time_limit,
+                        page_size=page_size,
+                        log_sql_errors=log_sql_errors,
+                    )
+            except OperationalError as ex:
+                raise sqlite3.OperationalError(*ex.args)
+
+        self.rows = [Row(result) for result in results]
+        self.description = description
+
+    def fetchall(self):
+        return self.rows
+
+    def fetchmany(self, max):
+        return self.rows[:max]
+
+    def __getitem__(self, index):
+        return self.rows[index]
index 5fd04703bf676b047a04c6a493341ab16ade2abf..0fb4e1c7b8d583d4e5e353952bb325122c679c52 100644 (file)
@@ -17,36 +17,6 @@ def patch_datasette():
     Monkey patching for original Datasette
     """
 
-    async def table_columns(self, table):
-        try:
-            return await self.original_table_columns(table)
-        except sqlite3.DatabaseError:
-            return ConnectorList.table_columns(self.path, table)
-
-    Database.original_table_columns = Database.table_columns
-    Database.table_columns = table_columns
-
-
-    async def primary_keys(self, table):
-        try:
-            return await self.original_primary_keys(table)
-        except sqlite3.DatabaseError:
-            return ConnectorList.primary_keys(self.path, table)
-
-    Database.original_primary_keys = Database.primary_keys
-    Database.primary_keys = primary_keys
-
-
-    async def fts_table(self, table):
-        try:
-            return await self.original_fts_table(table)
-        except sqlite3.DatabaseError:
-            return ConnectorList.fts_table(self.path, table)
-
-    Database.original_fts_table = Database.fts_table
-    Database.fts_table = fts_table
-
-
     def connect(self, write=False):
         try:
             # Check if it's a sqlite database