import uuid

# Note: These savepoint classes are implemented in peewee, but not the version we have. Copied here from https://github.com/coleifer/peewee/blob/b657d08a14e4cdafee417111ccba62ede9344222/peewee.py

class savepoint(object):
    def __init__(self, db, sid=None):
        self.db = db
        _compiler = db.compiler()
        self.sid = sid or 's' + uuid.uuid4().get_hex()
        self.quoted_sid = _compiler.quote(self.sid)

    def _execute(self, query):
        self.db.execute_sql(query, require_commit=False)

    def commit(self):
        self._execute('RELEASE SAVEPOINT %s;' % self.quoted_sid)

    def rollback(self):
        self._execute('ROLLBACK TO SAVEPOINT %s;' % self.quoted_sid)

    def __enter__(self):
        self._orig_autocommit = self.db.get_autocommit()
        self.db.set_autocommit(False)
        self._execute('SAVEPOINT %s;' % self.quoted_sid)
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        try:
            if exc_type:
                self.rollback()
            else:
                try:
                    self.commit()
                except:
                    self.rollback()
                    raise
        finally:
            self.db.set_autocommit(self._orig_autocommit)


class savepoint_sqlite(savepoint):
    def __enter__(self):
        conn = self.db.get_conn()
        # For sqlite, the connection's isolation_level *must* be set to None.
        # The act of setting it, though, will break any existing savepoints,
        # so only write to it if necessary.
        if conn.isolation_level is not None:
            self._orig_isolation_level = conn.isolation_level
            conn.isolation_level = None
        else:
            self._orig_isolation_level = None
        super(savepoint_sqlite, self).__enter__()

    def __exit__(self, exc_type, exc_val, exc_tb):
        try:
            return super(savepoint_sqlite, self).__exit__(
                exc_type, exc_val, exc_tb)
        finally:
            if self._orig_isolation_level is not None:
                self.db.get_conn().isolation_level = self._orig_isolation_level