이미지 Dataset에 대한 평균과 표준편차 구하기

사진 이미지는 촬영된 주변 환경에 따라 그 명도나 채도 등이 서로 다릅니다. 이 사진 이미지를 대상으로 하는 머신러닝을 수행하기 전에 이미지들을 동일한 환경으로 맞춰주는 후처리로 전체 이미지에 대한 화소값의 평균과 표준편차를 구해 이 값을 이미지들에 일괄적으로 적용합니다.

아래의 코드는 PyTorch에서 Dataset에 대한 평균과 표준편차를 구하기 위한 코드입니다.

transform = transforms.Compose([
    transforms.ToTensor()
])

dataset = torchvision.datasets.CIFAR10(root='./data/cifar10', train=True, download=True, transform=transform) 

mean = dataset.train_data.mean(axis=(0,1,2))
std = dataset.train_data.std(axis=(0,1,2))

mean = mean / 255
std = std / 255

실제 mean과 std의 값은 각각 [0.4913, 0.4821, 0.4465], [0.2470, 0.2434, 0.2615]와 유사한데, 실제 CIFAR10 데이터를 이용한 딥러닝 예제 코드에서 상수값으로 입력되는 바로 그 값입니다. 실제로 이 평균과 편차는 다음 코드 예시를 통해 적용됩니다.

train_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

train_dataset = torchvision.datasets.CIFAR10(root='./data/cifar10', train=True, download=True, transform=train_transform) 

“이미지 Dataset에 대한 평균과 표준편차 구하기”에 대한 5개의 댓글

  1. 안녕하세요. 이미지 dataset에서 평균과 표준편차를 찾는 과정이 궁금해서 오게 되었는데요.

    해당 코드를 돌려보니, ‘CIFAR10’ object has no attribute ‘train_data’ 라는 오류가 발생하는데 저 코드가 구동되는 것이 맞는지요?

    1. 안녕하세요..
      이 글을 올릴당시엔 정상적으로 수행되었던 코드일겁니다.
      디버깅에서 train_data 속성이 없다는 객체의 속성을 살펴보면 이와 유사한 놈이 있을겁니다.

    2. 안녕하세요? 저도 이 코드에서 에러가 나서 찾아보았는데요.
      dataset = torchvision.datasets.CIFAR10(root ~~~(생략) )으로 dataset 객체를 만들어주시고, 그 후에 print(dir(dataset))을 이용하면 해당 dataset 객체가 가지고 있는 attribute를 출력하실 수 있습니다.
      제가 확인해보니 train_data라고 하는 attribute가 없는게 맞네요.
      아마 torchvision 버전이 업그레이드 되면서 제외되었거나 그런게 아닌가? 하는 생각이 듭니다.(저도 쌩초보라 자세히는 모릅니다 ㅠㅠ)
      그래서 attribute를 찾아보니, ‘data’라고 하는 attribute가 있어서 해봤는데, 이게 위 글에서의 train_data와 같은 역할을 하는 것으로 보입니다.

      print(dataset.data)로 출력해보시면 CIFAR10의 데이터가 [0, 255] 형태로 출력되는 것을 확인하실 수 있을겁니다. 도움이 되셨으면 좋겠네요.

  2. mean = dataset.train_data.mean(axis=(0,1,2))
    std = dataset.train_data.std(axis=(0,1,2))
    mean = mean / 255
    std = std / 255
    이부분에서 왜 평균과 표준편차를 255로 나누었는지와
    axis=(0,1,2)의 의미가 궁금합니다!

    1. 256으로 나누는 이유는 정규화, 즉 0~1 사이의 값으로 만들어 주기 위함입니다.
      그리고 axis=(0,1,2)는 평균이나 표준편차를 낼 데이터들을 지정하기 위함입니다.
      axis에 대한 보다 더 명확한 의미는 구글링을 통해 관련 함수의 API 문서를 보시는 것을 권해 드립니다.

답글 남기기

이메일 주소는 공개되지 않습니다. 필수 필드는 *로 표시됩니다