Porocilo + slike

This commit is contained in:
2025-05-10 21:06:27 +02:00
parent b2676699a1
commit d933d92667
7 changed files with 467 additions and 2 deletions

114
eval_model.py Normal file
View File

@@ -0,0 +1,114 @@
import os
import torch
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import datasets, models
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc
from tqdm import tqdm
# ======= Settings =======
CIEDGE_DIR = "./CIEDGE"
CIFAKE_DIR = "./CIFAKE"
BATCH_SIZE = 16
NUM_CLASSES = 2
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ======= Transforms =======
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# ======= Load Datasets =======
cifake_dataset = datasets.ImageFolder(os.path.join(CIFAKE_DIR, "test"), transform=transform)
ciedge_dataset = datasets.ImageFolder(os.path.join(CIEDGE_DIR, "test"), transform=transform)
cifake_loader = DataLoader(cifake_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
ciedge_loader = DataLoader(ciedge_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
# ======= Load Models =======
def load_model(path):
model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
model.fc = nn.Linear(model.fc.in_features, NUM_CLASSES)
model.load_state_dict(torch.load(path, map_location="cpu"))
return model.to(DEVICE)
cifake_model = load_model("resnet18_cifake.pth")
ciedge_model = load_model("resnet18_ciedge.pth")
# ======= ROC Helper Functions =======
def get_probs_and_labels(model, dataloader):
model.eval()
all_probs = []
all_labels = []
with torch.no_grad():
for inputs, labels in tqdm(dataloader):
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
outputs = model(inputs)
probs = torch.softmax(outputs, dim=1)[:, 1] # class 1 probability
all_probs.extend(probs.cpu().numpy())
all_labels.extend(labels.cpu().numpy())
return np.array(all_probs), np.array(all_labels)
def plot_roc(probs, labels, title="ROC Curve"):
fpr, tpr, _ = roc_curve(labels, probs)
roc_auc = auc(fpr, tpr)
plt.figure()
plt.plot(fpr, tpr, color='blue', lw=2, label=f"AUC = {roc_auc:.2f}")
plt.plot([0, 1], [0, 1], color='gray', linestyle='--')
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title(title)
plt.legend(loc="lower right")
plt.grid()
plt.show()
# ======= Evaluate with Accuracy =======
def evaluate_accuracy(model, dataloader, label):
correct, total = 0, 0
model.eval()
with torch.no_grad():
for inputs, labels in tqdm(dataloader):
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
outputs = model(inputs)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f"{label}: {100 * correct / total:.2f}%")
# ======= Run All Evaluations =======
# Accuracy evaluations
evaluate_accuracy(cifake_model, cifake_loader, "CIFAKE model on CIFAKE data")
evaluate_accuracy(cifake_model, ciedge_loader, "CIFAKE model on CIEDGE data")
evaluate_accuracy(ciedge_model, cifake_loader, "CIEDGE model on CIFAKE data")
evaluate_accuracy(ciedge_model, ciedge_loader, "CIEDGE model on CIEDGE data")
# ROC curve evaluations
probs, labels = get_probs_and_labels(cifake_model, cifake_loader)
plot_roc(probs, labels, "ROC: CIFAKE model on CIFAKE data")
probs, labels = get_probs_and_labels(cifake_model, ciedge_loader)
plot_roc(probs, labels, "ROC: CIFAKE model on CIEDGE data")
probs, labels = get_probs_and_labels(ciedge_model, cifake_loader)
plot_roc(probs, labels, "ROC: CIEDGE model on CIFAKE data")
probs, labels = get_probs_and_labels(ciedge_model, ciedge_loader)
plot_roc(probs, labels, "ROC: CIEDGE model on CIEDGE data")
# 100%|██████████████████████████████████████████████████████████████████████████| 1250/1250 [00:15<00:00, 81.85it/s]
# CIFAKE model on CIFAKE data: 97.55%
# 100%|██████████████████████████████████████████████████████████████████████████| 1250/1250 [00:14<00:00, 83.71it/s]
# CIFAKE model on CIEDGE data: 62.73%
# 100%|██████████████████████████████████████████████████████████████████████████| 1250/1250 [00:15<00:00, 83.19it/s]
# CIEDGE model on CIFAKE data: 50.09%
# 100%|██████████████████████████████████████████████████████████████████████████| 1250/1250 [00:15<00:00, 83.08it/s]
# CIEDGE model on CIEDGE data: 93.19%