Files
task-3-2-2-text-classification/train.py
2026-04-27 21:43:12 +08:00

266 lines
8.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
训练与对比实验模块
功能:
1. 单模型训练与评估
2. 模型对比实验LR vs MLP
3. 向量化对比实验BoW vs TF-IDF
4. 超参数对比实验
"""
import numpy as np
import time
from datetime import datetime
from model_numpy import create_model
from dataset import load_data
from config import *
def train_single_model(X_train, y_train, X_test, y_test,
model_type, vectorizer_type, epochs, batch_size, lr, hidden_size):
"""
训练单个模型并返回结果
"""
print(f"\n{'='*60}")
print(f"训练配置:")
print(f" 模型: {model_type.upper()}")
print(f" 向量: {vectorizer_type.upper()}")
print(f" 学习率: {lr}")
if model_type == 'mlp':
print(f" 隐藏层大小: {hidden_size}")
print(f" 训练轮数: {epochs}")
print(f"{'='*60}")
# 重新加载数据(确保向量化方式正确)
X_tr, y_tr, X_te, y_te, _ = load_data(
DATA_DIR, MAX_FEATURES, MAX_SEQ_LEN, vectorizer_type
)
# 计算类别权重(解决数据不平衡问题)
pos_count = sum(y_tr)
neg_count = len(y_tr) - pos_count
n_samples = len(y_tr)
n_classes = 2
if USE_CLASS_WEIGHT:
# 权重 = n_samples / (n_classes * n_class_i)
# 负面样本少,权重更大
weight_pos = n_samples / (n_classes * pos_count)
weight_neg = n_samples / (n_classes * neg_count)
class_weight = {0: weight_neg, 1: weight_pos}
print(f" 类别权重: 正面={weight_pos:.2f}, 负面={weight_neg:.2f}")
else:
class_weight = None
print(f" 类别权重: 不使用")
# 创建模型
model = create_model(
model_type=model_type,
input_size=X_tr.shape[1],
hidden_size=hidden_size,
num_classes=NUM_CLASSES,
learning_rate=lr,
keep_prob=KEEP_PROB,
class_weight=class_weight
)
# 训练
start_time = time.time()
model.fit(X_tr, y_tr, X_te, y_te, epochs=epochs, batch_size=batch_size, verbose=True)
train_time = time.time() - start_time
# 最终评估
train_acc = model.accuracy(X_tr, y_tr)
test_acc = model.accuracy(X_te, y_te)
print(f"\n最终结果:")
print(f" 训练准确率: {train_acc:.4f}")
print(f" 测试准确率: {test_acc:.4f}")
print(f" 训练时间: {train_time:.2f}")
# 保存模型权重
# 生成模型文件名(包含配置标识和时间戳)
weight_suffix = "weighted" if USE_CLASS_WEIGHT else "noweight"
timestamp = datetime.now().strftime("%m%d_%H%M%S")
model_path = f"model_{model_type}_{vectorizer_type}_{weight_suffix}_{timestamp}"
model.save(model_path)
print(f" 权重文件: {model_path}_*.npy")
return {
'model_type': model_type,
'vectorizer_type': vectorizer_type,
'train_acc': train_acc,
'test_acc': test_acc,
'train_time': train_time,
'lr': lr,
'hidden_size': hidden_size
}
def run_comparison_experiments():
"""
运行对比实验
实验设计:
1. 不同向量化方法的效果对比 (BoW vs TF-IDF)
2. 不同模型的效果对比 (LR vs MLP)
3. 超参数对比 (学习率, 隐藏层大小)
"""
print("\n" + "=" * 70)
print("文本分类对比实验")
print("=" * 70)
results = []
# ===== 实验1: 向量化方法对比 (固定模型=LR) =====
print("\n\n" + "" * 20)
print("实验1: 向量化方法对比 (模型=LR, 学习率=0.05)")
print("" * 20)
for vec_type in COMPARE_VECTORS:
result = train_single_model(
None, None, None, None,
model_type='lr',
vectorizer_type=vec_type,
epochs=50,
batch_size=BATCH_SIZE,
lr=0.05,
hidden_size=HIDDEN_SIZE
)
results.append(result)
# ===== 实验2: 模型复杂度对比 (固定向量=TF-IDF) =====
print("\n\n" + "" * 20)
print("实验2: 模型复杂度对比 (向量=TF-IDF)")
print("" * 20)
for model_type in COMPARE_MODELS:
result = train_single_model(
None, None, None, None,
model_type=model_type,
vectorizer_type='tfidf',
epochs=50,
batch_size=BATCH_SIZE,
lr=LEARNING_RATE,
hidden_size=HIDDEN_SIZE
)
results.append(result)
# ===== 实验3: 学习率对比 (MLP) =====
print("\n\n" + "" * 20)
print("实验3: 学习率对比 (模型=MLP, 向量=TF-IDF)")
print("" * 20)
for lr in [0.01, 0.1]:
result = train_single_model(
None, None, None, None,
model_type='mlp',
vectorizer_type='tfidf',
epochs=50,
batch_size=BATCH_SIZE,
lr=lr,
hidden_size=64
)
results.append(result)
# ===== 实验4: 隐藏层大小对比 (MLP) =====
print("\n\n" + "" * 20)
print("实验4: 隐藏层大小对比 (模型=MLP, 向量=TF-IDF)")
print("" * 20)
for hidden_size in [32, 128]:
result = train_single_model(
None, None, None, None,
model_type='mlp',
vectorizer_type='tfidf',
epochs=50,
batch_size=BATCH_SIZE,
lr=LEARNING_RATE,
hidden_size=hidden_size
)
results.append(result)
# ===== 汇总报告 =====
print("\n\n" + "=" * 70)
print("实验结果汇总")
print("=" * 70)
print(f"\n{'配置':<40} {'训练准确率':<12} {'测试准确率':<12} {'训练时间(秒)':<10}")
print("-" * 76)
for r in results:
config = f"{r['model_type'].upper()} + {r['vectorizer_type'].upper()}"
if r['lr'] != LEARNING_RATE:
config += f", lr={r['lr']}"
if r['hidden_size'] != HIDDEN_SIZE:
config += f", h={r['hidden_size']}"
print(f"{config:<40} {r['train_acc']:<12.4f} {r['test_acc']:<12.4f} {r['train_time']:<10.2f}")
# 分析
print("\n" + "=" * 70)
print("结果分析")
print("=" * 70)
# 找出最佳配置
best = max(results, key=lambda x: x['test_acc'])
print(f"\n最佳测试准确率: {best['test_acc']:.4f}")
print(f"最佳配置: {best['model_type'].upper()} + {best['vectorizer_type'].upper()}")
# BoW vs TF-IDF 分析
bow_results = [r for r in results if r['vectorizer_type'] == 'bow' and r['model_type'] == 'lr']
tfidf_results = [r for r in results if r['vectorizer_type'] == 'tfidf' and r['model_type'] == 'lr']
if bow_results and tfidf_results:
bow_acc = bow_results[0]['test_acc']
tfidf_acc = tfidf_results[0]['test_acc']
print(f"\n向量化方法影响:")
print(f" BoW测试准确率: {bow_acc:.4f}")
print(f" TF-IDF测试准确率: {tfidf_acc:.4f}")
print(f" 差异: {abs(tfidf_acc - bow_acc):.4f} ({'TF-IDF更优' if tfidf_acc > bow_acc else 'BoW更优'})")
# LR vs MLP 分析
lr_results = [r for r in results if r['model_type'] == 'lr' and r['vectorizer_type'] == 'tfidf']
mlp_results = [r for r in results if r['model_type'] == 'mlp' and r['vectorizer_type'] == 'tfidf']
if lr_results and mlp_results:
lr_acc = lr_results[0]['test_acc']
mlp_acc = mlp_results[0]['test_acc']
print(f"\n模型复杂度影响:")
print(f" LR测试准确率: {lr_acc:.4f}")
print(f" MLP测试准确率: {mlp_acc:.4f}")
print(f" 差异: {abs(mlp_acc - lr_acc):.4f} ({'MLP更优' if mlp_acc > lr_acc else 'LR更优'})")
return results
def main():
"""主函数"""
print("=" * 70)
print("文本分类实验 - 纯NumPy实现")
print("数据集: ChnSentiCorp (中文酒店评论)")
print("=" * 70)
if RUN_COMPARISON:
# 运行对比实验
run_comparison_experiments()
else:
# 运行单个模型
train_single_model(
None, None, None, None,
model_type=MODEL_TYPE,
vectorizer_type=VECTORIZER_TYPE,
epochs=NUM_EPOCHS,
batch_size=BATCH_SIZE,
lr=LEARNING_RATE,
hidden_size=HIDDEN_SIZE
)
print("\n" + "=" * 70)
print("实验完成!")
print("=" * 70)
if __name__ == '__main__':
main()