Pathology Wang et al. Yao et al. CheXNet Our Implemented CheXNet Our Improved Model
Atelectasis 0.716 0.772 0.8094 0.8294 0.8311
Cardiomegaly 0.807 0.904 0.9248 0.9165 0.9220
Effusion 0.784 0.859 0.8638 0.8870 0.8891
Infiltration 0.609 0.695 0.7345 0.7143 0.7146
Mass 0.706 0.792 0.8676 0.8597 0.8627
Nodule 0.671 0.717 0.7802 0.7873 0.7883
Pneumonia 0.633 0.713 0.7680 0.7745 0.7820
Pneumothorax 0.806 0.841 0.8887 0.8726 0.8844
Consolidation 0.708 0.788 0.7901 0.8142 0.8148
Edema 0.835 0.882 0.8878 0.8932 0.8992
Emphysema 0.815 0.829 0.9371 0.9254 0.9343
Fibrosis 0.769 0.767 0.8047 0.8304 0.8385
Pleural Thickening 0.708 0.765 0.8062 0.7831 0.7914
Hernia 0.767 0.914 0.9164 0.9104 0.9206

# read_data

"""
Read images and corresponding labels.
"""

import torch
from torch.utils.data import Dataset
from PIL import Image
import os

class ChestXrayDataSet(Dataset):
    def __init__(self, data_dir, image_list_file, transform=None):
        """
        Args:
            data_dir: path to image directory.
            image_list_file: path to the file containing images
                with corresponding labels.
            transform: optional transform to be applied on a sample.
        """
        image_names = []
        labels = []
        with open(image_list_file, "r") as f:
            for line in f:
                items = line.split()
                image_name= items[0]
                label = items[1:]
                label = [int(i) for i in label]
                image_name = os.path.join(data_dir, image_name)
                image_names.append(image_name)
                labels.append(label)

        self.image_names = image_names
        self.labels = labels
        self.transform = transform

    def __getitem__(self, index):
        """
        Args:
            index: the index of item

        Returns:
            image and its labels
        """
        image_name = self.image_names[index]
        image = Image.open(image_name).convert('RGB')
        label = self.labels[index]
        if self.transform is not None:
            image = self.transform(image)
        return image, torch.FloatTensor(label)

    def __len__(self):
        return len(self.image_names)
목적 설명
이미지 분류 주어진 X-ray 이미지를 보고 폐렴, 결핵 등 질병 여부를 판단하는 분류 모델 학습
멀티 라벨 분류 한 이미지에 여러 질병이 동시에 존재할 수 있는 경우 ([0,1,0,1,...])
PyTorch DataLoader와 연동 이 클래스는 DataLoader로 감싸서 배치 단위 학습에 사용 가능

📌 실제 사용 예 (예시)

from torchvision import transforms
from torch.utils.data import DataLoader

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

train_dataset = ChestXrayDataSet(
    data_dir='./data/images',
    image_list_file='./data/train_labels.txt',
    transform=transform
)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

이 코드는 유명한 CheXNet 스타일의 구현으로, DenseNet-121을 기반으로 한 흉부 X-ray 이미지에서 14가지 흉부 질환을 진단하는 멀티라벨 분류 모델입니다. NIH의 ChestX-ray14 데이터셋을 대상으로 작동하며, 각 이미지에 대해 14개 질병에 대한 존재 여부를 예측합니다.

🧠 모델 개요:

DenseNet121 기반 멀티라벨 분류

✅ 주요 특징

항목 설명
모델 구조 DenseNet-121 + Linear(num_ftrs, 14) + Sigmoid
문제 유형 멀티라벨 이진 분류 (ex: 한 이미지에 질병 여러 개 동시 존재 가능)
출력 [batch_size, 14] 크기의 확률 벡터 (각 질병 존재 확률)
평가 지표 AUROC (클래스별 + 평균)

🧱 코드 구성 요소 설명

1. DenseNet121 클래스 (모델 구조 정의)

self.densenet121 = torchvision.models.densenet121(pretrained=True)
num_ftrs = self.densenet121.classifier.in_features
self.densenet121.classifier = nn.Sequential(
    nn.Linear(num_ftrs, out_size),
    nn.Sigmoid()
)