# -*- coding: utf-8 -*- """ 测试脚本 - 用训练好的模型识别学生手写数字图片 使用方法: python test_image.py path/to/image.png python test_image.py path/to/image.jpg python test_image.py path/to/folder/ # 识别文件夹内所有图片 依赖: pip install numpy pillow 图片要求: - 建议尺寸:28x28 或更大(程序会自动缩放) - 背景最好为白色,数字为黑色 - 手写数字清晰、无遮挡 """ import sys import os import numpy as np from PIL import Image # 尝试导入模型 try: from model_numpy import MLP except ImportError: print("错误:请在 digit_mlp_class 目录下运行此脚本") print(" cd digit_mlp_class") print(" python test_image.py your_image.png") sys.exit(1) def find_latest_model(): """查找最新的模型文件""" model_files = [f for f in os.listdir('.') if f.startswith('mnist_mlp_') and f.endswith('.npy')] if not model_files: return None # 按时间戳分组 timestamps = set() for f in model_files: # 文件格式: mnist_mlp_YYMMDD_HHMMSS_suffix.npy # 去掉后缀 _W1.npy, _b1.npy 等 base = f.rsplit('_', 1)[0] timestamps.add(base) # 选择最新的 latest = sorted(timestamps)[-1] # 检查完整模型 required = ['W1', 'b1', 'W2', 'b2'] for r in required: if not os.path.exists(f'{latest}_{r}.npy'): return None return latest def load_model(): """加载训练好的模型""" # 查找最新模型 model_path = find_latest_model() if model_path is None: print("错误:未找到训练好的模型!") print(" 请先运行: python main.py") print(" 或确保当前目录有 mnist_mlp_*.npy 模型文件") return None print(f"加载模型: {model_path}") model = MLP(input_size=784, hidden_size=128, num_classes=10, learning_rate=0.1) model.W1 = np.load(f'{model_path}_W1.npy') model.b1 = np.load(f'{model_path}_b1.npy') model.W2 = np.load(f'{model_path}_W2.npy') model.b2 = np.load(f'{model_path}_b2.npy') print("模型加载成功!\n") return model def preprocess_image(image_path): """ 图片预处理:将任意图片转为 28x28 归一化向量 处理流程: 1. 打开图片并转为灰度 2. 调整大小为 28x28 3. 转为NumPy数组并归一化到 [0, 1] 4. 展平为 784 维向量 """ try: img = Image.open(image_path).convert('L') # 转灰度 except Exception as e: raise ValueError(f"无法打开图片: {e}") # 保存原始尺寸用于调试 orig_width, orig_height = img.size # 缩放到 28x28 img = img.resize((28, 28), Image.LANCZOS) # 转为NumPy数组并归一化 img_array = np.array(img, dtype=np.float32) / 255.0 # MNIST是白底黑字(0=背景, 1=数字),如果图片是黑底白字需要反转 # 检查是否是黑底白字:背景均值 > 0.5 if img_array.mean() > 0.5: img_array = 1.0 - img_array # 展平为向量 img_vector = img_array.flatten() print(f" 图片: {os.path.basename(image_path)}") print(f" 原始尺寸: {orig_width}x{orig_height}") print(f" 处理后尺寸: 28x28") return img_vector def predict_image(model, image_path): """ 识别单张图片 返回: predicted_digit: 预测的数字 (0-9) confidence: 置信度 (0-1) all_probs: 所有数字的概率 """ # 预处理 img_vector = preprocess_image(image_path) # 预测 probs = model.predict_proba(img_vector.reshape(1, -1))[0] predicted_digit = np.argmax(probs) confidence = probs[predicted_digit] return predicted_digit, confidence, probs def print_results(digit, confidence, probs): """打印识别结果""" print(f" 预测结果: {digit}") print(f" 置信度: {confidence:.2%}") # 显示各数字概率 print(f"\n 各数字概率:") print(f" ", end="") for i in range(10): bar_len = int(probs[i] * 20) print(f" {i}:{'█' * bar_len}{'░' * (20-bar_len)} {probs[i]:.1%}") print() def main(): """主函数""" print("\n" + "=" * 60) print("手写数字识别 - 图片测试") print("=" * 60 + "\n") # 加载模型 model = load_model() if model is None: sys.exit(1) # 检查命令行参数 if len(sys.argv) < 2: print("使用方法:") print(" python test_image.py path/to/image.png") print(" python test_image.py path/to/image.jpg") print(" python test_image.py path/to/folder/") print() print("示例:") print(" python test_image.py my_digit.png") print(" python test_image.py ./test_images/") sys.exit(1) target = sys.argv[1] # 收集所有要处理的图片 image_paths = [] if os.path.isdir(target): # 文件夹:收集所有图片 extensions = ['.png', '.jpg', '.jpeg', '.bmp', '.gif'] for f in os.listdir(target): if any(f.lower().endswith(ext) for ext in extensions): image_paths.append(os.path.join(target, f)) image_paths.sort() elif os.path.isfile(target): image_paths = [target] else: print(f"错误:文件或目录不存在: {target}") sys.exit(1) if not image_paths: print(f"错误:在 {target} 中未找到图片文件") sys.exit(1) print(f"找到 {len(image_paths)} 张图片,开始识别...\n") # 批量识别 results = [] for path in image_paths: try: digit, confidence, probs = predict_image(model, path) results.append((path, digit, confidence)) print_results(digit, confidence, probs) except Exception as e: print(f" 识别失败: {e}\n") # 汇总结果 print("=" * 60) print(f"识别完成!共 {len(results)} 张图片") print("=" * 60) print(f"\n{'文件名':<30} {'预测':<6} {'置信度':<10}") print("-" * 50) for path, digit, confidence in results: filename = os.path.basename(path) print(f"{filename:<30} {digit:<6} {confidence:.1%}") # 如果有真实标签(文件夹命名中包含),可以计算准确率 print() if __name__ == '__main__': main()