rebased on master

This commit is contained in:
2025-04-23 10:14:49 +02:00
parent 2405289f79
commit b2676699a1
9 changed files with 396 additions and 0 deletions

85
train_model_ciedge.py Normal file
View File

@@ -0,0 +1,85 @@
import os
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import datasets, models
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm # https://tqdm.github.io/
# https://pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html
# ======= Settings =======
DATA_DIR = "./CIEDGE"
BATCH_SIZE = 16
NUM_CLASSES = 2
NUM_EPOCHS = 10
LEARNING_RATE = 1e-4
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ======= Transforms =======
transform = transforms.Compose([
transforms.Resize((224, 224)), # Ensure fixed input size
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], # ImageNet mean
[0.229, 0.224, 0.225]) # ImageNet std
])
# ======= Load Data =======
train_dataset = datasets.ImageFolder(os.path.join(DATA_DIR, "train"), transform=transform)
test_dataset = datasets.ImageFolder(os.path.join(DATA_DIR, "test"), transform=transform)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)
# ======= Load Model =======
model = models.resnet18(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, NUM_CLASSES)
model = model.to(DEVICE)
# ======= Loss and Optimizer =======
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
# ======= Training Loop =======
for epoch in range(NUM_EPOCHS):
model.train()
running_loss = 0.0
correct, total = 0, 0
loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}")
for inputs, labels in loop:
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
inputs = inputs.contiguous() # Ensure tensor layout for MIOpen
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
loop.set_postfix(loss=loss.item(), acc=100. * correct / total)
# ======= Evaluation =======
model.eval()
correct, total = 0, 0
with torch.no_grad():
for inputs, labels in test_loader:
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
inputs = inputs.contiguous()
outputs = model(inputs)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f"Test Accuracy: {100 * correct / total:.2f}%")
# ======= Save Model =======
torch.save(model.state_dict(), "resnet18_ciedge.pth")
#