Files
task-3-2-2-text-classification/dataset.py
2026-04-27 21:55:04 +08:00

287 lines
8.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
数据加载与向量化模块
支持两种向量化方法:
1. BoW (Bag of Words) - 词频向量
2. TF-IDF - 词频-逆文档频率向量
TF-IDF 的优势:
- 降低常见词(如"""")的权重
- 提升罕见词的信息量
- 通常效果优于简单BoW
"""
import os
import re
import csv
import math
import jieba
import numpy as np
from collections import Counter
try:
import urllib.request
import ssl
DOWNLOAD_AVAILABLE = True
except ImportError:
DOWNLOAD_AVAILABLE = False
DATASET_URL = "https://raw.githubusercontent.com/SophonPlus/ChineseNlpCorpus/master/datasets/ChnSentiCorp_htl_all/ChnSentiCorp_htl_all.csv"
def download_dataset(data_dir):
"""下载数据集(如果不存在)"""
csv_path = os.path.join(data_dir, 'ChnSentiCorp_htl_all.csv')
if os.path.exists(csv_path):
print(f"数据已存在: {csv_path}")
return True
if not DOWNLOAD_AVAILABLE:
return False
print("正在下载数据集...")
ssl_context = ssl.create_default_context()
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
try:
request = urllib.request.Request(DATASET_URL, headers={'User-Agent': 'Mozilla/5.0'})
response = urllib.request.urlopen(request, timeout=120, context=ssl_context)
os.makedirs(data_dir, exist_ok=True)
with open(csv_path, 'wb') as f:
f.write(response.read())
print(f"下载完成: {csv_path}")
return True
except Exception as e:
print(f"下载失败: {e}")
return False
def load_raw_data(data_dir):
"""加载原始数据"""
csv_path = os.path.join(data_dir, 'ChnSentiCorp_htl_all.csv')
texts, labels = [], []
with open(csv_path, 'r', encoding='utf-8') as f:
reader = csv.reader(f)
for row in reader:
if len(row) < 2:
continue
try:
label = int(row[0])
review = row[1].strip()
if review:
texts.append(review)
labels.append(label)
except (ValueError, IndexError):
continue
return texts, np.array(labels)
def tokenize(text):
"""中文分词"""
text = re.sub(r'[^\u4e00-\u9fa5a-zA-Z]', ' ', text)
words = jieba.lcut(text)
return [w for w in words if len(w) > 1]
# ==================== 向量化器 ====================
class BaseVectorizer:
"""向量化器基类"""
def fit(self, texts): pass
def transform(self, texts): pass
def fit_transform(self, texts): pass
class BoWVectorizer(BaseVectorizer):
"""
词袋模型 (Bag of Words)
原理:统计每个词在文本中出现的次数
向量维度 = 词表大小
每个维度 = 该词在本文本中出现的次数
"""
def __init__(self, max_features, max_seq_len):
self.max_features = max_features
self.max_seq_len = max_seq_len
self.vocab = {}
self.doc_freq = {} # 文档频率
self.num_docs = 0
def fit(self, texts):
"""构建词表(基于词频)"""
counter = Counter()
doc_counter = Counter() # 统计包含该词的文档数
for text in texts:
words = tokenize(text)
unique_words = set(words)
counter.update(words)
for w in unique_words:
doc_counter[w] += 1
self.num_docs = len(texts)
# 取最高频的词
most_common = counter.most_common(self.max_features)
self.vocab = {word: idx for idx, (word, _) in enumerate(most_common)}
# 记录文档频率用于TF-IDF
self.doc_freq = {w: doc_counter[w] for w in self.vocab}
print(f" BoW词表大小: {len(self.vocab)}")
return self
def transform(self, texts):
"""将文本转换为词频向量"""
vectors = []
for text in texts:
words = tokenize(text)
freq = [0] * self.max_seq_len
for i, word in enumerate(words[:self.max_seq_len]):
if word in self.vocab:
freq[i] = 1 # 二值(出现=1不出现=0
vectors.append(freq)
return np.array(vectors, dtype=np.float32)
def fit_transform(self, texts):
self.fit(texts)
return self.transform(texts)
class TFIDFVectorizer(BaseVectorizer):
"""
TF-IDF 向量器
原理:
- TF(词频) = 词在本文本中出现的次数
- IDF(逆文档频率) = log(总文档数 / 包含该词的文档数)
- TF-IDF = TF × IDF
优势:
- 降低常见无意义词的权重(如""""
- 提升罕见但有信息量的词
"""
def __init__(self, max_features, max_seq_len):
self.max_features = max_features
self.max_seq_len = max_seq_len
self.vocab = {}
self.idf = {} # 存储每个词的IDF值
self.num_docs = 0
def fit(self, texts):
"""构建词表并计算IDF"""
counter = Counter()
doc_counter = Counter()
for text in texts:
words = tokenize(text)
unique_words = set(words)
counter.update(words)
for w in unique_words:
doc_counter[w] += 1
self.num_docs = len(texts)
# 计算每个词的IDF
# IDF = log(总文档数 / 包含该词的文档数)
idf_values = {}
for word, df in doc_counter.items():
idf_values[word] = math.log(self.num_docs / (df + 1)) + 1 # 加1防零
# 取IDF值最高的词信息量最大的词
sorted_words = sorted(idf_values.items(), key=lambda x: x[1], reverse=True)
self.vocab = {word: idx for idx, (word, _) in enumerate(sorted_words[:self.max_features])}
# 保存IDF值
self.idf = {word: idf_values[word] for word in self.vocab}
print(f" TF-IDF词表大小: {len(self.vocab)}")
print(f" 平均IDF: {np.mean(list(self.idf.values())):.3f}")
return self
def transform(self, texts):
"""将文本转换为TF-IDF向量"""
vectors = []
for text in texts:
words = tokenize(text)
# 计算TF
tf = Counter(words)
tf_sum = len(words) if words else 1
# 生成向量
vec = [0.0] * self.max_seq_len
for i, word in enumerate(words[:self.max_seq_len]):
if word in self.vocab:
# TF × IDF
vec[i] = (tf[word] / tf_sum) * self.idf.get(word, 0)
vectors.append(vec)
return np.array(vectors, dtype=np.float32)
def fit_transform(self, texts):
self.fit(texts)
return self.transform(texts)
def load_data(data_dir, max_features, max_seq_len, vectorizer_type='tfidf'):
"""
加载并向量化数据
参数:
- vectorizer_type: 'tfidf''bow'
"""
if not download_dataset(data_dir):
raise RuntimeError("数据加载失败,请检查网络或手动下载数据集")
print("正在加载数据...")
texts, labels = load_raw_data(data_dir)
print(f"总评论数: {len(texts)}, 正面: {sum(labels)}, 负面: {len(labels) - sum(labels)}")
# 选择向量化器
if vectorizer_type == 'tfidf':
vectorizer = TFIDFVectorizer(max_features, max_seq_len)
vec_name = "TF-IDF"
else:
vectorizer = BoWVectorizer(max_features, max_seq_len)
vec_name = "BoW"
print(f"正在使用{vec_name}向量化...")
X = vectorizer.fit_transform(texts)
y = labels
# 打乱并划分
np.random.seed(42)
indices = np.random.permutation(len(X))
X = X[indices]
y = y[indices]
split_idx = int(len(X) * 0.8)
X_train, X_test = X[:split_idx], X[split_idx:]
y_train, y_test = y[:split_idx], y[split_idx:]
print(f"训练集: {len(X_train)}条, 测试集: {len(X_test)}")
return X_train, y_train, X_test, y_test, vectorizer
if __name__ == '__main__':
# 测试
print("=" * 60)
print("测试 TF-IDF 向量化")
print("=" * 60)
X_train, y_train, X_test, y_test, vec = load_data(
'data/ChnSentiCorp', max_features=3000, max_seq_len=100,
vectorizer_type='tfidf'
)
print(f"\nX_train shape: {X_train.shape}")
print(f"X_train sample (前5个特征): {X_train[0][:5]}")