# -*- coding: utf-8 -*- """ 数据集模块 - MNIST手写数字数据集加载 优先从本地data/目录加载,如果文件不存在则从sklearn下载 支持两种格式:.gz(官方格式)和 .zip(某些下载源) """ import os import struct import gzip import zipfile import numpy as np from config import * def local_files_exist(): """检查本地数据文件是否存在且完整""" data_dir = os.path.join(os.path.dirname(__file__), 'data') # 支持 .gz 和 .zip 格式(MNIST官方用.gz,但有些下载是zip) files = { 'train-images-idx3-ubyte': {'gz': 9912422, 'zip': 9187390}, 'train-labels-idx1-ubyte': {'gz': 28881, 'zip': 28405}, 't10k-images-idx3-ubyte': {'gz': 1648877, 'zip': 1534055}, 't10k-labels-idx1-ubyte': {'gz': 5148, 'zip': 4563}, } found_files = {} missing = [] for base_name, sizes in files.items(): gz_path = os.path.join(data_dir, base_name + '.gz') zip_path = os.path.join(data_dir, base_name + '.zip') if os.path.exists(gz_path): found_files[base_name] = (gz_path, sizes['gz'], 'gz') elif os.path.exists(zip_path): found_files[base_name] = (zip_path, sizes['zip'], 'zip') else: missing.append(base_name) if missing: return False, f"文件不存在: {', '.join(missing)}" # 检查大小是否正确 for base_name, (filepath, expected_size, fmt) in found_files.items(): actual_size = os.path.getsize(filepath) if actual_size != expected_size: return False, f"文件大小错误: {base_name} (期望{expected_size}, 实际{actual_size})" return True, "所有文件完整" def parse_idx_images(filepath): """解析IDX格式图像(支持.gz和.zip)""" if filepath.endswith('.zip'): with zipfile.ZipFile(filepath, 'r') as zf: # zip内的文件名没有.gz后缀 inner_name = zf.namelist()[0] with zf.open(inner_name) as f: magic, num, rows, cols = struct.unpack('>IIII', f.read(16)) images = np.frombuffer(f.read(), dtype=np.uint8) images = images.reshape(num, rows * cols) return images else: with gzip.open(filepath, 'rb') as f: magic, num, rows, cols = struct.unpack('>IIII', f.read(16)) images = np.frombuffer(f.read(), dtype=np.uint8) images = images.reshape(num, rows * cols) return images def parse_idx_labels(filepath): """解析IDX格式标签(支持.gz和.zip)""" if filepath.endswith('.zip'): with zipfile.ZipFile(filepath, 'r') as zf: # zip内的文件名没有.gz后缀 inner_name = zf.namelist()[0] with zf.open(inner_name) as f: magic, num = struct.unpack('>II', f.read(8)) labels = np.frombuffer(f.read(), dtype=np.uint8) return labels else: with gzip.open(filepath, 'rb') as f: magic, num = struct.unpack('>II', f.read(8)) labels = np.frombuffer(f.read(), dtype=np.uint8) return labels def load_data_from_local(): """从本地文件加载MNIST(自动检测.gz或.zip格式)""" data_dir = os.path.join(os.path.dirname(__file__), 'data') def find_file(base_name): """自动找文件,支持.gz和.zip""" gz_path = os.path.join(data_dir, base_name + '.gz') zip_path = os.path.join(data_dir, base_name + '.zip') if os.path.exists(gz_path): return gz_path elif os.path.exists(zip_path): return zip_path else: raise FileNotFoundError(f"找不到 {base_name} 的 .gz 或 .zip 文件") X_train = parse_idx_images(find_file('train-images-idx3-ubyte')) y_train = parse_idx_labels(find_file('train-labels-idx1-ubyte')) X_test = parse_idx_images(find_file('t10k-images-idx3-ubyte')) y_test = parse_idx_labels(find_file('t10k-labels-idx1-ubyte')) return X_train, y_train, X_test, y_test def load_data_from_sklearn(): """从sklearn加载MNIST(备选方案)""" from sklearn.datasets import fetch_openml print(" 正在从OpenML下载数据(首次可能需要1-2分钟)...") mnist = fetch_openml('mnist_784', version=1, as_frame=False, parser='auto') X = mnist.data.astype(np.float32) y = mnist.target.astype(int) X_train = X[:60000] / 255.0 X_test = X[60000:] / 255.0 y_train = y[:60000] y_test = y[60000:] return X_train, y_train, X_test, y_test def one_hot_encode(y, num_classes=10): one_hot = np.zeros((len(y), num_classes)) one_hot[np.arange(len(y)), y] = 1 return one_hot def load_data(): """ 加载MNIST数据集 优先从本地data/目录加载,如果文件不完整则从sklearn下载 """ print("\n" + "=" * 50) print("MNIST 数据集加载") print("=" * 50) # 优先检查本地文件 exists, msg = local_files_exist() if exists: print(f"\n ✓ 发现本地数据文件: {msg}") X_train, y_train, X_test, y_test = load_data_from_local() else: print(f"\n 本地文件: {msg}") print(" 尝试从sklearn下载...") try: X_train, y_train, X_test, y_test = load_data_from_sklearn() except Exception as e: print(f"\n 下载失败: {e}") print("\n 请确保 data/ 目录下有完整的4个数据文件!") raise # 归一化和One-Hot X_train = X_train.astype(np.float32) / 255.0 X_test = X_test.astype(np.float32) / 255.0 y_train = one_hot_encode(y_train, NUM_CLASSES) y_test = one_hot_encode(y_test, NUM_CLASSES) print(f"\n ✓ 完成!") print(f" 训练集: {X_train.shape[0]} 样本") print(f" 测试集: {X_test.shape[0]} 样本") print(f" 数值范围: [{X_train.min():.2f}, {X_train.max():.2f}]") return X_train, y_train, X_test, y_test if __name__ == '__main__': X_train, y_train, X_test, y_test = load_data() print(f"\n训练数据: {X_train.shape}")