Online Computation of Mean and Standard Deviation of an Image Set in Machine Learning

Standard definition:

$$ m = \frac{\sum_{i=1}^n x_i}{n}$$$$ \sigma = \sqrt{\frac{\sum_{i=1}^{n}(x_i - m)^2}{n}}$$

Considering that the image set for machine (deep) learning is usually very large, it is impractical to load all images into the limited memory and compute the $(m, \sigma)$ in one shot. Instead, we would need to update $(m, \sigma)$ with an online fashion, i.e., updating $(m, \sigma)$ when a new image comes in. That is,

Updating the mean $m$, $$m_t = \frac{m_{t-1} n_{t-1} + x_t}{n_{t-1} + 1}$$

Updating the standard deviation $\sigma$, note that $$\sigma^2 = \text{Var}(X) = \mathbb{E}[X^2] - \mathbb{E}[X]^2$$ We update $\mathbb{E}[X^2]$ first, then update $\text{Var}(X)$ or $\sigma$, $$\mathbb{E}[X^2]_t = \frac{\mathbb{E}[X^2]_{t-1} n_{t-1} + x^2_t}{n_{t-1} + 1}$$ $$\sigma^2_{t} = \text{Var}(X)_{t} =\mathbb{E}[X^2]_{t} - \mathbb{E}[X]_{t}^2= \mathbb{E}[X^2]_t - m_t^2$$ where $m_t = \mathbb{E}[X]_t$.

Sample code

In [2]:
import pandas as pd
import os
from PIL import Image, ImageStat
import numpy as np
from tqdm.notebook import tqdm
In [5]:
# get image file name list
image_folder = './image_data/'
image_list = os.listdir(image_folder)

# set image channel number 
# for natural RGB images, it is 3. However, for remote sensing hyperspectral data, it can be other numbers
channels = 3 # for RGB

total_pxls = 0 # total number of pixels when processing current image
mean = np.array([0.0]*channels) # mean of X
mean2 = np.array([0.0]*channels) # mean of square of X for computing variance or standard deviation

for fn in tqdm(image_list, ascii=True):
    # read single image
    img = Image.open(os.path.join(image_folder, fn))
    num_pxls = np.prod(img.size) # number of pixels in current image
    img = np.array(img) / 255.0 # normalize 0-255 to 0-1 for machine learning, this is optional
    
    # online computation of the mean and std
    # to update mean, E[X]
    sum1 = np.sum(img, axis=(0,1))
    mean = mean * total_pxls + sum1
    # to update mean2, E[X^2]
    sum2 = np.sum(img**2, axis=(0,1))
    mean2 = mean2 * total_pxls + sum2
    # update total number of pixels
    total_pxls += num_pxls
    # update E[X] and E[X^2]
    mean /= total_pxls
    mean2 /= total_pxls
    # update std = sqrt(Var)
    std = np.sqrt(mean2 - mean**2)
    
print('mean: {}'.format(mean))
print('std: {}'.format(std))
mean: [0.31567586 0.34446269 0.25939408]
std: [0.16605957 0.14458521 0.13643346]