Files
task-3-2-2-text-classification/config.py
2026-05-19 11:31:06 +08:00

39 lines
1.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
手写数字识别 - 超参数配置
纯NumPy实现的两层全连接神经网络
"""
# ===== 数据参数 =====
ONE_HOT = True # 标签是否使用One-Hot编码
# ===== 模型结构 =====
INPUT_SIZE = 784 # 28x28 = 784 像素
HIDDEN_SIZE = 128 # 隐藏层神经元数量
NUM_CLASSES = 10 # 0-9 十个数字
KEEP_PROB = 1.0 # Dropout保留比例1.0=不使用Dropout
# ===== 训练参数 =====
LEARNING_RATE = 0.1 # 学习率
NUM_EPOCHS = 50 # 训练轮数
BATCH_SIZE = 64 # 批大小
# ===== 随机种子(保证可复现) =====
SEED = 42
# ===== 实验配置 =====
RUN_COMPARISON = False # 是否运行对比实验
# ===== 依赖说明 =====
# 本项目需要以下库:
# numpy - 数值计算
# scikit-learn - 加载MNIST数据集会自动下载
# pandas - sklearn的依赖
#
# 安装命令:
# pip install numpy scikit-learn pandas
#
# 数据说明:
# 首次运行时会自动从OpenML下载MNIST数据集约12MB
# 下载后会自动缓存,后续运行直接使用缓存数据