FashionMNIST 데이터셋 이미지 생성 실험
들어가며
목표
- 모델을 활용하여 FashionMNIST 데이터셋의 각 패션 아이템(예: 티셔츠, 바지, 스니커즈 등)을 조건부로 생성하는 작업을 수행
- 각 클래스에 해당하는 이미지를 생성하는 cGAN (Conditional GAN) 모델을 직접 설계, 학습
사용 데이터셋 : 28 x 28의 이미지의 10개의 클래스를 가지고 있는 FashionMNIST 데이터셋
클래스 목록 :
- T-shirt/top
- Trouser
- Pullover
- Dress
- Coat
- Sandal
- Shirt
- Sneaker
- Bag
- Ankle boot
사용 모델

비교 목표
- One-Hot Encoding과 Embedding방식의 성능 차이 비교
- GAN과 LDM의 생성 이미지 차이 비교
FashionMNIST같은 간단한 데이터셋에서 조건 추가 방식 차이와 GAN vs LDM의 이미지 차이는 어떠한가 ?
파이프라인
1. 데이터 EDA
2. 데이터 전처리/ 로더생성
3. 모델링 - GAN
4. 모델 학습 / 조건 추가 방식 비교
5. LDM 모델링 / 성능지표, 생성 이미지 비교
6. 결론
Imports
!pip install torchmetrics[image]
import torch
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
import os
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import time
import matplotlib.image as mpimg
from torchmetrics.image import PeakSignalNoiseRatio
from torchmetrics.image.fid import FrechetInceptionDistance
1. 데이터 EDA
1-1. 데이터 불러오기
dataset = datasets.FashionMNIST(
root='./data', # 데이터를 저장할 폴더 경로
train=True, # True = 학습용(60,000장) / False = 테스트용(10,000장)
download=True, # 해당 경로에 데이터 없으면 자동으로 다운로드
transform=None # 이미지에 적용할 전처리
)
#경로 지정
data_dir = './data/FashionMNIST'
raw_data_dir = os.path.join(data_dir, 'raw')
for file in os.listdir(raw_data_dir):
print(file)

● train-images-idx3-ubyte : 학습 이미지 60,000장
● train-labels-idx1-ubyte : 학습 라벨 60,000개
● t10k-images-idx3-ubyte : 테스트 이미지 10,000장
● t10k-labels-idx1-ubyte : 테스트 라벨 10,000개
본래 FashionMINST 데이터셋은 분류에 이용되는 데이터셋이기 때문에 test dataset이 따로 있지만, 이미지 생성 태스크에서는 필요하지 않기 때문에 train dataset 60000장만 사용하기로 했다.
# 데이터셋 전체 크기
print(f"데이터 수: {len(dataset)}")
# 이미지 한 장 크기
image, label = dataset[0]
print(f"이미지 shape: {image.size}")
#이미지 색상공간
print(f"이미지 색상공간 : {image.mode}")

1-2. 클래스 확인
# 클래스 리스트 확인
class_names = dataset.classes
print(f"클래스 목록: {class_names}")
# 클래스별 데이터 수 확인
labels = [label for _, label in dataset]
unique, counts = np.unique(labels, return_counts=True)
for name, count in zip(class_names, counts):
print(f"{name}: {count}장")
# 클래스 분포 시각화
plt.figure(figsize=(5, 5))
plt.barh(class_names, counts)
plt.title('Classes')
plt.xlabel('Data')
plt.ylabel('Classes')
plt.tight_layout()
plt.show()

균등하게 6000장씩 데이터가 분배되어있는걸 확인할 수 있었다.
1-3. 이미지 확인
fig, axes = plt.subplots(2, 5, figsize=(12, 5))
for i, ax in enumerate(axes.flatten()):
# 각 클래스에서 이미지 하나씩 꺼내기
idx = next(j for j, (_, label) in enumerate(dataset) if label == i)
image, label = dataset[idx]
ax.imshow(image, cmap='gray')
ax.set_title(class_names[i])
ax.axis('off')
plt.suptitle('Sample Image(By Classes)', fontsize=14)
plt.tight_layout()
plt.show()

