๋ณธ๋ฌธ ๋ฐ”๋กœ๊ฐ€๊ธฐ
AI/Deep Learning

[Pytorch] Pre-traing Vision Transformer๋กœ Fine-tuning : ์ด๋ฏธ์ง€ ๋ถ„๋ฅ˜๊ธฐ

by je0nsye0n 2025. 3. 4.
๐Ÿš€ Model 
- Pytorch์—์„œ ์ œ๊ณตํ•˜๋Š” ViT model ์‚ฌ์šฉ(Vit_b_16)
- Link : Pre-trained weights, github code

๐Ÿš€ Dataset
- Kaggel์˜ Animal-10 dataset์„ ๊ฐ•์•„์ง€์™€ ๊ณ ์–‘์ด ์ด๋ฏธ์ง€๋งŒ ๋ฝ‘์•„์„œ ์‚ฌ์šฉ

๐Ÿš€ Task
- ๊ฐ•์•„์ง€์™€ ๊ณ ์–‘์ด ์ด๋ฏธ์ง€๋ฅผ ๋ถ„๋ฅ˜

 

1. ์ „์ฒด ์‹คํ–‰ ํ๋ฆ„ (Main)

์ „์ฒด ์ฝ”๋“œ์˜ ํ๋ฆ„์€ ์•„๋ž˜์™€ ๊ฐ™๋‹ค.

def main():    
    # ๋ฐ์ดํ„ฐ ์ค€๋น„
    FINE_TUNE_N = 1024   # ๊ฐ ํด๋ž˜์Šค๋ณ„๋กœ fine tuning์— ์‚ฌ์šฉํ•  ์ด๋ฏธ์ง€ ์ˆ˜
    PREDICT_N = 100    # ์˜ˆ์ธก์— ์‚ฌ์šฉํ•  ์ „์ฒด ์ด๋ฏธ์ง€ ์ˆ˜
    prepare_fine_tuning_dataset(animal_dir, fine_tuning_dir, FINE_TUNE_N)
    prepare_prediction_dataset(animal_dir, predict_dir, PREDICT_N)
    
    # ๋ชจ๋ธ fine tuning
    model = fine_tune_model(fine_tuning_dir, num_epochs=20, batch_size=32, learning_rate=1e-4)
    
    # ์˜ˆ์ธก ๋ฐ ๋กœ๊ทธ ์ €์žฅ
    predict_model(model, predict_dir, batch_size=8)
    
    # ๋กœ๊ทธ ๋น„๊ต ๋ฐ ํ‰๊ฐ€
    evaluate_predictions()

 

โ‘  kaggle์˜ dataset์„ ๋ฐ›์•„์™€ ํ›ˆ๋ จ์šฉ ๋ฐ์ดํ„ฐ์™€ ํ…Œ์ŠคํŠธ์šฉ ๋ฐ์ดํ„ฐ๋ฅผ ๋žœ๋ค์œผ๋กœ ๋ฝ‘์•„ ๋ถ„๋ฆฌ

- ๊ธฐ์กด model์€ Imagenet์œผ๋กœ Pre-train ๋˜์–ด ์žˆ๊ธฐ ๋•Œ๋ฌธ์— ๋ณธ ์ฝ”๋“œ์˜ task์ธ ๊ฐ•์•„์ง€์™€ ๊ณ ์–‘์ด ์ด๋ฏธ์ง€๋ฅผ ๋ถ„๋ฅ˜ํ•˜๊ธฐ ์œ„ํ•ด ํ•„์š”ํ•จ

- ๋ฐ์ดํ„ฐ๋Š” basic data์—์„œ ๋žœ๋ค์œผ๋กœ ๋ฝ‘์•„์™€ train๊ณผ test๋ฅผ ํ•  ์ˆ˜ ์žˆ๋„๋ก ์ฝ”๋“œ ๊ตฌ์„ฑ

- basic data๋Š” cat ํด๋”์™€ dog ํด๋”๋กœ ๊ตฌ์„ฑ์ด ๋˜์–ด์žˆ๊ธฐ ๋•Œ๋ฌธ์— ์ด๋ฏธ์ง€์— ๋Œ€ํ•œ ๋ผ๋ฒจ์„ ์ ๋Š” log.txt๋กœ ํ•™์Šต ๋ฐ ํ‰๊ฐ€ ์ง„ํ–‰

 

 

โ‘ก ํ›ˆ๋ จ์šฉ ๋ฐ์ดํ„ฐ๋กœ Fine tuning ์ง„ํ–‰

