Files
final-practice/visual_plot.py
2026-06-09 11:18:50 +08:00

42 lines
1.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
plt.rcParams["font.sans-serif"] = ["SimHei"]
plt.rcParams["axes.unicode_minus"] = False
# ---------------------- 图1Loss曲线训练loss ----------------------
loss_df = pd.read_csv("loss.csv")
plt.figure(figsize=(10,4))
plt.plot(loss_df["epoch"], loss_df["loss"], color="#e74c3c", linewidth=2, label="训练Loss")
plt.xlabel("Epoch 训练轮次")
plt.ylabel("Loss 损失值")
plt.title("MLP训练Loss变化曲线")
plt.legend()
plt.grid(alpha=0.3)
plt.tight_layout()
# 创建images文件夹存放图片
import os
if not os.path.exists("images"):
os.mkdir("images")
plt.savefig("images/loss_curve.png", dpi=300)
plt.close()
# ---------------------- 图210类别预测分布柱状图 ----------------------
label_df = pd.read_csv("my_labels.csv")
cate_count = label_df["label"].value_counts().sort_index()
cate_names = ["剧情","喜剧","科幻","悬疑","动作","爱情","动画","犯罪","奇幻","记录"]
plt.figure(figsize=(10,4))
bars = plt.bar([str(i) for i in range(10)], cate_count.values, color="#3498db")
plt.xlabel("类别编号")
plt.ylabel("样本数量")
plt.title("10个电影类别样本分布柱状图")
plt.xticks(range(10), cate_names, rotation=30)
plt.grid(axis="y", alpha=0.3)
plt.tight_layout()
plt.savefig("images/category_bar.png", dpi=300)
plt.close()
print("可视化绘图完成,图片保存在 images/ 文件夹")