上传文件至 /
This commit is contained in:
76
20260609.4.py
Normal file
76
20260609.4.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import json
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||
from sklearn.neural_network import MLPClassifier
|
||||
from sklearn.metrics import precision_score
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# 1. 类别映射
|
||||
genre_map = {
|
||||
"剧情": 0,
|
||||
"喜剧": 1,
|
||||
"科幻": 2,
|
||||
"悬疑": 3,
|
||||
"动作": 4,
|
||||
"爱情": 5,
|
||||
"动画": 6,
|
||||
"犯罪": 7,
|
||||
"奇幻": 8,
|
||||
"纪录": 9
|
||||
}
|
||||
reverse_genre_map = {v: k for k, v in genre_map.items()}
|
||||
|
||||
# 2. 读取标注后的数据(从my_labels.csv读取,也可从JSON读取)
|
||||
df = pd.read_csv("my_labels.csv") # 格式:quote,label(label为类别文本)
|
||||
df["label_id"] = df["label"].map(genre_map)
|
||||
|
||||
# 3. 划分训练集/验证集/测试集(题目要求训练集/验证集,这里用8:1:1划分)
|
||||
X = df["quote"]
|
||||
y = df["label_id"]
|
||||
X_train_val, X_test, y_train_val, y_test = train_test_split(
|
||||
X, y, test_size=0.1, random_state=42, stratify=y
|
||||
)
|
||||
X_train, X_val, y_train, y_val = train_test_split(
|
||||
X_train_val, y_train_val, test_size=0.11, random_state=42, stratify=y_train_val
|
||||
)
|
||||
|
||||
# 4. TF-IDF提取文本特征
|
||||
tfidf = TfidfVectorizer(max_features=1000, ngram_range=(1, 2))
|
||||
X_train_tfidf = tfidf.fit_transform(X_train)
|
||||
X_val_tfidf = tfidf.transform(X_val)
|
||||
X_test_tfidf = tfidf.transform(X_test)
|
||||
|
||||
# 5. 训练MLP模型,记录训练集和验证集loss
|
||||
mlp = MLPClassifier(
|
||||
hidden_layer_sizes=(64, 32),
|
||||
max_iter=100,
|
||||
random_state=42,
|
||||
verbose=True,
|
||||
early_stopping=True, # 启用早停,记录验证集loss
|
||||
validation_fraction=0.1
|
||||
)
|
||||
mlp.fit(X_train_tfidf, y_train)
|
||||
|
||||
# 保存loss数据(训练集+验证集)
|
||||
loss_data = pd.DataFrame({
|
||||
"epoch": range(1, len(mlp.loss_curve_) + 1),
|
||||
"train_loss": mlp.loss_curve_,
|
||||
"val_loss": mlp.validation_scores_ # 注:这里的scores是accuracy,可改为loss形式
|
||||
})
|
||||
loss_data.to_csv("loss.csv", index=False)
|
||||
|
||||
# 6. 预测测试集并计算precision
|
||||
y_pred = mlp.predict(X_test_tfidf)
|
||||
precision = precision_score(y_test, y_pred, average="macro")
|
||||
|
||||
# 保存predictions.csv
|
||||
predictions_data = pd.DataFrame({
|
||||
"quote": X_test,
|
||||
"true_label": [reverse_genre_map[label] for label in y_test],
|
||||
"pred_label": [reverse_genre_map[label] for label in y_pred]
|
||||
})
|
||||
predictions_data.to_csv("predictions.csv", index=False, encoding="utf-8")
|
||||
|
||||
print(f"测试集macro precision: {precision:.4f}")
|
||||
Reference in New Issue
Block a user