DEV Community

Cover image for A clean way to implement database transaction in Golang
TECH SCHOOL
TECH SCHOOL

Posted on • Edited on

A clean way to implement database transaction in Golang

Hi and welcome back!

In the previous lectures, we’ve learned how to write golang codes to perform CRUD operations on each individual table of the simple bank database.

But in a real world application, we often have to perform a transaction that combines some operations from several tables. Today we will learn a clean way to implement it in Golang.

Here's:

Database transaction

Before we jump into coding, let’s talk a bit about transaction!

What is a DB transaction?

Well, basically, it’s a single unit of work that’s often made up of multiple database operations.

what-is-db-tx

For example, in our simple bank, we want to transfer 10 USD from account 1 to account 2.

tx-example

This transaction comprises 5 operations:

  1. We create a transfer record with amount equals to 10.
  2. We create an entry record for account 1 with amount equals to -10, since money is moving out of this account.
  3. We create another entry record for account 2, but with amount equals to 10, because money is moving in to this account.
  4. Then we update the balance of account 1 by subtracting 10 from it.
  5. And finally we update the balance of account 2 by adding 10 to it.

This is the transaction that we’re going to implement in this article. We will come to that in a moment.

Why do we need to use DB transaction?

why-tx

There are 2 main reasons:

  1. We want our unit of work to be reliable and consistent, even in case of system failure.
  2. We want to provide isolation between programs that access the database concurrently.

ACID properties

In order to achieve these 2 goals, a database transaction must satisfy the ACID properties, where:

acid

  • A is Atomicity, which means either all operations of the transaction complete successfully, or the whole transaction fails, and everything is rolled back, the database is unchanged.
  • C is Consistency, which means the database state should remains valid after the transaction is executed. More precisely, all data written to the database must be valid according to predefined rules, including constraints, cascades, and triggers.
  • I is Isolation, meaning all transactions that run concurrently should not affect each other. There are several levels of isolation that defines when the changes made by 1 transaction can be visible to others. We will learn more about it in another lecture.
  • The last property is D, which stands for Durability. It basically means that all data written by a successful transaction must stay in a persistent storage and cannot be lost, even in case of system failure.

How to run a SQL DB transaction?

how-to-run-tx

It’s pretty simple:

  • We start a transaction with the BEGIN statement.
  • Then we write a series of normal SQL queries (or operations).
  • If all of them are successful, We COMMIT the transaction to make it permanent, the database will be changed to a new state.
  • Otherwise, if any query fails, we ROLLBACK the transaction, thus all changes made by previous queries of the transaction will be gone, and the database stays the same as it was before the transaction.

Alright, now we has some basic understanding about database transaction. Let’s learn how to implement it in Golang.

Implement DB transaction in Go

I’m gonna create a new file store.go inside the db/sqlc folder. In this file, let’s define a new Store struct.



type Store struct {
}


Enter fullscreen mode Exit fullscreen mode

This store will provide all functions to run database queries individually, as well as their combination within a transaction.

Use composition to extend Queries' functionality

For individual queries, we already have the Queries struct generated by sqlc that we have learn in previous lectures.

However, each query only do 1 operation on 1 specific table. So Queries struct doesn’t support transaction. That’s why we have to extend its functionality by embedding it inside the Store struct like this:



type Store struct {
    *Queries
}


Enter fullscreen mode Exit fullscreen mode

This is called a composition, and it is the preferred way to extend struct functionality in Golang instead of inheritance. By embedding Queries inside Store, all individual query functions provided by Queries will be available to Store.

We can support transaction by adding more functions to this new struct. In order to do so, we need the Store to have a sql.DB object because it is required to create a new db transaction.



type Store struct {
    *Queries
    db *sql.DB
}


Enter fullscreen mode Exit fullscreen mode

Create a new Store

OK, now let’s add a function to create a new Store object. It will takes a sql.DB as input, and returns a Store. Inside, we just build a new Store object and return it.



func NewStore(db *sql.DB) *Store {
    return &Store{
        db:      db,
        Queries: New(db),
    }
}


