# -*- coding: utf-8 -*- """ 手写数字识别 - 超参数配置 纯NumPy实现的两层全连接神经网络 """ # ===== 数据参数 ===== ONE_HOT = True # 标签是否使用One-Hot编码 # ===== 模型结构 ===== INPUT_SIZE = 784 # 28x28 = 784 像素 HIDDEN_SIZE = 128 # 隐藏层神经元数量 NUM_CLASSES = 10 # 0-9 十个数字 KEEP_PROB = 1.0 # Dropout保留比例(1.0=不使用Dropout) # ===== 训练参数 ===== LEARNING_RATE = 0.1 # 学习率 NUM_EPOCHS = 50 # 训练轮数 BATCH_SIZE = 64 # 批大小 # ===== 随机种子(保证可复现) ===== SEED = 42 # ===== 实验配置 ===== RUN_COMPARISON = False # 是否运行对比实验 # ===== 依赖说明 ===== # 本项目需要以下库: # numpy - 数值计算 # scikit-learn - 加载MNIST数据集(会自动下载) # pandas - sklearn的依赖 # # 安装命令: # pip install numpy scikit-learn pandas # # 数据说明: # 首次运行时会自动从OpenML下载MNIST数据集(约12MB) # 下载后会自动缓存,后续运行直接使用缓存数据