Files
task-3-3-2-MLP/digit_mlp_class/dataset.py
2026-05-21 15:08:03 +08:00

179 lines
6.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
数据集模块 - MNIST手写数字数据集加载
优先从本地data/目录加载如果文件不存在则从sklearn下载
支持两种格式:.gz官方格式和 .zip某些下载源
"""
import os
import struct
import gzip
import zipfile
import numpy as np
from config import *
def local_files_exist():
"""检查本地数据文件是否存在且完整"""
data_dir = os.path.join(os.path.dirname(__file__), 'data')
# 支持 .gz 和 .zip 格式MNIST官方用.gz但有些下载是zip
files = {
'train-images-idx3-ubyte': {'gz': 9912422, 'zip': 9187390},
'train-labels-idx1-ubyte': {'gz': 28881, 'zip': 28405},
't10k-images-idx3-ubyte': {'gz': 1648877, 'zip': 1534055},
't10k-labels-idx1-ubyte': {'gz': 5148, 'zip': 4563},
}
found_files = {}
missing = []
for base_name, sizes in files.items():
gz_path = os.path.join(data_dir, base_name + '.gz')
zip_path = os.path.join(data_dir, base_name + '.zip')
if os.path.exists(gz_path):
found_files[base_name] = (gz_path, sizes['gz'], 'gz')
elif os.path.exists(zip_path):
found_files[base_name] = (zip_path, sizes['zip'], 'zip')
else:
missing.append(base_name)
if missing:
return False, f"文件不存在: {', '.join(missing)}"
# 检查大小是否正确
for base_name, (filepath, expected_size, fmt) in found_files.items():
actual_size = os.path.getsize(filepath)
if actual_size != expected_size:
return False, f"文件大小错误: {base_name} (期望{expected_size}, 实际{actual_size})"
return True, "所有文件完整"
def parse_idx_images(filepath):
"""解析IDX格式图像支持.gz和.zip"""
if filepath.endswith('.zip'):
with zipfile.ZipFile(filepath, 'r') as zf:
# zip内的文件名没有.gz后缀
inner_name = zf.namelist()[0]
with zf.open(inner_name) as f:
magic, num, rows, cols = struct.unpack('>IIII', f.read(16))
images = np.frombuffer(f.read(), dtype=np.uint8)
images = images.reshape(num, rows * cols)
return images
else:
with gzip.open(filepath, 'rb') as f:
magic, num, rows, cols = struct.unpack('>IIII', f.read(16))
images = np.frombuffer(f.read(), dtype=np.uint8)
images = images.reshape(num, rows * cols)
return images
def parse_idx_labels(filepath):
"""解析IDX格式标签支持.gz和.zip"""
if filepath.endswith('.zip'):
with zipfile.ZipFile(filepath, 'r') as zf:
# zip内的文件名没有.gz后缀
inner_name = zf.namelist()[0]
with zf.open(inner_name) as f:
magic, num = struct.unpack('>II', f.read(8))
labels = np.frombuffer(f.read(), dtype=np.uint8)
return labels
else:
with gzip.open(filepath, 'rb') as f:
magic, num = struct.unpack('>II', f.read(8))
labels = np.frombuffer(f.read(), dtype=np.uint8)
return labels
def load_data_from_local():
"""从本地文件加载MNIST自动检测.gz或.zip格式"""
data_dir = os.path.join(os.path.dirname(__file__), 'data')
def find_file(base_name):
"""自动找文件,支持.gz和.zip"""
gz_path = os.path.join(data_dir, base_name + '.gz')
zip_path = os.path.join(data_dir, base_name + '.zip')
if os.path.exists(gz_path):
return gz_path
elif os.path.exists(zip_path):
return zip_path
else:
raise FileNotFoundError(f"找不到 {base_name} 的 .gz 或 .zip 文件")
X_train = parse_idx_images(find_file('train-images-idx3-ubyte'))
y_train = parse_idx_labels(find_file('train-labels-idx1-ubyte'))
X_test = parse_idx_images(find_file('t10k-images-idx3-ubyte'))
y_test = parse_idx_labels(find_file('t10k-labels-idx1-ubyte'))
return X_train, y_train, X_test, y_test
def load_data_from_sklearn():
"""从sklearn加载MNIST备选方案"""
from sklearn.datasets import fetch_openml
print(" 正在从OpenML下载数据首次可能需要1-2分钟...")
mnist = fetch_openml('mnist_784', version=1, as_frame=False, parser='auto')
X = mnist.data.astype(np.float32)
y = mnist.target.astype(int)
X_train = X[:60000] / 255.0
X_test = X[60000:] / 255.0
y_train = y[:60000]
y_test = y[60000:]
return X_train, y_train, X_test, y_test
def one_hot_encode(y, num_classes=10):
one_hot = np.zeros((len(y), num_classes))
one_hot[np.arange(len(y)), y] = 1
return one_hot
def load_data():
"""
加载MNIST数据集
优先从本地data/目录加载如果文件不完整则从sklearn下载
"""
print("\n" + "=" * 50)
print("MNIST 数据集加载")
print("=" * 50)
# 优先检查本地文件
exists, msg = local_files_exist()
if exists:
print(f"\n ✓ 发现本地数据文件: {msg}")
X_train, y_train, X_test, y_test = load_data_from_local()
else:
print(f"\n 本地文件: {msg}")
print(" 尝试从sklearn下载...")
try:
X_train, y_train, X_test, y_test = load_data_from_sklearn()
except Exception as e:
print(f"\n 下载失败: {e}")
print("\n 请确保 data/ 目录下有完整的4个数据文件")
raise
# 归一化和One-Hot
X_train = X_train.astype(np.float32) / 255.0
X_test = X_test.astype(np.float32) / 255.0
y_train = one_hot_encode(y_train, NUM_CLASSES)
y_test = one_hot_encode(y_test, NUM_CLASSES)
print(f"\n ✓ 完成!")
print(f" 训练集: {X_train.shape[0]} 样本")
print(f" 测试集: {X_test.shape[0]} 样本")
print(f" 数值范围: [{X_train.min():.2f}, {X_train.max():.2f}]")
return X_train, y_train, X_test, y_test
if __name__ == '__main__':
X_train, y_train, X_test, y_test = load_data()
print(f"\n训练数据: {X_train.shape}")