Enter fullscreen mode Exit fullscreen mode

Here db is the input sql.DB, and Queries is created by calling the New() function with that db object. The New() function was generated by sqlc as we’ve already known in the previous lectures. It creates and returns a Queries object.

Execute a generic DB transaction

Next, we will add a function to the Store to execute a generic database transaction.



func (store *Store) execTx(ctx context.Context, fn func(*Queries) error) error {
    ...
}


Enter fullscreen mode Exit fullscreen mode

The idea is simple: it takes a context and a callback function as input, then it will start a new db transaction, create a new Queries object with that transaction, call the callback function with the created Queries, and finally commit or rollback the transaction based on the error returned by that function.

Let’s implement this!

First, to start a new transaction, we call store.db.BeginTx(), pass in the context, and optionally a sql.TxOptions.



tx, err := store.db.BeginTx(ctx, &sql.TxOptions{})


Enter fullscreen mode Exit fullscreen mode

This option allows us to set a custom isolation level for this transaction.



type TxOptions struct {
    Isolation IsolationLevel
    ReadOnly  bool
}


Enter fullscreen mode Exit fullscreen mode

If we don’t set it explicitly, then the default isolation level of the database server will be used, which is read-committed in case of Postgres.

We will learn more about this in another lecture. For now, let’s just pass nil here to use the default value.



func (store *Store) execTx(ctx context.Context, fn func(*Queries) error) error {
    tx, err := store.db.BeginTx(ctx, nil)
    if err != nil {
        return err
    }

    q := New(tx)
    ...
}


Enter fullscreen mode Exit fullscreen mode

The BeginTx() function returns a transaction object or an error. If error is not nil, we just return it immediately. Otherwise, we call New() function with the created transaction tx, and get back a new Queries object.

This is the same New() function that we used in the NewStore() function. The only difference is, instead of passing in a sql.DB, we’re now passing in a sql.Tx object. This works because the New() function accepts a DBTX interface as we’ve seen in the previous lecture:



type DBTX interface {
    ExecContext(context.Context, string, ...interface{}) (sql.Result, error)
    PrepareContext(context.Context, string) (*sql.Stmt, error)
    QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error)
    QueryRowContext(context.Context, string, ...interface{}) *sql.Row
}

func New(db DBTX) *Queries {
    return &Queries{db: db}
}


Enter fullscreen mode Exit fullscreen mode

OK, now we have the queries that runs within transaction, we can call the input function with that queries, and get back an error.

If the error is not nil, then we need to rollback the transaction by calling tx.Rollback(). It also returns a rollback error.

If the rollback error is also not nil, then we have to report 2 errors. So we should combine them into 1 single error using fmt.Errorf() before returning:



func (store *Store) execTx(ctx context.Context, fn func(*Queries) error) error {
    tx, err := store.db.BeginTx(ctx, &sql.TxOptions)
    if err != nil {
        return err
    }

    q := New(tx)
    err = fn(q)
    if err != nil {
        if rbErr := tx.Rollback(); rbErr != nil {
            return fmt.Errorf("tx err: %v, rb err: %v", err, rbErr)
        }
        return err
    }

    return tx.Commit()
}


Enter fullscreen mode Exit fullscreen mode

In case the rollback is successful, we just return the original transaction error.

Finally, if all operations in the transaction are successful, we simply commit the transaction with tx.Commit(), and return its error to the caller.

And we’re done with the execTx() function. Note that this function is unexported (it starts with a lowercase letter e), because we don’t want external package to call it directly. Instead, we will provide an exported function for each specific transaction.

Implement money transfer transaction

Now let’s go a head and add a new TransferTx() function to perform the money transfer transaction example that we discussed at the beginning of the video.

To recall, it will create a new transfer record, add 2 new account entries, and update the 2 accounts’ balance within a single database transaction.

The input of this function will be a context and an argument object of type TransferTxParams. And it will return a TransferTxResult object or an error.



func (store *Store) TransferTx(ctx context.Context, arg TransferTxParams) (TransferTxResult, error) {
}


