]> git.jsancho.org Git - datasette-connectors.git/blobdiff - datasette_connectors/connectors.py
Overwriting Connector class is enough to operate with
[datasette-connectors.git] / datasette_connectors / connectors.py
index aa8f824cf4d0a51234107c7801e0fd2d3fa8d7a4..f09727f93bc07f59e60343ea3fcf78867e6ac1da 100644 (file)
@@ -1,5 +1,10 @@
 import pkg_resources
 import functools
+import re
+import sqlite3
+
+from .row import Row
+
 
 db_connectors = {}
 
@@ -26,76 +31,216 @@ class ConnectorList:
     def add_connector(name, connector):
         db_connectors[name] = connector
 
-    @staticmethod
-    @for_each_connector
-    def table_names(connector, path):
-        return connector.table_names(path)
-
-    @staticmethod
-    @for_each_connector
-    def hidden_table_names(connector, path):
-        return connector.hidden_table_names(path)
+    class DatabaseNotSupported(Exception):
+        pass
 
     @staticmethod
-    @for_each_connector
-    def view_names(connector, path):
-        return connector.view_names(path)
+    def connect(path):
+        for connector in db_connectors.values():
+            try:
+                return connector.connect(path)
+            except:
+                pass
+        else:
+            raise ConnectorList.DatabaseNotSupported
+
+
+class Connection:
+    def __init__(self, path, connector):
+        self.path = path
+        self.connector = connector
+
+    def execute(self, *args, **kwargs):
+        cursor = Cursor(self)
+        cursor.execute(*args, **kwargs)
+        return cursor
+
+    def cursor(self):
+        return Cursor(self)
+
+    def set_progress_handler(self, handler, n):
+        pass
+
+
+class OperationalError(Exception):
+    pass
+
+
+class Cursor:
+    class QueryNotSupported(Exception):
+        pass
+
+    def __init__(self, conn):
+        self.conn = conn
+        self.connector = conn.connector
+        self.rows = []
+        self.description = ()
+
+    def execute(
+        self,
+        sql,
+        params=None,
+        truncate=False,
+        custom_time_limit=None,
+        page_size=None,
+        log_sql_errors=True,
+    ):
+        if params is None:
+            params = {}
+        results = []
+        truncated = False
+        description = ()
+
+        # Normalize sql
+        sql = sql.strip()
+        sql = ' '.join(sql.split())
+
+        if sql == "select name from sqlite_master where type='table'" or \
+           sql == "select name from sqlite_master where type=\"table\"":
+            results = [{'name': name} for name in self.connector.table_names()]
+        elif sql == "select name from sqlite_master where rootpage = 0 and sql like '%VIRTUAL TABLE%USING FTS%'":
+            results = [{'name': name} for name in self.connector.hidden_table_names()]
+        elif sql == 'select 1 from sqlite_master where tbl_name = "geometry_columns"':
+            if self.connector.detect_spatialite():
+                results = [{'1': '1'}]
+        elif sql == "select name from sqlite_master where type='view'":
+            results = [{'name': name} for name in self.connector.view_names()]
+        elif sql.startswith("select count(*) from ["):
+            match = re.search(r'select count\(\*\) from \[(.*)\]', sql)
+            results = [{'count(*)': self.connector.table_count(match.group(1))}]
+        elif sql.startswith("select count(*) from "):
+            match = re.search(r'select count\(\*\) from (.*)', sql)
+            results = [{'count(*)': self.connector.table_count(match.group(1))}]
+        elif sql.startswith("PRAGMA table_info("):
+            match = re.search(r'PRAGMA table_info\((.*)\)', sql)
+            results = self.connector.table_info(match.group(1))
+        elif sql.startswith("select name from sqlite_master where rootpage = 0 and ( sql like \'%VIRTUAL TABLE%USING FTS%content="):
+            match = re.search(r'select name from sqlite_master where rootpage = 0 and \( sql like \'%VIRTUAL TABLE%USING FTS%content="(.*)"', sql)
+            if self.connector.detect_fts(match.group(1)):
+                results = [{'name': match.group(1)}]
+        elif sql.startswith("PRAGMA foreign_key_list(["):
+            match = re.search(r'PRAGMA foreign_key_list\(\[(.*)\]\)', sql)
+            results = self.connector.foreign_keys(match.group(1))
+        elif sql == "select 1 from sqlite_master where type='table' and name=?":
+            if self.connector.table_exists(params[0]):
+                results = [{'1': '1'}]
+        elif sql == "select sql from sqlite_master where name = :n and type=:t":
+            results = [{'sql': self.connector.table_definition(params['t'], params['n'])}]
+        elif sql == "select sql from sqlite_master where tbl_name = :n and type='index' and sql is not null":
+            results = [{'sql': sql} for sql in self.connector.indices_definition(params['n'])]
+        else:
+            try:
+                results, truncated, description = \
+                    self.connector.execute(
+                        sql,
+                        params=params,
+                        truncate=truncate,
+                        custom_time_limit=custom_time_limit,
+                        page_size=page_size,
+                        log_sql_errors=log_sql_errors,
+                    )
+            except OperationalError as ex:
+                raise sqlite3.OperationalError(*ex.args)
 
