Files
final-practice/train_mlp.py
2026-06-09 11:18:50 +08:00

48 lines
1.7 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import pandas as pd
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.model_selection import train_test_split
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import precision_score
import csv
# 1. 读取标注数据
df = pd.read_csv("my_labels.csv")
texts = df["text"].tolist()
labels = df["label"].tolist()
# 2. TF-IDF文本特征提取
tfidf = TfidfVectorizer()
X = tfidf.fit_transform(texts)
y = np.array(labels)
# 3. 划分训练集、验证集
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)
# 4. 训练MLP记录每轮loss
mlp = MLPClassifier(hidden_layer_sizes=(128,64), max_iter=100, random_state=42)
train_loss_list = []
val_precision_list = []
for epoch in range(1, mlp.max_iter+1):
mlp.partial_fit(X_train, y_train, classes=np.unique(y))
# 记录训练loss
train_loss_list.append({"epoch": epoch, "loss": mlp.loss_})
# 验证集预测、计算precision
y_pred_val = mlp.predict(X_val)
val_prec = precision_score(y_val, y_pred_val, average="macro", zero_division=0)
val_precision_list.append({"epoch": epoch, "precision": val_prec})
# 保存loss.csv
with open("loss.csv", "w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=["epoch", "loss"])
writer.writeheader()
writer.writerows(train_loss_list)
# 保存predictions.csv
with open("predictions.csv", "w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=["epoch", "precision"])
writer.writeheader()
writer.writerows(val_precision_list)
print("模型训练完成,已输出 loss.csv、predictions.csv")