diff --git a/2509165016.png b/2509165016.png new file mode 100644 index 0000000..09ea932 Binary files /dev/null and b/2509165016.png differ diff --git a/修改后main.py b/修改后main.py new file mode 100644 index 0000000..6ef47f0 --- /dev/null +++ b/修改后main.py @@ -0,0 +1,215 @@ +# -*- coding: utf-8 -*- +""" +主程序 - 手写数字识别 MLP 纯NumPy实现 + +使用方法: + python main.py # 运行默认配置 + python main.py --compare # 运行对比实验 + +依赖: + pip install numpy requests +""" + +import numpy as np +import time +from datetime import datetime +from model_numpy import MLP +from dataset import load_data +from config import * +import cv2 +from PIL import Image + + +def train_and_evaluate(): + """ + 训练并评估模型 + """ + print("=" * 60) + print("手写数字识别 - 纯NumPy MLP实现") + print("=" * 60) + + # ===== 加载数据 ===== + try: + X_train, y_train, X_test, y_test = load_data() + except Exception as e: + print(f"\n错误: {e}") + print("\n请手动下载数据文件:") + print(" 1. 创建 data/ 目录") + print(" 2. 下载以下文件到 data/:") + print(" - train-images-idx3-ubyte.gz (9.9 MB)") + print(" - train-labels-idx1-ubyte.gz (28 KB)") + print(" - t10k-images-idx3-ubyte.gz (1.6 MB)") + print(" - t10k-labels-idx1-ubyte.gz (5 KB)") + print(" 下载地址: https://storage.googleapis.com/tensorflow/tf-keras-datasets/") + return None, None, None + + # ===== 创建模型 ===== + print("\n[2] 创建MLP模型...") + model = MLP( + input_size=INPUT_SIZE, + hidden_size=HIDDEN_SIZE, + num_classes=NUM_CLASSES, + learning_rate=LEARNING_RATE, + seed=SEED + ) + + # ===== 训练模型 ===== + print("\n[3] 开始训练...") + start_time = time.time() + + model.fit( + X_train, y_train, + X_val=X_test, y_val=y_test, + epochs=NUM_EPOCHS, + batch_size=BATCH_SIZE, + verbose=True + ) + + train_time = time.time() - start_time + + # ===== 最终评估 ===== + print("\n" + "=" * 60) + print("训练完成!") + print("=" * 60) + + train_acc = model.accuracy(X_train, y_train) + test_acc = model.accuracy(X_test, y_test) + + print(f"\n最终结果:") + print(f" 训练准确率: {train_acc:.4f} ({train_acc*100:.2f}%)") + print(f" 测试准确率: {test_acc:.4f} ({test_acc*100:.2f}%)") + print(f" 训练时间: {train_time:.2f} 秒") + + # ===== 保存模型 ===== + timestamp = datetime.now().strftime("%m%d_%H%M%S") + model_path = f"mnist_mlp_{timestamp}" + model.save(model_path) + + # ===== 预测示例 ===== + print("\n[4] 预测示例:") + indices = np.random.choice(len(X_test), 5, replace=False) + + for i, idx in enumerate(indices): + img = X_test[idx] + true_label = np.argmax(y_test[idx]) + pred_label = model.predict(img.reshape(1, -1))[0] + prob = model.predict_proba(img.reshape(1, -1))[0] + + status = '✓' if true_label == pred_label else '✗' + print(f" 样本{i+1}: 真实={true_label}, 预测={pred_label}, " + f"置信度={prob[pred_label]:.2f} {status}") + + return model, train_acc, test_acc + def predict_custom_image(model, img_path): + """ + 预测自己的手写数字图片 + :param model: 训练好的MLP模型 + :param img_path: 图片路径(如test.png) + """ + # 1. 读取图片,转为灰度图 + img = Image.open(img_path).convert('L') + # 2. 缩放到28×28,和MNIST一致 + img = img.resize((28, 28), Image.Resampling.LANCZOS) + # 3. 转为numpy数组,归一化0~1 + img_array = np.array(img) / 255.0 + # 4. 反转颜色:白底黑字 → 黑底白字(和MNIST数据集一致) + img_array = 1 - img_array + # 5. 展平为784维 + x = img_array.reshape(1, -1) + # 6. 预测 + pred = model.predict(x)[0] + prob = model.predict_proba(x)[0] + print(f"\n===== 自定义图片预测结果 =====") + print(f"预测数字:{pred}") + print(f"置信度:{prob[pred]:.4f}") + return pred + +def run_comparison(): + """ + 运行对比实验 + """ + print("\n" + "=" * 60) + print("超参数对比实验") + print("=" * 60) + + # 加载数据 + try: + X_train, y_train, X_test, y_test = load_data() + except Exception as e: + print(f"加载数据失败: {e}") + return + + # 实验配置 + experiments = [ + {"hidden_size": 32, "lr": 0.1, "name": "小模型(32神经元)"}, + {"hidden_size": 128, "lr": 0.1, "name": "标准(128神经元)"}, + {"hidden_size": 256, "lr": 0.1, "name": "大模型(256神经元)"}, + {"hidden_size": 128, "lr": 0.01, "name": "小学习率(0.01)"}, + {"hidden_size": 128, "lr": 0.5, "name": "大学习率(0.5)"}, + ] + + results = [] + + for exp in experiments: + print(f"\n实验: {exp['name']}") + print("-" * 40) + + model = MLP( + input_size=INPUT_SIZE, + hidden_size=exp['hidden_size'], + num_classes=NUM_CLASSES, + learning_rate=exp['lr'], + seed=SEED + ) + + start_time = time.time() + model.fit(X_train, y_train, epochs=30, batch_size=BATCH_SIZE, verbose=False) + train_time = time.time() - start_time + + train_acc = model.accuracy(X_train, y_train) + test_acc = model.accuracy(X_test, y_test) + + results.append({ + 'name': exp['name'], + 'hidden_size': exp['hidden_size'], + 'lr': exp['lr'], + 'train_acc': train_acc, + 'test_acc': test_acc, + 'train_time': train_time + }) + + print(f" 训练准确率: {train_acc:.4f} | 测试准确率: {test_acc:.4f} | 时间: {train_time:.1f}s") + + # 汇总 + print("\n" + "=" * 60) + print("实验结果汇总") + print("=" * 60) + print(f"\n{'配置':<25} {'训练准确率':<12} {'测试准确率':<12} {'时间':<8}") + print("-" * 60) + + for r in results: + print(f"{r['name']:<25} {r['train_acc']:<12.4f} {r['test_acc']:<12.4f} {r['train_time']:<8.1f}s") + + best = max(results, key=lambda x: x['test_acc']) + print(f"\n最佳配置: {best['name']}, 测试准确率: {best['test_acc']:.4f}") + + +def main(): + """主函数""" + if RUN_COMPARISON: + run_comparison() + else: + model, train_acc, test_acc = train_and_evaluate() + # ========== 新增:训练完成后,输入图片路径识别 ========== + while True: + img_path = input("\n请输入手写数字图片路径(输入q退出):") + if img_path.lower() == 'q': + break + try: + predict_custom_image(model, img_path) + except Exception as e: + print(f"图片读取失败:{e}") + + print("\n" + "=" * 60) + print("程序结束!") + print("=" * 60) \ No newline at end of file