diff --git a/0-9/0-9(big)/0.jpg b/0-9/0-9(big)/0.jpg new file mode 100644 index 0000000..497b97f Binary files /dev/null and b/0-9/0-9(big)/0.jpg differ diff --git a/0-9/0-9(big)/1.jpg b/0-9/0-9(big)/1.jpg new file mode 100644 index 0000000..8bbbfa2 Binary files /dev/null and b/0-9/0-9(big)/1.jpg differ diff --git a/0-9/0-9(big)/2.jpg b/0-9/0-9(big)/2.jpg new file mode 100644 index 0000000..81a9ae3 Binary files /dev/null and b/0-9/0-9(big)/2.jpg differ diff --git a/0-9/0-9(big)/3.jpg b/0-9/0-9(big)/3.jpg new file mode 100644 index 0000000..a862f79 Binary files /dev/null and b/0-9/0-9(big)/3.jpg differ diff --git a/0-9/0-9(big)/4.jpg b/0-9/0-9(big)/4.jpg new file mode 100644 index 0000000..07495e8 Binary files /dev/null and b/0-9/0-9(big)/4.jpg differ diff --git a/0-9/0-9(big)/5.jpg b/0-9/0-9(big)/5.jpg new file mode 100644 index 0000000..2b50293 Binary files /dev/null and b/0-9/0-9(big)/5.jpg differ diff --git a/0-9/0-9(big)/6.jpg b/0-9/0-9(big)/6.jpg new file mode 100644 index 0000000..4f7384e Binary files /dev/null and b/0-9/0-9(big)/6.jpg differ diff --git a/0-9/0-9(big)/7.jpg b/0-9/0-9(big)/7.jpg new file mode 100644 index 0000000..536baf4 Binary files /dev/null and b/0-9/0-9(big)/7.jpg differ diff --git a/0-9/0-9(big)/8.jpg b/0-9/0-9(big)/8.jpg new file mode 100644 index 0000000..37dc720 Binary files /dev/null and b/0-9/0-9(big)/8.jpg differ diff --git a/0-9/0-9(big)/9.jpg b/0-9/0-9(big)/9.jpg new file mode 100644 index 0000000..89ca3d7 Binary files /dev/null and b/0-9/0-9(big)/9.jpg differ diff --git a/0-9/0-9(middle)/0.jpg b/0-9/0-9(middle)/0.jpg new file mode 100644 index 0000000..5b9e888 Binary files /dev/null and b/0-9/0-9(middle)/0.jpg differ diff --git a/0-9/0-9(middle)/1.jpg b/0-9/0-9(middle)/1.jpg new file mode 100644 index 0000000..ab79997 Binary files /dev/null and b/0-9/0-9(middle)/1.jpg differ diff --git a/0-9/0-9(middle)/2.jpg b/0-9/0-9(middle)/2.jpg new file mode 100644 index 0000000..7b57d36 Binary files /dev/null and b/0-9/0-9(middle)/2.jpg differ diff --git a/0-9/0-9(middle)/3.jpg b/0-9/0-9(middle)/3.jpg new file mode 100644 index 0000000..763cc41 Binary files /dev/null and b/0-9/0-9(middle)/3.jpg differ diff --git a/0-9/0-9(middle)/4.jpg b/0-9/0-9(middle)/4.jpg new file mode 100644 index 0000000..92bceae Binary files /dev/null and b/0-9/0-9(middle)/4.jpg differ diff --git a/0-9/0-9(middle)/5.jpg b/0-9/0-9(middle)/5.jpg new file mode 100644 index 0000000..8b607ef Binary files /dev/null and b/0-9/0-9(middle)/5.jpg differ diff --git a/0-9/0-9(middle)/6.jpg b/0-9/0-9(middle)/6.jpg new file mode 100644 index 0000000..3a16dfb Binary files /dev/null and b/0-9/0-9(middle)/6.jpg differ diff --git a/0-9/0-9(middle)/7.jpg b/0-9/0-9(middle)/7.jpg new file mode 100644 index 0000000..dbbd8f5 Binary files /dev/null and b/0-9/0-9(middle)/7.jpg differ diff --git a/0-9/0-9(middle)/8.jpg b/0-9/0-9(middle)/8.jpg new file mode 100644 index 0000000..3e29589 Binary files /dev/null and b/0-9/0-9(middle)/8.jpg differ diff --git a/0-9/0-9(middle)/9.jpg b/0-9/0-9(middle)/9.jpg new file mode 100644 index 0000000..bc15e28 Binary files /dev/null and b/0-9/0-9(middle)/9.jpg differ diff --git a/0-9/0-9(small)/0.jpg b/0-9/0-9(small)/0.jpg new file mode 100644 index 0000000..d46041b Binary files /dev/null and b/0-9/0-9(small)/0.jpg differ diff --git a/0-9/0-9(small)/1.jpg b/0-9/0-9(small)/1.jpg new file mode 100644 index 0000000..7e7fa1c Binary files /dev/null and b/0-9/0-9(small)/1.jpg differ diff --git a/0-9/0-9(small)/2.jpg b/0-9/0-9(small)/2.jpg new file mode 100644 index 0000000..6865a6d Binary files /dev/null and b/0-9/0-9(small)/2.jpg differ diff --git a/0-9/0-9(small)/3.jpg b/0-9/0-9(small)/3.jpg new file mode 100644 index 0000000..028ff45 Binary files /dev/null and b/0-9/0-9(small)/3.jpg differ diff --git a/0-9/0-9(small)/4.jpg b/0-9/0-9(small)/4.jpg new file mode 100644 index 0000000..65ae2eb Binary files /dev/null and b/0-9/0-9(small)/4.jpg differ diff --git a/0-9/0-9(small)/5.jpg b/0-9/0-9(small)/5.jpg new file mode 100644 index 0000000..eaaf351 Binary files /dev/null and b/0-9/0-9(small)/5.jpg differ diff --git a/0-9/0-9(small)/6.jpg b/0-9/0-9(small)/6.jpg new file mode 100644 index 0000000..6ba9950 Binary files /dev/null and b/0-9/0-9(small)/6.jpg differ diff --git a/0-9/0-9(small)/7.jpg b/0-9/0-9(small)/7.jpg new file mode 100644 index 0000000..3b7b45a Binary files /dev/null and b/0-9/0-9(small)/7.jpg differ diff --git a/0-9/0-9(small)/8.jpg b/0-9/0-9(small)/8.jpg new file mode 100644 index 0000000..dfb4747 Binary files /dev/null and b/0-9/0-9(small)/8.jpg differ diff --git a/0-9/0-9(small)/9.jpg b/0-9/0-9(small)/9.jpg new file mode 100644 index 0000000..b701ef0 Binary files /dev/null and b/0-9/0-9(small)/9.jpg differ diff --git a/0521.xlsx b/0521.xlsx new file mode 100644 index 0000000..9efcde2 Binary files /dev/null and b/0521.xlsx differ diff --git a/digit_mlp_class/0-9/0.jpg b/digit_mlp_class/0-9/0.jpg new file mode 100644 index 0000000..497b97f Binary files /dev/null and b/digit_mlp_class/0-9/0.jpg differ diff --git a/digit_mlp_class/0-9/1.jpg b/digit_mlp_class/0-9/1.jpg new file mode 100644 index 0000000..8bbbfa2 Binary files /dev/null and b/digit_mlp_class/0-9/1.jpg differ diff --git a/digit_mlp_class/0-9/2.jpg b/digit_mlp_class/0-9/2.jpg new file mode 100644 index 0000000..81a9ae3 Binary files /dev/null and b/digit_mlp_class/0-9/2.jpg differ diff --git a/digit_mlp_class/0-9/3.jpg b/digit_mlp_class/0-9/3.jpg new file mode 100644 index 0000000..a862f79 Binary files /dev/null and b/digit_mlp_class/0-9/3.jpg differ diff --git a/digit_mlp_class/0-9/4.jpg b/digit_mlp_class/0-9/4.jpg new file mode 100644 index 0000000..07495e8 Binary files /dev/null and b/digit_mlp_class/0-9/4.jpg differ diff --git a/digit_mlp_class/0-9/5.jpg b/digit_mlp_class/0-9/5.jpg new file mode 100644 index 0000000..2b50293 Binary files /dev/null and b/digit_mlp_class/0-9/5.jpg differ diff --git a/digit_mlp_class/0-9/6.jpg b/digit_mlp_class/0-9/6.jpg new file mode 100644 index 0000000..4f7384e Binary files /dev/null and b/digit_mlp_class/0-9/6.jpg differ diff --git a/digit_mlp_class/0-9/7.jpg b/digit_mlp_class/0-9/7.jpg new file mode 100644 index 0000000..536baf4 Binary files /dev/null and b/digit_mlp_class/0-9/7.jpg differ diff --git a/digit_mlp_class/0-9/8.jpg b/digit_mlp_class/0-9/8.jpg new file mode 100644 index 0000000..37dc720 Binary files /dev/null and b/digit_mlp_class/0-9/8.jpg differ diff --git a/digit_mlp_class/0-9/9.jpg b/digit_mlp_class/0-9/9.jpg new file mode 100644 index 0000000..89ca3d7 Binary files /dev/null and b/digit_mlp_class/0-9/9.jpg differ diff --git a/digit_mlp_class/README.md b/digit_mlp_class/README.md new file mode 100644 index 0000000..1f48f09 --- /dev/null +++ b/digit_mlp_class/README.md @@ -0,0 +1,172 @@ +# 手写数字识别 - 纯NumPy MLP实现 + +## 项目简介 + +使用纯NumPy实现的两层全连接神经网络(MLP),在MNIST数据集上进行手写数字识别。 + +**零深度学习框架依赖**,只需 `numpy`。 + +## 网络结构 + +``` +输入层(784) → 隐藏层(128) + ReLU → 输出层(10) + Softmax +``` + +- **输入**: 28×28=784 像素值,归一化到 [0, 1] +- **隐藏层**: 128 神经元,ReLU激活函数 +- **输出层**: 10 神经元(数字0-9),Softmax输出概率 + +## 文件结构 + +``` +digit_mlp_class/ +├── main.py # 主程序(训练/评估/对比实验) +├── model_numpy.py # MLP模型(纯NumPy实现) +├── dataset.py # MNIST数据集加载 +├── config.py # 超参数配置 +├── data/ # MNIST数据文件 +│ ├── train-images-idx3-ubyte.gz +│ ├── train-labels-idx1-ubyte.gz +│ ├── t10k-images-idx3-ubyte.gz +│ └── t10k-labels-idx1-ubyte.gz +└── README.md +``` + +## 依赖 + +``` +numpy +``` + +## 使用方法 + +### 1. 下载MNIST数据集 + +如果 `data/` 目录下没有数据文件,运行: + +```bash +python dataset.py +``` + +或手动下载: + +```bash +cd data/ +curl -LO https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz +curl -LO https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz +curl -LO https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz +curl -LO https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz +``` + +### 2. 训练模型 + +```bash +python main.py +``` + +### 3. 运行对比实验 + +```bash +python main.py --compare +``` + +## 代码设计 + +### model_numpy.py - MLP模型 + +核心实现: +- **前向传播**: 矩阵乘法 + ReLU + Softmax +- **反向传播**: 手动梯度计算 + 梯度下降 +- **权重初始化**: Xavier初始化(适合ReLU) + +```python +class MLP: + def __init__(self, input_size=784, hidden_size=128, num_classes=10) + def forward(self, X): # 前向传播 + def backward(self, X, y): # 反向传播 + def fit(self, X, y): # 训练 + def predict(self, X): # 预测 +``` + +### dataset.py - 数据加载 + +- 自动检测 `data/` 目录下的MNIST文件 +- 解析IDX格式(MNIST标准格式) +- 归一化像素值到 [0, 1] +- 支持One-Hot编码标签 + +### main.py - 主程序 + +两种运行模式: +1. **默认模式**: 训练一个模型并评估 +2. **对比模式** (`--compare`): 对比不同超参数的效果 + +## 数学原理 + +### 前向传播 + +``` +z1 = X @ W1 + b1 # 第一层线性变换 +a1 = ReLU(z1) # 第一层激活 + +z2 = a1 @ W2 + b2 # 第二层线性变换 +probs = softmax(z2) # 输出概率 +``` + +### 反向传播 + +``` +d_z2 = probs - y # 输出层梯度 +d_W2 = a1.T @ d_z2 # 第二层权重梯度 +d_z1 = d_z2 @ W2.T * relu_derivative(z1) # 隐藏层梯度 +d_W1 = X.T @ d_z1 # 第一层权重梯度 + +W1 -= lr * d_W1 / batch_size # 梯度下降更新 +W2 -= lr * d_W2 / batch_size +``` + +### 激活函数 + +**ReLU**: +``` +ReLU(x) = max(0, x) +ReLU'(x) = 1 if x > 0 else 0 +``` + +**Softmax**: +``` +softmax(x_i) = exp(x_i) / sum(exp(x_j)) +``` + +## 超参数 + +| 参数 | 默认值 | 说明 | +|------|--------|------| +| hidden_size | 128 | 隐藏层神经元数量 | +| learning_rate | 0.1 | 学习率 | +| epochs | 50 | 训练轮数 | +| batch_size | 64 | 批大小 | +| seed | 42 | 随机种子 | + +## 预期结果 + +- 训练准确率: ~98% +- 测试准确率: ~95-97% + +训练时间: 约 5-10 分钟(取决于硬件) + +## 扩展实验 + +1. **改变隐藏层大小**: 32 / 64 / 128 / 256 +2. **改变学习率**: 0.01 / 0.1 / 0.5 +3. **添加Dropout**: 防止过拟合 +4. **增加隐藏层数**: 784 → 256 → 128 → 10 + +## 教学用途 + +本项目适合用于讲解: +- 神经网络基本结构 +- 前向传播与反向传播原理 +- 梯度下降优化 +- NumPy矩阵操作 +- MNIST数据集处理 \ No newline at end of file diff --git a/digit_mlp_class/__pycache__/config.cpython-313.pyc b/digit_mlp_class/__pycache__/config.cpython-313.pyc new file mode 100644 index 0000000..3917855 Binary files /dev/null and b/digit_mlp_class/__pycache__/config.cpython-313.pyc differ diff --git a/digit_mlp_class/__pycache__/dataset.cpython-313.pyc b/digit_mlp_class/__pycache__/dataset.cpython-313.pyc new file mode 100644 index 0000000..1072f7f Binary files /dev/null and b/digit_mlp_class/__pycache__/dataset.cpython-313.pyc differ diff --git a/digit_mlp_class/__pycache__/model_numpy.cpython-313.pyc b/digit_mlp_class/__pycache__/model_numpy.cpython-313.pyc new file mode 100644 index 0000000..5bd5ce2 Binary files /dev/null and b/digit_mlp_class/__pycache__/model_numpy.cpython-313.pyc differ diff --git a/digit_mlp_class/config.py b/digit_mlp_class/config.py new file mode 100644 index 0000000..ba4d6b6 --- /dev/null +++ b/digit_mlp_class/config.py @@ -0,0 +1,39 @@ +# -*- coding: utf-8 -*- +""" +手写数字识别 - 超参数配置 + +纯NumPy实现的两层全连接神经网络 +""" + +# ===== 数据参数 ===== +ONE_HOT = True # 标签是否使用One-Hot编码 + +# ===== 模型结构 ===== +INPUT_SIZE = 784 # 28x28 = 784 像素 +HIDDEN_SIZE = 128 # 隐藏层神经元数量 +NUM_CLASSES = 10 # 0-9 十个数字 +KEEP_PROB = 1.0 # Dropout保留比例(1.0=不使用Dropout) + +# ===== 训练参数 ===== +LEARNING_RATE = 0.005 # 学习率 +NUM_EPOCHS = 120 # 训练轮数 +BATCH_SIZE = 64 # 批大小 + +# ===== 随机种子(保证可复现) ===== +SEED = 42 + +# ===== 实验配置 ===== +RUN_COMPARISON = False # 是否运行对比实验 + +# ===== 依赖说明 ===== +# 本项目需要以下库: +# numpy - 数值计算 +# scikit-learn - 加载MNIST数据集(会自动下载) +# pandas - sklearn的依赖 +# +# 安装命令: +# pip install numpy scikit-learn pandas +# +# 数据说明: +# 首次运行时会自动从OpenML下载MNIST数据集(约12MB) +# 下载后会自动缓存,后续运行直接使用缓存数据 \ No newline at end of file diff --git a/digit_mlp_class/data/t10k-images-idx3-ubyte.zip b/digit_mlp_class/data/t10k-images-idx3-ubyte.zip new file mode 100644 index 0000000..6681e7a Binary files /dev/null and b/digit_mlp_class/data/t10k-images-idx3-ubyte.zip differ diff --git a/digit_mlp_class/data/t10k-labels-idx1-ubyte.zip b/digit_mlp_class/data/t10k-labels-idx1-ubyte.zip new file mode 100644 index 0000000..b893497 Binary files /dev/null and b/digit_mlp_class/data/t10k-labels-idx1-ubyte.zip differ diff --git a/digit_mlp_class/data/train-images-idx3-ubyte.zip b/digit_mlp_class/data/train-images-idx3-ubyte.zip new file mode 100644 index 0000000..c904af1 Binary files /dev/null and b/digit_mlp_class/data/train-images-idx3-ubyte.zip differ diff --git a/digit_mlp_class/data/train-labels-idx1-ubyte.zip b/digit_mlp_class/data/train-labels-idx1-ubyte.zip new file mode 100644 index 0000000..71e1d3f Binary files /dev/null and b/digit_mlp_class/data/train-labels-idx1-ubyte.zip differ diff --git a/digit_mlp_class/dataset.py b/digit_mlp_class/dataset.py new file mode 100644 index 0000000..2d0867e --- /dev/null +++ b/digit_mlp_class/dataset.py @@ -0,0 +1,179 @@ +# -*- coding: utf-8 -*- +""" +数据集模块 - MNIST手写数字数据集加载 + +优先从本地data/目录加载,如果文件不存在则从sklearn下载 +支持两种格式:.gz(官方格式)和 .zip(某些下载源) +""" + +import os +import struct +import gzip +import zipfile +import numpy as np +from config import * + + +def local_files_exist(): + """检查本地数据文件是否存在且完整""" + data_dir = os.path.join(os.path.dirname(__file__), 'data') + + # 支持 .gz 和 .zip 格式(MNIST官方用.gz,但有些下载是zip) + files = { + 'train-images-idx3-ubyte': {'gz': 9912422, 'zip': 9187390}, + 'train-labels-idx1-ubyte': {'gz': 28881, 'zip': 28405}, + 't10k-images-idx3-ubyte': {'gz': 1648877, 'zip': 1534055}, + 't10k-labels-idx1-ubyte': {'gz': 5148, 'zip': 4563}, + } + + found_files = {} + missing = [] + + for base_name, sizes in files.items(): + gz_path = os.path.join(data_dir, base_name + '.gz') + zip_path = os.path.join(data_dir, base_name + '.zip') + + if os.path.exists(gz_path): + found_files[base_name] = (gz_path, sizes['gz'], 'gz') + elif os.path.exists(zip_path): + found_files[base_name] = (zip_path, sizes['zip'], 'zip') + else: + missing.append(base_name) + + if missing: + return False, f"文件不存在: {', '.join(missing)}" + + # 检查大小是否正确 + for base_name, (filepath, expected_size, fmt) in found_files.items(): + actual_size = os.path.getsize(filepath) + if actual_size != expected_size: + return False, f"文件大小错误: {base_name} (期望{expected_size}, 实际{actual_size})" + + return True, "所有文件完整" + + +def parse_idx_images(filepath): + """解析IDX格式图像(支持.gz和.zip)""" + if filepath.endswith('.zip'): + with zipfile.ZipFile(filepath, 'r') as zf: + # zip内的文件名没有.gz后缀 + inner_name = zf.namelist()[0] + with zf.open(inner_name) as f: + magic, num, rows, cols = struct.unpack('>IIII', f.read(16)) + images = np.frombuffer(f.read(), dtype=np.uint8) + images = images.reshape(num, rows * cols) + return images + else: + with gzip.open(filepath, 'rb') as f: + magic, num, rows, cols = struct.unpack('>IIII', f.read(16)) + images = np.frombuffer(f.read(), dtype=np.uint8) + images = images.reshape(num, rows * cols) + return images + + +def parse_idx_labels(filepath): + """解析IDX格式标签(支持.gz和.zip)""" + if filepath.endswith('.zip'): + with zipfile.ZipFile(filepath, 'r') as zf: + # zip内的文件名没有.gz后缀 + inner_name = zf.namelist()[0] + with zf.open(inner_name) as f: + magic, num = struct.unpack('>II', f.read(8)) + labels = np.frombuffer(f.read(), dtype=np.uint8) + return labels + else: + with gzip.open(filepath, 'rb') as f: + magic, num = struct.unpack('>II', f.read(8)) + labels = np.frombuffer(f.read(), dtype=np.uint8) + return labels + + +def load_data_from_local(): + """从本地文件加载MNIST(自动检测.gz或.zip格式)""" + data_dir = os.path.join(os.path.dirname(__file__), 'data') + + def find_file(base_name): + """自动找文件,支持.gz和.zip""" + gz_path = os.path.join(data_dir, base_name + '.gz') + zip_path = os.path.join(data_dir, base_name + '.zip') + if os.path.exists(gz_path): + return gz_path + elif os.path.exists(zip_path): + return zip_path + else: + raise FileNotFoundError(f"找不到 {base_name} 的 .gz 或 .zip 文件") + + X_train = parse_idx_images(find_file('train-images-idx3-ubyte')) + y_train = parse_idx_labels(find_file('train-labels-idx1-ubyte')) + X_test = parse_idx_images(find_file('t10k-images-idx3-ubyte')) + y_test = parse_idx_labels(find_file('t10k-labels-idx1-ubyte')) + + return X_train, y_train, X_test, y_test + + +def load_data_from_sklearn(): + """从sklearn加载MNIST(备选方案)""" + from sklearn.datasets import fetch_openml + + print(" 正在从OpenML下载数据(首次可能需要1-2分钟)...") + + mnist = fetch_openml('mnist_784', version=1, as_frame=False, parser='auto') + X = mnist.data.astype(np.float32) + y = mnist.target.astype(int) + + X_train = X[:60000] / 255.0 + X_test = X[60000:] / 255.0 + y_train = y[:60000] + y_test = y[60000:] + + return X_train, y_train, X_test, y_test + + +def one_hot_encode(y, num_classes=10): + one_hot = np.zeros((len(y), num_classes)) + one_hot[np.arange(len(y)), y] = 1 + return one_hot + + +def load_data(): + """ + 加载MNIST数据集 + + 优先从本地data/目录加载,如果文件不完整则从sklearn下载 + """ + print("\n" + "=" * 50) + print("MNIST 数据集加载") + print("=" * 50) + + # 优先检查本地文件 + exists, msg = local_files_exist() + if exists: + print(f"\n ✓ 发现本地数据文件: {msg}") + X_train, y_train, X_test, y_test = load_data_from_local() + else: + print(f"\n 本地文件: {msg}") + print(" 尝试从sklearn下载...") + try: + X_train, y_train, X_test, y_test = load_data_from_sklearn() + except Exception as e: + print(f"\n 下载失败: {e}") + print("\n 请确保 data/ 目录下有完整的4个数据文件!") + raise + + # 归一化和One-Hot + X_train = X_train.astype(np.float32) / 255.0 + X_test = X_test.astype(np.float32) / 255.0 + y_train = one_hot_encode(y_train, NUM_CLASSES) + y_test = one_hot_encode(y_test, NUM_CLASSES) + + print(f"\n ✓ 完成!") + print(f" 训练集: {X_train.shape[0]} 样本") + print(f" 测试集: {X_test.shape[0]} 样本") + print(f" 数值范围: [{X_train.min():.2f}, {X_train.max():.2f}]") + + return X_train, y_train, X_test, y_test + + +if __name__ == '__main__': + X_train, y_train, X_test, y_test = load_data() + print(f"\n训练数据: {X_train.shape}") \ No newline at end of file diff --git a/digit_mlp_class/main.py b/digit_mlp_class/main.py new file mode 100644 index 0000000..9d0fecf --- /dev/null +++ b/digit_mlp_class/main.py @@ -0,0 +1,191 @@ +# -*- coding: utf-8 -*- +""" +主程序 - 手写数字识别 MLP 纯NumPy实现 + +使用方法: + python main.py # 运行默认配置 + python main.py --compare # 运行对比实验 + +依赖: + pip install numpy requests +""" + +import numpy as np +import time +from datetime import datetime +from model_numpy import MLP +from dataset import load_data +from config import * + + +def train_and_evaluate(): + """ + 训练并评估模型 + """ + print("=" * 60) + print("手写数字识别 - 纯NumPy MLP实现") + print("=" * 60) + + # ===== 加载数据 ===== + try: + X_train, y_train, X_test, y_test = load_data() + except Exception as e: + print(f"\n错误: {e}") + print("\n请手动下载数据文件:") + print(" 1. 创建 data/ 目录") + print(" 2. 下载以下文件到 data/:") + print(" - train-images-idx3-ubyte.gz (9.9 MB)") + print(" - train-labels-idx1-ubyte.gz (28 KB)") + print(" - t10k-images-idx3-ubyte.gz (1.6 MB)") + print(" - t10k-labels-idx1-ubyte.gz (5 KB)") + print(" 下载地址: https://storage.googleapis.com/tensorflow/tf-keras-datasets/") + return None, None, None + + # ===== 创建模型 ===== + print("\n[2] 创建MLP模型...") + model = MLP( + input_size=INPUT_SIZE, + hidden_size=HIDDEN_SIZE, + num_classes=NUM_CLASSES, + learning_rate=LEARNING_RATE, + seed=SEED + ) + + # ===== 训练模型 ===== + print("\n[3] 开始训练...") + start_time = time.time() + + model.fit( + X_train, y_train, + X_val=X_test, y_val=y_test, + epochs=NUM_EPOCHS, + batch_size=BATCH_SIZE, + verbose=True + ) + + train_time = time.time() - start_time + + # ===== 最终评估 ===== + print("\n" + "=" * 60) + print("训练完成!") + print("=" * 60) + + train_acc = model.accuracy(X_train, y_train) + test_acc = model.accuracy(X_test, y_test) + + print(f"\n最终结果:") + print(f" 训练准确率: {train_acc:.4f} ({train_acc*100:.2f}%)") + print(f" 测试准确率: {test_acc:.4f} ({test_acc*100:.2f}%)") + print(f" 训练时间: {train_time:.2f} 秒") + + # ===== 保存模型 ===== + timestamp = datetime.now().strftime("%m%d_%H%M%S") + model_path = f"mnist_mlp_{timestamp}" + model.save(model_path) + + # ===== 预测示例 ===== + print("\n[4] 预测示例:") + indices = np.random.choice(len(X_test), 5, replace=False) + + for i, idx in enumerate(indices): + img = X_test[idx] + true_label = np.argmax(y_test[idx]) + pred_label = model.predict(img.reshape(1, -1))[0] + prob = model.predict_proba(img.reshape(1, -1))[0] + + status = '✓' if true_label == pred_label else '✗' + print(f" 样本{i+1}: 真实={true_label}, 预测={pred_label}, " + f"置信度={prob[pred_label]:.2f} {status}") + + return model, train_acc, test_acc + + +def run_comparison(): + """ + 运行对比实验 + """ + print("\n" + "=" * 60) + print("超参数对比实验") + print("=" * 60) + + # 加载数据 + try: + X_train, y_train, X_test, y_test = load_data() + except Exception as e: + print(f"加载数据失败: {e}") + return + + # 实验配置 + experiments = [ + {"hidden_size": 32, "lr": 0.1, "name": "小模型(32神经元)"}, + {"hidden_size": 128, "lr": 0.1, "name": "标准(128神经元)"}, + {"hidden_size": 256, "lr": 0.1, "name": "大模型(256神经元)"}, + {"hidden_size": 128, "lr": 0.01, "name": "小学习率(0.01)"}, + {"hidden_size": 128, "lr": 0.5, "name": "大学习率(0.5)"}, + ] + + results = [] + + for exp in experiments: + print(f"\n实验: {exp['name']}") + print("-" * 40) + + model = MLP( + input_size=INPUT_SIZE, + hidden_size=exp['hidden_size'], + num_classes=NUM_CLASSES, + learning_rate=exp['lr'], + seed=SEED + ) + + start_time = time.time() + model.fit(X_train, y_train, epochs=30, batch_size=BATCH_SIZE, verbose=False) + train_time = time.time() - start_time + + train_acc = model.accuracy(X_train, y_train) + test_acc = model.accuracy(X_test, y_test) + + results.append({ + 'name': exp['name'], + 'hidden_size': exp['hidden_size'], + 'lr': exp['lr'], + 'train_acc': train_acc, + 'test_acc': test_acc, + 'train_time': train_time + }) + + print(f" 训练准确率: {train_acc:.4f} | 测试准确率: {test_acc:.4f} | 时间: {train_time:.1f}s") + + # 汇总 + print("\n" + "=" * 60) + print("实验结果汇总") + print("=" * 60) + print(f"\n{'配置':<25} {'训练准确率':<12} {'测试准确率':<12} {'时间':<8}") + print("-" * 60) + + for r in results: + print(f"{r['name']:<25} {r['train_acc']:<12.4f} {r['test_acc']:<12.4f} {r['train_time']:<8.1f}s") + + best = max(results, key=lambda x: x['test_acc']) + print(f"\n最佳配置: {best['name']}, 测试准确率: {best['test_acc']:.4f}") + + +def main(): + """主函数""" + if RUN_COMPARISON: + run_comparison() + else: + train_and_evaluate() + + print("\n" + "=" * 60) + print("程序结束!") + print("=" * 60) + + +if __name__ == '__main__': + import sys + + if '--compare' in sys.argv: + RUN_COMPARISON = True + + main() \ No newline at end of file diff --git a/digit_mlp_class/mnist_mlp_0518_192248_W1.npy b/digit_mlp_class/mnist_mlp_0518_192248_W1.npy new file mode 100644 index 0000000..d232334 Binary files /dev/null and b/digit_mlp_class/mnist_mlp_0518_192248_W1.npy differ diff --git a/digit_mlp_class/mnist_mlp_0518_192248_W2.npy b/digit_mlp_class/mnist_mlp_0518_192248_W2.npy new file mode 100644 index 0000000..6e2e479 Binary files /dev/null and b/digit_mlp_class/mnist_mlp_0518_192248_W2.npy differ diff --git a/digit_mlp_class/mnist_mlp_0518_192248_b1.npy b/digit_mlp_class/mnist_mlp_0518_192248_b1.npy new file mode 100644 index 0000000..55bb987 Binary files /dev/null and b/digit_mlp_class/mnist_mlp_0518_192248_b1.npy differ diff --git a/digit_mlp_class/mnist_mlp_0518_192248_b2.npy b/digit_mlp_class/mnist_mlp_0518_192248_b2.npy new file mode 100644 index 0000000..6a22c8b Binary files /dev/null and b/digit_mlp_class/mnist_mlp_0518_192248_b2.npy differ diff --git a/digit_mlp_class/mnist_mlp_0521_135935_W1.npy b/digit_mlp_class/mnist_mlp_0521_135935_W1.npy new file mode 100644 index 0000000..b5c49c0 Binary files /dev/null and b/digit_mlp_class/mnist_mlp_0521_135935_W1.npy differ diff --git a/digit_mlp_class/mnist_mlp_0521_135935_W2.npy b/digit_mlp_class/mnist_mlp_0521_135935_W2.npy new file mode 100644 index 0000000..1bb169d Binary files /dev/null and b/digit_mlp_class/mnist_mlp_0521_135935_W2.npy differ diff --git a/digit_mlp_class/mnist_mlp_0521_135935_b1.npy b/digit_mlp_class/mnist_mlp_0521_135935_b1.npy new file mode 100644 index 0000000..242ae3b Binary files /dev/null and b/digit_mlp_class/mnist_mlp_0521_135935_b1.npy differ diff --git a/digit_mlp_class/mnist_mlp_0521_135935_b2.npy b/digit_mlp_class/mnist_mlp_0521_135935_b2.npy new file mode 100644 index 0000000..b5e9149 Binary files /dev/null and b/digit_mlp_class/mnist_mlp_0521_135935_b2.npy differ diff --git a/digit_mlp_class/mnist_mlp_0521_141704_W1.npy b/digit_mlp_class/mnist_mlp_0521_141704_W1.npy new file mode 100644 index 0000000..733c5d7 Binary files /dev/null and b/digit_mlp_class/mnist_mlp_0521_141704_W1.npy differ diff --git a/digit_mlp_class/mnist_mlp_0521_141704_W2.npy b/digit_mlp_class/mnist_mlp_0521_141704_W2.npy new file mode 100644 index 0000000..babbafc Binary files /dev/null and b/digit_mlp_class/mnist_mlp_0521_141704_W2.npy differ diff --git a/digit_mlp_class/mnist_mlp_0521_141704_b1.npy b/digit_mlp_class/mnist_mlp_0521_141704_b1.npy new file mode 100644 index 0000000..ad8ed05 Binary files /dev/null and b/digit_mlp_class/mnist_mlp_0521_141704_b1.npy differ diff --git a/digit_mlp_class/mnist_mlp_0521_141704_b2.npy b/digit_mlp_class/mnist_mlp_0521_141704_b2.npy new file mode 100644 index 0000000..4c931aa Binary files /dev/null and b/digit_mlp_class/mnist_mlp_0521_141704_b2.npy differ diff --git a/digit_mlp_class/mnist_mlp_0521_142058_W1.npy b/digit_mlp_class/mnist_mlp_0521_142058_W1.npy new file mode 100644 index 0000000..fdfb2d9 Binary files /dev/null and b/digit_mlp_class/mnist_mlp_0521_142058_W1.npy differ diff --git a/digit_mlp_class/mnist_mlp_0521_142058_W2.npy b/digit_mlp_class/mnist_mlp_0521_142058_W2.npy new file mode 100644 index 0000000..4e7f950 Binary files /dev/null and b/digit_mlp_class/mnist_mlp_0521_142058_W2.npy differ diff --git a/digit_mlp_class/mnist_mlp_0521_142058_b1.npy b/digit_mlp_class/mnist_mlp_0521_142058_b1.npy new file mode 100644 index 0000000..1dd6bf4 Binary files /dev/null and b/digit_mlp_class/mnist_mlp_0521_142058_b1.npy differ diff --git a/digit_mlp_class/mnist_mlp_0521_142058_b2.npy b/digit_mlp_class/mnist_mlp_0521_142058_b2.npy new file mode 100644 index 0000000..1c4bc91 Binary files /dev/null and b/digit_mlp_class/mnist_mlp_0521_142058_b2.npy differ diff --git a/digit_mlp_class/mnist_mlp_0521_142555_W1.npy b/digit_mlp_class/mnist_mlp_0521_142555_W1.npy new file mode 100644 index 0000000..e101231 Binary files /dev/null and b/digit_mlp_class/mnist_mlp_0521_142555_W1.npy differ diff --git a/digit_mlp_class/mnist_mlp_0521_142555_W2.npy b/digit_mlp_class/mnist_mlp_0521_142555_W2.npy new file mode 100644 index 0000000..931a767 Binary files /dev/null and b/digit_mlp_class/mnist_mlp_0521_142555_W2.npy differ diff --git a/digit_mlp_class/mnist_mlp_0521_142555_b1.npy b/digit_mlp_class/mnist_mlp_0521_142555_b1.npy new file mode 100644 index 0000000..69bbe48 Binary files /dev/null and b/digit_mlp_class/mnist_mlp_0521_142555_b1.npy differ diff --git a/digit_mlp_class/mnist_mlp_0521_142555_b2.npy b/digit_mlp_class/mnist_mlp_0521_142555_b2.npy new file mode 100644 index 0000000..6d09e00 Binary files /dev/null and b/digit_mlp_class/mnist_mlp_0521_142555_b2.npy differ diff --git a/digit_mlp_class/model_numpy.py b/digit_mlp_class/model_numpy.py new file mode 100644 index 0000000..df78f06 --- /dev/null +++ b/digit_mlp_class/model_numpy.py @@ -0,0 +1,305 @@ +# -*- coding: utf-8 -*- +""" +模型模块 - 纯NumPy实现手写数字识别MLP + +网络结构: 784 → 128 → 10 +- 输入层: 784 像素值 (28x28 展平) +- 隐藏层: 128 神经元 + ReLU激活 +- 输出层: 10 数字 (0-9) + Softmax + +纯NumPy实现,无任何深度学习框架依赖 +只需: numpy +""" + +import numpy as np + + +class MLP: + """ + 多层感知机(神经网络) + + 结构: + 输入(784) → 线性变换 → ReLU → 线性变换 → Softmax → 输出(10) + + 参数量: + W1: 784 × 128 = 100,352 + b1: 128 + W2: 128 × 10 = 1,280 + b2: 10 + 总计: ~101,770 参数 + """ + + def __init__(self, input_size=784, hidden_size=128, num_classes=10, + learning_rate=0.1, seed=42): + np.random.seed(seed) + + # ===== 第一层: 输入 → 隐藏层 ===== + # 权重: (input_size, hidden_size) + # Xavier初始化,适合ReLU + self.W1 = np.random.randn(input_size, hidden_size) * np.sqrt(2.0 / input_size) + self.b1 = np.zeros(hidden_size) + + # ===== 第二层: 隐藏层 → 输出 ===== + # 权重: (hidden_size, num_classes) + self.W2 = np.random.randn(hidden_size, num_classes) * np.sqrt(2.0 / hidden_size) + self.b2 = np.zeros(num_classes) + + # 保存超参数 + self.lr = learning_rate + self.input_size = input_size + self.hidden_size = hidden_size + self.num_classes = num_classes + + # 打印模型信息 + total_params = (input_size * hidden_size + hidden_size + + hidden_size * num_classes + num_classes) + print(f"\n{'='*50}") + print(f"MLP 网络结构:") + print(f" 输入层: {input_size} 神经元") + print(f" 隐藏层: {hidden_size} 神经元 + ReLU") + print(f" 输出层: {num_classes} 神经元 + Softmax") + print(f" 参数量: {total_params:,}") + print(f"{'='*50}") + + + def relu(self, x): + """ReLU激活函数: max(0, x)""" + return np.maximum(0, x) + + + def relu_derivative(self, x): + """ReLU导数: x > 0 时为1,否则为0""" + return (x > 0).astype(float) + + + def softmax(self, x): + """ + Softmax函数: 将数值转换为概率分布 + + softmax(x_i) = exp(x_i) / sum(exp(x_j)) + + 技巧: 减去最大值避免数值溢出 + """ + x_shifted = x - np.max(x, axis=1, keepdims=True) + exp_x = np.exp(x_shifted) + return exp_x / np.sum(exp_x, axis=1, keepdims=True) + + + def forward(self, X): + """ + 前向传播 + + Args: + X: (batch_size, 784) 图像像素值 + + Returns: + probs: (batch_size, 10) 每个类的概率 + """ + # ===== 第一层计算 ===== + # z1 = X @ W1 + b1 + # a1 = relu(z1) + self.z1 = X @ self.W1 + self.b1 # (batch, 784) @ (784, 128) = (batch, 128) + self.a1 = self.relu(self.z1) # (batch, 128) + + # ===== 第二层计算 ===== + # z2 = a1 @ W2 + b2 + # probs = softmax(z2) + self.z2 = self.a1 @ self.W2 + self.b2 # (batch, 128) @ (128, 10) = (batch, 10) + self.probs = self.softmax(self.z2) # (batch, 10) + + return self.probs + + + def backward(self, X, y): + """ + 反向传播(梯度下降) + + Args: + X: (batch_size, 784) 图像 + y: (batch_size, 10) One-Hot标签 + """ + batch_size = X.shape[0] + + # ===== 输出层梯度 ===== + # Softmax + 交叉熵的梯度简化为: p - y + d_z2 = self.probs - y # (batch, 10) + + # ===== 第二层梯度 ===== + d_W2 = self.a1.T @ d_z2 # (128, 10) + d_b2 = np.sum(d_z2, axis=0) # (10,) + + # ===== 隐藏层梯度 ===== + d_a1 = d_z2 @ self.W2.T # (batch, 128) + d_z1 = d_a1 * self.relu_derivative(self.z1) # (batch, 128) + + # ===== 第一层梯度 ===== + d_W1 = X.T @ d_z1 # (784, 128) + d_b1 = np.sum(d_z1, axis=0) # (128,) + + # ===== 梯度裁剪(防止梯度爆炸) ===== + max_grad = 1.0 + d_W1 = np.clip(d_W1, -max_grad, max_grad) + d_W2 = np.clip(d_W2, -max_grad, max_grad) + d_b1 = np.clip(d_b1, -max_grad, max_grad) + d_b2 = np.clip(d_b2, -max_grad, max_grad) + + # ===== 更新权重(梯度下降) ===== + self.W1 -= self.lr * d_W1 / batch_size + self.b1 -= self.lr * d_b1 / batch_size + self.W2 -= self.lr * d_W2 / batch_size + self.b2 -= self.lr * d_b2 / batch_size + + + def cross_entropy_loss(self, probs, y): + """ + 交叉熵损失 + + L = -sum(y * log(p)) / N + """ + # 取真实类别的概率 + correct_probs = probs[np.arange(len(y)), y.argmax(axis=1)] + # 避免log(0) + loss = -np.mean(np.log(np.clip(correct_probs, 1e-10, 1.0))) + return loss + + + def fit(self, X_train, y_train, X_val=None, y_val=None, + epochs=50, batch_size=64, verbose=True): + """ + 训练模型 + + Args: + X_train: 训练数据 (N, 784) + y_train: 训练标签 (N, 10) One-Hot + X_val: 验证数据(可选) + y_val: 验证标签(可选) + epochs: 训练轮数 + batch_size: 批大小 + verbose: 是否打印进度 + """ + N = len(X_train) + num_batches = (N + batch_size - 1) // batch_size + + for epoch in range(epochs): + # ===== 打乱数据 ===== + indices = np.random.permutation(N) + X_shuffled = X_train[indices] + y_shuffled = y_train[indices] + + epoch_loss = 0 + + # ===== 批训练 ===== + for batch_idx in range(num_batches): + start = batch_idx * batch_size + end = min(start + batch_size, N) + X_batch = X_shuffled[start:end] + y_batch = y_shuffled[start:end] + + # 前向传播 + probs = self.forward(X_batch) + + # 反向传播 + self.backward(X_batch, y_batch) + + # 计算损失 + loss = self.cross_entropy_loss(probs, y_batch) + epoch_loss += loss + + # ===== 打印进度 ===== + if verbose and (epoch + 1) % 5 == 0: + train_acc = self.accuracy(X_train, y_train) + msg = f"Epoch {epoch+1:3d}/{epochs} | Loss: {epoch_loss/num_batches:.4f} | 训练准确率: {train_acc:.4f}" + + if X_val is not None: + val_acc = self.accuracy(X_val, y_val) + msg += f" | 测试准确率: {val_acc:.4f}" + + print(msg) + + return self + + + def predict(self, X): + """ + 预测类别 + + Args: + X: (N, 784) 图像 + + Returns: + predictions: (N,) 预测的类别标签 (0-9) + """ + probs = self.forward(X) + return np.argmax(probs, axis=1) + + + def predict_proba(self, X): + """ + 预测概率 + + Returns: + probs: (N, 10) 每个类的概率 + """ + return self.forward(X) + + + def accuracy(self, X, y): + """ + 计算准确率 + + Args: + X: (N, 784) 图像 + y: (N,) 或 (N, 10) 标签 + """ + if len(y.shape) > 1: + y = np.argmax(y, axis=1) + predictions = self.predict(X) + return np.mean(predictions == y) + + + def save(self, filepath): + """保存模型权重""" + np.save(filepath + '_W1.npy', self.W1) + np.save(filepath + '_b1.npy', self.b1) + np.save(filepath + '_W2.npy', self.W2) + np.save(filepath + '_b2.npy', self.b2) + print(f"\n模型已保存: {filepath}") + + + @staticmethod + def load(filepath, input_size=784, hidden_size=128, num_classes=10, learning_rate=0.1): + """加载模型权重""" + model = MLP(input_size, hidden_size, num_classes, learning_rate) + model.W1 = np.load(filepath + '_W1.npy') + model.b1 = np.load(filepath + '_b1.npy') + model.W2 = np.load(filepath + '_W2.npy') + model.b2 = np.load(filepath + '_b2.npy') + print(f"\n模型已加载: {filepath}") + return model + + +# ===== 测试代码 ===== +if __name__ == '__main__': + # 简单测试 + print("测试MLP模型...") + + model = MLP(input_size=784, hidden_size=128, num_classes=10, learning_rate=0.1) + + # 模拟数据 + X_test = np.random.randn(32, 784) + y_test = np.zeros((32, 10)) + for i in range(32): + y_test[i, i % 10] = 1 + + # 前向传播测试 + probs = model.forward(X_test) + print(f"输出概率形状: {probs.shape}") + print(f"概率和: {probs[0].sum():.4f} (应该接近1)") + + # 反向传播测试 + model.backward(X_test, y_test) + print("反向传播测试通过!") + + # 预测测试 + preds = model.predict(X_test) + print(f"预测结果: {preds}") \ No newline at end of file diff --git a/digit_mlp_class/test_digit_5.png b/digit_mlp_class/test_digit_5.png new file mode 100644 index 0000000..2ff5b37 Binary files /dev/null and b/digit_mlp_class/test_digit_5.png differ diff --git a/digit_mlp_class/test_image.py b/digit_mlp_class/test_image.py new file mode 100644 index 0000000..89235b8 --- /dev/null +++ b/digit_mlp_class/test_image.py @@ -0,0 +1,231 @@ +# -*- coding: utf-8 -*- +""" +测试脚本 - 用训练好的模型识别学生手写数字图片 + +使用方法: + python test_image.py path/to/image.png + python test_image.py path/to/image.jpg + python test_image.py path/to/folder/ # 识别文件夹内所有图片 + +依赖: + pip install numpy pillow + +图片要求: + - 建议尺寸:28x28 或更大(程序会自动缩放) + - 背景最好为白色,数字为黑色 + - 手写数字清晰、无遮挡 +""" + +import sys +import os +import numpy as np +from PIL import Image + +# 尝试导入模型 +try: + from model_numpy import MLP +except ImportError: + print("错误:请在 digit_mlp_class 目录下运行此脚本") + print(" cd digit_mlp_class") + print(" python test_image.py your_image.png") + sys.exit(1) + + +def find_latest_model(): + """查找最新的模型文件""" + model_files = [f for f in os.listdir('.') if f.startswith('mnist_mlp_') and f.endswith('.npy')] + if not model_files: + return None + + # 按时间戳分组 + timestamps = set() + for f in model_files: + # 文件格式: mnist_mlp_YYMMDD_HHMMSS_suffix.npy + # 去掉后缀 _W1.npy, _b1.npy 等 + base = f.rsplit('_', 1)[0] + timestamps.add(base) + + # 选择最新的 + latest = sorted(timestamps)[-1] + + # 检查完整模型 + required = ['W1', 'b1', 'W2', 'b2'] + for r in required: + if not os.path.exists(f'{latest}_{r}.npy'): + return None + + return latest + + +def load_model(): + """加载训练好的模型""" + # 查找最新模型 + model_path = find_latest_model() + + if model_path is None: + print("错误:未找到训练好的模型!") + print(" 请先运行: python main.py") + print(" 或确保当前目录有 mnist_mlp_*.npy 模型文件") + return None + + print(f"加载模型: {model_path}") + + model = MLP(input_size=784, hidden_size=128, num_classes=10, learning_rate=0.1) + + model.W1 = np.load(f'{model_path}_W1.npy') + model.b1 = np.load(f'{model_path}_b1.npy') + model.W2 = np.load(f'{model_path}_W2.npy') + model.b2 = np.load(f'{model_path}_b2.npy') + + print("模型加载成功!\n") + return model + + +def preprocess_image(image_path): + """ + 图片预处理:将任意图片转为 28x28 归一化向量 + + 处理流程: + 1. 打开图片并转为灰度 + 2. 调整大小为 28x28 + 3. 转为NumPy数组并归一化到 [0, 1] + 4. 展平为 784 维向量 + """ + try: + img = Image.open(image_path).convert('L') # 转灰度 + except Exception as e: + raise ValueError(f"无法打开图片: {e}") + + # 保存原始尺寸用于调试 + orig_width, orig_height = img.size + + # 缩放到 28x28 + img = img.resize((28, 28), Image.LANCZOS) + + # 转为NumPy数组并归一化 + img_array = np.array(img, dtype=np.float32) / 255.0 + + # MNIST是白底黑字(0=背景, 1=数字),如果图片是黑底白字需要反转 + # 检查是否是黑底白字:背景均值 > 0.5 + if img_array.mean() > 0.5: + img_array = 1.0 - img_array + + # 展平为向量 + img_vector = img_array.flatten() + + print(f" 图片: {os.path.basename(image_path)}") + print(f" 原始尺寸: {orig_width}x{orig_height}") + print(f" 处理后尺寸: 28x28") + + return img_vector + + +def predict_image(model, image_path): + """ + 识别单张图片 + + 返回: + predicted_digit: 预测的数字 (0-9) + confidence: 置信度 (0-1) + all_probs: 所有数字的概率 + """ + # 预处理 + img_vector = preprocess_image(image_path) + + # 预测 + probs = model.predict_proba(img_vector.reshape(1, -1))[0] + predicted_digit = np.argmax(probs) + confidence = probs[predicted_digit] + + return predicted_digit, confidence, probs + + +def print_results(digit, confidence, probs): + """打印识别结果""" + print(f" 预测结果: {digit}") + print(f" 置信度: {confidence:.2%}") + + # 显示各数字概率 + print(f"\n 各数字概率:") + print(f" ", end="") + for i in range(10): + bar_len = int(probs[i] * 20) + print(f" {i}:{'█' * bar_len}{'░' * (20-bar_len)} {probs[i]:.1%}") + + print() + + +def main(): + """主函数""" + print("\n" + "=" * 60) + print("手写数字识别 - 图片测试") + print("=" * 60 + "\n") + + # 加载模型 + model = load_model() + if model is None: + sys.exit(1) + + # 检查命令行参数 + if len(sys.argv) < 2: + print("使用方法:") + print(" python test_image.py path/to/image.png") + print(" python test_image.py path/to/image.jpg") + print(" python test_image.py path/to/folder/") + print() + print("示例:") + print(" python test_image.py my_digit.png") + print(" python test_image.py ./test_images/") + sys.exit(1) + + target = sys.argv[1] + + # 收集所有要处理的图片 + image_paths = [] + + if os.path.isdir(target): + # 文件夹:收集所有图片 + extensions = ['.png', '.jpg', '.jpeg', '.bmp', '.gif'] + for f in os.listdir(target): + if any(f.lower().endswith(ext) for ext in extensions): + image_paths.append(os.path.join(target, f)) + image_paths.sort() + elif os.path.isfile(target): + image_paths = [target] + else: + print(f"错误:文件或目录不存在: {target}") + sys.exit(1) + + if not image_paths: + print(f"错误:在 {target} 中未找到图片文件") + sys.exit(1) + + print(f"找到 {len(image_paths)} 张图片,开始识别...\n") + + # 批量识别 + results = [] + for path in image_paths: + try: + digit, confidence, probs = predict_image(model, path) + results.append((path, digit, confidence)) + print_results(digit, confidence, probs) + except Exception as e: + print(f" 识别失败: {e}\n") + + # 汇总结果 + print("=" * 60) + print(f"识别完成!共 {len(results)} 张图片") + print("=" * 60) + print(f"\n{'文件名':<30} {'预测':<6} {'置信度':<10}") + print("-" * 50) + + for path, digit, confidence in results: + filename = os.path.basename(path) + print(f"{filename:<30} {digit:<6} {confidence:.1%}") + + # 如果有真实标签(文件夹命名中包含),可以计算准确率 + print() + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/digit_mlp_class/visualize.py b/digit_mlp_class/visualize.py new file mode 100644 index 0000000..6701887 --- /dev/null +++ b/digit_mlp_class/visualize.py @@ -0,0 +1,350 @@ +# -*- coding: utf-8 -*- +""" +可视化工具 - 展示神经网络各层的输出 + +用于课堂教学,让学生直观理解: +1. 输入图像长什么样 +2. 第一层隐藏层学到了什么特征 +3. 各层激活值的变化 + +使用方法: + python visualize.py # 可视化测试集前5张 + python visualize.py --single # 可视化单张图片 +""" + +import numpy as np +from PIL import Image +import os +import sys +import matplotlib +matplotlib.use('Agg') # 无头模式,不显示图形 +import matplotlib.pyplot as plt + + +def visualize_input_image(img_vector, save_path='visualizations/input.png'): + """把784维向量还原成28x28图像并保存""" + img = img_vector.reshape(28, 28) * 255 + img = img.astype(np.uint8) + Image.fromarray(img).save(save_path) + return save_path + + +def visualize_activations(model, img_vector, save_dir='visualizations'): + """ + 可视化网络各层的激活值 + """ + os.makedirs(save_dir, exist_ok=True) + + # 前向传播获取各层激活值 + model.forward(img_vector.reshape(1, -1)) + + # 1. 保存输入图像 + visualize_input_image(img_vector, os.path.join(save_dir, '01_input.png')) + + # 2. 可视化第一层激活(隐藏层) + hidden_activations = model.a1[0] # (128,) + visualize_hidden_layer(hidden_activations, os.path.join(save_dir, '02_hidden.png')) + + # 3. 可视化输出层概率 + output_probs = model.probs[0] # (10,) + visualize_output_prob(output_probs, os.path.join(save_dir, '03_output_prob.png')) + + # 4. 生成汇总图 + create_summary_image(img_vector, hidden_activations, output_probs, save_dir) + + return save_dir + + +def visualize_hidden_layer(activations, save_path): + """ + 可视化隐藏层激活值 + 把128个神经元的激活值排成8x16网格显示 + """ + grid_cols = 16 + grid_rows = 8 + cell_size = 24 + + img_h = grid_rows * cell_size + img_w = grid_cols * cell_size + grid = np.ones((img_h, img_w)) * 255 + + for i, act in enumerate(activations): + row = i // grid_cols + col = i % grid_cols + intensity = max(0, min(1, act * 2)) + color = int(255 * (1 - intensity * 0.7)) + grid[row*cell_size:(row+1)*cell_size-1, col*cell_size:(col+1)*cell_size-1] = color + + Image.fromarray(grid.astype(np.uint8)).save(save_path) + + +def visualize_output_prob(probs, save_path): + """可视化输出层概率分布""" + fig, ax = plt.subplots(figsize=(8, 4)) + + digits = list(range(10)) + colors = ['#3498db' if i != np.argmax(probs) else '#e74c3c' for i in digits] + + bars = ax.bar(digits, probs, color=colors) + ax.set_xlabel('数字', fontsize=12) + ax.set_ylabel('概率', fontsize=12) + ax.set_title('输出层:各数字的预测概率', fontsize=14) + ax.set_xticks(digits) + ax.set_ylim(0, 1) + + max_idx = np.argmax(probs) + ax.annotate(f'{probs[max_idx]:.1%}', + xy=(max_idx, probs[max_idx]), + ha='center', va='bottom', fontsize=10, color='#e74c3c', fontweight='bold') + + plt.tight_layout() + plt.savefig(save_path, dpi=100, bbox_inches='tight') + plt.close() + + +def create_summary_image(img_vector, hidden_activations, output_probs, save_dir): + """创建汇总图""" + fig = plt.figure(figsize=(14, 6)) + + # 1. 输入图像 + ax1 = fig.add_subplot(2, 4, 1) + ax1.imshow(img_vector.reshape(28, 28), cmap='gray') + ax1.set_title('(1) Input Image\n(28x28 pixels)', fontsize=11) + ax1.axis('off') + + # 2. 像素值分布 + ax2 = fig.add_subplot(2, 4, 2) + ax2.hist(img_vector, bins=30, color='#3498db', alpha=0.7, edgecolor='white') + ax2.set_title('(2) Pixel Value Distribution\n(normalized 0~1)', fontsize=11) + ax2.set_xlabel('像素值') + ax2.set_ylabel('频数') + + # 3. 隐藏层激活(热力图) + ax3 = fig.add_subplot(2, 4, 3) + # 128 = 8 × 16 + act_2d = hidden_activations.reshape(8, 16) + im = ax3.imshow(act_2d, cmap='Blues', aspect='auto') + ax3.set_title(f'(3) Hidden Layer\n(128 neurons)', fontsize=11) + ax3.axis('off') + plt.colorbar(im, ax=ax3, shrink=0.6) + + # 4. 隐藏层激活(条形图) + ax4 = fig.add_subplot(2, 4, 4) + ax4.bar(range(len(hidden_activations)), hidden_activations, color='#3498db', alpha=0.7) + ax4.set_title('(4) Neuron Activations', fontsize=11) + ax4.set_xlabel('神经元编号') + ax4.set_ylabel('激活强度') + + # 5. 输出概率 + ax5 = fig.add_subplot(2, 4, 5) + digits = list(range(10)) + colors = ['#3498db' if i != np.argmax(output_probs) else '#e74c3c' for i in digits] + ax5.bar(digits, output_probs, color=colors) + ax5.set_title('(5) Output Probabilities', fontsize=11) + ax5.set_xlabel('数字') + ax5.set_ylabel('概率') + ax5.set_ylim(0, 1) + ax5.set_xticks(digits) + + # 6. 最大概率 + ax6 = fig.add_subplot(2, 4, 6) + ax6.axis('off') + predicted = np.argmax(output_probs) + confidence = output_probs[predicted] + result_text = f'预测: {predicted}\n置信度: {confidence:.1%}' + ax6.text(0.5, 0.5, result_text, fontsize=24, ha='center', va='center', + bbox=dict(boxstyle='round', facecolor='#2ecc71', alpha=0.9), + transform=ax6.transAxes, color='white', fontweight='bold') + ax6.set_title('(6) Recognition Result', fontsize=11) + + # 7. 网络结构 + ax7 = fig.add_subplot(2, 4, 7) + ax7.axis('off') + structure_text = ( + '┌─────────────────┐\n' + '│ 输入层 784 │\n' + '│ (28×28展平) │\n' + '└────────┬────────┘\n' + ' │\n' + ' 线性变换+ReLU\n' + ' │\n' + '┌────────┴────────┐\n' + '│ 隐藏层 128 │\n' + '│ (特征提取) │\n' + '└────────┬────────┘\n' + ' │\n' + ' 线性变换+Softmax\n' + ' │\n' + '┌────────┴────────┐\n' + '│ 输出层 10 │\n' + '│ (数字0~9概率) │\n' + '└─────────────────┘' + ) + ax7.text(0.1, 0.95, structure_text, fontsize=9, va='top', + family='monospace', transform=ax7.transAxes, + bbox=dict(boxstyle='round', facecolor='#f8f9fa', alpha=0.9)) + ax7.set_title('(7) Network Structure', fontsize=11) + + # 8. 参数量说明 + ax8 = fig.add_subplot(2, 4, 8) + ax8.axis('off') + params_text = ( + 'MLP 参数量计算:\n\n' + 'W1: 784 × 128 = 100,352\n' + 'b1: 128\n\n' + 'W2: 128 × 10 = 1,280\n' + 'b2: 10\n\n' + '─────────────────\n' + '总计: 101,770 参数\n\n' + '全部用 NumPy 实现\n' + '无需任何深度学习框架!' + ) + ax8.text(0.1, 0.95, params_text, fontsize=10, va='top', + family='monospace', transform=ax8.transAxes) + ax8.set_title('(8) Parameters', fontsize=11) + + plt.suptitle('MLP Feature Maps Visualization - Handwritten Digits', fontsize=16, fontweight='bold', y=1.02) + plt.tight_layout() + plt.savefig(os.path.join(save_dir, 'summary.png'), dpi=120, bbox_inches='tight') + plt.close() + + +def load_model_or_train(): + """加载已训练的模型,如果没有则训练一个""" + import glob + model_files = glob.glob('mnist_mlp_*.npy') + + if model_files: + timestamps = sorted(set( + f.replace('mnist_mlp_', '').replace('_W1.npy', '') + for f in model_files if '_W1.npy' in f + )) + if timestamps: + model_path = 'mnist_mlp_' + timestamps[-1] + print(f"加载模型: {model_path}") + + from model_numpy import MLP + model = MLP(input_size=784, hidden_size=128, num_classes=10, learning_rate=0.1) + model.W1 = np.load(f'{model_path}_W1.npy') + model.b1 = np.load(f'{model_path}_b1.npy') + model.W2 = np.load(f'{model_path}_W2.npy') + model.b2 = np.load(f'{model_path}_b2.npy') + return model + + # 没有模型,用sklearn快速训练一个用于演示 + print("未找到已训练模型,使用sklearn数据快速训练演示模型...") + from sklearn.datasets import fetch_openml + from sklearn.model_selection import train_test_split + + mnist = fetch_openml('mnist_784', version=1, as_frame=False, parser='auto') + X = mnist.data[:10000].astype(np.float32) / 255.0 + y = mnist.target[:10000].astype(int) + + # 用sklearn的MLP代替 + from sklearn.neural_network import MLPClassifier + model = MLPClassifier( + hidden_layer_sizes=(128,), + max_iter=20, + alpha=1e-4, + solver='sgd', + learning_rate_init=0.1, + random_state=42, + verbose=False + ) + model.fit(X, y) + print("sklearn模型训练完成(仅用于可视化演示)") + return model + + +def main(): + """主函数""" + os.makedirs('visualizations', exist_ok=True) + + # 加载数据 + from dataset import load_data + print("加载MNIST数据集...") + X_train, y_train, X_test, y_test = load_data() + + # 加载模型 + model = load_model_or_train() + + # 获取真实标签 + if len(y_test.shape) > 1: + y_test_labels = np.argmax(y_test, axis=1) + else: + y_test_labels = y_test + + # 可视化测试集前5张 + print("\n可视化测试集前5张图片...") + for i in range(5): + img = X_test[i] + true_label = y_test_labels[i] + + sub_dir = f'visualizations/sample_{i}_true{true_label}' + os.makedirs(sub_dir, exist_ok=True) + + if hasattr(model, 'predict_proba'): + probs = model.predict_proba(img.reshape(1, -1))[0] + predicted = np.argmax(probs) + else: + predicted = model.predict(img.reshape(1, -1))[0] + + print(f" 样本{i}: 真实={true_label}, 预测={predicted}") + visualize_activations(model, img, sub_dir) + + # 创建对比汇总 + create_comparison_summary(X_test[:5], y_test_labels[:5], model, 'visualizations') + + print("\n✅ 可视化完成!") + print(" 查看 visualizations/ 目录下的图片和汇总图") + + +def create_comparison_summary(X_samples, y_true, model, save_dir): + """创建多个样本的对比汇总图""" + n_samples = len(X_samples) + + fig = plt.figure(figsize=(4 * n_samples, 8)) + + for i in range(n_samples): + img = X_samples[i] + true_label = y_true[i] + + if hasattr(model, 'predict_proba'): + probs = model.predict_proba(img.reshape(1, -1))[0] + predicted = np.argmax(probs) + else: + predicted = model.predict(img.reshape(1, -1))[0] + + # 输入图像 + ax = fig.add_subplot(3, n_samples, i + 1) + ax.imshow(img.reshape(28, 28), cmap='gray') + ax.set_title(f'真实: {true_label}', fontsize=12) + ax.axis('off') + + # 隐藏层激活 + ax = fig.add_subplot(3, n_samples, i + 1 + n_samples) + model.forward(img.reshape(1, -1)) + # 128 = 8 × 16 + hidden = model.a1[0].reshape(8, 16) + ax.imshow(hidden, cmap='Blues', aspect='auto') + ax.set_title(f'隐藏层激活', fontsize=10) + ax.axis('off') + + # 输出概率 + ax = fig.add_subplot(3, n_samples, i + 1 + 2*n_samples) + digits = list(range(10)) + colors = ['#e74c3c' if d == predicted else '#3498db' for d in digits] + ax.bar(digits, probs, color=colors) + ax.set_title(f'预测: {predicted}', fontsize=12) + ax.set_ylim(0, 1) + ax.set_xticks(digits) + ax.tick_params(labelsize=8) + + plt.suptitle('多样本特征图对比', fontsize=16, fontweight='bold') + plt.tight_layout() + plt.savefig(os.path.join(save_dir, 'comparison.png'), dpi=120, bbox_inches='tight') + plt.close() + + +if __name__ == '__main__': + main() \ No newline at end of file