픽셀 수가 많지는 않아(28 x 28), 선명하게 보이진 않는다.
2. 데이터 전처리 / 로더 생성
Custom한 데이터셋을 만들 필요 없이, datasets.FashionMNIST를 이용하여 데이터로더를 생성했다.
#transform 정의
my_transform = transforms.Compose([
transforms.RandomHorizontalFlip(), # 데이터 증강 : 좌우 반전
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,)) #tanh위해서 [-1, 1]범위로 초기화
])
#본격적으로 transform 적용한 데이터셋 불러오기
ds = datasets.FashionMNIST(
root='./data',
train=True,
download=True,
transform=my_transform
)
loader = DataLoader(
ds,
batch_size=128,
shuffle=True,
num_workers = 2,
pin_memory = True
)
3. 모델링
3-1. Generator 생성
one-hot으로 인코딩하여 채널을 추가하는 버전과, nn.Embedding을 활용하여 채널을 추가하는 버전을 만들었다.
one-hot encoding 버전 : torch.cat을 활용해 라벨 데이터를 추가했고, one-hot encoding 되어서 총 차원 수는 클래스 수인 10개가 늘어나게 된다.
#생성자 정의 - one-hot 버전
class Generator_OneHot(nn.Module):
def __init__(self, noise_dim, num_classes=10):
super(Generator_OneHot, self).__init__()
# noise_dim + num_classes (100 + 10 = 110)
self.upsample = nn.Sequential(
#커널 크기를 7로 하여 출력 크기를 28 / 4 = 7로 맞
nn.ConvTranspose2d(noise_dim + num_classes, 128, 7, 1, 0, bias=False), # (128, 7, 7)
nn.BatchNorm2d(128),
nn.ReLU(True),
nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False), # (64, 14, 14)
nn.BatchNorm2d(64),
nn.ReLU(True),
#최종 출력 : 흑백 이미지
nn.ConvTranspose2d(64, 1, 4, 2, 1, bias=False), # (1, 28, 28)
)
self.tanh = nn.Tanh()
def forward(self, noise, label):
# label → one-hot encoding
label_onehot = F.one_hot(label, num_classes=10) # (batch, 10)
label_onehot = label_onehot.view(-1, 10, 1, 1).float() # (batch, 10, 1, 1)
x = torch.cat([noise, label_onehot], dim=1) # (batch, 110, 1, 1)
x = self.upsample(x)
x = self.tanh(x)
return x
Embedding을 사용하는 버전에서는 one-hot인코딩을 하는 버전과 공평하게 입력 채널 수를 맞춰주기 위해서 embedded 차원을 10으로 특정했다.
#생성자 정의 - nn.Embedding 버전
class Generator_Embedding(nn.Module):
def __init__(self, noise_dim, num_classes=10, embed_dim=10):
super(Generator_Embedding, self).__init__()
# 라벨 임베딩
self.embedding = nn.Embedding(num_classes, embed_dim)
# noise_dim + embed_dim (100 + 10 = 110)
self.upsample = nn.Sequential(
nn.ConvTranspose2d(noise_dim + embed_dim, 128, 7, 1, 0, bias=False), # (128, 7, 7)
nn.BatchNorm2d(128),
nn.ReLU(True),
nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False), # (64, 14, 14)
nn.BatchNorm2d(64),
nn.ReLU(True),
nn.ConvTranspose2d(64, 1, 4, 2, 1, bias=False), # (1, 28, 28)
)
self.tanh = nn.Tanh()
def forward(self, noise, label):
#라벨 : nn.Embedding
embed = self.embedding(label) # (batch, 10)
embed = embed.view(-1, 10, 1, 1) # (batch, 10, 1, 1)
x = torch.cat([noise, embed], dim=1) # (batch, 110, 1, 1)
x = self.upsample(x)
x = self.tanh(x)
return x
Generator 출력 shape, 초기 이미지를 확인하여 올바르게 설계가 되었는지 확인했다.
# 테스트용 입력
noise = torch.randn(1, 100, 1, 1) # batch=1, noise_dim=100
label = torch.tensor([0]) # 클래스 0 (T-shirt/Top)
# one-hot Generator 확인
g_onehot = Generator_OneHot(noise_dim=100, num_classes=10)
fake_onehot = g_onehot(noise, label)
print(f'[one-hot] fake image shape : {fake_onehot.shape}')
# Embedding Generator 확인
g_embed = Generator_Embedding(noise_dim=100, num_classes=10, embed_dim=10)
fake_embed = g_embed(noise, label)
print(f'[Embedding] fake image shape : {fake_embed.shape}')

