Files
2026-04-27 21:46:20 +08:00

265 lines
7.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
预测脚本 - 加载训练好的模型,测试自定义文本
使用方法:
python predict.py
程序会:
1. 列出已保存的模型
2. 让学生选择模型
3. 加载模型和向量化器
4. 学生输入文本,实时预测情感
"""
import os
import re
import sys
import jieba
import numpy as np
import math
import csv
from collections import Counter
# 添加父目录到路径以便导入
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from config import DATA_DIR, MAX_FEATURES, MAX_SEQ_LEN, HIDDEN_SIZE
def find_saved_models():
"""查找已保存的模型"""
models = {}
for f in os.listdir('.'):
if not f.startswith('model_') or not f.endswith('.npy'):
continue
# MLP: model_mlp_tfidf_W1.npy -> model_mlp_tfidf
# LR: model_lr_bow_W.npy -> model_lr_bow
# 需要精确匹配,避免 tfidf 被截断
for suffix in ['_W1.npy', '_b1.npy', '_W2.npy', '_b2.npy', '_W.npy', '_b.npy']:
if f.endswith(suffix):
name = f[:-len(suffix)]
models[name] = True
break
return list(models.keys())
def tokenize(text):
"""中文分词"""
text = re.sub(r'[^\u4e00-\u9fa5a-zA-Z]', ' ', text)
words = jieba.lcut(text)
return [w for w in words if len(w) > 1]
class BoWVectorizer:
def __init__(self, max_seq_len, max_features=3000):
self.max_seq_len = max_seq_len
self.max_features = max_features
self.vocab = {}
def fit(self, texts):
counter = Counter()
for text in texts:
words = tokenize(text)
counter.update(words)
most_common = counter.most_common(self.max_features)
self.vocab = {word: idx for idx, (word, _) in enumerate(most_common)}
return self
def transform(self, text):
words = tokenize(text)
vec = [0] * self.max_seq_len
for i, word in enumerate(words[:self.max_seq_len]):
if word in self.vocab:
vec[i] = 1
return np.array([vec], dtype=np.float32)
class TFIDFVectorizer:
def __init__(self, max_seq_len, max_features=3000):
self.max_seq_len = max_seq_len
self.max_features = max_features
self.vocab = {}
self.idf = {}
self.num_docs = 0
def fit(self, texts):
counter = Counter()
doc_counter = Counter()
for text in texts:
words = tokenize(text)
unique_words = set(words)
counter.update(words)
for w in unique_words:
doc_counter[w] += 1
self.num_docs = len(texts)
# 按词频取最高频的词和训练时dataset.py一致
most_common = counter.most_common(self.max_features)
self.vocab = {word: idx for idx, (word, _) in enumerate(most_common)}
# 计算IDF只对词表中的词
self.idf = {}
for word in self.vocab:
df = doc_counter.get(word, 1)
self.idf[word] = math.log(self.num_docs / (df + 1)) + 1
return self
def transform(self, text):
words = tokenize(text)
tf = Counter(words)
tf_sum = len(words) if words else 1
vec = [0.0] * self.max_seq_len
for i, word in enumerate(words[:self.max_seq_len]):
if word in self.vocab:
vec[i] = (tf[word] / tf_sum) * self.idf.get(word, 0)
return np.array([vec], dtype=np.float32)
def load_vectorizer(vectorizer_type):
"""加载向量化器"""
csv_path = os.path.join(DATA_DIR, 'ChnSentiCorp_htl_all.csv')
if not os.path.exists(csv_path):
print("数据文件不存在,请先运行 main.py 训练模型")
return None
print("正在加载数据构建词表...")
texts, labels = [], []
with open(csv_path, 'r', encoding='utf-8') as f:
reader = csv.reader(f)
for row in reader:
if len(row) < 2:
continue
try:
labels.append(int(row[0]))
texts.append(row[1].strip())
except:
continue
# 判断向量化器类型
is_tfidf = 'tfidf' in vectorizer_type.lower()
if is_tfidf:
vectorizer = TFIDFVectorizer(MAX_SEQ_LEN)
vectorizer.fit(texts)
print(f" TF-IDF词表大小: {len(vectorizer.vocab)}")
else:
vectorizer = BoWVectorizer(MAX_SEQ_LEN)
vectorizer.fit(texts)
print(f" BoW词表大小: {len(vectorizer.vocab)}")
return vectorizer
def load_model(model_name, model_type, input_size, hidden_size):
"""加载模型"""
from model_numpy import MLP, LogisticRegression
if model_type == 'lr':
model = LogisticRegression(input_size, 2, learning_rate=0.05)
model.W = np.load(model_name + '_W.npy')
model.b = np.load(model_name + '_b.npy')
else: # mlp
model = MLP(input_size, hidden_size, 2, learning_rate=0.05, keep_prob=1.0)
model.W1 = np.load(model_name + '_W1.npy')
model.b1 = np.load(model_name + '_b1.npy')
model.W2 = np.load(model_name + '_W2.npy')
model.b2 = np.load(model_name + '_b2.npy')
return model
def predict_text(model, vectorizer, text):
"""预测单条文本"""
vec = vectorizer.transform(text)
prob = model.forward(vec)[0]
pred = np.argmax(prob)
label = "正面" if pred == 1 else "负面"
confidence = prob[pred] * 100
return label, confidence, prob
def main():
print("\n" + "=" * 60)
print("文本情感预测 - 加载已训练模型")
print("=" * 60)
# 查找已保存的模型
models = find_saved_models()
if not models:
print("\n未找到已保存的模型!")
print("请先运行 python main.py 训练模型")
return
# 让用户选择模型
print(f"\n已找到 {len(models)} 个模型:")
for i, name in enumerate(models, 1):
print(f" {i}. {name}")
print(f"\n请选择模型编号 (1-{len(models)}): ", end="", flush=True)
try:
choice = int(sys.stdin.readline().strip())
if choice < 1 or choice > len(models):
print("无效选择")
return
model_name = models[choice - 1]
except:
print("无效输入")
return
# 解析模型名称获取类型
parts = model_name.split('_')
model_type = parts[1] # 'lr' 或 'mlp'
vectorizer_type = parts[2] # 'bow' 或 'tfidf'
print(f"\n选中的模型: {model_name}")
print(f"模型类型: {model_type.upper()}")
print(f"向量化方式: {vectorizer_type.upper()}")
# 加载向量化器
print("\n正在加载向量化器...")
vectorizer = load_vectorizer(vectorizer_type)
if vectorizer is None:
return
# 加载模型
print("正在加载模型...")
model = load_model(model_name, model_type, MAX_SEQ_LEN, HIDDEN_SIZE)
print("模型加载成功!")
# 预测循环
print("\n" + "=" * 60)
print("开始预测输入文本后按回车q退出")
print("=" * 60)
while True:
try:
print("\n请输入评论文本: ", end="", flush=True)
text = sys.stdin.readline().strip()
if text.lower() == 'q':
print("再见!")
break
if not text:
continue
label, confidence, prob = predict_text(model, vectorizer, text)
print(f"\n预测结果: {label}")
print(f"置信度: {confidence:.1f}%")
print(f"详细: 正面概率={prob[1]*100:.1f}%, 负面概率={prob[0]*100:.1f}%")
except KeyboardInterrupt:
print("\n\n再见!")
break
except Exception as e:
print(f"预测出错: {e}")
if __name__ == '__main__':
main()