上传文件至 /
This commit is contained in:
231
test_image.py
Normal file
231
test_image.py
Normal file
@@ -0,0 +1,231 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
测试脚本 - 用训练好的模型识别学生手写数字图片
|
||||
|
||||
使用方法:
|
||||
python test_image.py path/to/image.png
|
||||
python test_image.py path/to/image.jpg
|
||||
python test_image.py path/to/folder/ # 识别文件夹内所有图片
|
||||
|
||||
依赖:
|
||||
pip install numpy pillow
|
||||
|
||||
图片要求:
|
||||
- 建议尺寸:28x28 或更大(程序会自动缩放)
|
||||
- 背景最好为白色,数字为黑色
|
||||
- 手写数字清晰、无遮挡
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
# 尝试导入模型
|
||||
try:
|
||||
from model_numpy import MLP
|
||||
except ImportError:
|
||||
print("错误:请在 digit_mlp_class 目录下运行此脚本")
|
||||
print(" cd digit_mlp_class")
|
||||
print(" python test_image.py your_image.png")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def find_latest_model():
|
||||
"""查找最新的模型文件"""
|
||||
model_files = [f for f in os.listdir('.') if f.startswith('mnist_mlp_') and f.endswith('.npy')]
|
||||
if not model_files:
|
||||
return None
|
||||
|
||||
# 按时间戳分组
|
||||
timestamps = set()
|
||||
for f in model_files:
|
||||
# 文件格式: mnist_mlp_YYMMDD_HHMMSS_suffix.npy
|
||||
# 去掉后缀 _W1.npy, _b1.npy 等
|
||||
base = f.rsplit('_', 1)[0]
|
||||
timestamps.add(base)
|
||||
|
||||
# 选择最新的
|
||||
latest = sorted(timestamps)[-1]
|
||||
|
||||
# 检查完整模型
|
||||
required = ['W1', 'b1', 'W2', 'b2']
|
||||
for r in required:
|
||||
if not os.path.exists(f'{latest}_{r}.npy'):
|
||||
return None
|
||||
|
||||
return latest
|
||||
|
||||
|
||||
def load_model():
|
||||
"""加载训练好的模型"""
|
||||
# 查找最新模型
|
||||
model_path = find_latest_model()
|
||||
|
||||
if model_path is None:
|
||||
print("错误:未找到训练好的模型!")
|
||||
print(" 请先运行: python main.py")
|
||||
print(" 或确保当前目录有 mnist_mlp_*.npy 模型文件")
|
||||
return None
|
||||
|
||||
print(f"加载模型: {model_path}")
|
||||
|
||||
model = MLP(input_size=784, hidden_size=128, num_classes=10, learning_rate=0.1)
|
||||
|
||||
model.W1 = np.load(f'{model_path}_W1.npy')
|
||||
model.b1 = np.load(f'{model_path}_b1.npy')
|
||||
model.W2 = np.load(f'{model_path}_W2.npy')
|
||||
model.b2 = np.load(f'{model_path}_b2.npy')
|
||||
|
||||
print("模型加载成功!\n")
|
||||
return model
|
||||
|
||||
|
||||
def preprocess_image(image_path):
|
||||
"""
|
||||
图片预处理:将任意图片转为 28x28 归一化向量
|
||||
|
||||
处理流程:
|
||||
1. 打开图片并转为灰度
|
||||
2. 调整大小为 28x28
|
||||
3. 转为NumPy数组并归一化到 [0, 1]
|
||||
4. 展平为 784 维向量
|
||||
"""
|
||||
try:
|
||||
img = Image.open(image_path).convert('L') # 转灰度
|
||||
except Exception as e:
|
||||
raise ValueError(f"无法打开图片: {e}")
|
||||
|
||||
# 保存原始尺寸用于调试
|
||||
orig_width, orig_height = img.size
|
||||
|
||||
# 缩放到 28x28
|
||||
img = img.resize((28, 28), Image.LANCZOS)
|
||||
|
||||
# 转为NumPy数组并归一化
|
||||
img_array = np.array(img, dtype=np.float32) / 255.0
|
||||
|
||||
# MNIST是白底黑字(0=背景, 1=数字),如果图片是黑底白字需要反转
|
||||
# 检查是否是黑底白字:背景均值 > 0.5
|
||||
if img_array.mean() > 0.5:
|
||||
img_array = 1.0 - img_array
|
||||
|
||||
# 展平为向量
|
||||
img_vector = img_array.flatten()
|
||||
|
||||
print(f" 图片: {os.path.basename(image_path)}")
|
||||
print(f" 原始尺寸: {orig_width}x{orig_height}")
|
||||
print(f" 处理后尺寸: 28x28")
|
||||
|
||||
return img_vector
|
||||
|
||||
|
||||
def predict_image(model, image_path):
|
||||
"""
|
||||
识别单张图片
|
||||
|
||||
返回:
|
||||
predicted_digit: 预测的数字 (0-9)
|
||||
confidence: 置信度 (0-1)
|
||||
all_probs: 所有数字的概率
|
||||
"""
|
||||
# 预处理
|
||||
img_vector = preprocess_image(image_path)
|
||||
|
||||
# 预测
|
||||
probs = model.predict_proba(img_vector.reshape(1, -1))[0]
|
||||
predicted_digit = np.argmax(probs)
|
||||
confidence = probs[predicted_digit]
|
||||
|
||||
return predicted_digit, confidence, probs
|
||||
|
||||
|
||||
def print_results(digit, confidence, probs):
|
||||
"""打印识别结果"""
|
||||
print(f" 预测结果: {digit}")
|
||||
print(f" 置信度: {confidence:.2%}")
|
||||
|
||||
# 显示各数字概率
|
||||
print(f"\n 各数字概率:")
|
||||
print(f" ", end="")
|
||||
for i in range(10):
|
||||
bar_len = int(probs[i] * 20)
|
||||
print(f" {i}:{'█' * bar_len}{'░' * (20-bar_len)} {probs[i]:.1%}")
|
||||
|
||||
print()
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
print("\n" + "=" * 60)
|
||||
print("手写数字识别 - 图片测试")
|
||||
print("=" * 60 + "\n")
|
||||
|
||||
# 加载模型
|
||||
model = load_model()
|
||||
if model is None:
|
||||
sys.exit(1)
|
||||
|
||||
# 检查命令行参数
|
||||
if len(sys.argv) < 2:
|
||||
print("使用方法:")
|
||||
print(" python test_image.py path/to/image.png")
|
||||
print(" python test_image.py path/to/image.jpg")
|
||||
print(" python test_image.py path/to/folder/")
|
||||
print()
|
||||
print("示例:")
|
||||
print(" python test_image.py my_digit.png")
|
||||
print(" python test_image.py ./test_images/")
|
||||
sys.exit(1)
|
||||
|
||||
target = sys.argv[1]
|
||||
|
||||
# 收集所有要处理的图片
|
||||
image_paths = []
|
||||
|
||||
if os.path.isdir(target):
|
||||
# 文件夹:收集所有图片
|
||||
extensions = ['.png', '.jpg', '.jpeg', '.bmp', '.gif']
|
||||
for f in os.listdir(target):
|
||||
if any(f.lower().endswith(ext) for ext in extensions):
|
||||
image_paths.append(os.path.join(target, f))
|
||||
image_paths.sort()
|
||||
elif os.path.isfile(target):
|
||||
image_paths = [target]
|
||||
else:
|
||||
print(f"错误:文件或目录不存在: {target}")
|
||||
sys.exit(1)
|
||||
|
||||
if not image_paths:
|
||||
print(f"错误:在 {target} 中未找到图片文件")
|
||||
sys.exit(1)
|
||||
|
||||
print(f"找到 {len(image_paths)} 张图片,开始识别...\n")
|
||||
|
||||
# 批量识别
|
||||
results = []
|
||||
for path in image_paths:
|
||||
try:
|
||||
digit, confidence, probs = predict_image(model, path)
|
||||
results.append((path, digit, confidence))
|
||||
print_results(digit, confidence, probs)
|
||||
except Exception as e:
|
||||
print(f" 识别失败: {e}\n")
|
||||
|
||||
# 汇总结果
|
||||
print("=" * 60)
|
||||
print(f"识别完成!共 {len(results)} 张图片")
|
||||
print("=" * 60)
|
||||
print(f"\n{'文件名':<30} {'预测':<6} {'置信度':<10}")
|
||||
print("-" * 50)
|
||||
|
||||
for path, digit, confidence in results:
|
||||
filename = os.path.basename(path)
|
||||
print(f"{filename:<30} {digit:<6} {confidence:.1%}")
|
||||
|
||||
# 如果有真实标签(文件夹命名中包含),可以计算准确率
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
350
visualize.py
Normal file
350
visualize.py
Normal file
@@ -0,0 +1,350 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
可视化工具 - 展示神经网络各层的输出
|
||||
|
||||
用于课堂教学,让学生直观理解:
|
||||
1. 输入图像长什么样
|
||||
2. 第一层隐藏层学到了什么特征
|
||||
3. 各层激活值的变化
|
||||
|
||||
使用方法:
|
||||
python visualize.py # 可视化测试集前5张
|
||||
python visualize.py --single # 可视化单张图片
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import os
|
||||
import sys
|
||||
import matplotlib
|
||||
matplotlib.use('Agg') # 无头模式,不显示图形
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
def visualize_input_image(img_vector, save_path='visualizations/input.png'):
|
||||
"""把784维向量还原成28x28图像并保存"""
|
||||
img = img_vector.reshape(28, 28) * 255
|
||||
img = img.astype(np.uint8)
|
||||
Image.fromarray(img).save(save_path)
|
||||
return save_path
|
||||
|
||||
|
||||
def visualize_activations(model, img_vector, save_dir='visualizations'):
|
||||
"""
|
||||
可视化网络各层的激活值
|
||||
"""
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
# 前向传播获取各层激活值
|
||||
model.forward(img_vector.reshape(1, -1))
|
||||
|
||||
# 1. 保存输入图像
|
||||
visualize_input_image(img_vector, os.path.join(save_dir, '01_input.png'))
|
||||
|
||||
# 2. 可视化第一层激活(隐藏层)
|
||||
hidden_activations = model.a1[0] # (128,)
|
||||
visualize_hidden_layer(hidden_activations, os.path.join(save_dir, '02_hidden.png'))
|
||||
|
||||
# 3. 可视化输出层概率
|
||||
output_probs = model.probs[0] # (10,)
|
||||
visualize_output_prob(output_probs, os.path.join(save_dir, '03_output_prob.png'))
|
||||
|
||||
# 4. 生成汇总图
|
||||
create_summary_image(img_vector, hidden_activations, output_probs, save_dir)
|
||||
|
||||
return save_dir
|
||||
|
||||
|
||||
def visualize_hidden_layer(activations, save_path):
|
||||
"""
|
||||
可视化隐藏层激活值
|
||||
把128个神经元的激活值排成8x16网格显示
|
||||
"""
|
||||
grid_cols = 16
|
||||
grid_rows = 8
|
||||
cell_size = 24
|
||||
|
||||
img_h = grid_rows * cell_size
|
||||
img_w = grid_cols * cell_size
|
||||
grid = np.ones((img_h, img_w)) * 255
|
||||
|
||||
for i, act in enumerate(activations):
|
||||
row = i // grid_cols
|
||||
col = i % grid_cols
|
||||
intensity = max(0, min(1, act * 2))
|
||||
color = int(255 * (1 - intensity * 0.7))
|
||||
grid[row*cell_size:(row+1)*cell_size-1, col*cell_size:(col+1)*cell_size-1] = color
|
||||
|
||||
Image.fromarray(grid.astype(np.uint8)).save(save_path)
|
||||
|
||||
|
||||
def visualize_output_prob(probs, save_path):
|
||||
"""可视化输出层概率分布"""
|
||||
fig, ax = plt.subplots(figsize=(8, 4))
|
||||
|
||||
digits = list(range(10))
|
||||
colors = ['#3498db' if i != np.argmax(probs) else '#e74c3c' for i in digits]
|
||||
|
||||
bars = ax.bar(digits, probs, color=colors)
|
||||
ax.set_xlabel('数字', fontsize=12)
|
||||
ax.set_ylabel('概率', fontsize=12)
|
||||
ax.set_title('输出层:各数字的预测概率', fontsize=14)
|
||||
ax.set_xticks(digits)
|
||||
ax.set_ylim(0, 1)
|
||||
|
||||
max_idx = np.argmax(probs)
|
||||
ax.annotate(f'{probs[max_idx]:.1%}',
|
||||
xy=(max_idx, probs[max_idx]),
|
||||
ha='center', va='bottom', fontsize=10, color='#e74c3c', fontweight='bold')
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(save_path, dpi=100, bbox_inches='tight')
|
||||
plt.close()
|
||||
|
||||
|
||||
def create_summary_image(img_vector, hidden_activations, output_probs, save_dir):
|
||||
"""创建汇总图"""
|
||||
fig = plt.figure(figsize=(14, 6))
|
||||
|
||||
# 1. 输入图像
|
||||
ax1 = fig.add_subplot(2, 4, 1)
|
||||
ax1.imshow(img_vector.reshape(28, 28), cmap='gray')
|
||||
ax1.set_title('(1) Input Image\n(28x28 pixels)', fontsize=11)
|
||||
ax1.axis('off')
|
||||
|
||||
# 2. 像素值分布
|
||||
ax2 = fig.add_subplot(2, 4, 2)
|
||||
ax2.hist(img_vector, bins=30, color='#3498db', alpha=0.7, edgecolor='white')
|
||||
ax2.set_title('(2) Pixel Value Distribution\n(normalized 0~1)', fontsize=11)
|
||||
ax2.set_xlabel('像素值')
|
||||
ax2.set_ylabel('频数')
|
||||
|
||||
# 3. 隐藏层激活(热力图)
|
||||
ax3 = fig.add_subplot(2, 4, 3)
|
||||
# 128 = 8 × 16
|
||||
act_2d = hidden_activations.reshape(8, 16)
|
||||
im = ax3.imshow(act_2d, cmap='Blues', aspect='auto')
|
||||
ax3.set_title(f'(3) Hidden Layer\n(128 neurons)', fontsize=11)
|
||||
ax3.axis('off')
|
||||
plt.colorbar(im, ax=ax3, shrink=0.6)
|
||||
|
||||
# 4. 隐藏层激活(条形图)
|
||||
ax4 = fig.add_subplot(2, 4, 4)
|
||||
ax4.bar(range(len(hidden_activations)), hidden_activations, color='#3498db', alpha=0.7)
|
||||
ax4.set_title('(4) Neuron Activations', fontsize=11)
|
||||
ax4.set_xlabel('神经元编号')
|
||||
ax4.set_ylabel('激活强度')
|
||||
|
||||
# 5. 输出概率
|
||||
ax5 = fig.add_subplot(2, 4, 5)
|
||||
digits = list(range(10))
|
||||
colors = ['#3498db' if i != np.argmax(output_probs) else '#e74c3c' for i in digits]
|
||||
ax5.bar(digits, output_probs, color=colors)
|
||||
ax5.set_title('(5) Output Probabilities', fontsize=11)
|
||||
ax5.set_xlabel('数字')
|
||||
ax5.set_ylabel('概率')
|
||||
ax5.set_ylim(0, 1)
|
||||
ax5.set_xticks(digits)
|
||||
|
||||
# 6. 最大概率
|
||||
ax6 = fig.add_subplot(2, 4, 6)
|
||||
ax6.axis('off')
|
||||
predicted = np.argmax(output_probs)
|
||||
confidence = output_probs[predicted]
|
||||
result_text = f'预测: {predicted}\n置信度: {confidence:.1%}'
|
||||
ax6.text(0.5, 0.5, result_text, fontsize=24, ha='center', va='center',
|
||||
bbox=dict(boxstyle='round', facecolor='#2ecc71', alpha=0.9),
|
||||
transform=ax6.transAxes, color='white', fontweight='bold')
|
||||
ax6.set_title('(6) Recognition Result', fontsize=11)
|
||||
|
||||
# 7. 网络结构
|
||||
ax7 = fig.add_subplot(2, 4, 7)
|
||||
ax7.axis('off')
|
||||
structure_text = (
|
||||
'┌─────────────────┐\n'
|
||||
'│ 输入层 784 │\n'
|
||||
'│ (28×28展平) │\n'
|
||||
'└────────┬────────┘\n'
|
||||
' │\n'
|
||||
' 线性变换+ReLU\n'
|
||||
' │\n'
|
||||
'┌────────┴────────┐\n'
|
||||
'│ 隐藏层 128 │\n'
|
||||
'│ (特征提取) │\n'
|
||||
'└────────┬────────┘\n'
|
||||
' │\n'
|
||||
' 线性变换+Softmax\n'
|
||||
' │\n'
|
||||
'┌────────┴────────┐\n'
|
||||
'│ 输出层 10 │\n'
|
||||
'│ (数字0~9概率) │\n'
|
||||
'└─────────────────┘'
|
||||
)
|
||||
ax7.text(0.1, 0.95, structure_text, fontsize=9, va='top',
|
||||
family='monospace', transform=ax7.transAxes,
|
||||
bbox=dict(boxstyle='round', facecolor='#f8f9fa', alpha=0.9))
|
||||
ax7.set_title('(7) Network Structure', fontsize=11)
|
||||
|
||||
# 8. 参数量说明
|
||||
ax8 = fig.add_subplot(2, 4, 8)
|
||||
ax8.axis('off')
|
||||
params_text = (
|
||||
'MLP 参数量计算:\n\n'
|
||||
'W1: 784 × 128 = 100,352\n'
|
||||
'b1: 128\n\n'
|
||||
'W2: 128 × 10 = 1,280\n'
|
||||
'b2: 10\n\n'
|
||||
'─────────────────\n'
|
||||
'总计: 101,770 参数\n\n'
|
||||
'全部用 NumPy 实现\n'
|
||||
'无需任何深度学习框架!'
|
||||
)
|
||||
ax8.text(0.1, 0.95, params_text, fontsize=10, va='top',
|
||||
family='monospace', transform=ax8.transAxes)
|
||||
ax8.set_title('(8) Parameters', fontsize=11)
|
||||
|
||||
plt.suptitle('MLP Feature Maps Visualization - Handwritten Digits', fontsize=16, fontweight='bold', y=1.02)
|
||||
plt.tight_layout()
|
||||
plt.savefig(os.path.join(save_dir, 'summary.png'), dpi=120, bbox_inches='tight')
|
||||
plt.close()
|
||||
|
||||
|
||||
def load_model_or_train():
|
||||
"""加载已训练的模型,如果没有则训练一个"""
|
||||
import glob
|
||||
model_files = glob.glob('mnist_mlp_*.npy')
|
||||
|
||||
if model_files:
|
||||
timestamps = sorted(set(
|
||||
f.replace('mnist_mlp_', '').replace('_W1.npy', '')
|
||||
for f in model_files if '_W1.npy' in f
|
||||
))
|
||||
if timestamps:
|
||||
model_path = 'mnist_mlp_' + timestamps[-1]
|
||||
print(f"加载模型: {model_path}")
|
||||
|
||||
from model_numpy import MLP
|
||||
model = MLP(input_size=784, hidden_size=128, num_classes=10, learning_rate=0.1)
|
||||
model.W1 = np.load(f'{model_path}_W1.npy')
|
||||
model.b1 = np.load(f'{model_path}_b1.npy')
|
||||
model.W2 = np.load(f'{model_path}_W2.npy')
|
||||
model.b2 = np.load(f'{model_path}_b2.npy')
|
||||
return model
|
||||
|
||||
# 没有模型,用sklearn快速训练一个用于演示
|
||||
print("未找到已训练模型,使用sklearn数据快速训练演示模型...")
|
||||
from sklearn.datasets import fetch_openml
|
||||
from sklearn.model_selection import train_test_split
|
||||
|
||||
mnist = fetch_openml('mnist_784', version=1, as_frame=False, parser='auto')
|
||||
X = mnist.data[:10000].astype(np.float32) / 255.0
|
||||
y = mnist.target[:10000].astype(int)
|
||||
|
||||
# 用sklearn的MLP代替
|
||||
from sklearn.neural_network import MLPClassifier
|
||||
model = MLPClassifier(
|
||||
hidden_layer_sizes=(128,),
|
||||
max_iter=20,
|
||||
alpha=1e-4,
|
||||
solver='sgd',
|
||||
learning_rate_init=0.1,
|
||||
random_state=42,
|
||||
verbose=False
|
||||
)
|
||||
model.fit(X, y)
|
||||
print("sklearn模型训练完成(仅用于可视化演示)")
|
||||
return model
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
os.makedirs('visualizations', exist_ok=True)
|
||||
|
||||
# 加载数据
|
||||
from dataset import load_data
|
||||
print("加载MNIST数据集...")
|
||||
X_train, y_train, X_test, y_test = load_data()
|
||||
|
||||
# 加载模型
|
||||
model = load_model_or_train()
|
||||
|
||||
# 获取真实标签
|
||||
if len(y_test.shape) > 1:
|
||||
y_test_labels = np.argmax(y_test, axis=1)
|
||||
else:
|
||||
y_test_labels = y_test
|
||||
|
||||
# 可视化测试集前5张
|
||||
print("\n可视化测试集前5张图片...")
|
||||
for i in range(5):
|
||||
img = X_test[i]
|
||||
true_label = y_test_labels[i]
|
||||
|
||||
sub_dir = f'visualizations/sample_{i}_true{true_label}'
|
||||
os.makedirs(sub_dir, exist_ok=True)
|
||||
|
||||
if hasattr(model, 'predict_proba'):
|
||||
probs = model.predict_proba(img.reshape(1, -1))[0]
|
||||
predicted = np.argmax(probs)
|
||||
else:
|
||||
predicted = model.predict(img.reshape(1, -1))[0]
|
||||
|
||||
print(f" 样本{i}: 真实={true_label}, 预测={predicted}")
|
||||
visualize_activations(model, img, sub_dir)
|
||||
|
||||
# 创建对比汇总
|
||||
create_comparison_summary(X_test[:5], y_test_labels[:5], model, 'visualizations')
|
||||
|
||||
print("\n✅ 可视化完成!")
|
||||
print(" 查看 visualizations/ 目录下的图片和汇总图")
|
||||
|
||||
|
||||
def create_comparison_summary(X_samples, y_true, model, save_dir):
|
||||
"""创建多个样本的对比汇总图"""
|
||||
n_samples = len(X_samples)
|
||||
|
||||
fig = plt.figure(figsize=(4 * n_samples, 8))
|
||||
|
||||
for i in range(n_samples):
|
||||
img = X_samples[i]
|
||||
true_label = y_true[i]
|
||||
|
||||
if hasattr(model, 'predict_proba'):
|
||||
probs = model.predict_proba(img.reshape(1, -1))[0]
|
||||
predicted = np.argmax(probs)
|
||||
else:
|
||||
predicted = model.predict(img.reshape(1, -1))[0]
|
||||
|
||||
# 输入图像
|
||||
ax = fig.add_subplot(3, n_samples, i + 1)
|
||||
ax.imshow(img.reshape(28, 28), cmap='gray')
|
||||
ax.set_title(f'真实: {true_label}', fontsize=12)
|
||||
ax.axis('off')
|
||||
|
||||
# 隐藏层激活
|
||||
ax = fig.add_subplot(3, n_samples, i + 1 + n_samples)
|
||||
model.forward(img.reshape(1, -1))
|
||||
# 128 = 8 × 16
|
||||
hidden = model.a1[0].reshape(8, 16)
|
||||
ax.imshow(hidden, cmap='Blues', aspect='auto')
|
||||
ax.set_title(f'隐藏层激活', fontsize=10)
|
||||
ax.axis('off')
|
||||
|
||||
# 输出概率
|
||||
ax = fig.add_subplot(3, n_samples, i + 1 + 2*n_samples)
|
||||
digits = list(range(10))
|
||||
colors = ['#e74c3c' if d == predicted else '#3498db' for d in digits]
|
||||
ax.bar(digits, probs, color=colors)
|
||||
ax.set_title(f'预测: {predicted}', fontsize=12)
|
||||
ax.set_ylim(0, 1)
|
||||
ax.set_xticks(digits)
|
||||
ax.tick_params(labelsize=8)
|
||||
|
||||
plt.suptitle('多样本特征图对比', fontsize=16, fontweight='bold')
|
||||
plt.tight_layout()
|
||||
plt.savefig(os.path.join(save_dir, 'comparison.png'), dpi=120, bbox_inches='tight')
|
||||
plt.close()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
Reference in New Issue
Block a user