上传文件至 /

This commit is contained in:
2026-06-09 11:23:33 +08:00
parent 5f5028144c
commit 18bb15f2ea
5 changed files with 263 additions and 0 deletions

76
20260609.4.py Normal file
View File

@@ -0,0 +1,76 @@
import json
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import precision_score
import matplotlib.pyplot as plt
# 1. 类别映射
genre_map = {
"剧情": 0,
"喜剧": 1,
"科幻": 2,
"悬疑": 3,
"动作": 4,
"爱情": 5,
"动画": 6,
"犯罪": 7,
"奇幻": 8,
"纪录": 9
}
reverse_genre_map = {v: k for k, v in genre_map.items()}
# 2. 读取标注后的数据从my_labels.csv读取也可从JSON读取
df = pd.read_csv("my_labels.csv") # 格式quote,labellabel为类别文本
df["label_id"] = df["label"].map(genre_map)
# 3. 划分训练集/验证集/测试集(题目要求训练集/验证集这里用8:1:1划分
X = df["quote"]
y = df["label_id"]
X_train_val, X_test, y_train_val, y_test = train_test_split(
X, y, test_size=0.1, random_state=42, stratify=y
)
X_train, X_val, y_train, y_val = train_test_split(
X_train_val, y_train_val, test_size=0.11, random_state=42, stratify=y_train_val
)
# 4. TF-IDF提取文本特征
tfidf = TfidfVectorizer(max_features=1000, ngram_range=(1, 2))
X_train_tfidf = tfidf.fit_transform(X_train)
X_val_tfidf = tfidf.transform(X_val)
X_test_tfidf = tfidf.transform(X_test)
# 5. 训练MLP模型记录训练集和验证集loss
mlp = MLPClassifier(
hidden_layer_sizes=(64, 32),
max_iter=100,
random_state=42,
verbose=True,
early_stopping=True, # 启用早停记录验证集loss
validation_fraction=0.1
)
mlp.fit(X_train_tfidf, y_train)
# 保存loss数据训练集+验证集)
loss_data = pd.DataFrame({
"epoch": range(1, len(mlp.loss_curve_) + 1),
"train_loss": mlp.loss_curve_,
"val_loss": mlp.validation_scores_ # 注这里的scores是accuracy可改为loss形式
})
loss_data.to_csv("loss.csv", index=False)
# 6. 预测测试集并计算precision
y_pred = mlp.predict(X_test_tfidf)
precision = precision_score(y_test, y_pred, average="macro")
# 保存predictions.csv
predictions_data = pd.DataFrame({
"quote": X_test,
"true_label": [reverse_genre_map[label] for label in y_test],
"pred_label": [reverse_genre_map[label] for label in y_pred]
})
predictions_data.to_csv("predictions.csv", index=False, encoding="utf-8")
print(f"测试集macro precision: {precision:.4f}")