]> git.jsancho.org Git - datasette-connectors.git/blob - datasette_connectors/cursor.py
Clean code
[datasette-connectors.git] / datasette_connectors / cursor.py
1 import re
2 import sqlite3
3
4 from .row import Row
5
6
7 class OperationalError(Exception):
8     pass
9
10
11 class Cursor:
12     class QueryNotSupported(Exception):
13         pass
14
15     def __init__(self, conn):
16         self.conn = conn
17         self.connector = conn.connector
18         self.rows = []
19         self.description = ()
20
21     def execute(
22         self,
23         sql,
24         params=None,
25         truncate=False,
26         custom_time_limit=None,
27         page_size=None,
28         log_sql_errors=True,
29     ):
30         if params is None:
31             params = {}
32         results = []
33         truncated = False
34         description = ()
35
36         # Normalize sql
37         sql = sql.strip()
38         sql = ' '.join(sql.split())
39
40         if sql == "select name from sqlite_master where type='table'" or \
41            sql == "select name from sqlite_master where type=\"table\"":
42             results = [{'name': name} for name in self.connector.table_names()]
43         elif sql == "select name from sqlite_master where rootpage = 0 and sql like '%VIRTUAL TABLE%USING FTS%'":
44             results = [{'name': name} for name in self.connector.hidden_table_names()]
45         elif sql == 'select 1 from sqlite_master where tbl_name = "geometry_columns"':
46             if self.connector.detect_spatialite():
47                 results = [{'1': '1'}]
48         elif sql == "select name from sqlite_master where type='view'":
49             results = [{'name': name} for name in self.connector.view_names()]
50         elif sql.startswith("select count(*) from ["):
51             match = re.search(r'select count\(\*\) from \[(.*)\]', sql)
52             results = [{'count(*)': self.connector.table_count(match.group(1))}]
53         elif sql.startswith("select count(*) from "):
54             match = re.search(r'select count\(\*\) from (.*)', sql)
55             results = [{'count(*)': self.connector.table_count(match.group(1))}]
56         elif sql.startswith("PRAGMA table_info("):
57             match = re.search(r'PRAGMA table_info\((.*)\)', sql)
58             results = self.connector.table_info(match.group(1))
59         elif sql.startswith("select name from sqlite_master where rootpage = 0 and ( sql like \'%VIRTUAL TABLE%USING FTS%content="):
60             match = re.search(r'select name from sqlite_master where rootpage = 0 and \( sql like \'%VIRTUAL TABLE%USING FTS%content="(.*)"', sql)
61             if self.connector.detect_fts(match.group(1)):
62                 results = [{'name': match.group(1)}]
63         elif sql.startswith("PRAGMA foreign_key_list(["):
64             match = re.search(r'PRAGMA foreign_key_list\(\[(.*)\]\)', sql)
65             results = self.connector.foreign_keys(match.group(1))
66         elif sql == "select 1 from sqlite_master where type='table' and name=?":
67             if self.connector.table_exists(params[0]):
68                 results = [{'1': '1'}]
69         elif sql == "select sql from sqlite_master where name = :n and type=:t":
70             results = [{'sql': self.connector.table_definition(params['t'], params['n'])}]
71         elif sql == "select sql from sqlite_master where tbl_name = :n and type='index' and sql is not null":
72             results = [{'sql': sql} for sql in self.connector.indices_definition(params['n'])]
73         else:
74             try:
75                 results, truncated, description = \
76                     self.connector.execute(
77                         sql,
78                         params=params,
79                         truncate=truncate,
80                         custom_time_limit=custom_time_limit,
81                         page_size=page_size,
82                         log_sql_errors=log_sql_errors,
83                     )
84             except OperationalError as ex:
85                 raise sqlite3.OperationalError(*ex.args)
86
87         self.rows = [Row(result) for result in results]
88         self.description = description
89
90     def fetchall(self):
91         return self.rows
92
93     def fetchmany(self, max):
94         return self.rows[:max]
95
96     def __getitem__(self, index):
97         return self.rows[index]