# 초기 노이즈시각화
fig, axes = plt.subplots(1, 2, figsize=(8, 4))
axes[0].imshow(fake_onehot.squeeze().detach().numpy(), cmap='gray')
axes[0].set_title('one-hot Generator Noise')
axes[0].axis('off')
axes[1].imshow(fake_embed.squeeze().detach().numpy(), cmap='gray')
axes[1].set_title('Embedding Generator Noise')
axes[1].axis('off')
plt.suptitle('Generator Noise (Before)')
plt.tight_layout()
plt.show()

shape와 초기 노이즈를 확인했을 때, 올바르게 설계된 것을 확인할 수 있었다.
3-2. Discriminator 생성
One-Hot 버전 Discriminator
#One-hot 버전 판별자
class Discriminator_OneHot(nn.Module):
def __init__(self, num_classes = 10) :
super(Discriminator_OneHot, self).__init__()
self.downsample = nn.Sequential(
# (batch, 11, 28, 28) → (batch, 64, 14, 14)
nn.Conv2d(1 + num_classes, 64, 4, 2, 1, bias = False),
nn.BatchNorm2d(64),
nn.LeakyReLU(0.2, inplace = True),
# (batch, 64, 14, 14) → (batch, 128, 7, 7)
nn.Conv2d(64, 128, 4, 2, 1, bias = False),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace = True),
# (batch, 128, 7, 7) → (batch, 1, 1, 1)
nn.Conv2d(128, 1, 7, 1, 0, bias = False)
)
self.sigmoid = nn.Sigmoid()
def forward(self, image, label) :
# label → one-hot → 이미지 크기로 펼치기
label_onehot = F.one_hot(label, num_classes=10) # (batch, 10)
label_onehot = label_onehot.view(-1, 10, 1, 1).float() # (batch, 10, 1, 1)
label_onehot = label_onehot.expand(-1, 10, 28, 28) # (batch, 10, 28, 28)
x = torch.cat([image, label_onehot], dim=1) # (batch, 11, 28, 28)
x = self.downsample(x) # (batch, 1, 1, 1)
x = self.sigmoid(x)
return x.view(-1, 1) # (batch, 1)
Embedding 버전 Discriminator
#Embedding 버전 판별
class Discriminator_Embedding(nn.Module):
def __init__(self, num_classes = 10) :
super(Discriminator_Embedding, self).__init__()
self.embedding = nn.Embedding(num_classes, num_classes)
self.downsample = nn.Sequential(
# (batch, 11, 28, 28) → (batch, 64, 14, 14)
nn.Conv2d(1 + num_classes, 64, 4, 2, 1, bias = False),
nn.BatchNorm2d(64),
nn.LeakyReLU(0.2, inplace = True),
# (batch, 64, 14, 14) → (batch, 128, 7, 7)
nn.Conv2d(64, 128, 4, 2, 1, bias = False),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace = True),
# (batch, 128, 7, 7) → (batch, 1, 1, 1)
nn.Conv2d(128, 1, 7, 1, 0, bias = False)
)
self.sigmoid = nn.Sigmoid()
def forward(self, image, label) :
# label → one-hot → 이미지 크기로 펼치기
# Embedding 방식
emb = self.embedding(label) # (batch, 10)
emb = emb.view(-1, 10, 1, 1).float() # (batch, 10, 1, 1)
emb = emb.expand(-1, 10, 28, 28) # (batch, 10, 28, 28)
x = torch.cat([image, emb], dim=1) # (batch, 11, 28, 28)
x = self.downsample(x) # (batch, 1, 1, 1)
x = self.sigmoid(x)
return x.view(-1, 1) # (batch, 1)
4. 학습 / 평가, 시각화
FID 계산 위한 denorm()함수 정의
# -1~1사이 값을 0~255로 스케일 : fid 계산 위해서
def denorm(x):
x = (x + 1) / 2 # -1~1 → 0~1
x = x.repeat(1, 3, 1, 1) # 1채널 → 3채널
x = (x * 255).byte() # 0~1 → 0~255
return x
#device 설정
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)
# 카테고리별 고정 noise (매번 같은 noise로 생성해서 비교)
fixed_noise = torch.randn(10, 100, 1, 1).to(device)
fixed_labels = torch.arange(0, 10).to(device) # 0~9 클래스 하나씩
#평가 지표 설정
psnr= PeakSignalNoiseRatio(data_range = 1.0).to(device)
fid = FrechetInceptionDistance().to(device)
4-1. One-Hot Encoding 버전 학습
손실함수로는, BCELOSS대신 보다 안정적인 LSGANLoss(MSELoss)를 사용했다. 또한 10 epoch마다 각 카테고리별 출력이미지를 저장해서 ,학습후 한번에 출력하여 관찰하여 학습이 되는 과정을 관찰할 수 있었다.
#One-Hot 버전 학습/평가
epochs = 30
#One-Hot모델 초기화
netG_OneHot = Generator_OneHot(noise_dim=100).to(device)
netD_OneHot = Discriminator_OneHot().to(device)
#손실함수 설정
criterion = nn.MSELoss()
#옵티마이저 설정
G_OneHot_optimizer = optim.Adam(netG_OneHot.parameters(), lr=0.0002)
D_OneHot_optimizer = optim.Adam(netD_OneHot.parameters(), lr=0.0001)
OneHot_history = {
'G_loss': [],
'D_loss': [],
'psnr' : [],
'fid' : [],
'time' : 0.0
}
best_fid = float('inf')
os.makedirs('results/OneHot', exist_ok=True)
os.makedirs('checkpoints', exist_ok = True)
print('학습 시작')
start_time = time.time()
for epoch in range(epochs):
G_loss_sum = 0
D_loss_sum = 0
psnr.reset()
fid.reset()
for step, (image, label) in enumerate(loader) :
image = image.to(device)
label = label.to(device)
#첫번째 이미지의 크기
batch_size = image.size(0)
# 진짜/가짜 라벨
real = torch.ones(batch_size, 1).to(device) # 1
fake = torch.zeros(batch_size, 1).to(device) # 0
#학습마다 랜덤 노이즈 생성
noise = torch.randn(batch_size, 100, 1, 1).to(device)
#판별자 학습
netD_OneHot.zero_grad()
#실제 이미지
real_image = netD_OneHot(image, label)
D_real_loss = criterion(real_image, real)
#가짜 이미지
fake_image = netG_OneHot(noise, label)
d_fake = netD_OneHot(fake_image.detach(), label)
D_fake_loss = criterion(d_fake, fake)
#판별자 손실 계산
D_onehot_loss = D_real_loss + D_fake_loss
#판별자 가중치 업데이트
D_onehot_loss.backward()
D_OneHot_optimizer.step()
#생성자 학습
netG_OneHot.zero_grad()
#가짜 이미지
d_fake = netD_OneHot(fake_image, label)
G_onehot_loss = criterion(d_fake, real) # 가짜를 진짜로 속이기
G_onehot_loss.backward()
G_OneHot_optimizer.step()
G_loss_sum += G_onehot_loss.item()
D_loss_sum += D_onehot_loss.item()
psnr.update((fake_image.detach() + 1) / 2, (image + 1) / 2)
#fid는 10 step마다 계산(효율성 때)
if (step+1) % 10 == 0 :
fid.update(denorm(image.detach()), real = True)
fid.update(denorm(fake_image.detach()), real=False)
print('.', end='', flush=True)
#history 저장 / epoch마다 출력
avg_G_loss = G_loss_sum / len(loader)
avg_D_loss = D_loss_sum / len(loader)
psnr_avg = psnr.compute().item()
fid_avg = fid.compute().item()
OneHot_history['G_loss'].append(avg_G_loss)
OneHot_history['D_loss'].append(avg_D_loss)
OneHot_history['psnr'].append(psnr_avg)
OneHot_history['fid'].append(fid_avg)
print(f'\nEpoch [{epoch+1}/{epochs}] ')
print(f'D_loss: {avg_D_loss:.4f}, G_loss: {avg_G_loss:.4f}')
print(f'PSNR : {psnr_avg}, Fid : {fid_avg}')
#10 epoch마다 카테고리별 생성된 이미지 출력
if (epoch + 1) % 10 == 0 :
netG_OneHot.eval()
with torch.no_grad():
fake_images = netG_OneHot(fixed_noise, fixed_labels) # (10, 1, 28, 28)
fig, axes = plt.subplots(1, 10, figsize=(20, 2))
for i, ax in enumerate(axes):
ax.imshow(fake_images[i].squeeze().cpu().numpy(), cmap='gray')
ax.set_title(class_names[i], fontsize=7)
ax.axis('off')
plt.suptitle(f'Epoch {epoch+1}')
plt.tight_layout()
plt.savefig(f'results/OneHot/epoch_{epoch+1}.png')
plt.close()
netG_OneHot.train()
if fid_avg < best_fid:
best_fid = fid_avg
torch.save(netG_OneHot.state_dict(), './checkpoints/best_netG_OneHot.pt')
torch.save(netD_OneHot.state_dict(), './checkpoints/best_netD_OneHot.pt')
print(f'best model saved! : Best Epoch : {epoch + 1}, Fid : {best_fid}')
end_time = time.time()
OneHot_history['time'] = end_time - start_time
print(f'학습 종료, 걸린 시간 : {OneHot_history["time"]:.2f}')

