Synchronization Patterns in Go

By Noam Yadgar [Tue, 21 Mar 2023 06:12:09 GMT] (#8)
noam.g4@gmail.com

This article is a bit advanced, for more basic lesson click here

The Go programming language is all about concurrency. Spawning multiple
goroutines, truly reveals the power of Go programs. As it lets us divide
our program into concurrent (“parallel”) processes.
But as you may know, opening the door to multithreaded programming comes
with a whole set of new issues we, as developers and software engineers,
have to consider. Most of them are related to shared memory and synchronization.

From identifying the issues and applying the pattern,
here are some things we can do in Golang,
to write better-multithreaded programs.

Mutex (Mutual Exclusion)

There are too many places online that formally explain what Mutex is.
So let’s not do it. Let me not explain what a Mutex is.
Instead, let’s try to walk this path from the other side. We’ll derive the need
for such a mechanism, and hopefully, this will help you understand what Mutex is
all about. Check out this code:

package main

import (
	"fmt"
)

var counter = 0

func main() {
	for i := 0; i < 1000; i++ {
		go func() {
			counter++
		}()
	}

	fmt.Println(counter)
}

This main function, spawns 1000 goroutines, each incrementing the counter variable by one.
The main function finishes by printing the counter value. So…

  • 1000 goroutines
  • Each is incrementing the counter by one (starting from zero)
  • Printing the counter value

We should see the output 1000 right?

> go run main.go 
973

Wrong! I think you know why. This is not synchronous code.
Our 1000 goroutines are all running concurrently, which means that the line:
fmt.Println(counter) may have been called, before all 1000 goroutines have finished.

Let’s fix it, by using a sync.WaitGroup
fmt.Println(counter) will wait for all goroutines to finish:

package main

import (
	"fmt"
	"sync"
)

var (
	n       = 1000
	counter = 0
	wg      = &sync.WaitGroup{}
)

func main() {
	wg.Add(n)
	for i := 0; i < n; i++ {
		go func() {
			defer wg.Done()
			counter++
		}()
	}

	wg.Wait()
	fmt.Println(counter)
}

wg.Wait() will block the main goroutine until all 1000 goroutines will
call wg.Done(). Cool, let’s run it:

> go run main.go 
896

What the…? Still? Shouldn’t counter have the value of 1000?
Why waiting for all goroutines is not enough?

My friend, this is the most classical form of a Race Condition.
You see, goroutines don’t know about each other, even if they’re all accessing
the same resources (such as the counter variable). Multiple goroutines
are accessing the counter variable simultaneously.

Since counter++ is just syntactic sugar for counter = counter + 1,
some goroutines will read the same value from counter and therefore,
set the incremented value to be the same.

Solution

We need a mechanism that allows goroutines to get a hold of a certain “flag”,
using an atomic operation. Atomic, will ensure that only one
goroutine can invoke the operation at a time. Whoever is holding this “flag”,
will block others, until the flag will be released. Technically speaking,
a goroutine that will invoke the function mutex.Lock(), will be blocked until
the mutex “flag” is released. Once released (with mutex.Unlock()), the blocked
goroutine will be able to acquire the “flag” and proceed.

If we’ll put a Mutex lock in front of the line counter++, we can ensure that
only a single goroutine can access the counter variable at a time.

package main

import (
	"fmt"
	"sync"
)

var (
	n       = 1000
	counter = 0
	wg      = &sync.WaitGroup{}
	mutex   = &sync.Mutex{}
)

func main() {
	wg.Add(n)
	for i := 0; i < n; i++ {
		go func() {
			defer mutex.Unlock()
			defer wg.Done()
			mutex.Lock()
			counter++
		}()
	}

	wg.Wait()
	fmt.Println(counter)
}
> go run main.go
1000 

This program will return 1000 every time, guaranteed.

Mutex is fairly simple, let’s dive deeper

Goroutines aren’t free.
Writing super fast, multithreaded programs can be very exciting,
so exciting that it’s sometimes easy to forget:
Our code is running in a limited resource environment.

Suppose we have a veryExpensive() function with the following specs:

  • It takes about 2 seconds to complete
  • It uses about 1Mb of memory
  • It runs in an environment with 1Gb of memory

Invoking go veryExpensive() a thousand times, can easily reach
the memory limit. Therefore, it’s important to know how to control the number
of running goroutines at a time.

Semaphore

Check out this code…

package main

import (
	"fmt"
	"runtime"
	"time"
)

const n = 10

func main() {

	for i := 0; i < n; i++ {
		go func() {
			time.Sleep(time.Second)
		}()

		// print the number of running goroutines without the main goroutine
		fmt.Printf("currently running: %d goroutines\n", runtime.NumGoroutine()-1)
	}
}

We are spawning 10 goroutines, each takes a second to complete and,
by using runtime.NumGoroutine()-1, we are printing the number of currently
running goroutines. How many goroutines will run simultaneously at the most?

> go run main.go 
currently running: 1 goroutines
currently running: 2 goroutines
currently running: 3 goroutines
currently running: 4 goroutines
currently running: 5 goroutines
currently running: 6 goroutines
currently running: 7 goroutines
currently running: 8 goroutines
currently running: 9 goroutines
currently running: 10 goroutines

If you thought 10 (or more accurately n), well done.
A second for most computers is a pretty long time. Without restricting the
number of goroutines, you can easily reach n goroutines. In fact,
n can probably be a lot bigger before our for loop will take longer than
a single goroutine to complete.

But what if we don’t know the size of n in advance? How can we make sure we’re
not reaching resources limit?

package main

import (
	"context"
	"fmt"
	"runtime"
	"time"

	"golang.org/x/sync/semaphore"
)

const (
	n          = 10
	goroutines = 2
)

var sm = semaphore.NewWeighted(goroutines)

func main() {

	ctx := context.Background()
	for i := 0; i < n; i++ {
		sm.Acquire(ctx, 1)
		go func() {
			defer sm.Release(1)
			time.Sleep(time.Second)
		}()

		// print the number of running goroutines without the main goroutine
		fmt.Printf("currently running: %d goroutines\n", runtime.NumGoroutine()-1)

	}
}

A Semaphore is very similar to a Mutex, but with a key difference:

To put it simply, while a Mutex has a single “flag”, a Semaphore has a
number of “seats” (usually represented as a primitive type such as an integer counter).
So if const goroutines is set to 2, the third goroutine that will invoke
sm.Acquire(ctx, 1) will be blocked, until one of the “seats” will be available.
Technically speaking, only 2 goroutines can invoke sm.Acquire(ctx, 1) and proceed
at a time. The rest will have to wait until sm.Release(1) will be invoked.

> go run main.go 
currently running: 1 goroutines
currently running: 2 goroutines
currently running: 2 goroutines
currently running: 2 goroutines
currently running: 2 goroutines
currently running: 2 goroutines
currently running: 2 goroutines
currently running: 2 goroutines
currently running: 2 goroutines
currently running: 2 goroutines

The first version of this program will take about a second to end,
while this version will take about 5 seconds.
We are paying the price of performance for program stability and resiliency.
In this kind of trade-off, it’s important to find a good balance.

Note that there’s a subtle difference of choice here.
I’ve decided to call the sm.Acquire function, outside of the goroutine,
right from the main goroutine, this blocks the main goroutine from
adding more goroutines. We can also choose to spawn the goroutine and block
its process.

Go Channels

Unlike Mutex and Semaphore (which are general concepts in computer science),
Channels are an exclusive concept of Go, designed specifically for communication and
synchronization between goroutines.

While it’s quite easy to mimic the behavior of Mutex, Semaphore and a WaitGroup
using channels, there are a lot of cases where using channels may eliminate the need
for such patterns. As channels direct us towards simpler yet effective designs.

package main

import "fmt"

const n = 1000
var counter = 0

func main() {
	c := make(chan int)
	defer close(c)

	for i := 0; i < n; i++ {
		go func() {
			c <- 1
		}()
	}

	for counter < n {
		counter += <-c
	}

	fmt.Println(counter)
}

This process is similar to the Mutex example from earlier,
but arguably better in design:

  • We don’t need a Mutex lock
  • We don’t need a WaitGroup
  • Goroutines are not reading and reassigning shared memory
  • Fewer dependencies and steps

Channels have a very efficient blocking mechanism. So whenever a goroutine
is trying to send to a full channel, it will be blocked until some other goroutine
will read the value from the channel. Kind of like a pipe, if we close one end
of the pipe, the pipe will be filled until the flow will stop.

Let’s have some fun with Channels

We’ll write a program that requests a bunch of web pages and sums up the number
of occurrences, the word “the” shows up on all the pages.

But instead of requesting and processing each web page synchronously,
We will put every request in its own goroutine, and accumulate
the results into our main goroutine.

multithread

We’ll start by defining what word we’re counting (in this case “the”),
defining a set of URLs we would like to request (2 of them are invalid),
and, writing a function to get the web page’s body as a string

package main

import (
	"fmt"
	"io"
	"net/http"
)

const word = "the"

var urls = []string{
	"https://code-pilot.me/how-ive-made-this-platform",
	"https://code-pilot.me/making-a-beautiful-error-handler-in-go",
	"https://code-pilot.me/mastering-goroutines-and-channels",
	"https://code-pilot.me/why-should-you-curry",
	"https://code-pilot.me/not-a-real-page",
	"not@validURL",
}

func get(url string) (string, error) {
	res, err := http.Get(url)
	if err != nil {
		return "", err
	}

	if res.StatusCode == 404 {
		return "", fmt.Errorf("not found: %s", url)
	}

	body, err := io.ReadAll(res.Body)
	if err != nil {
		return "", err
	}

	return string(body), nil
}

Now let’s declare 3 channels:

  • resChan to receive the sum of occurrences on each page
  • errChan to receive errors if there are any
  • sigChan to receive OS signals such as INTERRUPT (Ctrl+C)
func main() {
   
  resChan := make(chan int)
	errChan := make(chan error)
	sigChan := make(chan os.Signal, 1)

	signal.Notify(sigChan, os.Interrupt)

}

Let’s spawn each get(url) in a separate goroutine,
if there’s an error, we’ll send this error to errChan,
if not, we’ll send the sum to resChan

func main() {
  /* .
     .
     .  */
	for _, url := range urls {
		go func(url string) {
			body, err := get(url)
			if err != nil {
				errChan <- err
				return
			}

			y := strings.Count(strings.ToLower(body), word)
			fmt.Printf("found %d occurrences of the word \"%s\" in %s\n", y, word, url)
			resChan <- y
		}(url)
	}
}

Finally, let’s make our main goroutine listen and handle all incoming data from all channels:

  • If data is coming from resChan -> Accumulate to a global sum variable
  • If an error is coming from errChan -> Print the error
  • If sigChan receive an INTERRUPT signal -> Close all channels, print total sum, and exit with 0

This can be achieved, using a select block:

func main() {
  /* .
     .
     .  */
	sum := 0
	for {
		select {
		case x := <-resChan:
			sum += x
		case err := <-errChan:
			fmt.Println(err.Error())
		case <-sigChan:
			fmt.Printf("\nfound %d occurrences of the word \"%s\" in total\n", sum, word)
			os.Exit(0)
		}
	}  
}

Here’s the whole thing, let’s run it…

package main

import (
	"fmt"
	"io"
	"net/http"
	"os"
	"os/signal"
	"strings"
)

const word = "the"

var urls = []string{
	"https://code-pilot.me/how-ive-made-this-platform",
	"https://code-pilot.me/making-a-beautiful-error-handler-in-go",
	"https://code-pilot.me/mastering-goroutines-and-channels",
	"https://code-pilot.me/why-should-you-curry",
	"https://code-pilot.me/not-a-real-page",
	"not@validURL",
}

func get(url string) (string, error) {
	res, err := http.Get(url)
	if err != nil {
		return "", err
	}

	if res.StatusCode == 404 {
		return "", fmt.Errorf("not found: %s", url)
	}

	body, err := io.ReadAll(res.Body)
	if err != nil {
		return "", err
	}

	return string(body), nil
}

func main() {

	resChan := make(chan int)
	errChan := make(chan error)
	sigChan := make(chan os.Signal, 1)

	signal.Notify(sigChan, os.Interrupt)

	for _, url := range urls {
		go func(url string) {
			body, err := get(url)
			if err != nil {
				errChan <- err
				return
			}

			y := strings.Count(strings.ToLower(body), word)
			fmt.Printf("found %d occurrences of the word \"%s\" in %s\n", y, word, url)
			resChan <- y
		}(url)
	}

	sum := 0
	for {
		select {
		case x := <-resChan:
			sum += x
		case err := <-errChan:
			fmt.Println(err.Error())
		case <-sigChan:
			fmt.Printf("\nfound %d occurrences of the word \"%s\" in total\n", sum, word)
			os.Exit(0)
		}
	}
}
> go run main.go 
Get "not@validURL": unsupported protocol scheme ""
not found: https://code-pilot.me/not-a-real-page
found 51 occurrences of the word "the" in https://code-pilot.me/making-a-beautiful-error-handler-in-go
found 44 occurrences of the word "the" in https://code-pilot.me/why-should-you-curry
found 126 occurrences of the word "the" in https://code-pilot.me/mastering-goroutines-and-channels
found 45 occurrences of the word "the" in https://code-pilot.me/how-ive-made-this-platform
^C
found 266 occurrences of the word "the" in total

Wow, 266 is more than what I expected :) This is where I’m going to end, I hope you found this article interesting.
If you found some mistakes or, you have any suggestions, you can contact me (email above)
Thank you for reading.