1 from collections import OrderedDict
5 _connector_type = 'pytables'
8 "Open file and return tables info"
11 h5file = tables.open_file(path)
13 for table in filter(lambda node: not(isinstance(node, tables.group.Group)), h5file):
15 if isinstance(table, tables.table.Table):
16 colnames = table.colnames
18 h5tables[table._v_pathname] = {
19 'name': table._v_pathname,
26 'foreign_keys': {'incoming': [], 'outgoing': []},
30 return h5tables, views, _connector_type
33 parsed = sqlparse.parse(sql)
37 for token in stmt.tokens:
39 if current_keyword in parsed_sql and parsed_sql[current_keyword] == '':
40 # Check composed keywords like 'order by'
41 del parsed_sql[current_keyword]
42 current_keyword += " " + str(token)
44 current_keyword = str(token)
45 parsed_sql[current_keyword] = ""
46 elif type(token) is sqlparse.sql.Where:
47 parsed_sql['where'] = token
49 if not token.is_whitespace:
50 parsed_sql[current_keyword] += str(token)
57 def _translate_condition(table, condition, params):
58 field = condition.left.get_real_name()
60 operator = list(filter(lambda t: t.ttype == sqlparse.tokens.Comparison, condition.tokens))[0]
61 if operator.value in _operators:
62 operator = _operators[operator.value]
64 operator = operator.value
66 value = condition.right.value
67 if value.startswith(':'):
68 # Value is a parameters
71 # Cast value to the column type
72 coltype = table.coltypes[field]
73 if coltype == 'string':
74 params[value] = str(params[value])
75 elif coltype.startswith('int'):
76 params[value] = int(params[value])
77 elif coltype.startswith('float'):
78 params[value] = float(params[value])
80 translated = "{left} {operator} {right}".format(left=field, operator=operator, right=value)
81 return translated, params
84 def __init__(self, path):
86 self.h5file = tables.open_file(path)
88 def execute(self, sql, params=None, truncate=False):
95 parsed_sql = _parse_sql(sql)
96 table = self.h5file.get_node(parsed_sql['from'][1:-1])
98 fields = parsed_sql['select'].split(',')
100 # Use 'where' statement or get all the rows
101 if 'where' in parsed_sql:
105 for condition in parsed_sql['where'].get_sublists():
106 if str(condition) == '"rowid"=:p0':
107 start = int(params['p0'])
110 translated, params = _translate_condition(table, condition, params)
111 query.append(translated)
113 query = ') & ('.join(query)
114 query = '(' + query + ')'
115 table_rows = table.where(query, params, start, end)
117 table_rows = table.iterrows(start, end)
119 table_rows = table.iterrows()
122 if len(fields) == 1 and fields[0] == 'count(*)':
123 rows.append(Row({fields[0]: table.nrows}))
125 for table_row in table_rows:
129 row[field] = table_row.nrow
131 for col in table.colnames:
132 value = table_row[col]
133 if type(value) is bytes:
134 value = value.decode('utf-8')
137 row[field] = table_row[field]
140 # Prepare query description
143 for col in table.colnames:
144 description.append((col,))
146 description.append((field,))
150 return rows, truncated, tuple(description)
154 class Row(OrderedDict):
155 def __getitem__(self, label):
156 if type(label) is int:
157 return super(OrderedDict, self).__getitem__(list(self.keys())[label])
159 return super(OrderedDict, self).__getitem__(label)
162 return self.values().__iter__()