From 275a65f8dfb4e9733afb1ca90317bed77aa65758 Mon Sep 17 00:00:00 2001 From: gitea_eternal <401029566@qq.com> Date: Mon, 27 Apr 2026 21:46:02 +0800 Subject: [PATCH] Upload predict.py --- predict.py | 264 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 264 insertions(+) create mode 100644 predict.py diff --git a/predict.py b/predict.py new file mode 100644 index 0000000..c05d86e --- /dev/null +++ b/predict.py @@ -0,0 +1,264 @@ +# -*- coding: utf-8 -*- +""" +预测脚本 - 加载训练好的模型,测试自定义文本 + +使用方法: + python predict.py + +程序会: +1. 列出已保存的模型 +2. 让学生选择模型 +3. 加载模型和向量化器 +4. 学生输入文本,实时预测情感 +""" + +import os +import re +import sys +import jieba +import numpy as np +import math +import csv +from collections import Counter + +# 添加父目录到路径以便导入 +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + +from config import DATA_DIR, MAX_FEATURES, MAX_SEQ_LEN, HIDDEN_SIZE + + +def find_saved_models(): + """查找已保存的模型""" + models = {} + for f in os.listdir('.'): + if not f.startswith('model_') or not f.endswith('.npy'): + continue + # MLP: model_mlp_tfidf_W1.npy -> model_mlp_tfidf + # LR: model_lr_bow_W.npy -> model_lr_bow + # 需要精确匹配,避免 tfidf 被截断 + for suffix in ['_W1.npy', '_b1.npy', '_W2.npy', '_b2.npy', '_W.npy', '_b.npy']: + if f.endswith(suffix): + name = f[:-len(suffix)] + models[name] = True + break + return list(models.keys()) + + +def tokenize(text): + """中文分词""" + text = re.sub(r'[^\u4e00-\u9fa5a-zA-Z]', ' ', text) + words = jieba.lcut(text) + return [w for w in words if len(w) > 1] + + +class BoWVectorizer: + def __init__(self, max_seq_len, max_features=3000): + self.max_seq_len = max_seq_len + self.max_features = max_features + self.vocab = {} + + def fit(self, texts): + counter = Counter() + for text in texts: + words = tokenize(text) + counter.update(words) + most_common = counter.most_common(self.max_features) + self.vocab = {word: idx for idx, (word, _) in enumerate(most_common)} + return self + + def transform(self, text): + words = tokenize(text) + vec = [0] * self.max_seq_len + for i, word in enumerate(words[:self.max_seq_len]): + if word in self.vocab: + vec[i] = 1 + return np.array([vec], dtype=np.float32) + + +class TFIDFVectorizer: + def __init__(self, max_seq_len, max_features=3000): + self.max_seq_len = max_seq_len + self.max_features = max_features + self.vocab = {} + self.idf = {} + self.num_docs = 0 + + def fit(self, texts): + counter = Counter() + doc_counter = Counter() + for text in texts: + words = tokenize(text) + unique_words = set(words) + counter.update(words) + for w in unique_words: + doc_counter[w] += 1 + + self.num_docs = len(texts) + + # 按词频取最高频的词(和训练时dataset.py一致) + most_common = counter.most_common(self.max_features) + self.vocab = {word: idx for idx, (word, _) in enumerate(most_common)} + + # 计算IDF(只对词表中的词) + self.idf = {} + for word in self.vocab: + df = doc_counter.get(word, 1) + self.idf[word] = math.log(self.num_docs / (df + 1)) + 1 + + return self + + def transform(self, text): + words = tokenize(text) + tf = Counter(words) + tf_sum = len(words) if words else 1 + vec = [0.0] * self.max_seq_len + for i, word in enumerate(words[:self.max_seq_len]): + if word in self.vocab: + vec[i] = (tf[word] / tf_sum) * self.idf.get(word, 0) + return np.array([vec], dtype=np.float32) + + +def load_vectorizer(vectorizer_type): + """加载向量化器""" + csv_path = os.path.join(DATA_DIR, 'ChnSentiCorp_htl_all.csv') + if not os.path.exists(csv_path): + print("数据文件不存在,请先运行 main.py 训练模型") + return None + + print("正在加载数据构建词表...") + texts, labels = [], [] + with open(csv_path, 'r', encoding='utf-8') as f: + reader = csv.reader(f) + for row in reader: + if len(row) < 2: + continue + try: + labels.append(int(row[0])) + texts.append(row[1].strip()) + except: + continue + + # 判断向量化器类型 + is_tfidf = 'tfidf' in vectorizer_type.lower() + + if is_tfidf: + vectorizer = TFIDFVectorizer(MAX_SEQ_LEN) + vectorizer.fit(texts) + print(f" TF-IDF词表大小: {len(vectorizer.vocab)}") + else: + vectorizer = BoWVectorizer(MAX_SEQ_LEN) + vectorizer.fit(texts) + print(f" BoW词表大小: {len(vectorizer.vocab)}") + + return vectorizer + + +def load_model(model_name, model_type, input_size, hidden_size): + """加载模型""" + from model_numpy import MLP, LogisticRegression + + if model_type == 'lr': + model = LogisticRegression(input_size, 2, learning_rate=0.05) + model.W = np.load(model_name + '_W.npy') + model.b = np.load(model_name + '_b.npy') + else: # mlp + model = MLP(input_size, hidden_size, 2, learning_rate=0.05, keep_prob=1.0) + model.W1 = np.load(model_name + '_W1.npy') + model.b1 = np.load(model_name + '_b1.npy') + model.W2 = np.load(model_name + '_W2.npy') + model.b2 = np.load(model_name + '_b2.npy') + + return model + + +def predict_text(model, vectorizer, text): + """预测单条文本""" + vec = vectorizer.transform(text) + prob = model.forward(vec)[0] + pred = np.argmax(prob) + label = "正面" if pred == 1 else "负面" + confidence = prob[pred] * 100 + return label, confidence, prob + + +def main(): + print("\n" + "=" * 60) + print("文本情感预测 - 加载已训练模型") + print("=" * 60) + + # 查找已保存的模型 + models = find_saved_models() + + if not models: + print("\n未找到已保存的模型!") + print("请先运行 python main.py 训练模型") + return + + # 让用户选择模型 + print(f"\n已找到 {len(models)} 个模型:") + for i, name in enumerate(models, 1): + print(f" {i}. {name}") + + print(f"\n请选择模型编号 (1-{len(models)}): ", end="", flush=True) + try: + choice = int(sys.stdin.readline().strip()) + if choice < 1 or choice > len(models): + print("无效选择") + return + model_name = models[choice - 1] + except: + print("无效输入") + return + + # 解析模型名称获取类型 + parts = model_name.split('_') + model_type = parts[1] # 'lr' 或 'mlp' + vectorizer_type = parts[2] # 'bow' 或 'tfidf' + + print(f"\n选中的模型: {model_name}") + print(f"模型类型: {model_type.upper()}") + print(f"向量化方式: {vectorizer_type.upper()}") + + # 加载向量化器 + print("\n正在加载向量化器...") + vectorizer = load_vectorizer(vectorizer_type) + if vectorizer is None: + return + + # 加载模型 + print("正在加载模型...") + model = load_model(model_name, model_type, MAX_SEQ_LEN, HIDDEN_SIZE) + print("模型加载成功!") + + # 预测循环 + print("\n" + "=" * 60) + print("开始预测(输入文本后按回车,q退出)") + print("=" * 60) + + while True: + try: + print("\n请输入评论文本: ", end="", flush=True) + text = sys.stdin.readline().strip() + + if text.lower() == 'q': + print("再见!") + break + + if not text: + continue + + label, confidence, prob = predict_text(model, vectorizer, text) + + print(f"\n预测结果: {label}") + print(f"置信度: {confidence:.1f}%") + print(f"详细: 正面概率={prob[1]*100:.1f}%, 负面概率={prob[0]*100:.1f}%") + + except KeyboardInterrupt: + print("\n\n再见!") + break + except Exception as e: + print(f"预测出错: {e}") + + +if __name__ == '__main__': + main()