4-2. Embedding 버전 학습
# Embedding 버전 학습/평가
epochs = 30
# Embedding 모델 초기화
netG_Embed = Generator_Embedding(noise_dim=100).to(device)
netD_Embed = Discriminator_Embedding().to(device)
# 손실함수 설정
criterion = nn.MSELoss()
# 옵티마이저 설정
G_Embed_optimizer = optim.Adam(netG_Embed.parameters(), lr=0.0002)
D_Embed_optimizer = optim.Adam(netD_Embed.parameters(), lr=0.0001)
Embed_history = {
'G_loss': [],
'D_loss': [],
'psnr' : [],
'fid' : [],
'time' : 0.0
}
best_fid = float('inf')
os.makedirs('results/Embed', exist_ok=True)
os.makedirs('checkpoints', exist_ok=True)
print('학습 시작')
start_time = time.time()
for epoch in range(epochs):
G_loss_sum = 0
D_loss_sum = 0
psnr.reset()
fid.reset()
for step, (image, label) in enumerate(loader):
image = image.to(device)
label = label.to(device)
batch_size = image.size(0)
real = torch.ones(batch_size, 1).to(device)
fake = torch.zeros(batch_size, 1).to(device)
noise = torch.randn(batch_size, 100, 1, 1).to(device)
# 판별자 학습
netD_Embed.zero_grad()
real_image = netD_Embed(image, label)
D_real_loss = criterion(real_image, real)
fake_image = netG_Embed(noise, label)
d_fake = netD_Embed(fake_image.detach(), label)
D_fake_loss = criterion(d_fake, fake)
D_embed_loss = D_real_loss + D_fake_loss
D_embed_loss.backward()
D_Embed_optimizer.step()
# 생성자 학습
netG_Embed.zero_grad()
d_fake = netD_Embed(fake_image, label)
G_embed_loss = criterion(d_fake, real)
G_embed_loss.backward()
G_Embed_optimizer.step()
G_loss_sum += G_embed_loss.item()
D_loss_sum += D_embed_loss.item()
psnr.update((fake_image.detach() + 1) / 2, (image + 1) / 2)
if (step+1) % 10 == 0:
fid.update(denorm(image.detach()), real=True)
fid.update(denorm(fake_image.detach()), real=False)
print('.', end='', flush=True)
avg_G_loss = G_loss_sum / len(loader)
avg_D_loss = D_loss_sum / len(loader)
psnr_avg = psnr.compute().item()
fid_avg = fid.compute().item()
Embed_history['G_loss'].append(avg_G_loss)
Embed_history['D_loss'].append(avg_D_loss)
Embed_history['psnr'].append(psnr_avg)
Embed_history['fid'].append(fid_avg)
print(f'\nEpoch [{epoch+1}/{epochs}] ')
print(f'D_loss: {avg_D_loss:.4f}, G_loss: {avg_G_loss:.4f}')
print(f'PSNR : {psnr_avg}, Fid : {fid_avg}')
if (epoch + 1) % 10 == 0:
netG_Embed.eval()
with torch.no_grad():
fake_images = netG_Embed(fixed_noise, fixed_labels)
fig, axes = plt.subplots(1, 10, figsize=(20, 2))
for i, ax in enumerate(axes):
ax.imshow(fake_images[i].squeeze().cpu().numpy(), cmap='gray')
ax.set_title(class_names[i], fontsize=7)
ax.axis('off')
plt.tight_layout()
plt.savefig(f'results/Embed/epoch_{epoch+1}.png')
plt.close()
netG_Embed.train()
if fid_avg < best_fid:
best_fid = fid_avg
torch.save(netG_Embed.state_dict(), './checkpoints/best_netG_Embed.pt')
torch.save(netD_Embed.state_dict(), './checkpoints/best_netD_Embed.pt')
print(f'best model saved! : Best Epoch : {epoch}, Fid : {best_fid}')
end_time = time.time()
Embed_history['time'] = end_time - start_time
print(f'학습 종료, 걸린 시간 : {Embed_history["time"]:.2f}초')

