上传文件至 /
This commit is contained in:
79
config.py
79
config.py
@@ -1,40 +1,39 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
配置文件 - 所有超参数集中管理
|
||||
|
||||
设计思路:
|
||||
将超参数分门别类,学生可以单独修改某一类而不会影响其他
|
||||
"""
|
||||
|
||||
# ==================== 数据相关 ====================
|
||||
DATA_DIR = 'data/ChnSentiCorp' # 数据集路径
|
||||
MAX_FEATURES = 3000 # 词表最大容量
|
||||
MAX_SEQ_LEN = 100 # 句子最大长度(词数)
|
||||
VECTORIZER_TYPE = 'tfidf' # 'tfidf' 或 'bow'(向量化方式)
|
||||
|
||||
# ==================== 模型相关 ====================
|
||||
MODEL_TYPE = 'mlp' # 'mlp' 或 'lr'(模型类型)
|
||||
HIDDEN_SIZE = 64 # MLP隐藏层大小(LR忽略)
|
||||
NUM_CLASSES = 2 # 类别数(正面/负面二分类)
|
||||
KEEP_PROB = 1.0 # Dropout保留概率(LR忽略,设为1即可)
|
||||
|
||||
# ==================== 训练相关 ====================
|
||||
LEARNING_RATE = 0.06 # 学习率
|
||||
NUM_EPOCHS = 101 # 训练轮数
|
||||
BATCH_SIZE = 65 # 批次大小
|
||||
|
||||
# ==================== 类别权重(解决数据不平衡问题)====================
|
||||
USE_CLASS_WEIGHT = True # True=启用类别权重, False=不启用(对比用)
|
||||
# 权重计算公式: n_samples / (n_classes * n_class_i)
|
||||
# 正面评论多所以权重小,负面评论少所以权重大
|
||||
CLASS_WEIGHT_POS = 0.85 # 正面类权重(自动计算)
|
||||
CLASS_WEIGHT_NEG = 1.75 # 负面类权重(自动计算)
|
||||
|
||||
# ==================== 实验相关 ====================
|
||||
RUN_COMPARISON = False # True=运行对比实验, False=运行单个模型
|
||||
COMPARE_MODELS = ['lr', 'mlp'] # 要对比的模型列表
|
||||
COMPARE_VECTORS = ['bow', 'tfidf'] # 要对比的向量化方式
|
||||
|
||||
# ==================== 其他 ====================
|
||||
RANDOM_SEED = 42 # 随机种子(保证可复现)
|
||||
VERBOSE = True # 打印详细日志
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
手写数字识别 - 超参数配置
|
||||
|
||||
纯NumPy实现的两层全连接神经网络
|
||||
"""
|
||||
|
||||
# ===== 数据参数 =====
|
||||
ONE_HOT = True # 标签是否使用One-Hot编码
|
||||
|
||||
# ===== 模型结构 =====
|
||||
INPUT_SIZE = 784 # 28x28 = 784 像素
|
||||
HIDDEN_SIZE = 128 # 隐藏层神经元数量
|
||||
NUM_CLASSES = 10 # 0-9 十个数字
|
||||
KEEP_PROB = 1.0 # Dropout保留比例(1.0=不使用Dropout)
|
||||
|
||||
# ===== 训练参数 =====
|
||||
LEARNING_RATE = 0.1 # 学习率
|
||||
NUM_EPOCHS = 50 # 训练轮数
|
||||
BATCH_SIZE = 64 # 批大小
|
||||
|
||||
# ===== 随机种子(保证可复现) =====
|
||||
SEED = 42
|
||||
|
||||
# ===== 实验配置 =====
|
||||
RUN_COMPARISON = False # 是否运行对比实验
|
||||
|
||||
# ===== 依赖说明 =====
|
||||
# 本项目需要以下库:
|
||||
# numpy - 数值计算
|
||||
# scikit-learn - 加载MNIST数据集(会自动下载)
|
||||
# pandas - sklearn的依赖
|
||||
#
|
||||
# 安装命令:
|
||||
# pip install numpy scikit-learn pandas
|
||||
#
|
||||
# 数据说明:
|
||||
# 首次运行时会自动从OpenML下载MNIST数据集(约12MB)
|
||||
# 下载后会自动缓存,后续运行直接使用缓存数据
|
||||
465
dataset.py
465
dataset.py
@@ -1,286 +1,179 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
数据加载与向量化模块
|
||||
|
||||
支持两种向量化方法:
|
||||
1. BoW (Bag of Words) - 词频向量
|
||||
2. TF-IDF - 词频-逆文档频率向量
|
||||
|
||||
TF-IDF 的优势:
|
||||
- 降低常见词(如"的"、"是")的权重
|
||||
- 提升罕见词的信息量
|
||||
- 通常效果优于简单BoW
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import csv
|
||||
import math
|
||||
import jieba
|
||||
import numpy as np
|
||||
from collections import Counter
|
||||
|
||||
try:
|
||||
import urllib.request
|
||||
import ssl
|
||||
DOWNLOAD_AVAILABLE = True
|
||||
except ImportError:
|
||||
DOWNLOAD_AVAILABLE = False
|
||||
|
||||
|
||||
DATASET_URL = "https://raw.githubusercontent.com/SophonPlus/ChineseNlpCorpus/master/datasets/ChnSentiCorp_htl_all/ChnSentiCorp_htl_all.csv"
|
||||
|
||||
|
||||
def download_dataset(data_dir):
|
||||
"""下载数据集(如果不存在)"""
|
||||
csv_path = os.path.join(data_dir, 'ChnSentiCorp_htl_all.csv')
|
||||
|
||||
if os.path.exists(csv_path):
|
||||
print(f"数据已存在: {csv_path}")
|
||||
return True
|
||||
|
||||
if not DOWNLOAD_AVAILABLE:
|
||||
return False
|
||||
|
||||
print("正在下载数据集...")
|
||||
ssl_context = ssl.create_default_context()
|
||||
ssl_context.check_hostname = False
|
||||
ssl_context.verify_mode = ssl.CERT_NONE
|
||||
|
||||
try:
|
||||
request = urllib.request.Request(DATASET_URL, headers={'User-Agent': 'Mozilla/5.0'})
|
||||
response = urllib.request.urlopen(request, timeout=120, context=ssl_context)
|
||||
os.makedirs(data_dir, exist_ok=True)
|
||||
with open(csv_path, 'wb') as f:
|
||||
f.write(response.read())
|
||||
print(f"下载完成: {csv_path}")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"下载失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def load_raw_data(data_dir):
|
||||
"""加载原始数据"""
|
||||
csv_path = os.path.join(data_dir, 'ChnSentiCorp_htl_all.csv')
|
||||
texts, labels = [], []
|
||||
|
||||
with open(csv_path, 'r', encoding='utf-8') as f:
|
||||
reader = csv.reader(f)
|
||||
for row in reader:
|
||||
if len(row) < 2:
|
||||
continue
|
||||
try:
|
||||
label = int(row[0])
|
||||
review = row[1].strip()
|
||||
if review:
|
||||
texts.append(review)
|
||||
labels.append(label)
|
||||
except (ValueError, IndexError):
|
||||
continue
|
||||
|
||||
return texts, np.array(labels)
|
||||
|
||||
|
||||
def tokenize(text):
|
||||
"""中文分词"""
|
||||
text = re.sub(r'[^\u4e00-\u9fa5a-zA-Z]', ' ', text)
|
||||
words = jieba.lcut(text)
|
||||
return [w for w in words if len(w) > 1]
|
||||
|
||||
|
||||
# ==================== 向量化器 ====================
|
||||
|
||||
class BaseVectorizer:
|
||||
"""向量化器基类"""
|
||||
def fit(self, texts): pass
|
||||
def transform(self, texts): pass
|
||||
def fit_transform(self, texts): pass
|
||||
|
||||
|
||||
class BoWVectorizer(BaseVectorizer):
|
||||
"""
|
||||
词袋模型 (Bag of Words)
|
||||
|
||||
原理:统计每个词在文本中出现的次数
|
||||
向量维度 = 词表大小
|
||||
每个维度 = 该词在本文本中出现的次数
|
||||
"""
|
||||
|
||||
def __init__(self, max_features, max_seq_len):
|
||||
self.max_features = max_features
|
||||
self.max_seq_len = max_seq_len
|
||||
self.vocab = {}
|
||||
self.doc_freq = {} # 文档频率
|
||||
self.num_docs = 0
|
||||
|
||||
def fit(self, texts):
|
||||
"""构建词表(基于词频)"""
|
||||
counter = Counter()
|
||||
doc_counter = Counter() # 统计包含该词的文档数
|
||||
|
||||
for text in texts:
|
||||
words = tokenize(text)
|
||||
unique_words = set(words)
|
||||
counter.update(words)
|
||||
for w in unique_words:
|
||||
doc_counter[w] += 1
|
||||
|
||||
self.num_docs = len(texts)
|
||||
|
||||
# 取最高频的词
|
||||
most_common = counter.most_common(self.max_features)
|
||||
self.vocab = {word: idx for idx, (word, _) in enumerate(most_common)}
|
||||
|
||||
# 记录文档频率(用于TF-IDF)
|
||||
self.doc_freq = {w: doc_counter[w] for w in self.vocab}
|
||||
|
||||
print(f" BoW词表大小: {len(self.vocab)}")
|
||||
return self
|
||||
|
||||
def transform(self, texts):
|
||||
"""将文本转换为词频向量"""
|
||||
vectors = []
|
||||
for text in texts:
|
||||
words = tokenize(text)
|
||||
freq = [0] * self.max_seq_len
|
||||
for i, word in enumerate(words[:self.max_seq_len]):
|
||||
if word in self.vocab:
|
||||
freq[i] = 1 # 二值(出现=1,不出现=0)
|
||||
vectors.append(freq)
|
||||
return np.array(vectors, dtype=np.float32)
|
||||
|
||||
def fit_transform(self, texts):
|
||||
self.fit(texts)
|
||||
return self.transform(texts)
|
||||
|
||||
|
||||
class TFIDFVectorizer(BaseVectorizer):
|
||||
"""
|
||||
TF-IDF 向量器
|
||||
|
||||
原理:
|
||||
- TF(词频) = 词在本文本中出现的次数
|
||||
- IDF(逆文档频率) = log(总文档数 / 包含该词的文档数)
|
||||
- TF-IDF = TF × IDF
|
||||
|
||||
优势:
|
||||
- 降低常见无意义词的权重(如"的"、"是")
|
||||
- 提升罕见但有信息量的词
|
||||
"""
|
||||
|
||||
def __init__(self, max_features, max_seq_len):
|
||||
self.max_features = max_features
|
||||
self.max_seq_len = max_seq_len
|
||||
self.vocab = {}
|
||||
self.idf = {} # 存储每个词的IDF值
|
||||
self.num_docs = 0
|
||||
|
||||
def fit(self, texts):
|
||||
"""构建词表并计算IDF"""
|
||||
counter = Counter()
|
||||
doc_counter = Counter()
|
||||
|
||||
for text in texts:
|
||||
words = tokenize(text)
|
||||
unique_words = set(words)
|
||||
counter.update(words)
|
||||
for w in unique_words:
|
||||
doc_counter[w] += 1
|
||||
|
||||
self.num_docs = len(texts)
|
||||
|
||||
# 计算每个词的IDF
|
||||
# IDF = log(总文档数 / 包含该词的文档数)
|
||||
idf_values = {}
|
||||
for word, df in doc_counter.items():
|
||||
idf_values[word] = math.log(self.num_docs / (df + 1)) + 1 # 加1防零
|
||||
|
||||
# 取IDF值最高的词(信息量最大的词)
|
||||
sorted_words = sorted(idf_values.items(), key=lambda x: x[1], reverse=True)
|
||||
self.vocab = {word: idx for idx, (word, _) in enumerate(sorted_words[:self.max_features])}
|
||||
|
||||
# 保存IDF值
|
||||
self.idf = {word: idf_values[word] for word in self.vocab}
|
||||
|
||||
print(f" TF-IDF词表大小: {len(self.vocab)}")
|
||||
print(f" 平均IDF: {np.mean(list(self.idf.values())):.3f}")
|
||||
return self
|
||||
|
||||
def transform(self, texts):
|
||||
"""将文本转换为TF-IDF向量"""
|
||||
vectors = []
|
||||
for text in texts:
|
||||
words = tokenize(text)
|
||||
|
||||
# 计算TF
|
||||
tf = Counter(words)
|
||||
tf_sum = len(words) if words else 1
|
||||
|
||||
# 生成向量
|
||||
vec = [0.0] * self.max_seq_len
|
||||
for i, word in enumerate(words[:self.max_seq_len]):
|
||||
if word in self.vocab:
|
||||
# TF × IDF
|
||||
vec[i] = (tf[word] / tf_sum) * self.idf.get(word, 0)
|
||||
vectors.append(vec)
|
||||
|
||||
return np.array(vectors, dtype=np.float32)
|
||||
|
||||
def fit_transform(self, texts):
|
||||
self.fit(texts)
|
||||
return self.transform(texts)
|
||||
|
||||
|
||||
def load_data(data_dir, max_features, max_seq_len, vectorizer_type='tfidf'):
|
||||
"""
|
||||
加载并向量化数据
|
||||
|
||||
参数:
|
||||
- vectorizer_type: 'tfidf' 或 'bow'
|
||||
"""
|
||||
if not download_dataset(data_dir):
|
||||
raise RuntimeError("数据加载失败,请检查网络或手动下载数据集")
|
||||
|
||||
print("正在加载数据...")
|
||||
texts, labels = load_raw_data(data_dir)
|
||||
print(f"总评论数: {len(texts)}, 正面: {sum(labels)}, 负面: {len(labels) - sum(labels)}")
|
||||
|
||||
# 选择向量化器
|
||||
if vectorizer_type == 'tfidf':
|
||||
vectorizer = TFIDFVectorizer(max_features, max_seq_len)
|
||||
vec_name = "TF-IDF"
|
||||
else:
|
||||
vectorizer = BoWVectorizer(max_features, max_seq_len)
|
||||
vec_name = "BoW"
|
||||
|
||||
print(f"正在使用{vec_name}向量化...")
|
||||
X = vectorizer.fit_transform(texts)
|
||||
y = labels
|
||||
|
||||
# 打乱并划分
|
||||
np.random.seed(42)
|
||||
indices = np.random.permutation(len(X))
|
||||
X = X[indices]
|
||||
y = y[indices]
|
||||
|
||||
split_idx = int(len(X) * 0.8)
|
||||
X_train, X_test = X[:split_idx], X[split_idx:]
|
||||
y_train, y_test = y[:split_idx], y[split_idx:]
|
||||
|
||||
print(f"训练集: {len(X_train)}条, 测试集: {len(X_test)}条")
|
||||
|
||||
return X_train, y_train, X_test, y_test, vectorizer
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 测试
|
||||
print("=" * 60)
|
||||
print("测试 TF-IDF 向量化")
|
||||
print("=" * 60)
|
||||
X_train, y_train, X_test, y_test, vec = load_data(
|
||||
'data/ChnSentiCorp', max_features=3000, max_seq_len=100,
|
||||
vectorizer_type='tfidf'
|
||||
)
|
||||
print(f"\nX_train shape: {X_train.shape}")
|
||||
print(f"X_train sample (前5个特征): {X_train[0][:5]}")
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
数据集模块 - MNIST手写数字数据集加载
|
||||
|
||||
优先从本地data/目录加载,如果文件不存在则从sklearn下载
|
||||
支持两种格式:.gz(官方格式)和 .zip(某些下载源)
|
||||
"""
|
||||
|
||||
import os
|
||||
import struct
|
||||
import gzip
|
||||
import zipfile
|
||||
import numpy as np
|
||||
from config import *
|
||||
|
||||
|
||||
def local_files_exist():
|
||||
"""检查本地数据文件是否存在且完整"""
|
||||
data_dir = os.path.join(os.path.dirname(__file__), 'data')
|
||||
|
||||
# 支持 .gz 和 .zip 格式(MNIST官方用.gz,但有些下载是zip)
|
||||
files = {
|
||||
'train-images-idx3-ubyte': {'gz': 9912422, 'zip': 9187390},
|
||||
'train-labels-idx1-ubyte': {'gz': 28881, 'zip': 28405},
|
||||
't10k-images-idx3-ubyte': {'gz': 1648877, 'zip': 1534055},
|
||||
't10k-labels-idx1-ubyte': {'gz': 5148, 'zip': 4563},
|
||||
}
|
||||
|
||||
found_files = {}
|
||||
missing = []
|
||||
|
||||
for base_name, sizes in files.items():
|
||||
gz_path = os.path.join(data_dir, base_name + '.gz')
|
||||
zip_path = os.path.join(data_dir, base_name + '.zip')
|
||||
|
||||
if os.path.exists(gz_path):
|
||||
found_files[base_name] = (gz_path, sizes['gz'], 'gz')
|
||||
elif os.path.exists(zip_path):
|
||||
found_files[base_name] = (zip_path, sizes['zip'], 'zip')
|
||||
else:
|
||||
missing.append(base_name)
|
||||
|
||||
if missing:
|
||||
return False, f"文件不存在: {', '.join(missing)}"
|
||||
|
||||
# 检查大小是否正确
|
||||
for base_name, (filepath, expected_size, fmt) in found_files.items():
|
||||
actual_size = os.path.getsize(filepath)
|
||||
if actual_size != expected_size:
|
||||
return False, f"文件大小错误: {base_name} (期望{expected_size}, 实际{actual_size})"
|
||||
|
||||
return True, "所有文件完整"
|
||||
|
||||
|
||||
def parse_idx_images(filepath):
|
||||
"""解析IDX格式图像(支持.gz和.zip)"""
|
||||
if filepath.endswith('.zip'):
|
||||
with zipfile.ZipFile(filepath, 'r') as zf:
|
||||
# zip内的文件名没有.gz后缀
|
||||
inner_name = zf.namelist()[0]
|
||||
with zf.open(inner_name) as f:
|
||||
magic, num, rows, cols = struct.unpack('>IIII', f.read(16))
|
||||
images = np.frombuffer(f.read(), dtype=np.uint8)
|
||||
images = images.reshape(num, rows * cols)
|
||||
return images
|
||||
else:
|
||||
with gzip.open(filepath, 'rb') as f:
|
||||
magic, num, rows, cols = struct.unpack('>IIII', f.read(16))
|
||||
images = np.frombuffer(f.read(), dtype=np.uint8)
|
||||
images = images.reshape(num, rows * cols)
|
||||
return images
|
||||
|
||||
|
||||
def parse_idx_labels(filepath):
|
||||
"""解析IDX格式标签(支持.gz和.zip)"""
|
||||
if filepath.endswith('.zip'):
|
||||
with zipfile.ZipFile(filepath, 'r') as zf:
|
||||
# zip内的文件名没有.gz后缀
|
||||
inner_name = zf.namelist()[0]
|
||||
with zf.open(inner_name) as f:
|
||||
magic, num = struct.unpack('>II', f.read(8))
|
||||
labels = np.frombuffer(f.read(), dtype=np.uint8)
|
||||
return labels
|
||||
else:
|
||||
with gzip.open(filepath, 'rb') as f:
|
||||
magic, num = struct.unpack('>II', f.read(8))
|
||||
labels = np.frombuffer(f.read(), dtype=np.uint8)
|
||||
return labels
|
||||
|
||||
|
||||
def load_data_from_local():
|
||||
"""从本地文件加载MNIST(自动检测.gz或.zip格式)"""
|
||||
data_dir = os.path.join(os.path.dirname(__file__), 'data')
|
||||
|
||||
def find_file(base_name):
|
||||
"""自动找文件,支持.gz和.zip"""
|
||||
gz_path = os.path.join(data_dir, base_name + '.gz')
|
||||
zip_path = os.path.join(data_dir, base_name + '.zip')
|
||||
if os.path.exists(gz_path):
|
||||
return gz_path
|
||||
elif os.path.exists(zip_path):
|
||||
return zip_path
|
||||
else:
|
||||
raise FileNotFoundError(f"找不到 {base_name} 的 .gz 或 .zip 文件")
|
||||
|
||||
X_train = parse_idx_images(find_file('train-images-idx3-ubyte'))
|
||||
y_train = parse_idx_labels(find_file('train-labels-idx1-ubyte'))
|
||||
X_test = parse_idx_images(find_file('t10k-images-idx3-ubyte'))
|
||||
y_test = parse_idx_labels(find_file('t10k-labels-idx1-ubyte'))
|
||||
|
||||
return X_train, y_train, X_test, y_test
|
||||
|
||||
|
||||
def load_data_from_sklearn():
|
||||
"""从sklearn加载MNIST(备选方案)"""
|
||||
from sklearn.datasets import fetch_openml
|
||||
|
||||
print(" 正在从OpenML下载数据(首次可能需要1-2分钟)...")
|
||||
|
||||
mnist = fetch_openml('mnist_784', version=1, as_frame=False, parser='auto')
|
||||
X = mnist.data.astype(np.float32)
|
||||
y = mnist.target.astype(int)
|
||||
|
||||
X_train = X[:60000] / 255.0
|
||||
X_test = X[60000:] / 255.0
|
||||
y_train = y[:60000]
|
||||
y_test = y[60000:]
|
||||
|
||||
return X_train, y_train, X_test, y_test
|
||||
|
||||
|
||||
def one_hot_encode(y, num_classes=10):
|
||||
one_hot = np.zeros((len(y), num_classes))
|
||||
one_hot[np.arange(len(y)), y] = 1
|
||||
return one_hot
|
||||
|
||||
|
||||
def load_data():
|
||||
"""
|
||||
加载MNIST数据集
|
||||
|
||||
优先从本地data/目录加载,如果文件不完整则从sklearn下载
|
||||
"""
|
||||
print("\n" + "=" * 50)
|
||||
print("MNIST 数据集加载")
|
||||
print("=" * 50)
|
||||
|
||||
# 优先检查本地文件
|
||||
exists, msg = local_files_exist()
|
||||
if exists:
|
||||
print(f"\n ✓ 发现本地数据文件: {msg}")
|
||||
X_train, y_train, X_test, y_test = load_data_from_local()
|
||||
else:
|
||||
print(f"\n 本地文件: {msg}")
|
||||
print(" 尝试从sklearn下载...")
|
||||
try:
|
||||
X_train, y_train, X_test, y_test = load_data_from_sklearn()
|
||||
except Exception as e:
|
||||
print(f"\n 下载失败: {e}")
|
||||
print("\n 请确保 data/ 目录下有完整的4个数据文件!")
|
||||
raise
|
||||
|
||||
# 归一化和One-Hot
|
||||
X_train = X_train.astype(np.float32) / 255.0
|
||||
X_test = X_test.astype(np.float32) / 255.0
|
||||
y_train = one_hot_encode(y_train, NUM_CLASSES)
|
||||
y_test = one_hot_encode(y_test, NUM_CLASSES)
|
||||
|
||||
print(f"\n ✓ 完成!")
|
||||
print(f" 训练集: {X_train.shape[0]} 样本")
|
||||
print(f" 测试集: {X_test.shape[0]} 样本")
|
||||
print(f" 数值范围: [{X_train.min():.2f}, {X_train.max():.2f}]")
|
||||
|
||||
return X_train, y_train, X_test, y_test
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
X_train, y_train, X_test, y_test = load_data()
|
||||
print(f"\n训练数据: {X_train.shape}")
|
||||
225
main.py
225
main.py
@@ -1,34 +1,191 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
主程序入口
|
||||
|
||||
使用方式:
|
||||
|
||||
1. 运行单个模型(默认):
|
||||
python main.py
|
||||
|
||||
修改 config.py 中的 MODEL_TYPE 和 VECTORIZER_TYPE 来切换配置
|
||||
|
||||
2. 运行对比实验:
|
||||
修改 config.py 中 RUN_COMPARISON = True
|
||||
|
||||
这会依次运行:
|
||||
- 实验1: BoW vs TF-IDF (固定LR模型)
|
||||
- 实验2: LR vs MLP (固定TF-IDF)
|
||||
- 实验3: 不同学习率对比
|
||||
- 实验4: 不同隐藏层大小对比
|
||||
|
||||
最后输出汇总报告
|
||||
"""
|
||||
|
||||
from train import main
|
||||
|
||||
if __name__ == '__main__':
|
||||
print("\n" + "=" * 70)
|
||||
print("文本分类实验 - 纯NumPy实现")
|
||||
print("数据集: ChnSentiCorp (中文酒店评论)")
|
||||
print("模型: Logistic Regression / MLP")
|
||||
print("向量化: BoW / TF-IDF")
|
||||
print("=" * 70 + "\n")
|
||||
|
||||
main()
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
主程序 - 手写数字识别 MLP 纯NumPy实现
|
||||
|
||||
使用方法:
|
||||
python main.py # 运行默认配置
|
||||
python main.py --compare # 运行对比实验
|
||||
|
||||
依赖:
|
||||
pip install numpy requests
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import time
|
||||
from datetime import datetime
|
||||
from model_numpy import MLP
|
||||
from dataset import load_data
|
||||
from config import *
|
||||
|
||||
|
||||
def train_and_evaluate():
|
||||
"""
|
||||
训练并评估模型
|
||||
"""
|
||||
print("=" * 60)
|
||||
print("手写数字识别 - 纯NumPy MLP实现")
|
||||
print("=" * 60)
|
||||
|
||||
# ===== 加载数据 =====
|
||||
try:
|
||||
X_train, y_train, X_test, y_test = load_data()
|
||||
except Exception as e:
|
||||
print(f"\n错误: {e}")
|
||||
print("\n请手动下载数据文件:")
|
||||
print(" 1. 创建 data/ 目录")
|
||||
print(" 2. 下载以下文件到 data/:")
|
||||
print(" - train-images-idx3-ubyte.gz (9.9 MB)")
|
||||
print(" - train-labels-idx1-ubyte.gz (28 KB)")
|
||||
print(" - t10k-images-idx3-ubyte.gz (1.6 MB)")
|
||||
print(" - t10k-labels-idx1-ubyte.gz (5 KB)")
|
||||
print(" 下载地址: https://storage.googleapis.com/tensorflow/tf-keras-datasets/")
|
||||
return None, None, None
|
||||
|
||||
# ===== 创建模型 =====
|
||||
print("\n[2] 创建MLP模型...")
|
||||
model = MLP(
|
||||
input_size=INPUT_SIZE,
|
||||
hidden_size=HIDDEN_SIZE,
|
||||
num_classes=NUM_CLASSES,
|
||||
learning_rate=LEARNING_RATE,
|
||||
seed=SEED
|
||||
)
|
||||
|
||||
# ===== 训练模型 =====
|
||||
print("\n[3] 开始训练...")
|
||||
start_time = time.time()
|
||||
|
||||
model.fit(
|
||||
X_train, y_train,
|
||||
X_val=X_test, y_val=y_test,
|
||||
epochs=NUM_EPOCHS,
|
||||
batch_size=BATCH_SIZE,
|
||||
verbose=True
|
||||
)
|
||||
|
||||
train_time = time.time() - start_time
|
||||
|
||||
# ===== 最终评估 =====
|
||||
print("\n" + "=" * 60)
|
||||
print("训练完成!")
|
||||
print("=" * 60)
|
||||
|
||||
train_acc = model.accuracy(X_train, y_train)
|
||||
test_acc = model.accuracy(X_test, y_test)
|
||||
|
||||
print(f"\n最终结果:")
|
||||
print(f" 训练准确率: {train_acc:.4f} ({train_acc*100:.2f}%)")
|
||||
print(f" 测试准确率: {test_acc:.4f} ({test_acc*100:.2f}%)")
|
||||
print(f" 训练时间: {train_time:.2f} 秒")
|
||||
|
||||
# ===== 保存模型 =====
|
||||
timestamp = datetime.now().strftime("%m%d_%H%M%S")
|
||||
model_path = f"mnist_mlp_{timestamp}"
|
||||
model.save(model_path)
|
||||
|
||||
# ===== 预测示例 =====
|
||||
print("\n[4] 预测示例:")
|
||||
indices = np.random.choice(len(X_test), 5, replace=False)
|
||||
|
||||
for i, idx in enumerate(indices):
|
||||
img = X_test[idx]
|
||||
true_label = np.argmax(y_test[idx])
|
||||
pred_label = model.predict(img.reshape(1, -1))[0]
|
||||
prob = model.predict_proba(img.reshape(1, -1))[0]
|
||||
|
||||
status = '✓' if true_label == pred_label else '✗'
|
||||
print(f" 样本{i+1}: 真实={true_label}, 预测={pred_label}, "
|
||||
f"置信度={prob[pred_label]:.2f} {status}")
|
||||
|
||||
return model, train_acc, test_acc
|
||||
|
||||
|
||||
def run_comparison():
|
||||
"""
|
||||
运行对比实验
|
||||
"""
|
||||
print("\n" + "=" * 60)
|
||||
print("超参数对比实验")
|
||||
print("=" * 60)
|
||||
|
||||
# 加载数据
|
||||
try:
|
||||
X_train, y_train, X_test, y_test = load_data()
|
||||
except Exception as e:
|
||||
print(f"加载数据失败: {e}")
|
||||
return
|
||||
|
||||
# 实验配置
|
||||
experiments = [
|
||||
{"hidden_size": 32, "lr": 0.1, "name": "小模型(32神经元)"},
|
||||
{"hidden_size": 128, "lr": 0.1, "name": "标准(128神经元)"},
|
||||
{"hidden_size": 256, "lr": 0.1, "name": "大模型(256神经元)"},
|
||||
{"hidden_size": 128, "lr": 0.01, "name": "小学习率(0.01)"},
|
||||
{"hidden_size": 128, "lr": 0.5, "name": "大学习率(0.5)"},
|
||||
]
|
||||
|
||||
results = []
|
||||
|
||||
for exp in experiments:
|
||||
print(f"\n实验: {exp['name']}")
|
||||
print("-" * 40)
|
||||
|
||||
model = MLP(
|
||||
input_size=INPUT_SIZE,
|
||||
hidden_size=exp['hidden_size'],
|
||||
num_classes=NUM_CLASSES,
|
||||
learning_rate=exp['lr'],
|
||||
seed=SEED
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
model.fit(X_train, y_train, epochs=30, batch_size=BATCH_SIZE, verbose=False)
|
||||
train_time = time.time() - start_time
|
||||
|
||||
train_acc = model.accuracy(X_train, y_train)
|
||||
test_acc = model.accuracy(X_test, y_test)
|
||||
|
||||
results.append({
|
||||
'name': exp['name'],
|
||||
'hidden_size': exp['hidden_size'],
|
||||
'lr': exp['lr'],
|
||||
'train_acc': train_acc,
|
||||
'test_acc': test_acc,
|
||||
'train_time': train_time
|
||||
})
|
||||
|
||||
print(f" 训练准确率: {train_acc:.4f} | 测试准确率: {test_acc:.4f} | 时间: {train_time:.1f}s")
|
||||
|
||||
# 汇总
|
||||
print("\n" + "=" * 60)
|
||||
print("实验结果汇总")
|
||||
print("=" * 60)
|
||||
print(f"\n{'配置':<25} {'训练准确率':<12} {'测试准确率':<12} {'时间':<8}")
|
||||
print("-" * 60)
|
||||
|
||||
for r in results:
|
||||
print(f"{r['name']:<25} {r['train_acc']:<12.4f} {r['test_acc']:<12.4f} {r['train_time']:<8.1f}s")
|
||||
|
||||
best = max(results, key=lambda x: x['test_acc'])
|
||||
print(f"\n最佳配置: {best['name']}, 测试准确率: {best['test_acc']:.4f}")
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
if RUN_COMPARISON:
|
||||
run_comparison()
|
||||
else:
|
||||
train_and_evaluate()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("程序结束!")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import sys
|
||||
|
||||
if '--compare' in sys.argv:
|
||||
RUN_COMPARISON = True
|
||||
|
||||
main()
|
||||
BIN
mnist_mlp_0518_191820_b1.npy
Normal file
BIN
mnist_mlp_0518_191820_b1.npy
Normal file
Binary file not shown.
BIN
mnist_mlp_0518_191820_b2.npy
Normal file
BIN
mnist_mlp_0518_191820_b2.npy
Normal file
Binary file not shown.
Reference in New Issue
Block a user