上传文件至 /
This commit is contained in:
195
hhh.py
Normal file
195
hhh.py
Normal file
@@ -0,0 +1,195 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import requests
|
||||
import json
|
||||
import csv
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from bs4 import BeautifulSoup
|
||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.neural_network import MLPClassifier
|
||||
from sklearn.metrics import precision_score
|
||||
import matplotlib.pyplot as plt
|
||||
import os
|
||||
|
||||
# ----------------------
|
||||
# 配置
|
||||
# ----------------------
|
||||
plt.rcParams["font.sans-serif"] = ["SimHei"]
|
||||
plt.rcParams["axes.unicode_minus"] = False
|
||||
|
||||
# 类别映射(题目要求的10类)
|
||||
LABEL_MAP = {
|
||||
"剧情": 0, "喜剧": 1, "科幻": 2, "悬疑": 3, "动作": 4,
|
||||
"爱情": 5, "动画": 6, "犯罪": 7, "奇幻": 8, "纪录": 9
|
||||
}
|
||||
REVERSE_LABEL_MAP = {v: k for k, v in LABEL_MAP.items()}
|
||||
|
||||
# ----------------------
|
||||
# 1. 数据采集:爬取豆瓣Top250前50部电影
|
||||
# ----------------------
|
||||
print("=" * 50)
|
||||
print("步骤1:爬取豆瓣Top250前50部电影数据")
|
||||
print("=" * 50)
|
||||
|
||||
base_url = "https://movie.douban.com/top250?start={}&filter="
|
||||
headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36"
|
||||
}
|
||||
|
||||
movies = []
|
||||
for page in range(2): # 前2页,共50条
|
||||
url = base_url.format(page * 25)
|
||||
response = requests.get(url, headers=headers)
|
||||
soup = BeautifulSoup(response.text, "html.parser")
|
||||
|
||||
for idx, item in enumerate(soup.select(".grid_view li")):
|
||||
rank = page * 25 + idx + 1
|
||||
title = item.select_one(".title").text.strip() if item.select_one(".title") else ""
|
||||
actors_tag = item.select_one(".bd p")
|
||||
actors = actors_tag.text.strip().split("\n")[0] if actors_tag else ""
|
||||
quote_tag = item.select_one(".quote")
|
||||
quote = quote_tag.text.strip() if quote_tag else ""
|
||||
|
||||
movies.append({
|
||||
"rank": rank,
|
||||
"title": title,
|
||||
"actors": actors,
|
||||
"quote": quote
|
||||
})
|
||||
|
||||
# 保存movies.json
|
||||
with open("movies.json", "w", encoding="utf-8") as f:
|
||||
json.dump(movies, f, ensure_ascii=False, indent=2)
|
||||
print("✅ movies.json 已生成")
|
||||
|
||||
# ----------------------
|
||||
# 2. 数据处理:过滤空短评,生成待标注文本
|
||||
# ----------------------
|
||||
print("\n步骤2:过滤有效短评,生成待标注文件")
|
||||
valid_quotes = [m["quote"] for m in movies if m.get("quote") and m["quote"].strip()]
|
||||
print(f"✅ 过滤出 {len(valid_quotes)} 条有效短评")
|
||||
|
||||
with open("quotes_processed.txt", "w", encoding="utf-8") as f:
|
||||
for q in valid_quotes:
|
||||
f.write(q + "\n")
|
||||
print("✅ quotes_processed.txt 已生成(可导入Label-Studio标注)")
|
||||
|
||||
# ----------------------
|
||||
# 3. 标注兼容处理(如果没有my_labels.csv,自动生成模拟数据测试流程)
|
||||
# ----------------------
|
||||
print("\n步骤3:读取标注数据(若不存在则生成模拟数据)")
|
||||
if not os.path.exists("my_labels.csv"):
|
||||
print("⚠️ 未找到my_labels.csv,生成模拟标注数据用于测试(提交作业请替换为真实标注数据)")
|
||||
df = pd.DataFrame({
|
||||
"text": valid_quotes,
|
||||
"label": np.random.choice(list(LABEL_MAP.keys()), size=len(valid_quotes))
|
||||
})
|
||||
df.to_csv("my_labels.csv", index=False, encoding="utf-8")
|
||||
else:
|
||||
print("✅ 读取已有的my_labels.csv标注数据")
|
||||
df = pd.read_csv("my_labels.csv")
|
||||
|
||||
df["label_id"] = df["label"].map(LABEL_MAP)
|
||||
|
||||
# ----------------------
|
||||
# 4. 模型训练:TF-IDF + MLP
|
||||
# ----------------------
|
||||
print("\n步骤4:模型训练(TF-IDF + MLP)")
|
||||
# 文本特征提取
|
||||
vectorizer = TfidfVectorizer()
|
||||
X = vectorizer.fit_transform(df["text"]).toarray()
|
||||
y = df["label_id"].values
|
||||
|
||||
# 划分训练/验证集
|
||||
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)
|
||||
|
||||
# 训练MLP并记录loss
|
||||
max_epochs = 50
|
||||
train_losses = []
|
||||
val_losses = []
|
||||
|
||||
model = MLPClassifier(
|
||||
hidden_layer_sizes=(64, 32),
|
||||
max_iter=1, warm_start=True,
|
||||
random_state=42
|
||||
)
|
||||
|
||||
for epoch in range(max_epochs):
|
||||
model.fit(X_train, y_train)
|
||||
train_loss = model.loss_
|
||||
val_acc = model.score(X_val, y_val)
|
||||
val_loss = 1 - val_acc
|
||||
|
||||
train_losses.append(train_loss)
|
||||
val_losses.append(val_loss)
|
||||
|
||||
# 保存loss.csv
|
||||
with open("loss.csv", "w", newline="", encoding="utf-8") as f:
|
||||
writer = csv.writer(f)
|
||||
writer.writerow(["epoch", "train_loss", "val_loss"])
|
||||
for i in range(max_epochs):
|
||||
writer.writerow([i + 1, round(train_losses[i], 4), round(val_losses[i], 4)])
|
||||
print("✅ loss.csv 已生成")
|
||||
|
||||
# 预测并保存predictions.csv
|
||||
y_pred = model.predict(X_val)
|
||||
precision = precision_score(y_val, y_pred, average="macro")
|
||||
print(f"✅ 模型训练完成,Macro Precision: {precision:.4f}")
|
||||
|
||||
with open("predictions.csv", "w", newline="", encoding="utf-8") as f:
|
||||
writer = csv.writer(f)
|
||||
writer.writerow(["true_label", "pred_label", "label_name"])
|
||||
for true, pred in zip(y_val, y_pred):
|
||||
writer.writerow([true, pred, REVERSE_LABEL_MAP[pred]])
|
||||
print("✅ predictions.csv 已生成")
|
||||
|
||||
# ----------------------
|
||||
# 5. 可视化:loss曲线 + 类别分布柱状图
|
||||
# ----------------------
|
||||
print("\n步骤5:生成可视化图片")
|
||||
if not os.path.exists("images"):
|
||||
os.makedirs("images")
|
||||
|
||||
# 图1:loss曲线
|
||||
loss_df = pd.read_csv("loss.csv")
|
||||
plt.figure(figsize=(10, 5))
|
||||
plt.plot(loss_df["epoch"], loss_df["train_loss"], label="训练集Loss", color="#1f77b4")
|
||||
plt.plot(loss_df["epoch"], loss_df["val_loss"], label="验证集Loss", color="#ff7f0e")
|
||||
plt.xlabel("Epoch")
|
||||
plt.ylabel("Loss")
|
||||
plt.title("训练集与验证集Loss曲线")
|
||||
plt.legend()
|
||||
plt.grid(alpha=0.3)
|
||||
plt.savefig("images/loss_curve.png", dpi=300, bbox_inches="tight")
|
||||
plt.close()
|
||||
|
||||
# 图2:类别预测分布柱状图
|
||||
pred_df = pd.read_csv("predictions.csv")
|
||||
label_counts = pred_df["label_name"].value_counts()
|
||||
all_labels = list(LABEL_MAP.keys())
|
||||
counts = [label_counts.get(label, 0) for label in all_labels]
|
||||
|
||||
plt.figure(figsize=(12, 6))
|
||||
plt.bar(all_labels, counts, color="#2ca02c")
|
||||
plt.xlabel("电影类别")
|
||||
plt.ylabel("预测数量")
|
||||
plt.title("验证集电影类别预测分布")
|
||||
plt.xticks(rotation=45)
|
||||
plt.grid(axis="y", alpha=0.3)
|
||||
plt.tight_layout()
|
||||
plt.savefig("images/category_bar.png", dpi=300, bbox_inches="tight")
|
||||
plt.close()
|
||||
|
||||
print("✅ 可视化图片已保存到images/文件夹")
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print("所有步骤完成!已生成以下文件:")
|
||||
print("- movies.json")
|
||||
print("- quotes_processed.txt")
|
||||
print("- my_labels.csv(若之前没有则为模拟数据)")
|
||||
print("- loss.csv")
|
||||
print("- predictions.csv")
|
||||
print("- images/loss_curve.png")
|
||||
print("- images/category_bar.png")
|
||||
print("=" * 50)
|
||||
Reference in New Issue
Block a user