4-3. One-Hot Encoding, Embedding 비교분석
a) Loss 비교
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
# One-Hot G_loss, D_loss
axes[0].plot(OneHot_history['G_loss'], label='G_loss', color='blue')
axes[0].plot(OneHot_history['D_loss'], label='D_loss', color='red')
axes[0].set_title('One-Hot Loss')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].legend()
# Embedding G_loss, D_loss
axes[1].plot(Embed_history['G_loss'], label='G_loss', color='blue')
axes[1].plot(Embed_history['D_loss'], label='D_loss', color='red')
axes[1].set_title('Embedding Loss')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Loss')
axes[1].legend()
plt.suptitle('One-Hot vs Embedding Loss', fontsize=14)
plt.tight_layout()
plt.show()

One-Hot Encoding으로 라벨을 추가한 경우에는 ,G_Loss과 D_loss의 합이 작아지면서 수렴하는 그래프의 형태가 보였는데, Embedding으로 라벨링을 진행했을 때에는 오른쪽으로 갈수록 다시 발산하는 파형을 보여, 학습이 잘 되지 않음을 관찰할 수 있었다.
b) PSNR 비교
PSNR : 원본과 복원 이미지 간의 평균제곱오차(MSE)에 기반해 신호 대 잡음 비율을 데시벨 단위로 나타낸 품질 지표
-> 값이 클수록 품질이 복원 품질이 좋음, But GAN에서 크게 의미있지는 않아 참고용
#PSNR 비교
# best PSNR 값 찾기
best_onehot_psnr = max(OneHot_history['psnr']) # 높을수록 좋음
best_embed_psnr = max(Embed_history['psnr'])
plt.figure(figsize=(8, 5))
plt.plot(OneHot_history['psnr'], label=f'One-Hot (best: {best_onehot_psnr:.2f})', color='blue')
plt.plot(Embed_history['psnr'], label=f'Embedding (best: {best_embed_psnr:.2f})', color='red')
plt.title('One-Hot vs Embedding PSNR(higher is better)')
plt.xlabel('Epoch')
plt.ylabel('PSNR (dB)')
plt.legend()
plt.show()