Enter fullscreen mode Exit fullscreen mode

The TransferTxParams struct contains all necessary input parameters to transfer money between 2 accounts:



type TransferTxParams struct {
    FromAccountID int64 `json:"from_account_id"`
    ToAccountID   int64 `json:"to_account_id"`
    Amount        int64 `json:"amount"`
    }


Enter fullscreen mode Exit fullscreen mode
  • FromAccountID is the ID of the account where money will be sent from.
  • ToAccountID is the ID of the account where money will be sent to.
  • And the last field is the Amount of money to be sent.

The TransferTxResult struct contains the result of the transfer transaction. It has 5 fields:



type TransferTxResult struct {
    Transfer    Transfer `json:"transfer"`
    FromAccount Account  `json:"from_account"`
    ToAccount   Account  `json:"to_account"`
    FromEntry   Entry    `json:"from_entry"`
    ToEntry     Entry    `json:"to_entry"`
}


Enter fullscreen mode Exit fullscreen mode
  • The created Transfer record.
  • The FromAccount after its balance is subtracted.
  • The ToAccount after its its balance is added.
  • The FromEntry of the account which records that money is moving out of the FromAccount.
  • And the ToEntry of the account which records that money is moving in to the ToAccount.

Alright, now we can implement the transfer transaction. First we create an empty result. Then we call the store.execTx() function that we’ve written before to create and run a new database transaction.



func (store *Store) TransferTx(ctx context.Context, arg TransferTxParams) (TransferTxResult, error) {
    var result TransferTxResult

    err := store.execTx(ctx, func(q *Queries) error {
        ...
        return nil
    })

    return result, err
}


Enter fullscreen mode Exit fullscreen mode

We pass in the context and the callback function. For now let’s just return nil in this callback. Finally we return the result and the error of the execTx() call.

Now let’s come back to implement the callback function. Basically, we can use the queries object q to call any individual CRUD function that it provides.

Keep in mind that this queries object is created from 1 single database transaction, so all of its provided methods that we call will be run within that transaction.

Alright, let’s create the transfer record by calling q.CreateTransfer(), pass in the input context, and a CreateTransferParams argument:



func (store *Store) TransferTx(ctx context.Context, arg TransferTxParams) (TransferTxResult, error) {
    ...

    err := store.execTx(ctx, func(q *Queries) error {
        var err error

        result.Transfer, err = q.CreateTransfer(ctx, CreateTransferParams{
            FromAccountID: arg.FromAccountID,
            ToAccountID:   arg.ToAccountID,
            Amount:        arg.Amount,
        })
        if err != nil {
            return err
        }

        ...

        return nil
    })

    return result, err
}


Enter fullscreen mode Exit fullscreen mode

The output transfer will be saved to result.Transfer or an error. If error is not nil, we just return it right away.

Here you can see that we’re accessing the result variable of the outer function from inside this callback function. Similar for the arg variable.

This makes the callback function become a closure. Since Go lacks support for generics type, closure is often used when we want to get the result from a callback function, because the callback function itself doesn’t know the exact type of the result it should return.

OK so the 1st step to create a transfer record is done. Next step is to add 2 account entries: 1 for the FromAccount, and 1 for the ToAccount.

We call q.CreateAccountEntry(), pass in the context and a CreateAccountEntryParams, where AccountID is arg.FromAccountID, and Amount is -arg.Amount because money is moving out of this account.



func (store *Store) TransferTx(ctx context.Context, arg TransferTxParams) (TransferTxResult, error) {
    ...

    err := store.execTx(ctx, func(q *Queries) error {
        var err error

        result.Transfer, err = q.CreateTransfer(ctx, CreateTransferParams{
            FromAccountID: arg.FromAccountID,
            ToAccountID:   arg.ToAccountID,
            Amount:        arg.Amount,
        })
        if err != nil {
            return err
        }

        result.FromEntry, err = q.CreateEntry(ctx, CreateEntryParams{
            AccountID: arg.FromAccountID,
            Amount:    -arg.Amount,
        })
        if err != nil {
            return err
        }

        ...

        return nil
    })

    return result, err
}


