上传文件至 /
This commit is contained in:
225
main.py
225
main.py
@@ -1,34 +1,191 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
主程序入口
|
||||
|
||||
使用方式:
|
||||
|
||||
1. 运行单个模型(默认):
|
||||
python main.py
|
||||
|
||||
修改 config.py 中的 MODEL_TYPE 和 VECTORIZER_TYPE 来切换配置
|
||||
|
||||
2. 运行对比实验:
|
||||
修改 config.py 中 RUN_COMPARISON = True
|
||||
|
||||
这会依次运行:
|
||||
- 实验1: BoW vs TF-IDF (固定LR模型)
|
||||
- 实验2: LR vs MLP (固定TF-IDF)
|
||||
- 实验3: 不同学习率对比
|
||||
- 实验4: 不同隐藏层大小对比
|
||||
|
||||
最后输出汇总报告
|
||||
"""
|
||||
|
||||
from train import main
|
||||
|
||||
if __name__ == '__main__':
|
||||
print("\n" + "=" * 70)
|
||||
print("文本分类实验 - 纯NumPy实现")
|
||||
print("数据集: ChnSentiCorp (中文酒店评论)")
|
||||
print("模型: Logistic Regression / MLP")
|
||||
print("向量化: BoW / TF-IDF")
|
||||
print("=" * 70 + "\n")
|
||||
|
||||
main()
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
主程序 - 手写数字识别 MLP 纯NumPy实现
|
||||
|
||||
使用方法:
|
||||
python main.py # 运行默认配置
|
||||
python main.py --compare # 运行对比实验
|
||||
|
||||
依赖:
|
||||
pip install numpy requests
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import time
|
||||
from datetime import datetime
|
||||
from model_numpy import MLP
|
||||
from dataset import load_data
|
||||
from config import *
|
||||
|
||||
|
||||
def train_and_evaluate():
|
||||
"""
|
||||
训练并评估模型
|
||||
"""
|
||||
print("=" * 60)
|
||||
print("手写数字识别 - 纯NumPy MLP实现")
|
||||
print("=" * 60)
|
||||
|
||||
# ===== 加载数据 =====
|
||||
try:
|
||||
X_train, y_train, X_test, y_test = load_data()
|
||||
except Exception as e:
|
||||
print(f"\n错误: {e}")
|
||||
print("\n请手动下载数据文件:")
|
||||
print(" 1. 创建 data/ 目录")
|
||||
print(" 2. 下载以下文件到 data/:")
|
||||
print(" - train-images-idx3-ubyte.gz (9.9 MB)")
|
||||
print(" - train-labels-idx1-ubyte.gz (28 KB)")
|
||||
print(" - t10k-images-idx3-ubyte.gz (1.6 MB)")
|
||||
print(" - t10k-labels-idx1-ubyte.gz (5 KB)")
|
||||
print(" 下载地址: https://storage.googleapis.com/tensorflow/tf-keras-datasets/")
|
||||
return None, None, None
|
||||
|
||||
# ===== 创建模型 =====
|
||||
print("\n[2] 创建MLP模型...")
|
||||
model = MLP(
|
||||
input_size=INPUT_SIZE,
|
||||
hidden_size=HIDDEN_SIZE,
|
||||
num_classes=NUM_CLASSES,
|
||||
learning_rate=LEARNING_RATE,
|
||||
seed=SEED
|
||||
)
|
||||
|
||||
# ===== 训练模型 =====
|
||||
print("\n[3] 开始训练...")
|
||||
start_time = time.time()
|
||||
|
||||
model.fit(
|
||||
X_train, y_train,
|
||||
X_val=X_test, y_val=y_test,
|
||||
epochs=NUM_EPOCHS,
|
||||
batch_size=BATCH_SIZE,
|
||||
verbose=True
|
||||
)
|
||||
|
||||
train_time = time.time() - start_time
|
||||
|
||||
# ===== 最终评估 =====
|
||||
print("\n" + "=" * 60)
|
||||
print("训练完成!")
|
||||
print("=" * 60)
|
||||
|
||||
train_acc = model.accuracy(X_train, y_train)
|
||||
test_acc = model.accuracy(X_test, y_test)
|
||||
|
||||
print(f"\n最终结果:")
|
||||
print(f" 训练准确率: {train_acc:.4f} ({train_acc*100:.2f}%)")
|
||||
print(f" 测试准确率: {test_acc:.4f} ({test_acc*100:.2f}%)")
|
||||
print(f" 训练时间: {train_time:.2f} 秒")
|
||||
|
||||
# ===== 保存模型 =====
|
||||
timestamp = datetime.now().strftime("%m%d_%H%M%S")
|
||||
model_path = f"mnist_mlp_{timestamp}"
|
||||
model.save(model_path)
|
||||
|
||||
# ===== 预测示例 =====
|
||||
print("\n[4] 预测示例:")
|
||||
indices = np.random.choice(len(X_test), 5, replace=False)
|
||||
|
||||
for i, idx in enumerate(indices):
|
||||
img = X_test[idx]
|
||||
true_label = np.argmax(y_test[idx])
|
||||
pred_label = model.predict(img.reshape(1, -1))[0]
|
||||
prob = model.predict_proba(img.reshape(1, -1))[0]
|
||||
|
||||
status = '✓' if true_label == pred_label else '✗'
|
||||
print(f" 样本{i+1}: 真实={true_label}, 预测={pred_label}, "
|
||||
f"置信度={prob[pred_label]:.2f} {status}")
|
||||
|
||||
return model, train_acc, test_acc
|
||||
|
||||
|
||||
def run_comparison():
|
||||
"""
|
||||
运行对比实验
|
||||
"""
|
||||
print("\n" + "=" * 60)
|
||||
print("超参数对比实验")
|
||||
print("=" * 60)
|
||||
|
||||
# 加载数据
|
||||
try:
|
||||
X_train, y_train, X_test, y_test = load_data()
|
||||
except Exception as e:
|
||||
print(f"加载数据失败: {e}")
|
||||
return
|
||||
|
||||
# 实验配置
|
||||
experiments = [
|
||||
{"hidden_size": 32, "lr": 0.1, "name": "小模型(32神经元)"},
|
||||
{"hidden_size": 128, "lr": 0.1, "name": "标准(128神经元)"},
|
||||
{"hidden_size": 256, "lr": 0.1, "name": "大模型(256神经元)"},
|
||||
{"hidden_size": 128, "lr": 0.01, "name": "小学习率(0.01)"},
|
||||
{"hidden_size": 128, "lr": 0.5, "name": "大学习率(0.5)"},
|
||||
]
|
||||
|
||||
results = []
|
||||
|
||||
for exp in experiments:
|
||||
print(f"\n实验: {exp['name']}")
|
||||
print("-" * 40)
|
||||
|
||||
model = MLP(
|
||||
input_size=INPUT_SIZE,
|
||||
hidden_size=exp['hidden_size'],
|
||||
num_classes=NUM_CLASSES,
|
||||
learning_rate=exp['lr'],
|
||||
seed=SEED
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
model.fit(X_train, y_train, epochs=30, batch_size=BATCH_SIZE, verbose=False)
|
||||
train_time = time.time() - start_time
|
||||
|
||||
train_acc = model.accuracy(X_train, y_train)
|
||||
test_acc = model.accuracy(X_test, y_test)
|
||||
|
||||
results.append({
|
||||
'name': exp['name'],
|
||||
'hidden_size': exp['hidden_size'],
|
||||
'lr': exp['lr'],
|
||||
'train_acc': train_acc,
|
||||
'test_acc': test_acc,
|
||||
'train_time': train_time
|
||||
})
|
||||
|
||||
print(f" 训练准确率: {train_acc:.4f} | 测试准确率: {test_acc:.4f} | 时间: {train_time:.1f}s")
|
||||
|
||||
# 汇总
|
||||
print("\n" + "=" * 60)
|
||||
print("实验结果汇总")
|
||||
print("=" * 60)
|
||||
print(f"\n{'配置':<25} {'训练准确率':<12} {'测试准确率':<12} {'时间':<8}")
|
||||
print("-" * 60)
|
||||
|
||||
for r in results:
|
||||
print(f"{r['name']:<25} {r['train_acc']:<12.4f} {r['test_acc']:<12.4f} {r['train_time']:<8.1f}s")
|
||||
|
||||
best = max(results, key=lambda x: x['test_acc'])
|
||||
print(f"\n最佳配置: {best['name']}, 测试准确率: {best['test_acc']:.4f}")
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
if RUN_COMPARISON:
|
||||
run_comparison()
|
||||
else:
|
||||
train_and_evaluate()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("程序结束!")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import sys
|
||||
|
||||
if '--compare' in sys.argv:
|
||||
RUN_COMPARISON = True
|
||||
|
||||
main()
|
||||
Reference in New Issue
Block a user