diff --git a/graphdb/graphdb.go b/graphdb/graphdb.go index 6234203..c6f13ed 100644 --- a/graphdb/graphdb.go +++ b/graphdb/graphdb.go @@ -79,46 +79,43 @@ func NewDatabase(conn *sql.DB) (*Database, error) { } db := &Database{conn: conn} - if _, err := conn.Exec(createEntityTable); err != nil { - return nil, err - } - if _, err := conn.Exec(createEdgeTable); err != nil { - return nil, err - } - if _, err := conn.Exec(createEdgeIndices); err != nil { - return nil, err - } - - rollback := func() { - conn.Exec("ROLLBACK") - } - // Create root entities - if _, err := conn.Exec("BEGIN"); err != nil { + tx, err := conn.Begin() + if err != nil { return nil, err } - if _, err := conn.Exec("DELETE FROM entity where id = ?", "0"); err != nil { - rollback() + if _, err := tx.Exec(createEntityTable); err != nil { + return nil, err + } + if _, err := tx.Exec(createEdgeTable); err != nil { + return nil, err + } + if _, err := tx.Exec(createEdgeIndices); err != nil { return nil, err } - if _, err := conn.Exec("INSERT INTO entity (id) VALUES (?);", "0"); err != nil { - rollback() + if _, err := tx.Exec("DELETE FROM entity where id = ?", "0"); err != nil { + tx.Rollback() return nil, err } - if _, err := conn.Exec("DELETE FROM edge where entity_id=? and name=?", "0", "/"); err != nil { - rollback() + if _, err := tx.Exec("INSERT INTO entity (id) VALUES (?);", "0"); err != nil { + tx.Rollback() return nil, err } - if _, err := conn.Exec("INSERT INTO edge (entity_id, name) VALUES(?,?);", "0", "/"); err != nil { - rollback() + if _, err := tx.Exec("DELETE FROM edge where entity_id=? and name=?", "0", "/"); err != nil { + tx.Rollback() return nil, err } - if _, err := conn.Exec("COMMIT"); err != nil { + if _, err := tx.Exec("INSERT INTO edge (entity_id, name) VALUES(?,?);", "0", "/"); err != nil { + tx.Rollback() + return nil, err + } + + if err := tx.Commit(); err != nil { return nil, err } @@ -135,33 +132,32 @@ func (db *Database) Set(fullPath, id string) (*Entity, error) { db.mux.Lock() defer db.mux.Unlock() - rollback := func() { - db.conn.Exec("ROLLBACK") - } - if _, err := db.conn.Exec("BEGIN EXCLUSIVE"); err != nil { + tx, err := db.conn.Begin() + if err != nil { return nil, err } + var entityID string - if err := db.conn.QueryRow("SELECT id FROM entity WHERE id = ?;", id).Scan(&entityID); err != nil { + if err := tx.QueryRow("SELECT id FROM entity WHERE id = ?;", id).Scan(&entityID); err != nil { if err == sql.ErrNoRows { - if _, err := db.conn.Exec("INSERT INTO entity (id) VALUES(?);", id); err != nil { - rollback() + if _, err := tx.Exec("INSERT INTO entity (id) VALUES(?);", id); err != nil { + tx.Rollback() return nil, err } } else { - rollback() + tx.Rollback() return nil, err } } e := &Entity{id} parentPath, name := splitPath(fullPath) - if err := db.setEdge(parentPath, name, e); err != nil { - rollback() + if err := db.setEdge(parentPath, name, e, tx); err != nil { + tx.Rollback() return nil, err } - if _, err := db.conn.Exec("COMMIT"); err != nil { + if err := tx.Commit(); err != nil { return nil, err } return e, nil @@ -179,7 +175,7 @@ func (db *Database) Exists(name string) bool { return e != nil } -func (db *Database) setEdge(parentPath, name string, e *Entity) error { +func (db *Database) setEdge(parentPath, name string, e *Entity, tx *sql.Tx) error { parent, err := db.get(parentPath) if err != nil { return err @@ -188,7 +184,7 @@ func (db *Database) setEdge(parentPath, name string, e *Entity) error { return fmt.Errorf("Cannot set self as child") } - if _, err := db.conn.Exec("INSERT INTO edge (parent_id, name, entity_id) VALUES (?,?,?);", parent.id, name, e.id); err != nil { + if _, err := tx.Exec("INSERT INTO edge (parent_id, name, entity_id) VALUES (?,?,?);", parent.id, name, e.id); err != nil { return err } return nil @@ -371,18 +367,15 @@ func (db *Database) Purge(id string) (int, error) { db.mux.Lock() defer db.mux.Unlock() - rollback := func() { - db.conn.Exec("ROLLBACK") - } - - if _, err := db.conn.Exec("BEGIN"); err != nil { + tx, err := db.conn.Begin() + if err != nil { return -1, err } // Delete all edges - rows, err := db.conn.Exec("DELETE FROM edge WHERE entity_id = ?;", id) + rows, err := tx.Exec("DELETE FROM edge WHERE entity_id = ?;", id) if err != nil { - rollback() + tx.Rollback() return -1, err } @@ -392,14 +385,15 @@ func (db *Database) Purge(id string) (int, error) { } // Delete entity - if _, err := db.conn.Exec("DELETE FROM entity where id = ?;", id); err != nil { - rollback() + if _, err := tx.Exec("DELETE FROM entity where id = ?;", id); err != nil { + tx.Rollback() return -1, err } - if _, err := db.conn.Exec("COMMIT"); err != nil { + if err := tx.Commit(); err != nil { return -1, err } + return int(changes), nil }