-    @staticmethod
-    @for_each_connector
-    def table_columns(connector, path, table):
-        return connector.table_columns(path, table)
+        self.rows = [Row(result) for result in results]
+        self.description = description
 
-    @staticmethod
-    @for_each_connector
-    def primary_keys(connector, path, table):
-        return connector.primary_keys(path, table)
+    def fetchall(self):
+        return self.rows
 
-    @staticmethod
-    @for_each_connector
-    def fts_table(connector, path, table):
-        return connector.fts_table(path, table)
+    def fetchmany(self, max):
+        return self.rows[:max]
 
-    @staticmethod
-    @for_each_connector
-    def get_all_foreign_keys(connector, path):
-        return connector.get_all_foreign_keys(path)
-
-    @staticmethod
-    @for_each_connector
-    def table_counts(connector, path, *args, **kwargs):
-        return connector.table_counts(path, *args, **kwargs)
+    def __getitem__(self, index):
+        return self.rows[index]
 
 
 class Connector:
-    @staticmethod
-    def table_names(path):
-        return []
-
-    @staticmethod
-    def hidden_table_names(path):
-        return []
-
-    @staticmethod
-    def view_names(path):
-        return []
-
-    @staticmethod
-    def table_columns(path, table):
-        return []
-
-    @staticmethod
-    def primary_keys(path, table):
-        return []
-
-    @staticmethod
-    def fts_table(path, table):
-        return None
-
-    @staticmethod
-    def get_all_foreign_keys(path):
-        return {}
-
-    @staticmethod
-    def table_counts(path, *args, **kwargs):
-        return {}
+    connector_type = None
+    connection_class = Connection
+
+    def connect(self, path):
+        return self.connection_class(path, self)
+
+    def table_names(self):
+        """
+        Return a list of table names
+        """
+        raise NotImplementedError
+
+    def hidden_table_names(self):
+        raise NotImplementedError
+
+    def detect_spatialite(self):
+        """
+        Return boolean indicating if geometry_columns exists
+        """
+        raise NotImplementedError
+
+    def view_names(self):
+        """
+        Return a list of view names
+        """
+        raise NotImplementedError
+
+    def table_count(self, table_name):
+        """
+        Return an integer with the rows count of the table
+        """
+        raise NotImplementedError
+
+    def table_info(self, table_name):
+        """
+        Return a list of dictionaries with columns description, with format:
+        [
+            {
+                'idx': 0,
+                'name': 'column1',
+                'primary_key': False,
+            },
+            ...
+        ]
+        """
+        raise NotImplementedError
+
+    def detect_fts(self, table_name):
+        """
+        Return boolean indicating if table has a corresponding FTS virtual table
+        """
+        raise NotImplementedError
+
+    def foreign_keys(self, table_name):
+        """
+        Return a list of dictionaries with foreign keys description
+        id, seq, table_name, from_, to_, on_update, on_delete, match
+        """
+        raise NotImplementedError
+
+    def table_exists(self, table_name):
+        """
+        Return boolean indicating if table exists in the database
+        """
+        raise NotImplementedError
+
+    def table_definition(self, table_type, table_name):
+        """
+        Return string with a 'CREATE TABLE' sql definition
+        """
+        raise NotImplementedError
+
+    def indices_definition(self, table_name):
+        """
+        Return a list of strings with 'CREATE INDEX' sql definitions
+        """
+        raise NotImplementedError
+
+    def execute(
+        self,
+        sql,
+        params=None,
+        truncate=False,
+        custom_time_limit=None,
+        page_size=None,
+        log_sql_errors=True,
+    ):
+        raise NotImplementedError