| Pathology |
Wang et al. |
Yao et al. |
CheXNet |
Our Implemented CheXNet |
Our Improved Model |
| Atelectasis |
0.716 |
0.772 |
0.8094 |
0.8294 |
0.8311 |
| Cardiomegaly |
0.807 |
0.904 |
0.9248 |
0.9165 |
0.9220 |
| Effusion |
0.784 |
0.859 |
0.8638 |
0.8870 |
0.8891 |
| Infiltration |
0.609 |
0.695 |
0.7345 |
0.7143 |
0.7146 |
| Mass |
0.706 |
0.792 |
0.8676 |
0.8597 |
0.8627 |
| Nodule |
0.671 |
0.717 |
0.7802 |
0.7873 |
0.7883 |
| Pneumonia |
0.633 |
0.713 |
0.7680 |
0.7745 |
0.7820 |
| Pneumothorax |
0.806 |
0.841 |
0.8887 |
0.8726 |
0.8844 |
| Consolidation |
0.708 |
0.788 |
0.7901 |
0.8142 |
0.8148 |
| Edema |
0.835 |
0.882 |
0.8878 |
0.8932 |
0.8992 |
| Emphysema |
0.815 |
0.829 |
0.9371 |
0.9254 |
0.9343 |
| Fibrosis |
0.769 |
0.767 |
0.8047 |
0.8304 |
0.8385 |
| Pleural Thickening |
0.708 |
0.765 |
0.8062 |
0.7831 |
0.7914 |
| Hernia |
0.767 |
0.914 |
0.9164 |
0.9104 |
0.9206 |
# read_data
"""
Read images and corresponding labels.
"""
import torch
from torch.utils.data import Dataset
from PIL import Image
import os
class ChestXrayDataSet(Dataset):
def __init__(self, data_dir, image_list_file, transform=None):
"""
Args:
data_dir: path to image directory.
image_list_file: path to the file containing images
with corresponding labels.
transform: optional transform to be applied on a sample.
"""
image_names = []
labels = []
with open(image_list_file, "r") as f:
for line in f:
items = line.split()
image_name= items[0]
label = items[1:]
label = [int(i) for i in label]
image_name = os.path.join(data_dir, image_name)
image_names.append(image_name)
labels.append(label)
self.image_names = image_names
self.labels = labels
self.transform = transform
def __getitem__(self, index):
"""
Args:
index: the index of item
Returns:
image and its labels
"""
image_name = self.image_names[index]
image = Image.open(image_name).convert('RGB')
label = self.labels[index]
if self.transform is not None:
image = self.transform(image)
return image, torch.FloatTensor(label)
def __len__(self):
return len(self.image_names)
| 목적 |
설명 |
| 이미지 분류 |
주어진 X-ray 이미지를 보고 폐렴, 결핵 등 질병 여부를 판단하는 분류 모델 학습 |
| 멀티 라벨 분류 |
한 이미지에 여러 질병이 동시에 존재할 수 있는 경우 ([0,1,0,1,...]) |
| PyTorch DataLoader와 연동 |
이 클래스는 DataLoader로 감싸서 배치 단위 학습에 사용 가능 |
📌 실제 사용 예 (예시)
from torchvision import transforms
from torch.utils.data import DataLoader
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean, std)
])
train_dataset = ChestXrayDataSet(
data_dir='./data/images',
image_list_file='./data/train_labels.txt',
transform=transform
)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
이 코드는 유명한 CheXNet 스타일의 구현으로, DenseNet-121을 기반으로 한 흉부 X-ray 이미지에서 14가지 흉부 질환을 진단하는 멀티라벨 분류 모델입니다. NIH의 ChestX-ray14 데이터셋을 대상으로 작동하며, 각 이미지에 대해 14개 질병에 대한 존재 여부를 예측합니다.
🧠 모델 개요:
DenseNet121 기반 멀티라벨 분류
✅ 주요 특징
| 항목 |
설명 |
| 모델 구조 |
DenseNet-121 + Linear(num_ftrs, 14) + Sigmoid |
| 문제 유형 |
멀티라벨 이진 분류 (ex: 한 이미지에 질병 여러 개 동시 존재 가능) |
| 출력 |
[batch_size, 14] 크기의 확률 벡터 (각 질병 존재 확률) |
| 평가 지표 |
AUROC (클래스별 + 평균) |
🧱 코드 구성 요소 설명
1. DenseNet121 클래스 (모델 구조 정의)
self.densenet121 = torchvision.models.densenet121(pretrained=True)
num_ftrs = self.densenet121.classifier.in_features
self.densenet121.classifier = nn.Sequential(
nn.Linear(num_ftrs, out_size),
nn.Sigmoid()
)