Files
ovt-projekt/eval_model_ciedge.py
2025-04-23 10:18:58 +02:00

122 lines
3.8 KiB
Python

import os
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import datasets, models
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm # https://tqdm.github.io/
# https://pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html
# ======= Settings =======
CIEDGE_DIR = "./CIEDGE"
CIFAKE_DIR = "./CIFAKE"
BATCH_SIZE = 16
NUM_CLASSES = 2
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ======= Transforms =======
transform = transforms.Compose([
transforms.Resize((224, 224)), # Ensure fixed input size
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], # ImageNet mean
[0.229, 0.224, 0.225]) # ImageNet std
])
cifake_dataset = datasets.ImageFolder(os.path.join(CIFAKE_DIR, "test"), transform=transform)
ciedge_dataset = datasets.ImageFolder(os.path.join(CIEDGE_DIR, "test"), transform=transform)
cifake_loader = DataLoader(cifake_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
ciedge_loader = DataLoader(ciedge_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
# ======= Load Model =======
cifake_model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
ciedge_model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
cifake_model.fc = nn.Linear(cifake_model.fc.in_features, NUM_CLASSES)
ciedge_model.fc = nn.Linear(ciedge_model.fc.in_features, NUM_CLASSES)
ciedge_state = torch.load("resnet18_ciedge.pth", map_location=torch.device('cpu'))
cifake_state = torch.load("resnet18_cifake.pth", map_location=torch.device('cpu'))
cifake_model.load_state_dict(cifake_state)
ciedge_model.load_state_dict(ciedge_state)
cifake_model = cifake_model.to(DEVICE)
ciedge_model = ciedge_model.to(DEVICE)
cifake_model.eval()
correct, total = 0, 0
with torch.no_grad():
for inputs, labels in cifake_loader:
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
inputs = inputs.contiguous()
outputs = cifake_model(inputs)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
if total > 400:
print(f"CIFAKE model evaluating CIFAKE data: {100 * correct / total:.2f}%")
break
cifake_model.eval()
correct, total = 0, 0
with torch.no_grad():
for inputs, labels in ciedge_loader:
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
inputs = inputs.contiguous()
outputs = cifake_model(inputs)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
if total > 400:
print(f"CIFAKE model evaluating CIEDGE data: {100 * correct / total:.2f}%")
break
ciedge_model.eval()
correct, total = 0, 0
with torch.no_grad():
for inputs, labels in cifake_loader:
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
inputs = inputs.contiguous()
outputs = ciedge_model(inputs)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
if total > 400:
print(f"CIEDGE model evaluating CIFAKE data: {100 * correct / total:.2f}%")
break
ciedge_model.eval()
correct, total = 0, 0
with torch.no_grad():
for inputs, labels in ciedge_loader:
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
inputs = inputs.contiguous()
outputs = ciedge_model(inputs)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
if total > 400:
print(f"CIEDGE model evaluating CIEDGE data: {100 * correct / total:.2f}%")
break