import os import torch import torchvision.transforms as transforms from torch.utils.data import DataLoader from torchvision import datasets, models import torch.nn as nn import numpy as np import matplotlib.pyplot as plt from sklearn.metrics import roc_curve, auc from tqdm import tqdm # ======= Settings ======= CIEDGE_DIR = "./CIEDGE" CIFAKE_DIR = "./CIFAKE" BATCH_SIZE = 16 NUM_CLASSES = 2 DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # ======= Transforms ======= transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # ======= Load Datasets ======= cifake_dataset = datasets.ImageFolder(os.path.join(CIFAKE_DIR, "test"), transform=transform) ciedge_dataset = datasets.ImageFolder(os.path.join(CIEDGE_DIR, "test"), transform=transform) cifake_loader = DataLoader(cifake_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True) ciedge_loader = DataLoader(ciedge_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True) # ======= Load Models ======= def load_model(path): model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1) model.fc = nn.Linear(model.fc.in_features, NUM_CLASSES) model.load_state_dict(torch.load(path, map_location="cpu")) return model.to(DEVICE) cifake_model = load_model("resnet18_cifake.pth") ciedge_model = load_model("resnet18_ciedge.pth") # ======= ROC Helper Functions ======= def get_probs_and_labels(model, dataloader): model.eval() all_probs = [] all_labels = [] with torch.no_grad(): for inputs, labels in tqdm(dataloader): inputs, labels = inputs.to(DEVICE), labels.to(DEVICE) outputs = model(inputs) probs = torch.softmax(outputs, dim=1)[:, 1] # class 1 probability all_probs.extend(probs.cpu().numpy()) all_labels.extend(labels.cpu().numpy()) return np.array(all_probs), np.array(all_labels) def plot_roc(probs, labels, title="ROC Curve"): fpr, tpr, _ = roc_curve(labels, probs) roc_auc = auc(fpr, tpr) plt.figure() plt.plot(fpr, tpr, color='blue', lw=2, label=f"AUC = {roc_auc:.2f}") plt.plot([0, 1], [0, 1], color='gray', linestyle='--') plt.xlabel("False Positive Rate") plt.ylabel("True Positive Rate") plt.title(title) plt.legend(loc="lower right") plt.grid() plt.show() # ======= Evaluate with Accuracy ======= def evaluate_accuracy(model, dataloader, label): correct, total = 0, 0 model.eval() with torch.no_grad(): for inputs, labels in tqdm(dataloader): inputs, labels = inputs.to(DEVICE), labels.to(DEVICE) outputs = model(inputs) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() print(f"{label}: {100 * correct / total:.2f}%") # ======= Run All Evaluations ======= # Accuracy evaluations evaluate_accuracy(cifake_model, cifake_loader, "CIFAKE model on CIFAKE data") evaluate_accuracy(cifake_model, ciedge_loader, "CIFAKE model on CIEDGE data") evaluate_accuracy(ciedge_model, cifake_loader, "CIEDGE model on CIFAKE data") evaluate_accuracy(ciedge_model, ciedge_loader, "CIEDGE model on CIEDGE data") # ROC curve evaluations probs, labels = get_probs_and_labels(cifake_model, cifake_loader) plot_roc(probs, labels, "ROC: CIFAKE model on CIFAKE data") probs, labels = get_probs_and_labels(cifake_model, ciedge_loader) plot_roc(probs, labels, "ROC: CIFAKE model on CIEDGE data") probs, labels = get_probs_and_labels(ciedge_model, cifake_loader) plot_roc(probs, labels, "ROC: CIEDGE model on CIFAKE data") probs, labels = get_probs_and_labels(ciedge_model, ciedge_loader) plot_roc(probs, labels, "ROC: CIEDGE model on CIEDGE data") # 100%|██████████████████████████████████████████████████████████████████████████| 1250/1250 [00:15<00:00, 81.85it/s] # CIFAKE model on CIFAKE data: 97.55% # 100%|██████████████████████████████████████████████████████████████████████████| 1250/1250 [00:14<00:00, 83.71it/s] # CIFAKE model on CIEDGE data: 62.73% # 100%|██████████████████████████████████████████████████████████████████████████| 1250/1250 [00:15<00:00, 83.19it/s] # CIEDGE model on CIFAKE data: 50.09% # 100%|██████████████████████████████████████████████████████████████████████████| 1250/1250 [00:15<00:00, 83.08it/s] # CIEDGE model on CIEDGE data: 93.19%