X-Git-Url: https://git.jsancho.org/?p=datasette-connectors.git;a=blobdiff_plain;f=datasette_connectors%2Fconnectors.py;h=f09727f93bc07f59e60343ea3fcf78867e6ac1da;hp=aa8f824cf4d0a51234107c7801e0fd2d3fa8d7a4;hb=5c00383b9044ca27de9c51a511962ffad65ed5f3;hpb=52416a749fac092a032a8b5239e477dd68180dfa diff --git a/datasette_connectors/connectors.py b/datasette_connectors/connectors.py index aa8f824..f09727f 100644 --- a/datasette_connectors/connectors.py +++ b/datasette_connectors/connectors.py @@ -1,5 +1,10 @@ import pkg_resources import functools +import re +import sqlite3 + +from .row import Row + db_connectors = {} @@ -26,76 +31,216 @@ class ConnectorList: def add_connector(name, connector): db_connectors[name] = connector - @staticmethod - @for_each_connector - def table_names(connector, path): - return connector.table_names(path) - - @staticmethod - @for_each_connector - def hidden_table_names(connector, path): - return connector.hidden_table_names(path) + class DatabaseNotSupported(Exception): + pass @staticmethod - @for_each_connector - def view_names(connector, path): - return connector.view_names(path) + def connect(path): + for connector in db_connectors.values(): + try: + return connector.connect(path) + except: + pass + else: + raise ConnectorList.DatabaseNotSupported + + +class Connection: + def __init__(self, path, connector): + self.path = path + self.connector = connector + + def execute(self, *args, **kwargs): + cursor = Cursor(self) + cursor.execute(*args, **kwargs) + return cursor + + def cursor(self): + return Cursor(self) + + def set_progress_handler(self, handler, n): + pass + + +class OperationalError(Exception): + pass + + +class Cursor: + class QueryNotSupported(Exception): + pass + + def __init__(self, conn): + self.conn = conn + self.connector = conn.connector + self.rows = [] + self.description = () + + def execute( + self, + sql, + params=None, + truncate=False, + custom_time_limit=None, + page_size=None, + log_sql_errors=True, + ): + if params is None: + params = {} + results = [] + truncated = False + description = () + + # Normalize sql + sql = sql.strip() + sql = ' '.join(sql.split()) + + if sql == "select name from sqlite_master where type='table'" or \ + sql == "select name from sqlite_master where type=\"table\"": + results = [{'name': name} for name in self.connector.table_names()] + elif sql == "select name from sqlite_master where rootpage = 0 and sql like '%VIRTUAL TABLE%USING FTS%'": + results = [{'name': name} for name in self.connector.hidden_table_names()] + elif sql == 'select 1 from sqlite_master where tbl_name = "geometry_columns"': + if self.connector.detect_spatialite(): + results = [{'1': '1'}] + elif sql == "select name from sqlite_master where type='view'": + results = [{'name': name} for name in self.connector.view_names()] + elif sql.startswith("select count(*) from ["): + match = re.search(r'select count\(\*\) from \[(.*)\]', sql) + results = [{'count(*)': self.connector.table_count(match.group(1))}] + elif sql.startswith("select count(*) from "): + match = re.search(r'select count\(\*\) from (.*)', sql) + results = [{'count(*)': self.connector.table_count(match.group(1))}] + elif sql.startswith("PRAGMA table_info("): + match = re.search(r'PRAGMA table_info\((.*)\)', sql) + results = self.connector.table_info(match.group(1)) + elif sql.startswith("select name from sqlite_master where rootpage = 0 and ( sql like \'%VIRTUAL TABLE%USING FTS%content="): + match = re.search(r'select name from sqlite_master where rootpage = 0 and \( sql like \'%VIRTUAL TABLE%USING FTS%content="(.*)"', sql) + if self.connector.detect_fts(match.group(1)): + results = [{'name': match.group(1)}] + elif sql.startswith("PRAGMA foreign_key_list(["): + match = re.search(r'PRAGMA foreign_key_list\(\[(.*)\]\)', sql) + results = self.connector.foreign_keys(match.group(1)) + elif sql == "select 1 from sqlite_master where type='table' and name=?": + if self.connector.table_exists(params[0]): + results = [{'1': '1'}] + elif sql == "select sql from sqlite_master where name = :n and type=:t": + results = [{'sql': self.connector.table_definition(params['t'], params['n'])}] + elif sql == "select sql from sqlite_master where tbl_name = :n and type='index' and sql is not null": + results = [{'sql': sql} for sql in self.connector.indices_definition(params['n'])] + else: + try: + results, truncated, description = \ + self.connector.execute( + sql, + params=params, + truncate=truncate, + custom_time_limit=custom_time_limit, + page_size=page_size, + log_sql_errors=log_sql_errors, + ) + except OperationalError as ex: + raise sqlite3.OperationalError(*ex.args) - @staticmethod - @for_each_connector - def table_columns(connector, path, table): - return connector.table_columns(path, table) + self.rows = [Row(result) for result in results] + self.description = description - @staticmethod - @for_each_connector - def primary_keys(connector, path, table): - return connector.primary_keys(path, table) + def fetchall(self): + return self.rows - @staticmethod - @for_each_connector - def fts_table(connector, path, table): - return connector.fts_table(path, table) + def fetchmany(self, max): + return self.rows[:max] - @staticmethod - @for_each_connector - def get_all_foreign_keys(connector, path): - return connector.get_all_foreign_keys(path) - - @staticmethod - @for_each_connector - def table_counts(connector, path, *args, **kwargs): - return connector.table_counts(path, *args, **kwargs) + def __getitem__(self, index): + return self.rows[index] class Connector: - @staticmethod - def table_names(path): - return [] - - @staticmethod - def hidden_table_names(path): - return [] - - @staticmethod - def view_names(path): - return [] - - @staticmethod - def table_columns(path, table): - return [] - - @staticmethod - def primary_keys(path, table): - return [] - - @staticmethod - def fts_table(path, table): - return None - - @staticmethod - def get_all_foreign_keys(path): - return {} - - @staticmethod - def table_counts(path, *args, **kwargs): - return {} + connector_type = None + connection_class = Connection + + def connect(self, path): + return self.connection_class(path, self) + + def table_names(self): + """ + Return a list of table names + """ + raise NotImplementedError + + def hidden_table_names(self): + raise NotImplementedError + + def detect_spatialite(self): + """ + Return boolean indicating if geometry_columns exists + """ + raise NotImplementedError + + def view_names(self): + """ + Return a list of view names + """ + raise NotImplementedError + + def table_count(self, table_name): + """ + Return an integer with the rows count of the table + """ + raise NotImplementedError + + def table_info(self, table_name): + """ + Return a list of dictionaries with columns description, with format: + [ + { + 'idx': 0, + 'name': 'column1', + 'primary_key': False, + }, + ... + ] + """ + raise NotImplementedError + + def detect_fts(self, table_name): + """ + Return boolean indicating if table has a corresponding FTS virtual table + """ + raise NotImplementedError + + def foreign_keys(self, table_name): + """ + Return a list of dictionaries with foreign keys description + id, seq, table_name, from_, to_, on_update, on_delete, match + """ + raise NotImplementedError + + def table_exists(self, table_name): + """ + Return boolean indicating if table exists in the database + """ + raise NotImplementedError + + def table_definition(self, table_type, table_name): + """ + Return string with a 'CREATE TABLE' sql definition + """ + raise NotImplementedError + + def indices_definition(self, table_name): + """ + Return a list of strings with 'CREATE INDEX' sql definitions + """ + raise NotImplementedError + + def execute( + self, + sql, + params=None, + truncate=False, + custom_time_limit=None, + page_size=None, + log_sql_errors=True, + ): + raise NotImplementedError