PSNR의 경우, One-Hot방식과 Embdedding 방식의 지표가 비슷하게 나왔다.
c) FID 비교
Fid : 실제 이미지 분포 vs 생성 이미지 분포의 거리
-> 거리가 가까울수록 = 생성 이미지가 실제와 비슷 : 낮을수록 좋음
- FID = 0 → 완벽하게 같은 분포 (현실적으로 불가능)
- FID < 50 → 꽤 좋은 편
- FID < 20 → 매우 좋은 편
- FID > 100 → 학습이 잘 안된 것
#Fid 비교
#best Fid값 찾기
best_onehot_fid = min(OneHot_history['fid']) #낮을수록 좋음
best_embed_fid = min(Embed_history['fid'])
plt.figure(figsize=(8, 5))
plt.plot(OneHot_history['fid'], label=f'One-Hot (best: {best_onehot_fid:.2f})', color='blue')
plt.plot(Embed_history['fid'], label=f'Embedding (best: {best_embed_fid:.2f})', color='red')
plt.title('One-Hot vs Embedding Fid(Lower is better)')
plt.xlabel('Epoch')
plt.ylabel('FID')
plt.legend()
plt.show()

Embedding 방식의 경우, Fid값이 100 가까이 나온것으로 보아 학습이 잘 되지 않은것으로 보인다.
d) 생성 이미지 비교
# 에폭당 2개 이미지 + 1개 구분선 (마지막 제외)
saved_epochs = range(10, 31, 10)
total_rows = len(saved_epochs) * 2 + (len(saved_epochs) - 1)
height_ratios = []
for i in range(len(saved_epochs)):
height_ratios.extend([4, 4]) # 이미지 2개
if i < len(saved_epochs) - 1:
height_ratios.append(0.1) # 구분선
fig, axes = plt.subplots(total_rows, 1, figsize=(20, sum(height_ratios)),
gridspec_kw={'height_ratios': height_ratios})
row = 0
for i, epoch in enumerate(saved_epochs):
img_onehot = mpimg.imread(f'results/OneHot/epoch_{epoch}.png')
axes[row].imshow(img_onehot)
axes[row].set_title(f'[Epoch {epoch}] One-Hot', fontsize=15, fontweight='bold')
axes[row].axis('off')
row += 1
img_embed = mpimg.imread(f'results/Embed/epoch_{epoch}.png')
axes[row].imshow(img_embed)
axes[row].set_title(f'[Epoch {epoch}] Embedding', fontsize=15, fontweight='bold')
axes[row].axis('off')
row += 1
# 구분선 axes
if i < len(saved_epochs) - 1:
axes[row].axhline(y=0.5, color='black', linewidth=2)
axes[row].axis('off')
row += 1
plt.suptitle('One-Hot vs Embedding Generated Images by Epochs', fontsize=16, fontweight='bold')
plt.tight_layout()
plt.show()

