导包
1 | import os |
定义参数
1 | # Device configuration |
图片预处理
1 | transform = transforms.Compose([ |
导入CIFAR10数据集,定义数据加载器
1 | # CIFAR-10 dataset |
定义网络
1 | # 3x3 convolution |
定义损失函数和优化器
1 | criterion = nn.CrossEntropyLoss() |
更新学习率
1 | def update_lr(optimizer, lr): |
训练模型
1 | total_step = len(train_loader) |
测试模型
1 | model.eval() |