Files
final-practice/20260609.4.py
2026-06-09 11:23:33 +08:00

76 lines
2.4 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import precision_score
import matplotlib.pyplot as plt
# 1. 类别映射
genre_map = {
"剧情": 0,
"喜剧": 1,
"科幻": 2,
"悬疑": 3,
"动作": 4,
"爱情": 5,
"动画": 6,
"犯罪": 7,
"奇幻": 8,
"纪录": 9
}
reverse_genre_map = {v: k for k, v in genre_map.items()}
# 2. 读取标注后的数据从my_labels.csv读取也可从JSON读取
df = pd.read_csv("my_labels.csv") # 格式quote,labellabel为类别文本
df["label_id"] = df["label"].map(genre_map)
# 3. 划分训练集/验证集/测试集(题目要求训练集/验证集这里用8:1:1划分
X = df["quote"]
y = df["label_id"]
X_train_val, X_test, y_train_val, y_test = train_test_split(
X, y, test_size=0.1, random_state=42, stratify=y
)
X_train, X_val, y_train, y_val = train_test_split(
X_train_val, y_train_val, test_size=0.11, random_state=42, stratify=y_train_val
)
# 4. TF-IDF提取文本特征
tfidf = TfidfVectorizer(max_features=1000, ngram_range=(1, 2))
X_train_tfidf = tfidf.fit_transform(X_train)
X_val_tfidf = tfidf.transform(X_val)
X_test_tfidf = tfidf.transform(X_test)
# 5. 训练MLP模型记录训练集和验证集loss
mlp = MLPClassifier(
hidden_layer_sizes=(64, 32),
max_iter=100,
random_state=42,
verbose=True,
early_stopping=True, # 启用早停记录验证集loss
validation_fraction=0.1
)
mlp.fit(X_train_tfidf, y_train)
# 保存loss数据训练集+验证集)
loss_data = pd.DataFrame({
"epoch": range(1, len(mlp.loss_curve_) + 1),
"train_loss": mlp.loss_curve_,
"val_loss": mlp.validation_scores_ # 注这里的scores是accuracy可改为loss形式
})
loss_data.to_csv("loss.csv", index=False)
# 6. 预测测试集并计算precision
y_pred = mlp.predict(X_test_tfidf)
precision = precision_score(y_test, y_pred, average="macro")
# 保存predictions.csv
predictions_data = pd.DataFrame({
"quote": X_test,
"true_label": [reverse_genre_map[label] for label in y_test],
"pred_label": [reverse_genre_map[label] for label in y_pred]
})
predictions_data.to_csv("predictions.csv", index=False, encoding="utf-8")
print(f"测试集macro precision: {precision:.4f}")