Main์—์„œ ์„ค์ •ํ•œ ๊ฐ’์„ ๊ธฐ์ค€์œผ๋กœ ์‚ฌ์ „ ํ•™์Šต๋œ Vision Transformer๋ฅผ ๋ถˆ๋Ÿฌ์™€ Fine-tuning ์ง„ํ–‰

 

โ‘ข ํ…Œ์ŠคํŠธ์šฉ ๋ฐ์ดํ„ธ Test ์ง„ํ–‰ ํ›„, ์ •ํ™•๋„ ์ธก์ •

 

2. ๊ฒฐ๊ณผ ํ™•์ธ

์•„๋ž˜์˜ ํ…Œ์ŠคํŠธ train data = 1024, test data = 100์œผ๋กœ ์„ค์ •ํ•˜์—ฌ ๋ถ„๋ฅ˜ ๊ฒฐ๊ณผ๋ฅผ ํ™•์ธํ•˜์˜€๋‹ค.

๋‘ ๋ฒˆ์˜ ์ธก์ • ๋ชจ๋‘ ๋ถ„๋ฅ˜๊ฐ€ ์ž˜๋˜๋Š” ๊ฒƒ์œผ๋กœ ํ™•์ธ๋˜๋ฉฐ, Fine tuning์ด ์ž˜ ๋จน์—ฌ์ง„ ๊ฒƒ์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ๋‹ค.

(์ดˆ๋ฐ˜์— ์ •ํ™•๋„๊ฐ€ ๋„ˆ๋ฌด ๋‚ฎ์•„์„œ ํ›ˆ๋ จ ๋ฐ์ดํ„ฐ๋ฅผ ์ ์ฐจ ๋Š˜๋ ธ์Œ)

 

3. ์ „์ฒด ์ฝ”๋“œ

import os
import glob
import random
import shutil
import re
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.models import vit_b_16, ViT_B_16_Weights

# ํด๋” ์—†์œผ๋ฉด ์ƒ์„ฑํ•˜๋Š” ํ•จ์ˆ˜
def create_dir(path):
    if not os.path.exists(path):
        os.makedirs(path)

# [Step 3] Fine Tuning ๋ฐ์ดํ„ฐ์…‹ ์ค€๋น„ (๊ฐ ํด๋ž˜์Šค์—์„œ N๊ฐœ์”ฉ ๋ณต์‚ฌ)
def prepare_fine_tuning_dataset(animal_dir, fine_tuning_dir, N):
    categories = ["cat", "dog"]
    for category in categories:
        category_dir = os.path.join(animal_dir, category)
        images = glob.glob(os.path.join(category_dir, "*.jpeg"))
        if len(images) < N:
            print(f"{category_dir}์— ์ด๋ฏธ์ง€๊ฐ€ ๋ถ€์กฑํ•ฉ๋‹ˆ๋‹ค. (ํ•„์š”:{N}, ์žˆ์Œ:{len(images)})")
            continue
        selected = random.sample(images, N)
        for i, img_path in enumerate(selected, start=1):
            filename = f"{category}_Img{i}.jpeg"
            dest_path = os.path.join(fine_tuning_dir, filename)
            shutil.copy(img_path, dest_path)
    print("Fine tuning ๋ฐ์ดํ„ฐ์…‹ ์ค€๋น„ ์™„๋ฃŒ.")

# [Step 4] ์˜ˆ์ธก ๋ฐ์ดํ„ฐ์…‹ ์ค€๋น„ ๋ฐ DataLog.txt ์ƒ์„ฑ
def prepare_prediction_dataset(animal_dir, predict_dir, N):
    categories = ["cat", "dog"]
    all_images = []
    for category in categories:
        category_dir = os.path.join(animal_dir, category)
        images = glob.glob(os.path.join(category_dir, "*.jpeg"))
        for img_path in images:
            all_images.append((img_path, category))
    if len(all_images) < N:
        print(f"์ „์ฒด ์ด๋ฏธ์ง€๊ฐ€ ๋ถ€์กฑํ•ฉ๋‹ˆ๋‹ค. (ํ•„์š”:{N}, ์žˆ์Œ:{len(all_images)})")
        N = len(all_images)
    selected = random.sample(all_images, N)
    random.shuffle(selected)  # ์ˆœ์„œ๋ฅผ ์„ž์Œ
    data_log_path = os.path.join("Data", "DataLog.txt")
    with open(data_log_path, "w") as f:
        for i, (img_path, category) in enumerate(selected, start=1):
            filename = f"Img{i}.jpeg"
            dest_path = os.path.join(predict_dir, filename)
            shutil.copy(img_path, dest_path)
            f.write(f"{filename}: {category}\n")
    print("Predict ๋ฐ์ดํ„ฐ์…‹ ์ค€๋น„ ์™„๋ฃŒ ๋ฐ DataLog.txt ์ƒ์„ฑ๋จ.")