Enter fullscreen mode Exit fullscreen mode

And just like before, if error is not nil, we just return it so that the transaction will be rolled back.

We do similar thing to create an account entry for the to account. But this time, it is result.ToEntry, the AccountID is arg.ToAccountID, and the Amount is just arg.Amount since money is moving in to this account.



func (store *Store) TransferTx(ctx context.Context, arg TransferTxParams) (TransferTxResult, error) {
    ...

    err := store.execTx(ctx, func(q *Queries) error {
        var err error

        result.Transfer, err = q.CreateTransfer(ctx, CreateTransferParams{
            FromAccountID: arg.FromAccountID,
            ToAccountID:   arg.ToAccountID,
            Amount:        arg.Amount,
        })
        if err != nil {
            return err
        }

        result.FromEntry, err = q.CreateEntry(ctx, CreateEntryParams{
            AccountID: arg.FromAccountID,
            Amount:    -arg.Amount,
        })
        if err != nil {
            return err
        }

        result.ToEntry, err = q.CreateEntry(ctx, CreateEntryParams{
            AccountID: arg.ToAccountID,
            Amount:    arg.Amount,
        })
        if err != nil {
            return err
        }

        return nil

        // TODO: update accounts' balance
    })

    return result, err
}


Enter fullscreen mode Exit fullscreen mode

And we’re done with the account entries creation. The last step to update account balance will be more complicated because it involves locking and preventing potential deadlock.

So I think it’s worth a separate lecture to talk about it in details. Let's add a TODO comment here, and we will come back to implement it in the next one.

Test money transfer transaction

For now let’s say our transfer transaction is done with 1 transfer record and 2 account entries are created. We have to test it to make sure that it’s working as expected.

I’m gonna create a new store_test.go file. It’s in the same db package as our store.go. Then let’s define a new unit test for the TransferTx() function.



func TestTransferTx(t *testing.T) {
    ...
}


Enter fullscreen mode Exit fullscreen mode

First we need to create a new Store object. The NewStore() function requires a sql.DB object.

If you still remember, in the previous lecture, we have already created a sql.DB object in the main_test.go file with this sql.Open() function call:



func TestMain(m *testing.M) {
    conn, err := sql.Open(dbDriver, dbSource)
    ...
}


Enter fullscreen mode Exit fullscreen mode

So in order to reuse it, here instead of assigning the result to the connection variable, we will declare a new global variable: testDB, and store the result of the sql.Open() command in it.



var testQueries *Queries
var testDB *sql.DB

func TestMain(m *testing.M) {
    var err error
    testDB, err = sql.Open(dbDriver, dbSource)
    if err != nil {
        log.Fatal("cannot connect to db:", err)
    }

    testQueries = New(testDB)

    os.Exit(m.Run())
}


Enter fullscreen mode Exit fullscreen mode

OK, now we can come back to our unit test and pass the testDB object into the NewStore() function to create a new Store:



func TestTransferTx(t *testing.T) {
    store := NewStore(testDB)

    account1 := createRandomAccount(t)
    account2 := createRandomAccount(t)

    ...
}


Enter fullscreen mode Exit fullscreen mode

Next, we create 2 random accounts using the createRandomAccount() function we wrote in the previous lecture. We will send money from account 1 to account 2.

From my experience, writing database transaction is something we must always be very careful with. It can be easy to write, but can also easily become a nightmare if we don’t handle the concurrency carefully.

So the best way to make sure that our transaction works well is to run it with several concurrent go routines.

Let’s say I want to run n = 5 concurrent transfer transactions, and each of them will transfer an amount of 10 from account 1 to account 2. So I will use a simple for loop with n iterations:



func TestTransferTx(t *testing.T) {
    store := NewStore(testDB)

    account1 := createRandomAccount(t)
    account2 := createRandomAccount(t)

    n := 5
    amount := int64(10)

    // run n concurrent transfer transaction
    for i := 0; i < n; i++ {
        go func() {
            result, err := store.TransferTx(context.Background(), TransferTxParams{
                FromAccountID: account1.ID,
                ToAccountID:   account2.ID,
                Amount:        amount,
            })

            ...
        }()
    }

    ...
}


