115 lines
5.0 KiB
Python
115 lines
5.0 KiB
Python
import os
|
|
import torch
|
|
import torchvision.transforms as transforms
|
|
from torch.utils.data import DataLoader
|
|
from torchvision import datasets, models
|
|
import torch.nn as nn
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
from sklearn.metrics import roc_curve, auc
|
|
from tqdm import tqdm
|
|
|
|
# ======= Settings =======
|
|
CIEDGE_DIR = "./CIEDGE"
|
|
CIFAKE_DIR = "./CIFAKE"
|
|
BATCH_SIZE = 16
|
|
NUM_CLASSES = 2
|
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
# ======= Transforms =======
|
|
transform = transforms.Compose([
|
|
transforms.Resize((224, 224)),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
|
])
|
|
|
|
# ======= Load Datasets =======
|
|
cifake_dataset = datasets.ImageFolder(os.path.join(CIFAKE_DIR, "test"), transform=transform)
|
|
ciedge_dataset = datasets.ImageFolder(os.path.join(CIEDGE_DIR, "test"), transform=transform)
|
|
|
|
cifake_loader = DataLoader(cifake_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
|
|
ciedge_loader = DataLoader(ciedge_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
|
|
|
|
# ======= Load Models =======
|
|
def load_model(path):
|
|
model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
|
|
model.fc = nn.Linear(model.fc.in_features, NUM_CLASSES)
|
|
model.load_state_dict(torch.load(path, map_location="cpu"))
|
|
return model.to(DEVICE)
|
|
|
|
cifake_model = load_model("resnet18_cifake.pth")
|
|
ciedge_model = load_model("resnet18_ciedge.pth")
|
|
|
|
# ======= ROC Helper Functions =======
|
|
def get_probs_and_labels(model, dataloader):
|
|
model.eval()
|
|
all_probs = []
|
|
all_labels = []
|
|
|
|
with torch.no_grad():
|
|
for inputs, labels in tqdm(dataloader):
|
|
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
|
|
outputs = model(inputs)
|
|
|
|
probs = torch.softmax(outputs, dim=1)[:, 1] # class 1 probability
|
|
all_probs.extend(probs.cpu().numpy())
|
|
all_labels.extend(labels.cpu().numpy())
|
|
|
|
return np.array(all_probs), np.array(all_labels)
|
|
|
|
def plot_roc(probs, labels, title="ROC Curve"):
|
|
fpr, tpr, _ = roc_curve(labels, probs)
|
|
roc_auc = auc(fpr, tpr)
|
|
|
|
plt.figure()
|
|
plt.plot(fpr, tpr, color='blue', lw=2, label=f"AUC = {roc_auc:.2f}")
|
|
plt.plot([0, 1], [0, 1], color='gray', linestyle='--')
|
|
plt.xlabel("False Positive Rate")
|
|
plt.ylabel("True Positive Rate")
|
|
plt.title(title)
|
|
plt.legend(loc="lower right")
|
|
plt.grid()
|
|
plt.show()
|
|
|
|
# ======= Evaluate with Accuracy =======
|
|
def evaluate_accuracy(model, dataloader, label):
|
|
correct, total = 0, 0
|
|
model.eval()
|
|
with torch.no_grad():
|
|
for inputs, labels in tqdm(dataloader):
|
|
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
|
|
outputs = model(inputs)
|
|
_, predicted = torch.max(outputs.data, 1)
|
|
total += labels.size(0)
|
|
correct += (predicted == labels).sum().item()
|
|
print(f"{label}: {100 * correct / total:.2f}%")
|
|
|
|
# ======= Run All Evaluations =======
|
|
# Accuracy evaluations
|
|
evaluate_accuracy(cifake_model, cifake_loader, "CIFAKE model on CIFAKE data")
|
|
evaluate_accuracy(cifake_model, ciedge_loader, "CIFAKE model on CIEDGE data")
|
|
evaluate_accuracy(ciedge_model, cifake_loader, "CIEDGE model on CIFAKE data")
|
|
evaluate_accuracy(ciedge_model, ciedge_loader, "CIEDGE model on CIEDGE data")
|
|
|
|
# ROC curve evaluations
|
|
probs, labels = get_probs_and_labels(cifake_model, cifake_loader)
|
|
plot_roc(probs, labels, "ROC: CIFAKE model on CIFAKE data")
|
|
|
|
probs, labels = get_probs_and_labels(cifake_model, ciedge_loader)
|
|
plot_roc(probs, labels, "ROC: CIFAKE model on CIEDGE data")
|
|
|
|
probs, labels = get_probs_and_labels(ciedge_model, cifake_loader)
|
|
plot_roc(probs, labels, "ROC: CIEDGE model on CIFAKE data")
|
|
|
|
probs, labels = get_probs_and_labels(ciedge_model, ciedge_loader)
|
|
plot_roc(probs, labels, "ROC: CIEDGE model on CIEDGE data")
|
|
|
|
# 100%|██████████████████████████████████████████████████████████████████████████| 1250/1250 [00:15<00:00, 81.85it/s]
|
|
# CIFAKE model on CIFAKE data: 97.55%
|
|
# 100%|██████████████████████████████████████████████████████████████████████████| 1250/1250 [00:14<00:00, 83.71it/s]
|
|
# CIFAKE model on CIEDGE data: 62.73%
|
|
# 100%|██████████████████████████████████████████████████████████████████████████| 1250/1250 [00:15<00:00, 83.19it/s]
|
|
# CIEDGE model on CIFAKE data: 50.09%
|
|
# 100%|██████████████████████████████████████████████████████████████████████████| 1250/1250 [00:15<00:00, 83.08it/s]
|
|
# CIEDGE model on CIEDGE data: 93.19%
|