39 lines
1.2 KiB
Python
39 lines
1.2 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
手写数字识别 - 超参数配置
|
||
|
||
纯NumPy实现的两层全连接神经网络
|
||
"""
|
||
|
||
# ===== 数据参数 =====
|
||
ONE_HOT = True # 标签是否使用One-Hot编码
|
||
|
||
# ===== 模型结构 =====
|
||
INPUT_SIZE = 784 # 28x28 = 784 像素
|
||
HIDDEN_SIZE = 128 # 隐藏层神经元数量
|
||
NUM_CLASSES = 10 # 0-9 十个数字
|
||
KEEP_PROB = 1.0 # Dropout保留比例(1.0=不使用Dropout)
|
||
|
||
# ===== 训练参数 =====
|
||
LEARNING_RATE = 0.005 # 学习率
|
||
NUM_EPOCHS = 120 # 训练轮数
|
||
BATCH_SIZE = 64 # 批大小
|
||
|
||
# ===== 随机种子(保证可复现) =====
|
||
SEED = 42
|
||
|
||
# ===== 实验配置 =====
|
||
RUN_COMPARISON = False # 是否运行对比实验
|
||
|
||
# ===== 依赖说明 =====
|
||
# 本项目需要以下库:
|
||
# numpy - 数值计算
|
||
# scikit-learn - 加载MNIST数据集(会自动下载)
|
||
# pandas - sklearn的依赖
|
||
#
|
||
# 安装命令:
|
||
# pip install numpy scikit-learn pandas
|
||
#
|
||
# 数据说明:
|
||
# 首次运行时会自动从OpenML下载MNIST数据集(约12MB)
|
||
# 下载后会自动缓存,后续运行直接使用缓存数据 |