上传文件至 /
This commit is contained in:
42
visual_plot.py
Normal file
42
visual_plot.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
plt.rcParams["font.sans-serif"] = ["SimHei"]
|
||||
plt.rcParams["axes.unicode_minus"] = False
|
||||
|
||||
# ---------------------- 图1:Loss曲线(训练loss) ----------------------
|
||||
loss_df = pd.read_csv("loss.csv")
|
||||
plt.figure(figsize=(10,4))
|
||||
plt.plot(loss_df["epoch"], loss_df["loss"], color="#e74c3c", linewidth=2, label="训练Loss")
|
||||
plt.xlabel("Epoch 训练轮次")
|
||||
plt.ylabel("Loss 损失值")
|
||||
plt.title("MLP训练Loss变化曲线")
|
||||
plt.legend()
|
||||
plt.grid(alpha=0.3)
|
||||
plt.tight_layout()
|
||||
|
||||
# 创建images文件夹存放图片
|
||||
import os
|
||||
if not os.path.exists("images"):
|
||||
os.mkdir("images")
|
||||
plt.savefig("images/loss_curve.png", dpi=300)
|
||||
plt.close()
|
||||
|
||||
# ---------------------- 图2:10类别预测分布柱状图 ----------------------
|
||||
label_df = pd.read_csv("my_labels.csv")
|
||||
cate_count = label_df["label"].value_counts().sort_index()
|
||||
cate_names = ["剧情","喜剧","科幻","悬疑","动作","爱情","动画","犯罪","奇幻","记录"]
|
||||
|
||||
plt.figure(figsize=(10,4))
|
||||
bars = plt.bar([str(i) for i in range(10)], cate_count.values, color="#3498db")
|
||||
plt.xlabel("类别编号")
|
||||
plt.ylabel("样本数量")
|
||||
plt.title("10个电影类别样本分布柱状图")
|
||||
plt.xticks(range(10), cate_names, rotation=30)
|
||||
plt.grid(axis="y", alpha=0.3)
|
||||
plt.tight_layout()
|
||||
plt.savefig("images/category_bar.png", dpi=300)
|
||||
plt.close()
|
||||
|
||||
print("可视化绘图完成,图片保存在 images/ 文件夹")
|
||||
Reference in New Issue
Block a user