231 lines
6.4 KiB
Python
231 lines
6.4 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
测试脚本 - 用训练好的模型识别学生手写数字图片
|
||
|
||
使用方法:
|
||
python test_image.py path/to/image.png
|
||
python test_image.py path/to/image.jpg
|
||
python test_image.py path/to/folder/ # 识别文件夹内所有图片
|
||
|
||
依赖:
|
||
pip install numpy pillow
|
||
|
||
图片要求:
|
||
- 建议尺寸:28x28 或更大(程序会自动缩放)
|
||
- 背景最好为白色,数字为黑色
|
||
- 手写数字清晰、无遮挡
|
||
"""
|
||
|
||
import sys
|
||
import os
|
||
import numpy as np
|
||
from PIL import Image
|
||
|
||
# 尝试导入模型
|
||
try:
|
||
from model_numpy import MLP
|
||
except ImportError:
|
||
print("错误:请在 digit_mlp_class 目录下运行此脚本")
|
||
print(" cd digit_mlp_class")
|
||
print(" python test_image.py your_image.png")
|
||
sys.exit(1)
|
||
|
||
|
||
def find_latest_model():
|
||
"""查找最新的模型文件"""
|
||
model_files = [f for f in os.listdir('.') if f.startswith('mnist_mlp_') and f.endswith('.npy')]
|
||
if not model_files:
|
||
return None
|
||
|
||
# 按时间戳分组
|
||
timestamps = set()
|
||
for f in model_files:
|
||
# 文件格式: mnist_mlp_YYMMDD_HHMMSS_suffix.npy
|
||
# 去掉后缀 _W1.npy, _b1.npy 等
|
||
base = f.rsplit('_', 1)[0]
|
||
timestamps.add(base)
|
||
|
||
# 选择最新的
|
||
latest = sorted(timestamps)[-1]
|
||
|
||
# 检查完整模型
|
||
required = ['W1', 'b1', 'W2', 'b2']
|
||
for r in required:
|
||
if not os.path.exists(f'{latest}_{r}.npy'):
|
||
return None
|
||
|
||
return latest
|
||
|
||
|
||
def load_model():
|
||
"""加载训练好的模型"""
|
||
# 查找最新模型
|
||
model_path = find_latest_model()
|
||
|
||
if model_path is None:
|
||
print("错误:未找到训练好的模型!")
|
||
print(" 请先运行: python main.py")
|
||
print(" 或确保当前目录有 mnist_mlp_*.npy 模型文件")
|
||
return None
|
||
|
||
print(f"加载模型: {model_path}")
|
||
|
||
model = MLP(input_size=784, hidden_size=128, num_classes=10, learning_rate=0.1)
|
||
|
||
model.W1 = np.load(f'{model_path}_W1.npy')
|
||
model.b1 = np.load(f'{model_path}_b1.npy')
|
||
model.W2 = np.load(f'{model_path}_W2.npy')
|
||
model.b2 = np.load(f'{model_path}_b2.npy')
|
||
|
||
print("模型加载成功!\n")
|
||
return model
|
||
|
||
|
||
def preprocess_image(image_path):
|
||
"""
|
||
图片预处理:将任意图片转为 28x28 归一化向量
|
||
|
||
处理流程:
|
||
1. 打开图片并转为灰度
|
||
2. 调整大小为 28x28
|
||
3. 转为NumPy数组并归一化到 [0, 1]
|
||
4. 展平为 784 维向量
|
||
"""
|
||
try:
|
||
img = Image.open(image_path).convert('L') # 转灰度
|
||
except Exception as e:
|
||
raise ValueError(f"无法打开图片: {e}")
|
||
|
||
# 保存原始尺寸用于调试
|
||
orig_width, orig_height = img.size
|
||
|
||
# 缩放到 28x28
|
||
img = img.resize((28, 28), Image.LANCZOS)
|
||
|
||
# 转为NumPy数组并归一化
|
||
img_array = np.array(img, dtype=np.float32) / 255.0
|
||
|
||
# MNIST是白底黑字(0=背景, 1=数字),如果图片是黑底白字需要反转
|
||
# 检查是否是黑底白字:背景均值 > 0.5
|
||
if img_array.mean() > 0.5:
|
||
img_array = 1.0 - img_array
|
||
|
||
# 展平为向量
|
||
img_vector = img_array.flatten()
|
||
|
||
print(f" 图片: {os.path.basename(image_path)}")
|
||
print(f" 原始尺寸: {orig_width}x{orig_height}")
|
||
print(f" 处理后尺寸: 28x28")
|
||
|
||
return img_vector
|
||
|
||
|
||
def predict_image(model, image_path):
|
||
"""
|
||
识别单张图片
|
||
|
||
返回:
|
||
predicted_digit: 预测的数字 (0-9)
|
||
confidence: 置信度 (0-1)
|
||
all_probs: 所有数字的概率
|
||
"""
|
||
# 预处理
|
||
img_vector = preprocess_image(image_path)
|
||
|
||
# 预测
|
||
probs = model.predict_proba(img_vector.reshape(1, -1))[0]
|
||
predicted_digit = np.argmax(probs)
|
||
confidence = probs[predicted_digit]
|
||
|
||
return predicted_digit, confidence, probs
|
||
|
||
|
||
def print_results(digit, confidence, probs):
|
||
"""打印识别结果"""
|
||
print(f" 预测结果: {digit}")
|
||
print(f" 置信度: {confidence:.2%}")
|
||
|
||
# 显示各数字概率
|
||
print(f"\n 各数字概率:")
|
||
print(f" ", end="")
|
||
for i in range(10):
|
||
bar_len = int(probs[i] * 20)
|
||
print(f" {i}:{'█' * bar_len}{'░' * (20-bar_len)} {probs[i]:.1%}")
|
||
|
||
print()
|
||
|
||
|
||
def main():
|
||
"""主函数"""
|
||
print("\n" + "=" * 60)
|
||
print("手写数字识别 - 图片测试")
|
||
print("=" * 60 + "\n")
|
||
|
||
# 加载模型
|
||
model = load_model()
|
||
if model is None:
|
||
sys.exit(1)
|
||
|
||
# 检查命令行参数
|
||
if len(sys.argv) < 2:
|
||
print("使用方法:")
|
||
print(" python test_image.py path/to/image.png")
|
||
print(" python test_image.py path/to/image.jpg")
|
||
print(" python test_image.py path/to/folder/")
|
||
print()
|
||
print("示例:")
|
||
print(" python test_image.py my_digit.png")
|
||
print(" python test_image.py ./test_images/")
|
||
sys.exit(1)
|
||
|
||
target = sys.argv[1]
|
||
|
||
# 收集所有要处理的图片
|
||
image_paths = []
|
||
|
||
if os.path.isdir(target):
|
||
# 文件夹:收集所有图片
|
||
extensions = ['.png', '.jpg', '.jpeg', '.bmp', '.gif']
|
||
for f in os.listdir(target):
|
||
if any(f.lower().endswith(ext) for ext in extensions):
|
||
image_paths.append(os.path.join(target, f))
|
||
image_paths.sort()
|
||
elif os.path.isfile(target):
|
||
image_paths = [target]
|
||
else:
|
||
print(f"错误:文件或目录不存在: {target}")
|
||
sys.exit(1)
|
||
|
||
if not image_paths:
|
||
print(f"错误:在 {target} 中未找到图片文件")
|
||
sys.exit(1)
|
||
|
||
print(f"找到 {len(image_paths)} 张图片,开始识别...\n")
|
||
|
||
# 批量识别
|
||
results = []
|
||
for path in image_paths:
|
||
try:
|
||
digit, confidence, probs = predict_image(model, path)
|
||
results.append((path, digit, confidence))
|
||
print_results(digit, confidence, probs)
|
||
except Exception as e:
|
||
print(f" 识别失败: {e}\n")
|
||
|
||
# 汇总结果
|
||
print("=" * 60)
|
||
print(f"识别完成!共 {len(results)} 张图片")
|
||
print("=" * 60)
|
||
print(f"\n{'文件名':<30} {'预测':<6} {'置信度':<10}")
|
||
print("-" * 50)
|
||
|
||
for path, digit, confidence in results:
|
||
filename = os.path.basename(path)
|
||
print(f"{filename:<30} {digit:<6} {confidence:.1%}")
|
||
|
||
# 如果有真实标签(文件夹命名中包含),可以计算准确率
|
||
print()
|
||
|
||
|
||
if __name__ == '__main__':
|
||
main() |