Upload predict.py
This commit is contained in:
264
predict.py
Normal file
264
predict.py
Normal file
@@ -0,0 +1,264 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
预测脚本 - 加载训练好的模型,测试自定义文本
|
||||||
|
|
||||||
|
使用方法:
|
||||||
|
python predict.py
|
||||||
|
|
||||||
|
程序会:
|
||||||
|
1. 列出已保存的模型
|
||||||
|
2. 让学生选择模型
|
||||||
|
3. 加载模型和向量化器
|
||||||
|
4. 学生输入文本,实时预测情感
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import sys
|
||||||
|
import jieba
|
||||||
|
import numpy as np
|
||||||
|
import math
|
||||||
|
import csv
|
||||||
|
from collections import Counter
|
||||||
|
|
||||||
|
# 添加父目录到路径以便导入
|
||||||
|
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
|
from config import DATA_DIR, MAX_FEATURES, MAX_SEQ_LEN, HIDDEN_SIZE
|
||||||
|
|
||||||
|
|
||||||
|
def find_saved_models():
|
||||||
|
"""查找已保存的模型"""
|
||||||
|
models = {}
|
||||||
|
for f in os.listdir('.'):
|
||||||
|
if not f.startswith('model_') or not f.endswith('.npy'):
|
||||||
|
continue
|
||||||
|
# MLP: model_mlp_tfidf_W1.npy -> model_mlp_tfidf
|
||||||
|
# LR: model_lr_bow_W.npy -> model_lr_bow
|
||||||
|
# 需要精确匹配,避免 tfidf 被截断
|
||||||
|
for suffix in ['_W1.npy', '_b1.npy', '_W2.npy', '_b2.npy', '_W.npy', '_b.npy']:
|
||||||
|
if f.endswith(suffix):
|
||||||
|
name = f[:-len(suffix)]
|
||||||
|
models[name] = True
|
||||||
|
break
|
||||||
|
return list(models.keys())
|
||||||
|
|
||||||
|
|
||||||
|
def tokenize(text):
|
||||||
|
"""中文分词"""
|
||||||
|
text = re.sub(r'[^\u4e00-\u9fa5a-zA-Z]', ' ', text)
|
||||||
|
words = jieba.lcut(text)
|
||||||
|
return [w for w in words if len(w) > 1]
|
||||||
|
|
||||||
|
|
||||||
|
class BoWVectorizer:
|
||||||
|
def __init__(self, max_seq_len, max_features=3000):
|
||||||
|
self.max_seq_len = max_seq_len
|
||||||
|
self.max_features = max_features
|
||||||
|
self.vocab = {}
|
||||||
|
|
||||||
|
def fit(self, texts):
|
||||||
|
counter = Counter()
|
||||||
|
for text in texts:
|
||||||
|
words = tokenize(text)
|
||||||
|
counter.update(words)
|
||||||
|
most_common = counter.most_common(self.max_features)
|
||||||
|
self.vocab = {word: idx for idx, (word, _) in enumerate(most_common)}
|
||||||
|
return self
|
||||||
|
|
||||||
|
def transform(self, text):
|
||||||
|
words = tokenize(text)
|
||||||
|
vec = [0] * self.max_seq_len
|
||||||
|
for i, word in enumerate(words[:self.max_seq_len]):
|
||||||
|
if word in self.vocab:
|
||||||
|
vec[i] = 1
|
||||||
|
return np.array([vec], dtype=np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
class TFIDFVectorizer:
|
||||||
|
def __init__(self, max_seq_len, max_features=3000):
|
||||||
|
self.max_seq_len = max_seq_len
|
||||||
|
self.max_features = max_features
|
||||||
|
self.vocab = {}
|
||||||
|
self.idf = {}
|
||||||
|
self.num_docs = 0
|
||||||
|
|
||||||
|
def fit(self, texts):
|
||||||
|
counter = Counter()
|
||||||
|
doc_counter = Counter()
|
||||||
|
for text in texts:
|
||||||
|
words = tokenize(text)
|
||||||
|
unique_words = set(words)
|
||||||
|
counter.update(words)
|
||||||
|
for w in unique_words:
|
||||||
|
doc_counter[w] += 1
|
||||||
|
|
||||||
|
self.num_docs = len(texts)
|
||||||
|
|
||||||
|
# 按词频取最高频的词(和训练时dataset.py一致)
|
||||||
|
most_common = counter.most_common(self.max_features)
|
||||||
|
self.vocab = {word: idx for idx, (word, _) in enumerate(most_common)}
|
||||||
|
|
||||||
|
# 计算IDF(只对词表中的词)
|
||||||
|
self.idf = {}
|
||||||
|
for word in self.vocab:
|
||||||
|
df = doc_counter.get(word, 1)
|
||||||
|
self.idf[word] = math.log(self.num_docs / (df + 1)) + 1
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def transform(self, text):
|
||||||
|
words = tokenize(text)
|
||||||
|
tf = Counter(words)
|
||||||
|
tf_sum = len(words) if words else 1
|
||||||
|
vec = [0.0] * self.max_seq_len
|
||||||
|
for i, word in enumerate(words[:self.max_seq_len]):
|
||||||
|
if word in self.vocab:
|
||||||
|
vec[i] = (tf[word] / tf_sum) * self.idf.get(word, 0)
|
||||||
|
return np.array([vec], dtype=np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def load_vectorizer(vectorizer_type):
|
||||||
|
"""加载向量化器"""
|
||||||
|
csv_path = os.path.join(DATA_DIR, 'ChnSentiCorp_htl_all.csv')
|
||||||
|
if not os.path.exists(csv_path):
|
||||||
|
print("数据文件不存在,请先运行 main.py 训练模型")
|
||||||
|
return None
|
||||||
|
|
||||||
|
print("正在加载数据构建词表...")
|
||||||
|
texts, labels = [], []
|
||||||
|
with open(csv_path, 'r', encoding='utf-8') as f:
|
||||||
|
reader = csv.reader(f)
|
||||||
|
for row in reader:
|
||||||
|
if len(row) < 2:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
labels.append(int(row[0]))
|
||||||
|
texts.append(row[1].strip())
|
||||||
|
except:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 判断向量化器类型
|
||||||
|
is_tfidf = 'tfidf' in vectorizer_type.lower()
|
||||||
|
|
||||||
|
if is_tfidf:
|
||||||
|
vectorizer = TFIDFVectorizer(MAX_SEQ_LEN)
|
||||||
|
vectorizer.fit(texts)
|
||||||
|
print(f" TF-IDF词表大小: {len(vectorizer.vocab)}")
|
||||||
|
else:
|
||||||
|
vectorizer = BoWVectorizer(MAX_SEQ_LEN)
|
||||||
|
vectorizer.fit(texts)
|
||||||
|
print(f" BoW词表大小: {len(vectorizer.vocab)}")
|
||||||
|
|
||||||
|
return vectorizer
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(model_name, model_type, input_size, hidden_size):
|
||||||
|
"""加载模型"""
|
||||||
|
from model_numpy import MLP, LogisticRegression
|
||||||
|
|
||||||
|
if model_type == 'lr':
|
||||||
|
model = LogisticRegression(input_size, 2, learning_rate=0.05)
|
||||||
|
model.W = np.load(model_name + '_W.npy')
|
||||||
|
model.b = np.load(model_name + '_b.npy')
|
||||||
|
else: # mlp
|
||||||
|
model = MLP(input_size, hidden_size, 2, learning_rate=0.05, keep_prob=1.0)
|
||||||
|
model.W1 = np.load(model_name + '_W1.npy')
|
||||||
|
model.b1 = np.load(model_name + '_b1.npy')
|
||||||
|
model.W2 = np.load(model_name + '_W2.npy')
|
||||||
|
model.b2 = np.load(model_name + '_b2.npy')
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def predict_text(model, vectorizer, text):
|
||||||
|
"""预测单条文本"""
|
||||||
|
vec = vectorizer.transform(text)
|
||||||
|
prob = model.forward(vec)[0]
|
||||||
|
pred = np.argmax(prob)
|
||||||
|
label = "正面" if pred == 1 else "负面"
|
||||||
|
confidence = prob[pred] * 100
|
||||||
|
return label, confidence, prob
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("文本情感预测 - 加载已训练模型")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# 查找已保存的模型
|
||||||
|
models = find_saved_models()
|
||||||
|
|
||||||
|
if not models:
|
||||||
|
print("\n未找到已保存的模型!")
|
||||||
|
print("请先运行 python main.py 训练模型")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 让用户选择模型
|
||||||
|
print(f"\n已找到 {len(models)} 个模型:")
|
||||||
|
for i, name in enumerate(models, 1):
|
||||||
|
print(f" {i}. {name}")
|
||||||
|
|
||||||
|
print(f"\n请选择模型编号 (1-{len(models)}): ", end="", flush=True)
|
||||||
|
try:
|
||||||
|
choice = int(sys.stdin.readline().strip())
|
||||||
|
if choice < 1 or choice > len(models):
|
||||||
|
print("无效选择")
|
||||||
|
return
|
||||||
|
model_name = models[choice - 1]
|
||||||
|
except:
|
||||||
|
print("无效输入")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 解析模型名称获取类型
|
||||||
|
parts = model_name.split('_')
|
||||||
|
model_type = parts[1] # 'lr' 或 'mlp'
|
||||||
|
vectorizer_type = parts[2] # 'bow' 或 'tfidf'
|
||||||
|
|
||||||
|
print(f"\n选中的模型: {model_name}")
|
||||||
|
print(f"模型类型: {model_type.upper()}")
|
||||||
|
print(f"向量化方式: {vectorizer_type.upper()}")
|
||||||
|
|
||||||
|
# 加载向量化器
|
||||||
|
print("\n正在加载向量化器...")
|
||||||
|
vectorizer = load_vectorizer(vectorizer_type)
|
||||||
|
if vectorizer is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
# 加载模型
|
||||||
|
print("正在加载模型...")
|
||||||
|
model = load_model(model_name, model_type, MAX_SEQ_LEN, HIDDEN_SIZE)
|
||||||
|
print("模型加载成功!")
|
||||||
|
|
||||||
|
# 预测循环
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("开始预测(输入文本后按回车,q退出)")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
print("\n请输入评论文本: ", end="", flush=True)
|
||||||
|
text = sys.stdin.readline().strip()
|
||||||
|
|
||||||
|
if text.lower() == 'q':
|
||||||
|
print("再见!")
|
||||||
|
break
|
||||||
|
|
||||||
|
if not text:
|
||||||
|
continue
|
||||||
|
|
||||||
|
label, confidence, prob = predict_text(model, vectorizer, text)
|
||||||
|
|
||||||
|
print(f"\n预测结果: {label}")
|
||||||
|
print(f"置信度: {confidence:.1f}%")
|
||||||
|
print(f"详细: 正面概率={prob[1]*100:.1f}%, 负面概率={prob[0]*100:.1f}%")
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n\n再见!")
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
print(f"预测出错: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user