# Fine Tuning์šฉ ์ปค์Šคํ…€ ๋ฐ์ดํ„ฐ์…‹ (ํŒŒ์ผ๋ช…์œผ๋กœ๋ถ€ํ„ฐ label ์ถ”์ถœ)
class FineTuneDataset(Dataset):
    def __init__(self, root, transform=None):
        self.root = root
        self.transform = transform
        self.image_files = [f for f in os.listdir(root) if f.endswith(".jpeg")]
        
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        filename = self.image_files[idx]
        # ํŒŒ์ผ๋ช… ํ˜•์‹: "cat_Img{i}.jpeg" ๋˜๋Š” "dog_Img{i}.jpeg"
        label_str = filename.split("_")[0]
        label = 0 if label_str.lower() == "cat" else 1
        img_path = os.path.join(self.root, filename)
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, label

# ์˜ˆ์ธก์šฉ ์ปค์Šคํ…€ ๋ฐ์ดํ„ฐ์…‹ (ํŒŒ์ผ๋ช… ์ •๋ ฌ)
class PredictDataset(Dataset):
    def __init__(self, root, transform=None):
        self.root = root
        self.transform = transform
        self.image_files = [f for f in os.listdir(root) if f.endswith(".jpeg")]
        self.image_files.sort(key=lambda x: int(re.search(r"(\d+)", x).group(1)))
        
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        filename = self.image_files[idx]
        img_path = os.path.join(self.root, filename)
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, filename

def fine_tune_model(fine_tuning_dir, num_epochs=10, batch_size=8, learning_rate=1e-4):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    dataset = FineTuneDataset(fine_tuning_dir, transform=transform)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    # ImageNet์œผ๋กœ ์‚ฌ์ „ํ•™์Šต๋œ Vision Transformer ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
    weights = ViT_B_16_Weights.IMAGENET1K_V1
    model = vit_b_16(weights=weights)
    # ๋ถ„๋ฅ˜ head๋ฅผ 2 ํด๋ž˜์Šค ๋ถ„๋ฅ˜๋กœ ๋ณ€๊ฒฝ (cat:0, dog:1)
    num_features = model.heads.head.in_features
    model.heads.head = nn.Linear(num_features, 2)
    model = model.to(device)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        for images, labels in dataloader:
            images = images.to(device)
            labels = labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item() * images.size(0)
        
        epoch_loss = running_loss / len(dataset)
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")
        
        # loss๊ฐ€ 0.001 ์ดํ•˜์ด๋ฉด ํ•™์Šต ์ค‘๋‹จ
        if epoch_loss < 0.001:
            print("Loss threshold reached. Stopping training.")
            break

    print("Fine tuning ์™„๋ฃŒ.")
    return model


# [Step 4 ~ 5] ์˜ˆ์ธก ๋ฐ PredictLog.txt ๊ธฐ๋ก ํ•จ์ˆ˜
def predict_model(model, predict_dir, batch_size=8):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    dataset = PredictDataset(predict_dir, transform=transform)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    
    model.eval()
    predictions = {}
    with torch.no_grad():
        for images, filenames in dataloader:
            images = images.to(device)
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            preds = preds.cpu().numpy()
            for filename, pred in zip(filenames, preds):
                # ์ˆซ์ž label์„ ๋ฌธ์ž์—ด(label)๋กœ ๋ณ€ํ™˜
                pred_label = "cat" if pred == 0 else "dog"
                predictions[filename] = pred_label
    predict_log_path = os.path.join("Data", "PredictLog.txt")
    with open(predict_log_path, "w") as f:
        # ํŒŒ์ผ๋ช… ์ˆœ์„œ๋Œ€๋กœ ๊ธฐ๋ก (Img1.jpeg, Img2.jpeg, …)
        filenames_sorted = sorted(predictions.keys(), key=lambda x: int(re.search(r"(\d+)", x).group(1)))
        for filename in filenames_sorted:
            f.write(f"{filename}: {predictions[filename]}\n")
    print("์˜ˆ์ธก ์™„๋ฃŒ ๋ฐ PredictLog.txt ์ƒ์„ฑ๋จ.")
    
