179 lines
6.1 KiB
Python
179 lines
6.1 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
数据集模块 - MNIST手写数字数据集加载
|
||
|
||
优先从本地data/目录加载,如果文件不存在则从sklearn下载
|
||
支持两种格式:.gz(官方格式)和 .zip(某些下载源)
|
||
"""
|
||
|
||
import os
|
||
import struct
|
||
import gzip
|
||
import zipfile
|
||
import numpy as np
|
||
from config import *
|
||
|
||
|
||
def local_files_exist():
|
||
"""检查本地数据文件是否存在且完整"""
|
||
data_dir = os.path.join(os.path.dirname(__file__), 'data')
|
||
|
||
# 支持 .gz 和 .zip 格式(MNIST官方用.gz,但有些下载是zip)
|
||
files = {
|
||
'train-images-idx3-ubyte': {'gz': 9912422, 'zip': 9187390},
|
||
'train-labels-idx1-ubyte': {'gz': 28881, 'zip': 28405},
|
||
't10k-images-idx3-ubyte': {'gz': 1648877, 'zip': 1534055},
|
||
't10k-labels-idx1-ubyte': {'gz': 5148, 'zip': 4563},
|
||
}
|
||
|
||
found_files = {}
|
||
missing = []
|
||
|
||
for base_name, sizes in files.items():
|
||
gz_path = os.path.join(data_dir, base_name + '.gz')
|
||
zip_path = os.path.join(data_dir, base_name + '.zip')
|
||
|
||
if os.path.exists(gz_path):
|
||
found_files[base_name] = (gz_path, sizes['gz'], 'gz')
|
||
elif os.path.exists(zip_path):
|
||
found_files[base_name] = (zip_path, sizes['zip'], 'zip')
|
||
else:
|
||
missing.append(base_name)
|
||
|
||
if missing:
|
||
return False, f"文件不存在: {', '.join(missing)}"
|
||
|
||
# 检查大小是否正确
|
||
for base_name, (filepath, expected_size, fmt) in found_files.items():
|
||
actual_size = os.path.getsize(filepath)
|
||
if actual_size != expected_size:
|
||
return False, f"文件大小错误: {base_name} (期望{expected_size}, 实际{actual_size})"
|
||
|
||
return True, "所有文件完整"
|
||
|
||
|
||
def parse_idx_images(filepath):
|
||
"""解析IDX格式图像(支持.gz和.zip)"""
|
||
if filepath.endswith('.zip'):
|
||
with zipfile.ZipFile(filepath, 'r') as zf:
|
||
# zip内的文件名没有.gz后缀
|
||
inner_name = zf.namelist()[0]
|
||
with zf.open(inner_name) as f:
|
||
magic, num, rows, cols = struct.unpack('>IIII', f.read(16))
|
||
images = np.frombuffer(f.read(), dtype=np.uint8)
|
||
images = images.reshape(num, rows * cols)
|
||
return images
|
||
else:
|
||
with gzip.open(filepath, 'rb') as f:
|
||
magic, num, rows, cols = struct.unpack('>IIII', f.read(16))
|
||
images = np.frombuffer(f.read(), dtype=np.uint8)
|
||
images = images.reshape(num, rows * cols)
|
||
return images
|
||
|
||
|
||
def parse_idx_labels(filepath):
|
||
"""解析IDX格式标签(支持.gz和.zip)"""
|
||
if filepath.endswith('.zip'):
|
||
with zipfile.ZipFile(filepath, 'r') as zf:
|
||
# zip内的文件名没有.gz后缀
|
||
inner_name = zf.namelist()[0]
|
||
with zf.open(inner_name) as f:
|
||
magic, num = struct.unpack('>II', f.read(8))
|
||
labels = np.frombuffer(f.read(), dtype=np.uint8)
|
||
return labels
|
||
else:
|
||
with gzip.open(filepath, 'rb') as f:
|
||
magic, num = struct.unpack('>II', f.read(8))
|
||
labels = np.frombuffer(f.read(), dtype=np.uint8)
|
||
return labels
|
||
|
||
|
||
def load_data_from_local():
|
||
"""从本地文件加载MNIST(自动检测.gz或.zip格式)"""
|
||
data_dir = os.path.join(os.path.dirname(__file__), 'data')
|
||
|
||
def find_file(base_name):
|
||
"""自动找文件,支持.gz和.zip"""
|
||
gz_path = os.path.join(data_dir, base_name + '.gz')
|
||
zip_path = os.path.join(data_dir, base_name + '.zip')
|
||
if os.path.exists(gz_path):
|
||
return gz_path
|
||
elif os.path.exists(zip_path):
|
||
return zip_path
|
||
else:
|
||
raise FileNotFoundError(f"找不到 {base_name} 的 .gz 或 .zip 文件")
|
||
|
||
X_train = parse_idx_images(find_file('train-images-idx3-ubyte'))
|
||
y_train = parse_idx_labels(find_file('train-labels-idx1-ubyte'))
|
||
X_test = parse_idx_images(find_file('t10k-images-idx3-ubyte'))
|
||
y_test = parse_idx_labels(find_file('t10k-labels-idx1-ubyte'))
|
||
|
||
return X_train, y_train, X_test, y_test
|
||
|
||
|
||
def load_data_from_sklearn():
|
||
"""从sklearn加载MNIST(备选方案)"""
|
||
from sklearn.datasets import fetch_openml
|
||
|
||
print(" 正在从OpenML下载数据(首次可能需要1-2分钟)...")
|
||
|
||
mnist = fetch_openml('mnist_784', version=1, as_frame=False, parser='auto')
|
||
X = mnist.data.astype(np.float32)
|
||
y = mnist.target.astype(int)
|
||
|
||
X_train = X[:60000] / 255.0
|
||
X_test = X[60000:] / 255.0
|
||
y_train = y[:60000]
|
||
y_test = y[60000:]
|
||
|
||
return X_train, y_train, X_test, y_test
|
||
|
||
|
||
def one_hot_encode(y, num_classes=10):
|
||
one_hot = np.zeros((len(y), num_classes))
|
||
one_hot[np.arange(len(y)), y] = 1
|
||
return one_hot
|
||
|
||
|
||
def load_data():
|
||
"""
|
||
加载MNIST数据集
|
||
|
||
优先从本地data/目录加载,如果文件不完整则从sklearn下载
|
||
"""
|
||
print("\n" + "=" * 50)
|
||
print("MNIST 数据集加载")
|
||
print("=" * 50)
|
||
|
||
# 优先检查本地文件
|
||
exists, msg = local_files_exist()
|
||
if exists:
|
||
print(f"\n ✓ 发现本地数据文件: {msg}")
|
||
X_train, y_train, X_test, y_test = load_data_from_local()
|
||
else:
|
||
print(f"\n 本地文件: {msg}")
|
||
print(" 尝试从sklearn下载...")
|
||
try:
|
||
X_train, y_train, X_test, y_test = load_data_from_sklearn()
|
||
except Exception as e:
|
||
print(f"\n 下载失败: {e}")
|
||
print("\n 请确保 data/ 目录下有完整的4个数据文件!")
|
||
raise
|
||
|
||
# 归一化和One-Hot
|
||
X_train = X_train.astype(np.float32) / 255.0
|
||
X_test = X_test.astype(np.float32) / 255.0
|
||
y_train = one_hot_encode(y_train, NUM_CLASSES)
|
||
y_test = one_hot_encode(y_test, NUM_CLASSES)
|
||
|
||
print(f"\n ✓ 完成!")
|
||
print(f" 训练集: {X_train.shape[0]} 样本")
|
||
print(f" 测试集: {X_test.shape[0]} 样本")
|
||
print(f" 数值范围: [{X_train.min():.2f}, {X_train.max():.2f}]")
|
||
|
||
return X_train, y_train, X_test, y_test
|
||
|
||
|
||
if __name__ == '__main__':
|
||
X_train, y_train, X_test, y_test = load_data()
|
||
print(f"\n训练数据: {X_train.shape}") |