上传文件至 /

This commit is contained in:
2026-05-19 11:29:33 +08:00
parent 0c6c48f7a7
commit da18184960
2 changed files with 215 additions and 0 deletions

215
修改后main.py Normal file
View File

@@ -0,0 +1,215 @@
# -*- coding: utf-8 -*-
"""
主程序 - 手写数字识别 MLP 纯NumPy实现
使用方法:
python main.py # 运行默认配置
python main.py --compare # 运行对比实验
依赖:
pip install numpy requests
"""
import numpy as np
import time
from datetime import datetime
from model_numpy import MLP
from dataset import load_data
from config import *
import cv2
from PIL import Image
def train_and_evaluate():
"""
训练并评估模型
"""
print("=" * 60)
print("手写数字识别 - 纯NumPy MLP实现")
print("=" * 60)
# ===== 加载数据 =====
try:
X_train, y_train, X_test, y_test = load_data()
except Exception as e:
print(f"\n错误: {e}")
print("\n请手动下载数据文件:")
print(" 1. 创建 data/ 目录")
print(" 2. 下载以下文件到 data/:")
print(" - train-images-idx3-ubyte.gz (9.9 MB)")
print(" - train-labels-idx1-ubyte.gz (28 KB)")
print(" - t10k-images-idx3-ubyte.gz (1.6 MB)")
print(" - t10k-labels-idx1-ubyte.gz (5 KB)")
print(" 下载地址: https://storage.googleapis.com/tensorflow/tf-keras-datasets/")
return None, None, None
# ===== 创建模型 =====
print("\n[2] 创建MLP模型...")
model = MLP(
input_size=INPUT_SIZE,
hidden_size=HIDDEN_SIZE,
num_classes=NUM_CLASSES,
learning_rate=LEARNING_RATE,
seed=SEED
)
# ===== 训练模型 =====
print("\n[3] 开始训练...")
start_time = time.time()
model.fit(
X_train, y_train,
X_val=X_test, y_val=y_test,
epochs=NUM_EPOCHS,
batch_size=BATCH_SIZE,
verbose=True
)
train_time = time.time() - start_time
# ===== 最终评估 =====
print("\n" + "=" * 60)
print("训练完成!")
print("=" * 60)
train_acc = model.accuracy(X_train, y_train)
test_acc = model.accuracy(X_test, y_test)
print(f"\n最终结果:")
print(f" 训练准确率: {train_acc:.4f} ({train_acc*100:.2f}%)")
print(f" 测试准确率: {test_acc:.4f} ({test_acc*100:.2f}%)")
print(f" 训练时间: {train_time:.2f}")
# ===== 保存模型 =====
timestamp = datetime.now().strftime("%m%d_%H%M%S")
model_path = f"mnist_mlp_{timestamp}"
model.save(model_path)
# ===== 预测示例 =====
print("\n[4] 预测示例:")
indices = np.random.choice(len(X_test), 5, replace=False)
for i, idx in enumerate(indices):
img = X_test[idx]
true_label = np.argmax(y_test[idx])
pred_label = model.predict(img.reshape(1, -1))[0]
prob = model.predict_proba(img.reshape(1, -1))[0]
status = '' if true_label == pred_label else ''
print(f" 样本{i+1}: 真实={true_label}, 预测={pred_label}, "
f"置信度={prob[pred_label]:.2f} {status}")
return model, train_acc, test_acc
def predict_custom_image(model, img_path):
"""
预测自己的手写数字图片
:param model: 训练好的MLP模型
:param img_path: 图片路径如test.png
"""
# 1. 读取图片,转为灰度图
img = Image.open(img_path).convert('L')
# 2. 缩放到28×28和MNIST一致
img = img.resize((28, 28), Image.Resampling.LANCZOS)
# 3. 转为numpy数组归一化0~1
img_array = np.array(img) / 255.0
# 4. 反转颜色:白底黑字 → 黑底白字和MNIST数据集一致
img_array = 1 - img_array
# 5. 展平为784维
x = img_array.reshape(1, -1)
# 6. 预测
pred = model.predict(x)[0]
prob = model.predict_proba(x)[0]
print(f"\n===== 自定义图片预测结果 =====")
print(f"预测数字:{pred}")
print(f"置信度:{prob[pred]:.4f}")
return pred
def run_comparison():
"""
运行对比实验
"""
print("\n" + "=" * 60)
print("超参数对比实验")
print("=" * 60)
# 加载数据
try:
X_train, y_train, X_test, y_test = load_data()
except Exception as e:
print(f"加载数据失败: {e}")
return
# 实验配置
experiments = [
{"hidden_size": 32, "lr": 0.1, "name": "小模型(32神经元)"},
{"hidden_size": 128, "lr": 0.1, "name": "标准(128神经元)"},
{"hidden_size": 256, "lr": 0.1, "name": "大模型(256神经元)"},
{"hidden_size": 128, "lr": 0.01, "name": "小学习率(0.01)"},
{"hidden_size": 128, "lr": 0.5, "name": "大学习率(0.5)"},
]
results = []
for exp in experiments:
print(f"\n实验: {exp['name']}")
print("-" * 40)
model = MLP(
input_size=INPUT_SIZE,
hidden_size=exp['hidden_size'],
num_classes=NUM_CLASSES,
learning_rate=exp['lr'],
seed=SEED
)
start_time = time.time()
model.fit(X_train, y_train, epochs=30, batch_size=BATCH_SIZE, verbose=False)
train_time = time.time() - start_time
train_acc = model.accuracy(X_train, y_train)
test_acc = model.accuracy(X_test, y_test)
results.append({
'name': exp['name'],
'hidden_size': exp['hidden_size'],
'lr': exp['lr'],
'train_acc': train_acc,
'test_acc': test_acc,
'train_time': train_time
})
print(f" 训练准确率: {train_acc:.4f} | 测试准确率: {test_acc:.4f} | 时间: {train_time:.1f}s")
# 汇总
print("\n" + "=" * 60)
print("实验结果汇总")
print("=" * 60)
print(f"\n{'配置':<25} {'训练准确率':<12} {'测试准确率':<12} {'时间':<8}")
print("-" * 60)
for r in results:
print(f"{r['name']:<25} {r['train_acc']:<12.4f} {r['test_acc']:<12.4f} {r['train_time']:<8.1f}s")
best = max(results, key=lambda x: x['test_acc'])
print(f"\n最佳配置: {best['name']}, 测试准确率: {best['test_acc']:.4f}")
def main():
"""主函数"""
if RUN_COMPARISON:
run_comparison()
else:
model, train_acc, test_acc = train_and_evaluate()
# ========== 新增:训练完成后,输入图片路径识别 ==========
while True:
img_path = input("\n请输入手写数字图片路径(输入q退出)")
if img_path.lower() == 'q':
break
try:
predict_custom_image(model, img_path)
except Exception as e:
print(f"图片读取失败:{e}")
print("\n" + "=" * 60)
print("程序结束!")
print("=" * 60)