From 03608c5c39b1b837fc4cf6248311b8ddad1dda2e Mon Sep 17 00:00:00 2001
From: Javier Sancho <jsf@jsancho.org>
Date: Thu, 20 Mar 2014 16:06:44 +0100
Subject: [PATCH] Savepoints to protect integrity of documents

---
 MySQL.py      | 12 ++++++++++++
 collection.py | 16 +++++++++++-----
 connection.py |  9 +++++++++
 3 files changed, 32 insertions(+), 5 deletions(-)

diff --git a/MySQL.py b/MySQL.py
index ba00c33..03969f7 100644
--- a/MySQL.py
+++ b/MySQL.py
@@ -171,3 +171,15 @@ class Connection(connection.Connection):
 
     def rollback(self):
         self._db_con.rollback()
+
+    def savepoint(self, name):
+        self.execute("SAVEPOINT %s" % name)
+        return True
+
+    def commit_savepoint(self, name):
+        self.execute("RELEASE SAVEPOINT %s" % name)
+        return True
+
+    def rollback_savepoint(self, name):
+        self.execute("ROLLBACK TO %s" % name)
+        return True
diff --git a/collection.py b/collection.py
index b8d51a0..8041353 100644
--- a/collection.py
+++ b/collection.py
@@ -75,11 +75,17 @@ class Collection(object):
 
     def _insert_document(self, doc_id, doc):
         fields = self._get_fields()
-        for f in doc:
-            if not f in fields:
-                self._create_field(f)
-            table_f = '%s$%s' % (self.table_name, f)
-            self._insert_field(doc_id, table_f, f, doc[f])
+        self.database.connection.savepoint("insert_document")
+        try:
+            for f in doc:
+                if not f in fields:
+                    self._create_field(f)
+                table_f = '%s$%s' % (self.table_name, f)
+                self._insert_field(doc_id, table_f, f, doc[f])
+            self.database.connection.commit_savepoint("insert_document")
+        except:
+            self.database.connection.rollback_savepoint("insert_document")
+            raise
 
     def _insert_field(self, doc_id, field_table, field_name, field_value):
         values = {
diff --git a/connection.py b/connection.py
index 7623047..c3f5647 100644
--- a/connection.py
+++ b/connection.py
@@ -93,3 +93,12 @@ class Connection(object):
 
     def rollback(self):
         pass
+
+    def savepoint(self, name):
+        pass
+
+    def commit_savepoint(self, name):
+        pass
+
+    def rollback_savepoint(self, name):
+        pass
-- 
2.39.5