Files
task-3-2-2-text-classification/修改后main.py
2026-05-19 11:29:33 +08:00

215 lines
6.6 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
主程序 - 手写数字识别 MLP 纯NumPy实现
使用方法:
python main.py # 运行默认配置
python main.py --compare # 运行对比实验
依赖:
pip install numpy requests
"""
import numpy as np
import time
from datetime import datetime
from model_numpy import MLP
from dataset import load_data
from config import *
import cv2
from PIL import Image
def train_and_evaluate():
"""
训练并评估模型
"""
print("=" * 60)
print("手写数字识别 - 纯NumPy MLP实现")
print("=" * 60)
# ===== 加载数据 =====
try:
X_train, y_train, X_test, y_test = load_data()
except Exception as e:
print(f"\n错误: {e}")
print("\n请手动下载数据文件:")
print(" 1. 创建 data/ 目录")
print(" 2. 下载以下文件到 data/:")
print(" - train-images-idx3-ubyte.gz (9.9 MB)")
print(" - train-labels-idx1-ubyte.gz (28 KB)")
print(" - t10k-images-idx3-ubyte.gz (1.6 MB)")
print(" - t10k-labels-idx1-ubyte.gz (5 KB)")
print(" 下载地址: https://storage.googleapis.com/tensorflow/tf-keras-datasets/")
return None, None, None
# ===== 创建模型 =====
print("\n[2] 创建MLP模型...")
model = MLP(
input_size=INPUT_SIZE,
hidden_size=HIDDEN_SIZE,
num_classes=NUM_CLASSES,
learning_rate=LEARNING_RATE,
seed=SEED
)
# ===== 训练模型 =====
print("\n[3] 开始训练...")
start_time = time.time()
model.fit(
X_train, y_train,
X_val=X_test, y_val=y_test,
epochs=NUM_EPOCHS,
batch_size=BATCH_SIZE,
verbose=True
)
train_time = time.time() - start_time
# ===== 最终评估 =====
print("\n" + "=" * 60)
print("训练完成!")
print("=" * 60)
train_acc = model.accuracy(X_train, y_train)
test_acc = model.accuracy(X_test, y_test)
print(f"\n最终结果:")
print(f" 训练准确率: {train_acc:.4f} ({train_acc*100:.2f}%)")
print(f" 测试准确率: {test_acc:.4f} ({test_acc*100:.2f}%)")
print(f" 训练时间: {train_time:.2f}")
# ===== 保存模型 =====
timestamp = datetime.now().strftime("%m%d_%H%M%S")
model_path = f"mnist_mlp_{timestamp}"
model.save(model_path)
# ===== 预测示例 =====
print("\n[4] 预测示例:")
indices = np.random.choice(len(X_test), 5, replace=False)
for i, idx in enumerate(indices):
img = X_test[idx]
true_label = np.argmax(y_test[idx])
pred_label = model.predict(img.reshape(1, -1))[0]
prob = model.predict_proba(img.reshape(1, -1))[0]
status = '' if true_label == pred_label else ''
print(f" 样本{i+1}: 真实={true_label}, 预测={pred_label}, "
f"置信度={prob[pred_label]:.2f} {status}")
return model, train_acc, test_acc
def predict_custom_image(model, img_path):
"""
预测自己的手写数字图片
:param model: 训练好的MLP模型
:param img_path: 图片路径如test.png
"""
# 1. 读取图片,转为灰度图
img = Image.open(img_path).convert('L')
# 2. 缩放到28×28和MNIST一致
img = img.resize((28, 28), Image.Resampling.LANCZOS)
# 3. 转为numpy数组归一化0~1
img_array = np.array(img) / 255.0
# 4. 反转颜色:白底黑字 → 黑底白字和MNIST数据集一致
img_array = 1 - img_array
# 5. 展平为784维
x = img_array.reshape(1, -1)
# 6. 预测
pred = model.predict(x)[0]
prob = model.predict_proba(x)[0]
print(f"\n===== 自定义图片预测结果 =====")
print(f"预测数字:{pred}")
print(f"置信度:{prob[pred]:.4f}")
return pred
def run_comparison():
"""
运行对比实验
"""
print("\n" + "=" * 60)
print("超参数对比实验")
print("=" * 60)
# 加载数据
try:
X_train, y_train, X_test, y_test = load_data()
except Exception as e:
print(f"加载数据失败: {e}")
return
# 实验配置
experiments = [
{"hidden_size": 32, "lr": 0.1, "name": "小模型(32神经元)"},
{"hidden_size": 128, "lr": 0.1, "name": "标准(128神经元)"},
{"hidden_size": 256, "lr": 0.1, "name": "大模型(256神经元)"},
{"hidden_size": 128, "lr": 0.01, "name": "小学习率(0.01)"},
{"hidden_size": 128, "lr": 0.5, "name": "大学习率(0.5)"},
]
results = []
for exp in experiments:
print(f"\n实验: {exp['name']}")
print("-" * 40)
model = MLP(
input_size=INPUT_SIZE,
hidden_size=exp['hidden_size'],
num_classes=NUM_CLASSES,
learning_rate=exp['lr'],
seed=SEED
)
start_time = time.time()
model.fit(X_train, y_train, epochs=30, batch_size=BATCH_SIZE, verbose=False)
train_time = time.time() - start_time
train_acc = model.accuracy(X_train, y_train)
test_acc = model.accuracy(X_test, y_test)
results.append({
'name': exp['name'],
'hidden_size': exp['hidden_size'],
'lr': exp['lr'],
'train_acc': train_acc,
'test_acc': test_acc,
'train_time': train_time
})
print(f" 训练准确率: {train_acc:.4f} | 测试准确率: {test_acc:.4f} | 时间: {train_time:.1f}s")
# 汇总
print("\n" + "=" * 60)
print("实验结果汇总")
print("=" * 60)
print(f"\n{'配置':<25} {'训练准确率':<12} {'测试准确率':<12} {'时间':<8}")
print("-" * 60)
for r in results:
print(f"{r['name']:<25} {r['train_acc']:<12.4f} {r['test_acc']:<12.4f} {r['train_time']:<8.1f}s")
best = max(results, key=lambda x: x['test_acc'])
print(f"\n最佳配置: {best['name']}, 测试准确率: {best['test_acc']:.4f}")
def main():
"""主函数"""
if RUN_COMPARISON:
run_comparison()
else:
model, train_acc, test_acc = train_and_evaluate()
# ========== 新增:训练完成后,输入图片路径识别 ==========
while True:
img_path = input("\n请输入手写数字图片路径(输入q退出)")
if img_path.lower() == 'q':
break
try:
predict_custom_image(model, img_path)
except Exception as e:
print(f"图片读取失败:{e}")
print("\n" + "=" * 60)
print("程序结束!")
print("=" * 60)