Files
task-3-3-2-MLP/digit_mlp_class/test_image.py
2026-05-21 15:08:03 +08:00

231 lines
6.4 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
测试脚本 - 用训练好的模型识别学生手写数字图片
使用方法:
python test_image.py path/to/image.png
python test_image.py path/to/image.jpg
python test_image.py path/to/folder/ # 识别文件夹内所有图片
依赖:
pip install numpy pillow
图片要求:
- 建议尺寸28x28 或更大(程序会自动缩放)
- 背景最好为白色,数字为黑色
- 手写数字清晰、无遮挡
"""
import sys
import os
import numpy as np
from PIL import Image
# 尝试导入模型
try:
from model_numpy import MLP
except ImportError:
print("错误:请在 digit_mlp_class 目录下运行此脚本")
print(" cd digit_mlp_class")
print(" python test_image.py your_image.png")
sys.exit(1)
def find_latest_model():
"""查找最新的模型文件"""
model_files = [f for f in os.listdir('.') if f.startswith('mnist_mlp_') and f.endswith('.npy')]
if not model_files:
return None
# 按时间戳分组
timestamps = set()
for f in model_files:
# 文件格式: mnist_mlp_YYMMDD_HHMMSS_suffix.npy
# 去掉后缀 _W1.npy, _b1.npy 等
base = f.rsplit('_', 1)[0]
timestamps.add(base)
# 选择最新的
latest = sorted(timestamps)[-1]
# 检查完整模型
required = ['W1', 'b1', 'W2', 'b2']
for r in required:
if not os.path.exists(f'{latest}_{r}.npy'):
return None
return latest
def load_model():
"""加载训练好的模型"""
# 查找最新模型
model_path = find_latest_model()
if model_path is None:
print("错误:未找到训练好的模型!")
print(" 请先运行: python main.py")
print(" 或确保当前目录有 mnist_mlp_*.npy 模型文件")
return None
print(f"加载模型: {model_path}")
model = MLP(input_size=784, hidden_size=128, num_classes=10, learning_rate=0.1)
model.W1 = np.load(f'{model_path}_W1.npy')
model.b1 = np.load(f'{model_path}_b1.npy')
model.W2 = np.load(f'{model_path}_W2.npy')
model.b2 = np.load(f'{model_path}_b2.npy')
print("模型加载成功!\n")
return model
def preprocess_image(image_path):
"""
图片预处理:将任意图片转为 28x28 归一化向量
处理流程:
1. 打开图片并转为灰度
2. 调整大小为 28x28
3. 转为NumPy数组并归一化到 [0, 1]
4. 展平为 784 维向量
"""
try:
img = Image.open(image_path).convert('L') # 转灰度
except Exception as e:
raise ValueError(f"无法打开图片: {e}")
# 保存原始尺寸用于调试
orig_width, orig_height = img.size
# 缩放到 28x28
img = img.resize((28, 28), Image.LANCZOS)
# 转为NumPy数组并归一化
img_array = np.array(img, dtype=np.float32) / 255.0
# MNIST是白底黑字(0=背景, 1=数字),如果图片是黑底白字需要反转
# 检查是否是黑底白字:背景均值 > 0.5
if img_array.mean() > 0.5:
img_array = 1.0 - img_array
# 展平为向量
img_vector = img_array.flatten()
print(f" 图片: {os.path.basename(image_path)}")
print(f" 原始尺寸: {orig_width}x{orig_height}")
print(f" 处理后尺寸: 28x28")
return img_vector
def predict_image(model, image_path):
"""
识别单张图片
返回:
predicted_digit: 预测的数字 (0-9)
confidence: 置信度 (0-1)
all_probs: 所有数字的概率
"""
# 预处理
img_vector = preprocess_image(image_path)
# 预测
probs = model.predict_proba(img_vector.reshape(1, -1))[0]
predicted_digit = np.argmax(probs)
confidence = probs[predicted_digit]
return predicted_digit, confidence, probs
def print_results(digit, confidence, probs):
"""打印识别结果"""
print(f" 预测结果: {digit}")
print(f" 置信度: {confidence:.2%}")
# 显示各数字概率
print(f"\n 各数字概率:")
print(f" ", end="")
for i in range(10):
bar_len = int(probs[i] * 20)
print(f" {i}:{'' * bar_len}{'' * (20-bar_len)} {probs[i]:.1%}")
print()
def main():
"""主函数"""
print("\n" + "=" * 60)
print("手写数字识别 - 图片测试")
print("=" * 60 + "\n")
# 加载模型
model = load_model()
if model is None:
sys.exit(1)
# 检查命令行参数
if len(sys.argv) < 2:
print("使用方法:")
print(" python test_image.py path/to/image.png")
print(" python test_image.py path/to/image.jpg")
print(" python test_image.py path/to/folder/")
print()
print("示例:")
print(" python test_image.py my_digit.png")
print(" python test_image.py ./test_images/")
sys.exit(1)
target = sys.argv[1]
# 收集所有要处理的图片
image_paths = []
if os.path.isdir(target):
# 文件夹:收集所有图片
extensions = ['.png', '.jpg', '.jpeg', '.bmp', '.gif']
for f in os.listdir(target):
if any(f.lower().endswith(ext) for ext in extensions):
image_paths.append(os.path.join(target, f))
image_paths.sort()
elif os.path.isfile(target):
image_paths = [target]
else:
print(f"错误:文件或目录不存在: {target}")
sys.exit(1)
if not image_paths:
print(f"错误:在 {target} 中未找到图片文件")
sys.exit(1)
print(f"找到 {len(image_paths)} 张图片,开始识别...\n")
# 批量识别
results = []
for path in image_paths:
try:
digit, confidence, probs = predict_image(model, path)
results.append((path, digit, confidence))
print_results(digit, confidence, probs)
except Exception as e:
print(f" 识别失败: {e}\n")
# 汇总结果
print("=" * 60)
print(f"识别完成!共 {len(results)} 张图片")
print("=" * 60)
print(f"\n{'文件名':<30} {'预测':<6} {'置信度':<10}")
print("-" * 50)
for path, digit, confidence in results:
filename = os.path.basename(path)
print(f"{filename:<30} {digit:<6} {confidence:.1%}")
# 如果有真实标签(文件夹命名中包含),可以计算准确率
print()
if __name__ == '__main__':
main()