One-Hot encoding의 경우가 더 선명한 사진이 나오는 것을 확인할 수있다.
e) 학습시간 비교
#학습시간 비교
plt.figure(figsize=(6, 4))
models = ['One-Hot', 'Embedding']
times = [OneHot_history['time'], Embed_history['time']]
plt.bar(models, times, color=['blue', 'red'])
plt.title('Learning Time by Labeling')
plt.xlabel('Labeling')
plt.ylabel('Time(sec)')
plt.yscale('log')
# 막대 위에 값 표시
for i, t in enumerate(times): # 'time' 변수 이름을 't_val'로 변경
plt.text(i, t, f'{t:.2f}(sec)', ha='center', va='bottom')
plt.tight_layout()
plt.show()

결론 : Embedding의 경우에는 One-hot Encoding보다 시간도 오래걸렸고, 학습도 제대로 되지 않았다. 차원의 개수의 문제일 수 있지만, 공정하게 10차원으로 라벨링을 진행했기 때문에, MNIST데이터셋에서 Embedding 보다 One-Hot으로 라벨링을 진행하는 것이 효과적이라고 결론이 나왔다
one-hot → 클래스 구분이 명확 (0 아니면 1) :단순한 데이터에 더 적합
Embedding → 클래스간 관계를 학습, 복잡한 데이터일수록 강점이 드러남
요약 : FashionMNIST처럼 단순한 데이터셋에서는 one-hot이 Embedding보다 효과적이다 !