From 53779bd31eeae8217d619d4953722fa8418516f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E4=BD=B3=E8=B1=AA?= <2509165033@student.example.com> Date: Tue, 19 May 2026 11:31:06 +0800 Subject: [PATCH] =?UTF-8?q?=E4=B8=8A=E4=BC=A0=E6=96=87=E4=BB=B6=E8=87=B3?= =?UTF-8?q?=20/?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config.py | 79 +++--- dataset.py | 465 ++++++++++++++--------------------- main.py | 225 ++++++++++++++--- mnist_mlp_0518_191820_b1.npy | Bin 0 -> 1152 bytes mnist_mlp_0518_191820_b2.npy | Bin 0 -> 208 bytes 5 files changed, 409 insertions(+), 360 deletions(-) create mode 100644 mnist_mlp_0518_191820_b1.npy create mode 100644 mnist_mlp_0518_191820_b2.npy diff --git a/config.py b/config.py index fe984df..9c59379 100644 --- a/config.py +++ b/config.py @@ -1,40 +1,39 @@ -# -*- coding: utf-8 -*- -""" -配置文件 - 所有超参数集中管理 - -设计思路: -将超参数分门别类,学生可以单独修改某一类而不会影响其他 -""" - -# ==================== 数据相关 ==================== -DATA_DIR = 'data/ChnSentiCorp' # 数据集路径 -MAX_FEATURES = 3000 # 词表最大容量 -MAX_SEQ_LEN = 100 # 句子最大长度(词数) -VECTORIZER_TYPE = 'tfidf' # 'tfidf' 或 'bow'(向量化方式) - -# ==================== 模型相关 ==================== -MODEL_TYPE = 'mlp' # 'mlp' 或 'lr'(模型类型) -HIDDEN_SIZE = 64 # MLP隐藏层大小(LR忽略) -NUM_CLASSES = 2 # 类别数(正面/负面二分类) -KEEP_PROB = 1.0 # Dropout保留概率(LR忽略,设为1即可) - -# ==================== 训练相关 ==================== -LEARNING_RATE = 0.06 # 学习率 -NUM_EPOCHS = 101 # 训练轮数 -BATCH_SIZE = 65 # 批次大小 - -# ==================== 类别权重(解决数据不平衡问题)==================== -USE_CLASS_WEIGHT = True # True=启用类别权重, False=不启用(对比用) -# 权重计算公式: n_samples / (n_classes * n_class_i) -# 正面评论多所以权重小,负面评论少所以权重大 -CLASS_WEIGHT_POS = 0.85 # 正面类权重(自动计算) -CLASS_WEIGHT_NEG = 1.75 # 负面类权重(自动计算) - -# ==================== 实验相关 ==================== -RUN_COMPARISON = False # True=运行对比实验, False=运行单个模型 -COMPARE_MODELS = ['lr', 'mlp'] # 要对比的模型列表 -COMPARE_VECTORS = ['bow', 'tfidf'] # 要对比的向量化方式 - -# ==================== 其他 ==================== -RANDOM_SEED = 42 # 随机种子(保证可复现) -VERBOSE = True # 打印详细日志 +# -*- coding: utf-8 -*- +""" +手写数字识别 - 超参数配置 + +纯NumPy实现的两层全连接神经网络 +""" + +# ===== 数据参数 ===== +ONE_HOT = True # 标签是否使用One-Hot编码 + +# ===== 模型结构 ===== +INPUT_SIZE = 784 # 28x28 = 784 像素 +HIDDEN_SIZE = 128 # 隐藏层神经元数量 +NUM_CLASSES = 10 # 0-9 十个数字 +KEEP_PROB = 1.0 # Dropout保留比例(1.0=不使用Dropout) + +# ===== 训练参数 ===== +LEARNING_RATE = 0.1 # 学习率 +NUM_EPOCHS = 50 # 训练轮数 +BATCH_SIZE = 64 # 批大小 + +# ===== 随机种子(保证可复现) ===== +SEED = 42 + +# ===== 实验配置 ===== +RUN_COMPARISON = False # 是否运行对比实验 + +# ===== 依赖说明 ===== +# 本项目需要以下库: +# numpy - 数值计算 +# scikit-learn - 加载MNIST数据集(会自动下载) +# pandas - sklearn的依赖 +# +# 安装命令: +# pip install numpy scikit-learn pandas +# +# 数据说明: +# 首次运行时会自动从OpenML下载MNIST数据集(约12MB) +# 下载后会自动缓存,后续运行直接使用缓存数据 \ No newline at end of file diff --git a/dataset.py b/dataset.py index 4f4163e..2d0867e 100644 --- a/dataset.py +++ b/dataset.py @@ -1,286 +1,179 @@ -# -*- coding: utf-8 -*- -""" -数据加载与向量化模块 - -支持两种向量化方法: -1. BoW (Bag of Words) - 词频向量 -2. TF-IDF - 词频-逆文档频率向量 - -TF-IDF 的优势: -- 降低常见词(如"的"、"是")的权重 -- 提升罕见词的信息量 -- 通常效果优于简单BoW -""" - -import os -import re -import csv -import math -import jieba -import numpy as np -from collections import Counter - -try: - import urllib.request - import ssl - DOWNLOAD_AVAILABLE = True -except ImportError: - DOWNLOAD_AVAILABLE = False - - -DATASET_URL = "https://raw.githubusercontent.com/SophonPlus/ChineseNlpCorpus/master/datasets/ChnSentiCorp_htl_all/ChnSentiCorp_htl_all.csv" - - -def download_dataset(data_dir): - """下载数据集(如果不存在)""" - csv_path = os.path.join(data_dir, 'ChnSentiCorp_htl_all.csv') - - if os.path.exists(csv_path): - print(f"数据已存在: {csv_path}") - return True - - if not DOWNLOAD_AVAILABLE: - return False - - print("正在下载数据集...") - ssl_context = ssl.create_default_context() - ssl_context.check_hostname = False - ssl_context.verify_mode = ssl.CERT_NONE - - try: - request = urllib.request.Request(DATASET_URL, headers={'User-Agent': 'Mozilla/5.0'}) - response = urllib.request.urlopen(request, timeout=120, context=ssl_context) - os.makedirs(data_dir, exist_ok=True) - with open(csv_path, 'wb') as f: - f.write(response.read()) - print(f"下载完成: {csv_path}") - return True - except Exception as e: - print(f"下载失败: {e}") - return False - - -def load_raw_data(data_dir): - """加载原始数据""" - csv_path = os.path.join(data_dir, 'ChnSentiCorp_htl_all.csv') - texts, labels = [], [] - - with open(csv_path, 'r', encoding='utf-8') as f: - reader = csv.reader(f) - for row in reader: - if len(row) < 2: - continue - try: - label = int(row[0]) - review = row[1].strip() - if review: - texts.append(review) - labels.append(label) - except (ValueError, IndexError): - continue - - return texts, np.array(labels) - - -def tokenize(text): - """中文分词""" - text = re.sub(r'[^\u4e00-\u9fa5a-zA-Z]', ' ', text) - words = jieba.lcut(text) - return [w for w in words if len(w) > 1] - - -# ==================== 向量化器 ==================== - -class BaseVectorizer: - """向量化器基类""" - def fit(self, texts): pass - def transform(self, texts): pass - def fit_transform(self, texts): pass - - -class BoWVectorizer(BaseVectorizer): - """ - 词袋模型 (Bag of Words) - - 原理:统计每个词在文本中出现的次数 - 向量维度 = 词表大小 - 每个维度 = 该词在本文本中出现的次数 - """ - - def __init__(self, max_features, max_seq_len): - self.max_features = max_features - self.max_seq_len = max_seq_len - self.vocab = {} - self.doc_freq = {} # 文档频率 - self.num_docs = 0 - - def fit(self, texts): - """构建词表(基于词频)""" - counter = Counter() - doc_counter = Counter() # 统计包含该词的文档数 - - for text in texts: - words = tokenize(text) - unique_words = set(words) - counter.update(words) - for w in unique_words: - doc_counter[w] += 1 - - self.num_docs = len(texts) - - # 取最高频的词 - most_common = counter.most_common(self.max_features) - self.vocab = {word: idx for idx, (word, _) in enumerate(most_common)} - - # 记录文档频率(用于TF-IDF) - self.doc_freq = {w: doc_counter[w] for w in self.vocab} - - print(f" BoW词表大小: {len(self.vocab)}") - return self - - def transform(self, texts): - """将文本转换为词频向量""" - vectors = [] - for text in texts: - words = tokenize(text) - freq = [0] * self.max_seq_len - for i, word in enumerate(words[:self.max_seq_len]): - if word in self.vocab: - freq[i] = 1 # 二值(出现=1,不出现=0) - vectors.append(freq) - return np.array(vectors, dtype=np.float32) - - def fit_transform(self, texts): - self.fit(texts) - return self.transform(texts) - - -class TFIDFVectorizer(BaseVectorizer): - """ - TF-IDF 向量器 - - 原理: - - TF(词频) = 词在本文本中出现的次数 - - IDF(逆文档频率) = log(总文档数 / 包含该词的文档数) - - TF-IDF = TF × IDF - - 优势: - - 降低常见无意义词的权重(如"的"、"是") - - 提升罕见但有信息量的词 - """ - - def __init__(self, max_features, max_seq_len): - self.max_features = max_features - self.max_seq_len = max_seq_len - self.vocab = {} - self.idf = {} # 存储每个词的IDF值 - self.num_docs = 0 - - def fit(self, texts): - """构建词表并计算IDF""" - counter = Counter() - doc_counter = Counter() - - for text in texts: - words = tokenize(text) - unique_words = set(words) - counter.update(words) - for w in unique_words: - doc_counter[w] += 1 - - self.num_docs = len(texts) - - # 计算每个词的IDF - # IDF = log(总文档数 / 包含该词的文档数) - idf_values = {} - for word, df in doc_counter.items(): - idf_values[word] = math.log(self.num_docs / (df + 1)) + 1 # 加1防零 - - # 取IDF值最高的词(信息量最大的词) - sorted_words = sorted(idf_values.items(), key=lambda x: x[1], reverse=True) - self.vocab = {word: idx for idx, (word, _) in enumerate(sorted_words[:self.max_features])} - - # 保存IDF值 - self.idf = {word: idf_values[word] for word in self.vocab} - - print(f" TF-IDF词表大小: {len(self.vocab)}") - print(f" 平均IDF: {np.mean(list(self.idf.values())):.3f}") - return self - - def transform(self, texts): - """将文本转换为TF-IDF向量""" - vectors = [] - for text in texts: - words = tokenize(text) - - # 计算TF - tf = Counter(words) - tf_sum = len(words) if words else 1 - - # 生成向量 - vec = [0.0] * self.max_seq_len - for i, word in enumerate(words[:self.max_seq_len]): - if word in self.vocab: - # TF × IDF - vec[i] = (tf[word] / tf_sum) * self.idf.get(word, 0) - vectors.append(vec) - - return np.array(vectors, dtype=np.float32) - - def fit_transform(self, texts): - self.fit(texts) - return self.transform(texts) - - -def load_data(data_dir, max_features, max_seq_len, vectorizer_type='tfidf'): - """ - 加载并向量化数据 - - 参数: - - vectorizer_type: 'tfidf' 或 'bow' - """ - if not download_dataset(data_dir): - raise RuntimeError("数据加载失败,请检查网络或手动下载数据集") - - print("正在加载数据...") - texts, labels = load_raw_data(data_dir) - print(f"总评论数: {len(texts)}, 正面: {sum(labels)}, 负面: {len(labels) - sum(labels)}") - - # 选择向量化器 - if vectorizer_type == 'tfidf': - vectorizer = TFIDFVectorizer(max_features, max_seq_len) - vec_name = "TF-IDF" - else: - vectorizer = BoWVectorizer(max_features, max_seq_len) - vec_name = "BoW" - - print(f"正在使用{vec_name}向量化...") - X = vectorizer.fit_transform(texts) - y = labels - - # 打乱并划分 - np.random.seed(42) - indices = np.random.permutation(len(X)) - X = X[indices] - y = y[indices] - - split_idx = int(len(X) * 0.8) - X_train, X_test = X[:split_idx], X[split_idx:] - y_train, y_test = y[:split_idx], y[split_idx:] - - print(f"训练集: {len(X_train)}条, 测试集: {len(X_test)}条") - - return X_train, y_train, X_test, y_test, vectorizer - - -if __name__ == '__main__': - # 测试 - print("=" * 60) - print("测试 TF-IDF 向量化") - print("=" * 60) - X_train, y_train, X_test, y_test, vec = load_data( - 'data/ChnSentiCorp', max_features=3000, max_seq_len=100, - vectorizer_type='tfidf' - ) - print(f"\nX_train shape: {X_train.shape}") - print(f"X_train sample (前5个特征): {X_train[0][:5]}") +# -*- coding: utf-8 -*- +""" +数据集模块 - MNIST手写数字数据集加载 + +优先从本地data/目录加载,如果文件不存在则从sklearn下载 +支持两种格式:.gz(官方格式)和 .zip(某些下载源) +""" + +import os +import struct +import gzip +import zipfile +import numpy as np +from config import * + + +def local_files_exist(): + """检查本地数据文件是否存在且完整""" + data_dir = os.path.join(os.path.dirname(__file__), 'data') + + # 支持 .gz 和 .zip 格式(MNIST官方用.gz,但有些下载是zip) + files = { + 'train-images-idx3-ubyte': {'gz': 9912422, 'zip': 9187390}, + 'train-labels-idx1-ubyte': {'gz': 28881, 'zip': 28405}, + 't10k-images-idx3-ubyte': {'gz': 1648877, 'zip': 1534055}, + 't10k-labels-idx1-ubyte': {'gz': 5148, 'zip': 4563}, + } + + found_files = {} + missing = [] + + for base_name, sizes in files.items(): + gz_path = os.path.join(data_dir, base_name + '.gz') + zip_path = os.path.join(data_dir, base_name + '.zip') + + if os.path.exists(gz_path): + found_files[base_name] = (gz_path, sizes['gz'], 'gz') + elif os.path.exists(zip_path): + found_files[base_name] = (zip_path, sizes['zip'], 'zip') + else: + missing.append(base_name) + + if missing: + return False, f"文件不存在: {', '.join(missing)}" + + # 检查大小是否正确 + for base_name, (filepath, expected_size, fmt) in found_files.items(): + actual_size = os.path.getsize(filepath) + if actual_size != expected_size: + return False, f"文件大小错误: {base_name} (期望{expected_size}, 实际{actual_size})" + + return True, "所有文件完整" + + +def parse_idx_images(filepath): + """解析IDX格式图像(支持.gz和.zip)""" + if filepath.endswith('.zip'): + with zipfile.ZipFile(filepath, 'r') as zf: + # zip内的文件名没有.gz后缀 + inner_name = zf.namelist()[0] + with zf.open(inner_name) as f: + magic, num, rows, cols = struct.unpack('>IIII', f.read(16)) + images = np.frombuffer(f.read(), dtype=np.uint8) + images = images.reshape(num, rows * cols) + return images + else: + with gzip.open(filepath, 'rb') as f: + magic, num, rows, cols = struct.unpack('>IIII', f.read(16)) + images = np.frombuffer(f.read(), dtype=np.uint8) + images = images.reshape(num, rows * cols) + return images + + +def parse_idx_labels(filepath): + """解析IDX格式标签(支持.gz和.zip)""" + if filepath.endswith('.zip'): + with zipfile.ZipFile(filepath, 'r') as zf: + # zip内的文件名没有.gz后缀 + inner_name = zf.namelist()[0] + with zf.open(inner_name) as f: + magic, num = struct.unpack('>II', f.read(8)) + labels = np.frombuffer(f.read(), dtype=np.uint8) + return labels + else: + with gzip.open(filepath, 'rb') as f: + magic, num = struct.unpack('>II', f.read(8)) + labels = np.frombuffer(f.read(), dtype=np.uint8) + return labels + + +def load_data_from_local(): + """从本地文件加载MNIST(自动检测.gz或.zip格式)""" + data_dir = os.path.join(os.path.dirname(__file__), 'data') + + def find_file(base_name): + """自动找文件,支持.gz和.zip""" + gz_path = os.path.join(data_dir, base_name + '.gz') + zip_path = os.path.join(data_dir, base_name + '.zip') + if os.path.exists(gz_path): + return gz_path + elif os.path.exists(zip_path): + return zip_path + else: + raise FileNotFoundError(f"找不到 {base_name} 的 .gz 或 .zip 文件") + + X_train = parse_idx_images(find_file('train-images-idx3-ubyte')) + y_train = parse_idx_labels(find_file('train-labels-idx1-ubyte')) + X_test = parse_idx_images(find_file('t10k-images-idx3-ubyte')) + y_test = parse_idx_labels(find_file('t10k-labels-idx1-ubyte')) + + return X_train, y_train, X_test, y_test + + +def load_data_from_sklearn(): + """从sklearn加载MNIST(备选方案)""" + from sklearn.datasets import fetch_openml + + print(" 正在从OpenML下载数据(首次可能需要1-2分钟)...") + + mnist = fetch_openml('mnist_784', version=1, as_frame=False, parser='auto') + X = mnist.data.astype(np.float32) + y = mnist.target.astype(int) + + X_train = X[:60000] / 255.0 + X_test = X[60000:] / 255.0 + y_train = y[:60000] + y_test = y[60000:] + + return X_train, y_train, X_test, y_test + + +def one_hot_encode(y, num_classes=10): + one_hot = np.zeros((len(y), num_classes)) + one_hot[np.arange(len(y)), y] = 1 + return one_hot + + +def load_data(): + """ + 加载MNIST数据集 + + 优先从本地data/目录加载,如果文件不完整则从sklearn下载 + """ + print("\n" + "=" * 50) + print("MNIST 数据集加载") + print("=" * 50) + + # 优先检查本地文件 + exists, msg = local_files_exist() + if exists: + print(f"\n ✓ 发现本地数据文件: {msg}") + X_train, y_train, X_test, y_test = load_data_from_local() + else: + print(f"\n 本地文件: {msg}") + print(" 尝试从sklearn下载...") + try: + X_train, y_train, X_test, y_test = load_data_from_sklearn() + except Exception as e: + print(f"\n 下载失败: {e}") + print("\n 请确保 data/ 目录下有完整的4个数据文件!") + raise + + # 归一化和One-Hot + X_train = X_train.astype(np.float32) / 255.0 + X_test = X_test.astype(np.float32) / 255.0 + y_train = one_hot_encode(y_train, NUM_CLASSES) + y_test = one_hot_encode(y_test, NUM_CLASSES) + + print(f"\n ✓ 完成!") + print(f" 训练集: {X_train.shape[0]} 样本") + print(f" 测试集: {X_test.shape[0]} 样本") + print(f" 数值范围: [{X_train.min():.2f}, {X_train.max():.2f}]") + + return X_train, y_train, X_test, y_test + + +if __name__ == '__main__': + X_train, y_train, X_test, y_test = load_data() + print(f"\n训练数据: {X_train.shape}") \ No newline at end of file diff --git a/main.py b/main.py index 2dbbe1c..9d0fecf 100644 --- a/main.py +++ b/main.py @@ -1,34 +1,191 @@ -# -*- coding: utf-8 -*- -""" -主程序入口 - -使用方式: - -1. 运行单个模型(默认): - python main.py - - 修改 config.py 中的 MODEL_TYPE 和 VECTORIZER_TYPE 来切换配置 - -2. 运行对比实验: - 修改 config.py 中 RUN_COMPARISON = True - - 这会依次运行: - - 实验1: BoW vs TF-IDF (固定LR模型) - - 实验2: LR vs MLP (固定TF-IDF) - - 实验3: 不同学习率对比 - - 实验4: 不同隐藏层大小对比 - - 最后输出汇总报告 -""" - -from train import main - -if __name__ == '__main__': - print("\n" + "=" * 70) - print("文本分类实验 - 纯NumPy实现") - print("数据集: ChnSentiCorp (中文酒店评论)") - print("模型: Logistic Regression / MLP") - print("向量化: BoW / TF-IDF") - print("=" * 70 + "\n") - - main() +# -*- coding: utf-8 -*- +""" +主程序 - 手写数字识别 MLP 纯NumPy实现 + +使用方法: + python main.py # 运行默认配置 + python main.py --compare # 运行对比实验 + +依赖: + pip install numpy requests +""" + +import numpy as np +import time +from datetime import datetime +from model_numpy import MLP +from dataset import load_data +from config import * + + +def train_and_evaluate(): + """ + 训练并评估模型 + """ + print("=" * 60) + print("手写数字识别 - 纯NumPy MLP实现") + print("=" * 60) + + # ===== 加载数据 ===== + try: + X_train, y_train, X_test, y_test = load_data() + except Exception as e: + print(f"\n错误: {e}") + print("\n请手动下载数据文件:") + print(" 1. 创建 data/ 目录") + print(" 2. 下载以下文件到 data/:") + print(" - train-images-idx3-ubyte.gz (9.9 MB)") + print(" - train-labels-idx1-ubyte.gz (28 KB)") + print(" - t10k-images-idx3-ubyte.gz (1.6 MB)") + print(" - t10k-labels-idx1-ubyte.gz (5 KB)") + print(" 下载地址: https://storage.googleapis.com/tensorflow/tf-keras-datasets/") + return None, None, None + + # ===== 创建模型 ===== + print("\n[2] 创建MLP模型...") + model = MLP( + input_size=INPUT_SIZE, + hidden_size=HIDDEN_SIZE, + num_classes=NUM_CLASSES, + learning_rate=LEARNING_RATE, + seed=SEED + ) + + # ===== 训练模型 ===== + print("\n[3] 开始训练...") + start_time = time.time() + + model.fit( + X_train, y_train, + X_val=X_test, y_val=y_test, + epochs=NUM_EPOCHS, + batch_size=BATCH_SIZE, + verbose=True + ) + + train_time = time.time() - start_time + + # ===== 最终评估 ===== + print("\n" + "=" * 60) + print("训练完成!") + print("=" * 60) + + train_acc = model.accuracy(X_train, y_train) + test_acc = model.accuracy(X_test, y_test) + + print(f"\n最终结果:") + print(f" 训练准确率: {train_acc:.4f} ({train_acc*100:.2f}%)") + print(f" 测试准确率: {test_acc:.4f} ({test_acc*100:.2f}%)") + print(f" 训练时间: {train_time:.2f} 秒") + + # ===== 保存模型 ===== + timestamp = datetime.now().strftime("%m%d_%H%M%S") + model_path = f"mnist_mlp_{timestamp}" + model.save(model_path) + + # ===== 预测示例 ===== + print("\n[4] 预测示例:") + indices = np.random.choice(len(X_test), 5, replace=False) + + for i, idx in enumerate(indices): + img = X_test[idx] + true_label = np.argmax(y_test[idx]) + pred_label = model.predict(img.reshape(1, -1))[0] + prob = model.predict_proba(img.reshape(1, -1))[0] + + status = '✓' if true_label == pred_label else '✗' + print(f" 样本{i+1}: 真实={true_label}, 预测={pred_label}, " + f"置信度={prob[pred_label]:.2f} {status}") + + return model, train_acc, test_acc + + +def run_comparison(): + """ + 运行对比实验 + """ + print("\n" + "=" * 60) + print("超参数对比实验") + print("=" * 60) + + # 加载数据 + try: + X_train, y_train, X_test, y_test = load_data() + except Exception as e: + print(f"加载数据失败: {e}") + return + + # 实验配置 + experiments = [ + {"hidden_size": 32, "lr": 0.1, "name": "小模型(32神经元)"}, + {"hidden_size": 128, "lr": 0.1, "name": "标准(128神经元)"}, + {"hidden_size": 256, "lr": 0.1, "name": "大模型(256神经元)"}, + {"hidden_size": 128, "lr": 0.01, "name": "小学习率(0.01)"}, + {"hidden_size": 128, "lr": 0.5, "name": "大学习率(0.5)"}, + ] + + results = [] + + for exp in experiments: + print(f"\n实验: {exp['name']}") + print("-" * 40) + + model = MLP( + input_size=INPUT_SIZE, + hidden_size=exp['hidden_size'], + num_classes=NUM_CLASSES, + learning_rate=exp['lr'], + seed=SEED + ) + + start_time = time.time() + model.fit(X_train, y_train, epochs=30, batch_size=BATCH_SIZE, verbose=False) + train_time = time.time() - start_time + + train_acc = model.accuracy(X_train, y_train) + test_acc = model.accuracy(X_test, y_test) + + results.append({ + 'name': exp['name'], + 'hidden_size': exp['hidden_size'], + 'lr': exp['lr'], + 'train_acc': train_acc, + 'test_acc': test_acc, + 'train_time': train_time + }) + + print(f" 训练准确率: {train_acc:.4f} | 测试准确率: {test_acc:.4f} | 时间: {train_time:.1f}s") + + # 汇总 + print("\n" + "=" * 60) + print("实验结果汇总") + print("=" * 60) + print(f"\n{'配置':<25} {'训练准确率':<12} {'测试准确率':<12} {'时间':<8}") + print("-" * 60) + + for r in results: + print(f"{r['name']:<25} {r['train_acc']:<12.4f} {r['test_acc']:<12.4f} {r['train_time']:<8.1f}s") + + best = max(results, key=lambda x: x['test_acc']) + print(f"\n最佳配置: {best['name']}, 测试准确率: {best['test_acc']:.4f}") + + +def main(): + """主函数""" + if RUN_COMPARISON: + run_comparison() + else: + train_and_evaluate() + + print("\n" + "=" * 60) + print("程序结束!") + print("=" * 60) + + +if __name__ == '__main__': + import sys + + if '--compare' in sys.argv: + RUN_COMPARISON = True + + main() \ No newline at end of file diff --git a/mnist_mlp_0518_191820_b1.npy b/mnist_mlp_0518_191820_b1.npy new file mode 100644 index 0000000000000000000000000000000000000000..55bb987f7335c0c929de7cead799de886fd9a9e1 GIT binary patch literal 1152 zcmbWr|3B0R9KdlUZN5CFEt78*vs}sA>Z+CEO$-r9;^>RS`BII1PSVPQlydsw%V|z^ zx^J5A&Ydr(E$=(d@i=$MG3HCY#u%B@lt!)og}r`zJsDnyJiP;!XvAq;AcazCAq2_R zoU}V^2g$~q6c$ay62VbH(L^Zqqq_v3rcpmQEh6{~^~0~Y1{?DP^Z%XBhrHMgUY}0%u@R@=wBUK z*h#j*N(J6nsI@A+AO%rD@%V&yKE&+d`Q2hEVZ_7I zF?IfDWX|lfi+0b$b3W{Q25$=?!1rYNrd5y7mXHUw4=9mV#;NPhy9WK5((aobzd^>u zsT*2-Y}oIzuN0+iV3HdWa@siwERuUvgzz>x{?+GV)L8+|+o}vpa%z#EaQO4S6(_Jc zte(z3D}*;;&+QNU#zVni?D9xj7UUlhFC-4|fJRQV;Kp(w>8ZeTXm1T((_lR&J|^wqd93GiQPZC)|kf_MD>e$AL}gdE@bjhXIJxYudXPcs&Q zGP^+CGM9@Uz06+8WCsxH7rP=$YT<%LQhcwf4!6v2xR|Lfgm{OQ)5e*#aOx;I*fhKu zzNtthzgKplFPmKyF)YJfJBD{%*(8I~VXZiSdoea_)%@OyBS2lfS+i5K7jaMhuk%s2 zd63QyoMf%(Kob?mf_%LWPAralr+M&UO&aT*!o3PrtFVdW&PCJ8^1$oDXjo<~5WYYj z&^KB#zpO5R7)34Q5ZOMPv;k0ImQ+U6>v0G#7)GhT6o z5c@^fXaZFQvKBUBqPUBEnjhPO*#PvwKytf2e@P=aQ1zY8q%=T=*|>^#+bna=~&;Z`A~hYwFe^cLRBQ^3$kto2H#1|#+N6Td_- zz*NyF_OWZijNFweOQ}ENMChKUk7rd7dUx3FS`K?CsTB#fY#;Cq2 W%bGwOlQ3c8SO-`Ahs%;XxbPo%Bsa4F literal 0 HcmV?d00001 diff --git a/mnist_mlp_0518_191820_b2.npy b/mnist_mlp_0518_191820_b2.npy new file mode 100644 index 0000000000000000000000000000000000000000..6a22c8bff26d2d0b28517a984ea153cc6cebe90f GIT binary patch literal 208 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1ZlV+i=qoAIaUsO_*m=~X4l#&V(cT3DEP6dh= zXCxM+0{I$-20EHL3bhL41Fk)DLai%KKG?4_WvQ;;r(^r4OkkOOUf|IFXR}@?X}vsP z-+1D3`rC|6`I%jzty0Bk-cgd>2jcfOZPk3`*;p*=F E01i`0P5=M^ literal 0 HcmV?d00001