rebased on master

This commit is contained in:
2025-04-23 10:14:49 +02:00
parent 2405289f79
commit b2676699a1
9 changed files with 396 additions and 0 deletions

121
eval_model_ciedge.py Normal file
View File

@@ -0,0 +1,121 @@
import os
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import datasets, models
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm # https://tqdm.github.io/
# https://pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html
# ======= Settings =======
CIEDGE_DIR = "./CIEDGE"
CIFAKE_DIR = "./CIFAKE"
BATCH_SIZE = 16
NUM_CLASSES = 2
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ======= Transforms =======
transform = transforms.Compose([
transforms.Resize((224, 224)), # Ensure fixed input size
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], # ImageNet mean
[0.229, 0.224, 0.225]) # ImageNet std
])
cifake_dataset = datasets.ImageFolder(os.path.join(CIFAKE_DIR, "test"), transform=transform)
ciedge_dataset = datasets.ImageFolder(os.path.join(CIEDGE_DIR, "test"), transform=transform)
cifake_loader = DataLoader(cifake_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
ciedge_loader = DataLoader(ciedge_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
# ======= Load Model =======
cifake_model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
ciedge_model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
cifake_model.fc = nn.Linear(cifake_model.fc.in_features, NUM_CLASSES)
ciedge_model.fc = nn.Linear(ciedge_model.fc.in_features, NUM_CLASSES)
ciedge_state = torch.load("resnet18_ciedge.pth", map_location=torch.device('cpu'))
cifake_state = torch.load("resnet18_cifake.pth", map_location=torch.device('cpu'))
cifake_model.load_state_dict(cifake_state)
ciedge_model.load_state_dict(ciedge_state)
cifake_model = cifake_model.to(DEVICE)
ciedge_model = ciedge_model.to(DEVICE)
cifake_model.eval()
correct, total = 0, 0
with torch.no_grad():
for inputs, labels in cifake_loader:
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
inputs = inputs.contiguous()
outputs = cifake_model(inputs)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
if total > 400:
print(f"CIFAKE model evaluating CIFAKE data: {100 * correct / total:.2f}%")
break
cifake_model.eval()
correct, total = 0, 0
with torch.no_grad():
for inputs, labels in ciedge_loader:
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
inputs = inputs.contiguous()
outputs = cifake_model(inputs)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
if total > 400:
print(f"CIFAKE model evaluating CIEDGE data: {100 * correct / total:.2f}%")
break
ciedge_model.eval()
correct, total = 0, 0
with torch.no_grad():
for inputs, labels in cifake_loader:
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
inputs = inputs.contiguous()
outputs = ciedge_model(inputs)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
if total > 400:
print(f"CIEDGE model evaluating CIFAKE data: {100 * correct / total:.2f}%")
break
ciedge_model.eval()
correct, total = 0, 0
with torch.no_grad():
for inputs, labels in ciedge_loader:
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
inputs = inputs.contiguous()
outputs = ciedge_model(inputs)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
if total > 400:
print(f"CIEDGE model evaluating CIEDGE data: {100 * correct / total:.2f}%")
break