上传文件至 /
This commit is contained in:
48
train_mlp.py
Normal file
48
train_mlp.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.neural_network import MLPClassifier
|
||||
from sklearn.metrics import precision_score
|
||||
import csv
|
||||
|
||||
# 1. 读取标注数据
|
||||
df = pd.read_csv("my_labels.csv")
|
||||
texts = df["text"].tolist()
|
||||
labels = df["label"].tolist()
|
||||
|
||||
# 2. TF-IDF文本特征提取
|
||||
tfidf = TfidfVectorizer()
|
||||
X = tfidf.fit_transform(texts)
|
||||
y = np.array(labels)
|
||||
|
||||
# 3. 划分训练集、验证集
|
||||
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)
|
||||
|
||||
# 4. 训练MLP,记录每轮loss
|
||||
mlp = MLPClassifier(hidden_layer_sizes=(128,64), max_iter=100, random_state=42)
|
||||
train_loss_list = []
|
||||
val_precision_list = []
|
||||
|
||||
for epoch in range(1, mlp.max_iter+1):
|
||||
mlp.partial_fit(X_train, y_train, classes=np.unique(y))
|
||||
# 记录训练loss
|
||||
train_loss_list.append({"epoch": epoch, "loss": mlp.loss_})
|
||||
# 验证集预测、计算precision
|
||||
y_pred_val = mlp.predict(X_val)
|
||||
val_prec = precision_score(y_val, y_pred_val, average="macro", zero_division=0)
|
||||
val_precision_list.append({"epoch": epoch, "precision": val_prec})
|
||||
|
||||
# 保存loss.csv
|
||||
with open("loss.csv", "w", newline="", encoding="utf-8") as f:
|
||||
writer = csv.DictWriter(f, fieldnames=["epoch", "loss"])
|
||||
writer.writeheader()
|
||||
writer.writerows(train_loss_list)
|
||||
|
||||
# 保存predictions.csv
|
||||
with open("predictions.csv", "w", newline="", encoding="utf-8") as f:
|
||||
writer = csv.DictWriter(f, fieldnames=["epoch", "precision"])
|
||||
writer.writeheader()
|
||||
writer.writerows(val_precision_list)
|
||||
|
||||
print("模型训练完成,已输出 loss.csv、predictions.csv")
|
||||
Reference in New Issue
Block a user