Language/Python
[Python] RGB 채널에 대한 각각의 평균과 표준 편차 구하는 함수
ye_ju
2023. 11. 19. 14:01
def calculate_n(dataset):
# R,G,B 채널에 대한 각각의 평균 산출
mean_ = np.array([np.mean(x.numpy(), axis = (1, 2)) for x, _ in dataset]) # 전체 R,G,B의 평균
mean_r = mean_[:, 0].mean()
mean_g = mean_[:, 1].mean()
mean_b = mean_[:, 2].mean()
# R,G,B 채널에 대한 각각의 표준편차 산출
std_ = np.array([np.std(x.numpy(), axis = (1, 2)) for x, _ in dataset]) # 전체 R,G,B의 표준편차
std_r = std_[:, 0].std()
std_g = std_[:, 1].std()
std_b = std_[:, 2].std()
return (mean_r, mean_g, mean_b), (std_r, std_g, std_b)
mean_, std_ = calculate_n(train_loader)
print(f'평균(R,G,B): {mean_}\n표준편차(R,G,B): {std_}')