Enter fullscreen mode Exit fullscreen mode

And inside the loop, we use the go keyword to start a new routine. Inside the go routine, we call store.TransferTx() function with a background context and a TransferTxParams object, where FromAccountID is account1.ID, ToAccountID is account2.ID, and Amount is 10 as we declared amount = 10 above.

This function returns a result or an error. We cannot just use testify/require to check them right here because this function is running inside a different go routine from the one that our TestTransferTx function is running on, so there’s no guarantee that it will stop the whole test if a condition is not satisfied.

The correct way to verify the error and result is to send them back to the main go routine that our test is running on, and check them from there.

To do that, we can use channels. Channel is designed to connect concurrent Go routines, and allow them to safely share data with each other without explicit locking.

In our case, we need 1 channel to receive the errors, and 1 other channel to receive the TransferTxResult. We use the make keyword to create the channel.



func TestTransferTx(t *testing.T) {
    ...

    // run n concurrent transfer transaction
    errs := make(chan error)
    results := make(chan TransferTxResult)

    for i := 0; i < n; i++ {
        go func() {
            result, err := store.TransferTx(context.Background(), TransferTxParams{
                FromAccountID: account1.ID,
                ToAccountID:   account2.ID,
                Amount:        amount,
            })

            errs <- err
            results <- result
        }()
    }
}


Enter fullscreen mode Exit fullscreen mode

Now, inside the go routine, we can send err to the errs channel using this arrow operator <-. The channel should be on the left, and data to be sent should be on the right of the arrow operator.

Similarly, we send result to the results channel. Then, we will check these errors and results from outside.

We simply run a for loop of n iterations. To receive the error from the channel, we use the same arrow operator, but this time, the channel is on the right of the arrow, and the variable to store the received data is on the left.



func TestTransferTx(t *testing.T) {
    ...

    // run n concurrent transfer transaction
    ...

    // check results
    for i := 0; i < n; i++ {
        err := <-errs
        require.NoError(t, err)

        result := <-results
        require.NotEmpty(t, result)

        ...
    }
}


Enter fullscreen mode Exit fullscreen mode

We require no errors here, which means the received err should be nil. Likewise, we receive result from the results channel and check that result is not an empty object.

As result contains several objects inside, let’s verify each of them. Start with the result.Transfer:



func TestTransferTx(t *testing.T) {
    ...

    // check results
    for i := 0; i < n; i++ {
        err := <-errs
        require.NoError(t, err)

        result := <-results
        require.NotEmpty(t, result)

        // check transfer
        transfer := result.Transfer
        require.NotEmpty(t, transfer)
        require.Equal(t, account1.ID, transfer.FromAccountID)
        require.Equal(t, account2.ID, transfer.ToAccountID)
        require.Equal(t, amount, transfer.Amount)
        require.NotZero(t, transfer.ID)
        require.NotZero(t, transfer.CreatedAt)

        ...
    }
}


Enter fullscreen mode Exit fullscreen mode

We require this transfer object to be not empty. Then the FromAccountID field of transfer should equal to account1.ID, the ToAccountID field of transfer should equal to account2.ID, and transfer.Amount should equal to the input amount.

The ID field of transfer should not be zero because it’s an auto-increment field. And finally transfer.CreatedAt should not be a zero value because we expect the database to fill in the default value, which is the current timestamp.

Now to be sure that a transfer record is really created in the database, we should call store.GetTransfer() to find a record with ID equals to transfer.ID:



func TestTransferTx(t *testing.T) {
    ...

    // check results
    for i := 0; i < n; i++ {
        // check transfer
        ...

        _, err = store.GetTransfer(context.Background(), transfer.ID)
        require.NoError(t, err)

        ...
    }
}



Enter fullscreen mode Exit fullscreen mode

