Porocilo + slike
This commit is contained in:
@@ -16,7 +16,7 @@ BATCH_SIZE = 16
|
||||
NUM_CLASSES = 2
|
||||
NUM_EPOCHS = 10
|
||||
LEARNING_RATE = 1e-4
|
||||
DEVICE = "cpu"
|
||||
DEVICE = "cuda"
|
||||
|
||||
# ======= Transforms =======
|
||||
transform = transforms.Compose([
|
||||
@@ -33,7 +33,7 @@ test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True, num_
|
||||
# ======= Load Model =======
|
||||
model = models.resnet18(pretrained=True)
|
||||
model.fc = nn.Linear(model.fc.in_features, NUM_CLASSES)
|
||||
state = torch.load("resnet18_cifake.pth", map_location=torch.device('cpu'))
|
||||
state = torch.load("resnet18_cifake.pth", map_location=torch.device('cuda'))
|
||||
model.load_state_dict(state)
|
||||
model = model.to(DEVICE)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user