import os import torch import torchvision import torchvision.transforms as transforms from torch.utils.data import DataLoader from torchvision import datasets, models import torch.nn as nn import torch.optim as optim from tqdm import tqdm # https://tqdm.github.io/ # https://pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html # ======= Settings ======= CIEDGE_DIR = "./CIEDGE" CIFAKE_DIR = "./CIFAKE" BATCH_SIZE = 16 NUM_CLASSES = 2 DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # ======= Transforms ======= transform = transforms.Compose([ transforms.Resize((224, 224)), # Ensure fixed input size transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], # ImageNet mean [0.229, 0.224, 0.225]) # ImageNet std ]) cifake_dataset = datasets.ImageFolder(os.path.join(CIFAKE_DIR, "test"), transform=transform) ciedge_dataset = datasets.ImageFolder(os.path.join(CIEDGE_DIR, "test"), transform=transform) cifake_loader = DataLoader(cifake_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True) ciedge_loader = DataLoader(ciedge_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True) # ======= Load Model ======= cifake_model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1) ciedge_model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1) cifake_model.fc = nn.Linear(cifake_model.fc.in_features, NUM_CLASSES) ciedge_model.fc = nn.Linear(ciedge_model.fc.in_features, NUM_CLASSES) ciedge_state = torch.load("resnet18_ciedge.pth", map_location=torch.device('cpu')) cifake_state = torch.load("resnet18_cifake.pth", map_location=torch.device('cpu')) cifake_model.load_state_dict(cifake_state) ciedge_model.load_state_dict(ciedge_state) cifake_model = cifake_model.to(DEVICE) ciedge_model = ciedge_model.to(DEVICE) cifake_model.eval() correct, total = 0, 0 with torch.no_grad(): for inputs, labels in cifake_loader: inputs, labels = inputs.to(DEVICE), labels.to(DEVICE) inputs = inputs.contiguous() outputs = cifake_model(inputs) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() if total > 400: print(f"CIFAKE model evaluating CIFAKE data: {100 * correct / total:.2f}%") break cifake_model.eval() correct, total = 0, 0 with torch.no_grad(): for inputs, labels in ciedge_loader: inputs, labels = inputs.to(DEVICE), labels.to(DEVICE) inputs = inputs.contiguous() outputs = cifake_model(inputs) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() if total > 400: print(f"CIFAKE model evaluating CIEDGE data: {100 * correct / total:.2f}%") break ciedge_model.eval() correct, total = 0, 0 with torch.no_grad(): for inputs, labels in cifake_loader: inputs, labels = inputs.to(DEVICE), labels.to(DEVICE) inputs = inputs.contiguous() outputs = ciedge_model(inputs) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() if total > 400: print(f"CIEDGE model evaluating CIFAKE data: {100 * correct / total:.2f}%") break ciedge_model.eval() correct, total = 0, 0 with torch.no_grad(): for inputs, labels in ciedge_loader: inputs, labels = inputs.to(DEVICE), labels.to(DEVICE) inputs = inputs.contiguous() outputs = ciedge_model(inputs) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() if total > 400: print(f"CIEDGE model evaluating CIEDGE data: {100 * correct / total:.2f}%") break