Here you can see that, because the Queries object is embedded inside the Store, the GetTransfer() function of Queries is also available to the Store.

If the transfer really exists, this function should not return an error, so we require no error here.

Next we will check the account entries of the result. First, the FromEntry:



func TestTransferTx(t *testing.T) {
    ...

    // check results
    for i := 0; i < n; i++ {
        // check transfer
        ...

        // check entries
        fromEntry := result.FromEntry
        require.NotEmpty(t, fromEntry)
        require.Equal(t, account1.ID, fromEntry.AccountID)
        require.Equal(t, -amount, fromEntry.Amount)
        require.NotZero(t, fromEntry.ID)
        require.NotZero(t, fromEntry.CreatedAt)

        _, err = store.GetEntry(context.Background(), fromEntry.ID)
        require.NoError(t, err)

        ...
    }
}


Enter fullscreen mode Exit fullscreen mode

Just like before, we check that it should not be empty, The account ID should be account1.ID and the Amount of the entry should equal to -amount because money is going out. Finally the ID and created at fields of the entry should be not zero.

We also try to get the entry from the database to make sure that it’s really got created.

Checking the to entry is similar. So I just copy this block of code and change these variable and field names to toEntry.



func TestTransferTx(t *testing.T) {
    ...

    // check results
    for i := 0; i < n; i++ {
        // check transfer
        ...

        // check entries
        ...

        toEntry := result.ToEntry
        require.NotEmpty(t, toEntry)
        require.Equal(t, account2.ID, toEntry.AccountID)
        require.Equal(t, amount, toEntry.Amount)
        require.NotZero(t, toEntry.ID)
        require.NotZero(t, toEntry.CreatedAt)

        _, err = store.GetEntry(context.Background(), toEntry.ID)
        require.NoError(t, err)

        // TODO: check accounts' balance
    }
}


Enter fullscreen mode Exit fullscreen mode

The account ID should be account2.ID instead of account1.ID. And the Amount should be amount instead of -amount because money is going in.

In the end, we should get the toEntry record from the database instead of fromEntry.

Now keep in mind that we should also check the accounts’ balance as well. But since we haven’t implemented the part to update accounts’ balance yet, let’s just add a TODO comment here for now, and we will complete it in the next lecture.

Alright, now the test is ready Let’s try to run it.

test-passed

It passed! Excellent!

Let’s run the whole package tests.

all-passed

All passed! The coverage is about 80%, which is very good.

And that wraps up today’s lecture about how to implement database transaction in Golang. I hope you enjoy it.

You can try to implement the update account balance yourself while waiting for the next lecture.

Happy coding and see you in the next article!


If you like the article, please subscribe to our Youtube channel and follow us on Twitter for more tutorials in the future.


If you want to join me on my current amazing team at Voodoo, check out our job openings here. Remote or onsite in Paris/Amsterdam/London/Berlin/Barcelona with visa sponsorship.

Top comments (4)

Collapse
 
ahmedarmohamed profile image
Ahmed Mohamed

I'm getting
=== RUN TestTransferTx
panic: test timed out after 30s
How do I go about it?

Collapse
 
long_lngc_d60d5bbe92c4 profile image
Long Lê Ngọc

You should replace this code in main_test.go:

testDB, err := sql.Open(dbDriver, dbSource) with testDB, err = sql.Open(dbDriver, dbSource).
Because testDB and err were declare, so you shouldn't use :=, just use =

Collapse
 
volocanh profile image
Vo Loc Anh

I got this panic when run TestTransferTx?
panic: runtime error: invalid memory address or nil pointer dereference
How do I go about it?

Collapse
 
husnaram profile image
Husna Ramadan

if you still stuck here, you can fixed it.
i got same and my fail. here we go.
in filename main_test.go at main test function i typed:

func TestMain(m *testing.M) {
...
    testDB, err := sql.Open(dbDriver, dbSource)
...
}
Enter fullscreen mode Exit fullscreen mode

i should type it, without colon:

testDB, err = sql.Open(dbDriver, dbSource)
Enter fullscreen mode Exit fullscreen mode