Files
ovt-projekt/eval_model_cifake.py
2025-05-10 21:06:27 +02:00

58 lines
1.8 KiB
Python

import os
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import datasets, models
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm # https://tqdm.github.io/
# https://pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html
# ======= Settings =======
DATA_DIR = "./CIFAKE"
BATCH_SIZE = 16
NUM_CLASSES = 2
NUM_EPOCHS = 10
LEARNING_RATE = 1e-4
DEVICE = "cuda"
# ======= Transforms =======
transform = transforms.Compose([
transforms.Resize((224, 224)), # Ensure fixed input size
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], # ImageNet mean
[0.229, 0.224, 0.225]) # ImageNet std
])
# ======= Load Data =======
test_dataset = datasets.ImageFolder(os.path.join(DATA_DIR, "test"), transform=transform)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
# ======= Load Model =======
model = models.resnet18(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, NUM_CLASSES)
state = torch.load("resnet18_cifake.pth", map_location=torch.device('cuda'))
model.load_state_dict(state)
model = model.to(DEVICE)
# ======= Evaluation =======
model.eval()
correct, total = 0, 0
with torch.no_grad():
for inputs, labels in tqdm(test_loader):
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
inputs = inputs.contiguous()
outputs = model(inputs)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
#print(f"total={total}, correct={correct}")
if total >= 400:
break
print(f"Test Accuracy: {100 * correct / total:.2f}%")