Porocilo + slike
This commit is contained in:
114
eval_model.py
Normal file
114
eval_model.py
Normal file
@@ -0,0 +1,114 @@
|
||||
import os
|
||||
import torch
|
||||
import torchvision.transforms as transforms
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision import datasets, models
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from sklearn.metrics import roc_curve, auc
|
||||
from tqdm import tqdm
|
||||
|
||||
# ======= Settings =======
|
||||
CIEDGE_DIR = "./CIEDGE"
|
||||
CIFAKE_DIR = "./CIFAKE"
|
||||
BATCH_SIZE = 16
|
||||
NUM_CLASSES = 2
|
||||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# ======= Transforms =======
|
||||
transform = transforms.Compose([
|
||||
transforms.Resize((224, 224)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||
])
|
||||
|
||||
# ======= Load Datasets =======
|
||||
cifake_dataset = datasets.ImageFolder(os.path.join(CIFAKE_DIR, "test"), transform=transform)
|
||||
ciedge_dataset = datasets.ImageFolder(os.path.join(CIEDGE_DIR, "test"), transform=transform)
|
||||
|
||||
cifake_loader = DataLoader(cifake_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
|
||||
ciedge_loader = DataLoader(ciedge_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
|
||||
|
||||
# ======= Load Models =======
|
||||
def load_model(path):
|
||||
model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
|
||||
model.fc = nn.Linear(model.fc.in_features, NUM_CLASSES)
|
||||
model.load_state_dict(torch.load(path, map_location="cpu"))
|
||||
return model.to(DEVICE)
|
||||
|
||||
cifake_model = load_model("resnet18_cifake.pth")
|
||||
ciedge_model = load_model("resnet18_ciedge.pth")
|
||||
|
||||
# ======= ROC Helper Functions =======
|
||||
def get_probs_and_labels(model, dataloader):
|
||||
model.eval()
|
||||
all_probs = []
|
||||
all_labels = []
|
||||
|
||||
with torch.no_grad():
|
||||
for inputs, labels in tqdm(dataloader):
|
||||
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
|
||||
outputs = model(inputs)
|
||||
|
||||
probs = torch.softmax(outputs, dim=1)[:, 1] # class 1 probability
|
||||
all_probs.extend(probs.cpu().numpy())
|
||||
all_labels.extend(labels.cpu().numpy())
|
||||
|
||||
return np.array(all_probs), np.array(all_labels)
|
||||
|
||||
def plot_roc(probs, labels, title="ROC Curve"):
|
||||
fpr, tpr, _ = roc_curve(labels, probs)
|
||||
roc_auc = auc(fpr, tpr)
|
||||
|
||||
plt.figure()
|
||||
plt.plot(fpr, tpr, color='blue', lw=2, label=f"AUC = {roc_auc:.2f}")
|
||||
plt.plot([0, 1], [0, 1], color='gray', linestyle='--')
|
||||
plt.xlabel("False Positive Rate")
|
||||
plt.ylabel("True Positive Rate")
|
||||
plt.title(title)
|
||||
plt.legend(loc="lower right")
|
||||
plt.grid()
|
||||
plt.show()
|
||||
|
||||
# ======= Evaluate with Accuracy =======
|
||||
def evaluate_accuracy(model, dataloader, label):
|
||||
correct, total = 0, 0
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
for inputs, labels in tqdm(dataloader):
|
||||
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
|
||||
outputs = model(inputs)
|
||||
_, predicted = torch.max(outputs.data, 1)
|
||||
total += labels.size(0)
|
||||
correct += (predicted == labels).sum().item()
|
||||
print(f"{label}: {100 * correct / total:.2f}%")
|
||||
|
||||
# ======= Run All Evaluations =======
|
||||
# Accuracy evaluations
|
||||
evaluate_accuracy(cifake_model, cifake_loader, "CIFAKE model on CIFAKE data")
|
||||
evaluate_accuracy(cifake_model, ciedge_loader, "CIFAKE model on CIEDGE data")
|
||||
evaluate_accuracy(ciedge_model, cifake_loader, "CIEDGE model on CIFAKE data")
|
||||
evaluate_accuracy(ciedge_model, ciedge_loader, "CIEDGE model on CIEDGE data")
|
||||
|
||||
# ROC curve evaluations
|
||||
probs, labels = get_probs_and_labels(cifake_model, cifake_loader)
|
||||
plot_roc(probs, labels, "ROC: CIFAKE model on CIFAKE data")
|
||||
|
||||
probs, labels = get_probs_and_labels(cifake_model, ciedge_loader)
|
||||
plot_roc(probs, labels, "ROC: CIFAKE model on CIEDGE data")
|
||||
|
||||
probs, labels = get_probs_and_labels(ciedge_model, cifake_loader)
|
||||
plot_roc(probs, labels, "ROC: CIEDGE model on CIFAKE data")
|
||||
|
||||
probs, labels = get_probs_and_labels(ciedge_model, ciedge_loader)
|
||||
plot_roc(probs, labels, "ROC: CIEDGE model on CIEDGE data")
|
||||
|
||||
# 100%|██████████████████████████████████████████████████████████████████████████| 1250/1250 [00:15<00:00, 81.85it/s]
|
||||
# CIFAKE model on CIFAKE data: 97.55%
|
||||
# 100%|██████████████████████████████████████████████████████████████████████████| 1250/1250 [00:14<00:00, 83.71it/s]
|
||||
# CIFAKE model on CIEDGE data: 62.73%
|
||||
# 100%|██████████████████████████████████████████████████████████████████████████| 1250/1250 [00:15<00:00, 83.19it/s]
|
||||
# CIEDGE model on CIFAKE data: 50.09%
|
||||
# 100%|██████████████████████████████████████████████████████████████████████████| 1250/1250 [00:15<00:00, 83.08it/s]
|
||||
# CIEDGE model on CIEDGE data: 93.19%
|
||||
Reference in New Issue
Block a user