From 3cc49f23a9f3c0e8cb2b7eb707382c6ae708c1f4 Mon Sep 17 00:00:00 2001 From: Javier Sancho Date: Fri, 18 Sep 2020 11:10:19 +0200 Subject: [PATCH] Clean code --- datasette_connectors/__init__.py | 3 +- datasette_connectors/connection.py | 18 ++++ datasette_connectors/connectors.py | 129 +---------------------------- datasette_connectors/cursor.py | 97 ++++++++++++++++++++++ datasette_connectors/monkey.py | 30 ------- 5 files changed, 118 insertions(+), 159 deletions(-) create mode 100644 datasette_connectors/connection.py create mode 100644 datasette_connectors/cursor.py diff --git a/datasette_connectors/__init__.py b/datasette_connectors/__init__.py index 21cf96b..834e398 100644 --- a/datasette_connectors/__init__.py +++ b/datasette_connectors/__init__.py @@ -1 +1,2 @@ -from .connectors import Connector, OperationalError +from .connectors import Connector +from .cursor import OperationalError diff --git a/datasette_connectors/connection.py b/datasette_connectors/connection.py new file mode 100644 index 0000000..7761506 --- /dev/null +++ b/datasette_connectors/connection.py @@ -0,0 +1,18 @@ +from .cursor import Cursor + + +class Connection: + def __init__(self, path, connector): + self.path = path + self.connector = connector + + def execute(self, *args, **kwargs): + cursor = Cursor(self) + cursor.execute(*args, **kwargs) + return cursor + + def cursor(self): + return Cursor(self) + + def set_progress_handler(self, handler, n): + pass diff --git a/datasette_connectors/connectors.py b/datasette_connectors/connectors.py index f09727f..bdba215 100644 --- a/datasette_connectors/connectors.py +++ b/datasette_connectors/connectors.py @@ -1,25 +1,8 @@ -import pkg_resources -import functools -import re -import sqlite3 - -from .row import Row +from .connection import Connection db_connectors = {} -def for_each_connector(func): - @functools.wraps(func) - def wrapper_for_each_connector(*args, **kwargs): - for connector in db_connectors.values(): - try: - return func(connector, *args, **kwargs) - except: - pass - else: - raise Exception("No database connector found!!") - return wrapper_for_each_connector - class ConnectorList: @staticmethod @@ -45,116 +28,6 @@ class ConnectorList: raise ConnectorList.DatabaseNotSupported -class Connection: - def __init__(self, path, connector): - self.path = path - self.connector = connector - - def execute(self, *args, **kwargs): - cursor = Cursor(self) - cursor.execute(*args, **kwargs) - return cursor - - def cursor(self): - return Cursor(self) - - def set_progress_handler(self, handler, n): - pass - - -class OperationalError(Exception): - pass - - -class Cursor: - class QueryNotSupported(Exception): - pass - - def __init__(self, conn): - self.conn = conn - self.connector = conn.connector - self.rows = [] - self.description = () - - def execute( - self, - sql, - params=None, - truncate=False, - custom_time_limit=None, - page_size=None, - log_sql_errors=True, - ): - if params is None: - params = {} - results = [] - truncated = False - description = () - - # Normalize sql - sql = sql.strip() - sql = ' '.join(sql.split()) - - if sql == "select name from sqlite_master where type='table'" or \ - sql == "select name from sqlite_master where type=\"table\"": - results = [{'name': name} for name in self.connector.table_names()] - elif sql == "select name from sqlite_master where rootpage = 0 and sql like '%VIRTUAL TABLE%USING FTS%'": - results = [{'name': name} for name in self.connector.hidden_table_names()] - elif sql == 'select 1 from sqlite_master where tbl_name = "geometry_columns"': - if self.connector.detect_spatialite(): - results = [{'1': '1'}] - elif sql == "select name from sqlite_master where type='view'": - results = [{'name': name} for name in self.connector.view_names()] - elif sql.startswith("select count(*) from ["): - match = re.search(r'select count\(\*\) from \[(.*)\]', sql) - results = [{'count(*)': self.connector.table_count(match.group(1))}] - elif sql.startswith("select count(*) from "): - match = re.search(r'select count\(\*\) from (.*)', sql) - results = [{'count(*)': self.connector.table_count(match.group(1))}] - elif sql.startswith("PRAGMA table_info("): - match = re.search(r'PRAGMA table_info\((.*)\)', sql) - results = self.connector.table_info(match.group(1)) - elif sql.startswith("select name from sqlite_master where rootpage = 0 and ( sql like \'%VIRTUAL TABLE%USING FTS%content="): - match = re.search(r'select name from sqlite_master where rootpage = 0 and \( sql like \'%VIRTUAL TABLE%USING FTS%content="(.*)"', sql) - if self.connector.detect_fts(match.group(1)): - results = [{'name': match.group(1)}] - elif sql.startswith("PRAGMA foreign_key_list(["): - match = re.search(r'PRAGMA foreign_key_list\(\[(.*)\]\)', sql) - results = self.connector.foreign_keys(match.group(1)) - elif sql == "select 1 from sqlite_master where type='table' and name=?": - if self.connector.table_exists(params[0]): - results = [{'1': '1'}] - elif sql == "select sql from sqlite_master where name = :n and type=:t": - results = [{'sql': self.connector.table_definition(params['t'], params['n'])}] - elif sql == "select sql from sqlite_master where tbl_name = :n and type='index' and sql is not null": - results = [{'sql': sql} for sql in self.connector.indices_definition(params['n'])] - else: - try: - results, truncated, description = \ - self.connector.execute( - sql, - params=params, - truncate=truncate, - custom_time_limit=custom_time_limit, - page_size=page_size, - log_sql_errors=log_sql_errors, - ) - except OperationalError as ex: - raise sqlite3.OperationalError(*ex.args) - - self.rows = [Row(result) for result in results] - self.description = description - - def fetchall(self): - return self.rows - - def fetchmany(self, max): - return self.rows[:max] - - def __getitem__(self, index): - return self.rows[index] - - class Connector: connector_type = None connection_class = Connection diff --git a/datasette_connectors/cursor.py b/datasette_connectors/cursor.py new file mode 100644 index 0000000..6facf54 --- /dev/null +++ b/datasette_connectors/cursor.py @@ -0,0 +1,97 @@ +import re +import sqlite3 + +from .row import Row + + +class OperationalError(Exception): + pass + + +class Cursor: + class QueryNotSupported(Exception): + pass + + def __init__(self, conn): + self.conn = conn + self.connector = conn.connector + self.rows = [] + self.description = () + + def execute( + self, + sql, + params=None, + truncate=False, + custom_time_limit=None, + page_size=None, + log_sql_errors=True, + ): + if params is None: + params = {} + results = [] + truncated = False + description = () + + # Normalize sql + sql = sql.strip() + sql = ' '.join(sql.split()) + + if sql == "select name from sqlite_master where type='table'" or \ + sql == "select name from sqlite_master where type=\"table\"": + results = [{'name': name} for name in self.connector.table_names()] + elif sql == "select name from sqlite_master where rootpage = 0 and sql like '%VIRTUAL TABLE%USING FTS%'": + results = [{'name': name} for name in self.connector.hidden_table_names()] + elif sql == 'select 1 from sqlite_master where tbl_name = "geometry_columns"': + if self.connector.detect_spatialite(): + results = [{'1': '1'}] + elif sql == "select name from sqlite_master where type='view'": + results = [{'name': name} for name in self.connector.view_names()] + elif sql.startswith("select count(*) from ["): + match = re.search(r'select count\(\*\) from \[(.*)\]', sql) + results = [{'count(*)': self.connector.table_count(match.group(1))}] + elif sql.startswith("select count(*) from "): + match = re.search(r'select count\(\*\) from (.*)', sql) + results = [{'count(*)': self.connector.table_count(match.group(1))}] + elif sql.startswith("PRAGMA table_info("): + match = re.search(r'PRAGMA table_info\((.*)\)', sql) + results = self.connector.table_info(match.group(1)) + elif sql.startswith("select name from sqlite_master where rootpage = 0 and ( sql like \'%VIRTUAL TABLE%USING FTS%content="): + match = re.search(r'select name from sqlite_master where rootpage = 0 and \( sql like \'%VIRTUAL TABLE%USING FTS%content="(.*)"', sql) + if self.connector.detect_fts(match.group(1)): + results = [{'name': match.group(1)}] + elif sql.startswith("PRAGMA foreign_key_list(["): + match = re.search(r'PRAGMA foreign_key_list\(\[(.*)\]\)', sql) + results = self.connector.foreign_keys(match.group(1)) + elif sql == "select 1 from sqlite_master where type='table' and name=?": + if self.connector.table_exists(params[0]): + results = [{'1': '1'}] + elif sql == "select sql from sqlite_master where name = :n and type=:t": + results = [{'sql': self.connector.table_definition(params['t'], params['n'])}] + elif sql == "select sql from sqlite_master where tbl_name = :n and type='index' and sql is not null": + results = [{'sql': sql} for sql in self.connector.indices_definition(params['n'])] + else: + try: + results, truncated, description = \ + self.connector.execute( + sql, + params=params, + truncate=truncate, + custom_time_limit=custom_time_limit, + page_size=page_size, + log_sql_errors=log_sql_errors, + ) + except OperationalError as ex: + raise sqlite3.OperationalError(*ex.args) + + self.rows = [Row(result) for result in results] + self.description = description + + def fetchall(self): + return self.rows + + def fetchmany(self, max): + return self.rows[:max] + + def __getitem__(self, index): + return self.rows[index] diff --git a/datasette_connectors/monkey.py b/datasette_connectors/monkey.py index 5fd0470..0fb4e1c 100644 --- a/datasette_connectors/monkey.py +++ b/datasette_connectors/monkey.py @@ -17,36 +17,6 @@ def patch_datasette(): Monkey patching for original Datasette """ - async def table_columns(self, table): - try: - return await self.original_table_columns(table) - except sqlite3.DatabaseError: - return ConnectorList.table_columns(self.path, table) - - Database.original_table_columns = Database.table_columns - Database.table_columns = table_columns - - - async def primary_keys(self, table): - try: - return await self.original_primary_keys(table) - except sqlite3.DatabaseError: - return ConnectorList.primary_keys(self.path, table) - - Database.original_primary_keys = Database.primary_keys - Database.primary_keys = primary_keys - - - async def fts_table(self, table): - try: - return await self.original_fts_table(table) - except sqlite3.DatabaseError: - return ConnectorList.fts_table(self.path, table) - - Database.original_fts_table = Database.fts_table - Database.fts_table = fts_table - - def connect(self, write=False): try: # Check if it's a sqlite database -- 2.39.5