# [Step 6] DataLog.txt์™€ PredictLog.txt ๋น„๊ตํ•˜์—ฌ ์ •ํ™•๋„ ํ‰๊ฐ€ ๋ฐ ์ž˜๋ชป ์˜ˆ์ธก๋œ ์ด๋ฏธ์ง€ ํŒŒ์ผ ์ถœ๋ ฅ
def evaluate_predictions():
    data_log_path = os.path.join("Data", "DataLog.txt")
    predict_log_path = os.path.join("Data", "PredictLog.txt")
    
    with open(data_log_path, "r") as f:
        data_lines = f.readlines()
    with open(predict_log_path, "r") as f:
        predict_lines = f.readlines()
    
    if len(data_lines) != len(predict_lines):
        print("๊ฒฝ๊ณ : DataLog์™€ PredictLog์˜ ํ•ญ๋ชฉ ์ˆ˜๊ฐ€ ๋‹ค๋ฆ…๋‹ˆ๋‹ค.")
    
    total = min(len(data_lines), len(predict_lines))
    correct = 0
    misclassified = []
    
    for i in range(total):
        # ๊ฐ ์ค„์€ "ImgX.jpeg: label" ํ˜•์‹
        data_parts = data_lines[i].strip().split(":")
        predict_parts = predict_lines[i].strip().split(":")
        filename = data_parts[0].strip()
        true_label = data_parts[1].strip()
        pred_label = predict_parts[1].strip()
        
        if true_label == pred_label:
            correct += 1
        else:
            misclassified.append((filename, true_label, pred_label))
    
    print(f"์ •ํ™•๋„: {correct} / {total} (์ผ์น˜ํ•˜๋Š” ๊ฐœ์ˆ˜)")
    
    if misclassified:
        print("\n์ž˜๋ชป ์˜ˆ์ธก๋œ ์ด๋ฏธ์ง€ ํŒŒ์ผ๋“ค:")
        for filename, true_label, pred_label in misclassified:
            print(f"{filename}: ์‹ค์ œ = {true_label}, ์˜ˆ์ธก = {pred_label}")
    else:
        print("\n๋ชจ๋“  ์ด๋ฏธ์ง€๊ฐ€ ์ •ํ™•ํ•˜๊ฒŒ ์˜ˆ์ธก๋˜์—ˆ์Šต๋‹ˆ๋‹ค.")

# ์ „์ฒด ์‹คํ–‰ ํ๋ฆ„
def main():
    # ๋””๋ ‰ํ† ๋ฆฌ ์„ค์ •
    base_dir = "Data"
    animal_dir = os.path.join(base_dir, "Animal-10")
    fine_tuning_dir = os.path.join(base_dir, "Fine_tuning")
    predict_dir = os.path.join(base_dir, "Predict")
    
    # ํด๋” ์ƒ์„ฑ (์—†์œผ๋ฉด)
    create_dir(fine_tuning_dir)
    create_dir(predict_dir)
    
    # ์‚ฌ์šฉํ•  ์ด๋ฏธ์ง€ ๊ฐœ์ˆ˜ (ํ•„์š”์— ๋”ฐ๋ผ ์กฐ์ •)
    FINE_TUNE_N = 1024   # ๊ฐ ํด๋ž˜์Šค๋ณ„๋กœ fine tuning์— ์‚ฌ์šฉํ•  ์ด๋ฏธ์ง€ ์ˆ˜
    PREDICT_N = 100    # ์˜ˆ์ธก์— ์‚ฌ์šฉํ•  ์ „์ฒด ์ด๋ฏธ์ง€ ์ˆ˜
    
    # ๋ฐ์ดํ„ฐ ์ค€๋น„
    prepare_fine_tuning_dataset(animal_dir, fine_tuning_dir, FINE_TUNE_N)
    prepare_prediction_dataset(animal_dir, predict_dir, PREDICT_N)
    
    # ๋ชจ๋ธ fine tuning
    model = fine_tune_model(fine_tuning_dir, num_epochs=20, batch_size=32, learning_rate=1e-4)
    
    # ์˜ˆ์ธก ๋ฐ ๋กœ๊ทธ ์ €์žฅ
    predict_model(model, predict_dir, batch_size=8)
    
    # ๋กœ๊ทธ ๋น„๊ต ๋ฐ ํ‰๊ฐ€
    evaluate_predictions()

if __name__ == '__main__':
    main()