60 lines
1.8 KiB
Python
60 lines
1.8 KiB
Python
import numpy as np
|
||
from dataset import load_data, BoWVectorizer, TFIDFVectorizer
|
||
from train import train
|
||
import config as cfg
|
||
import pickle
|
||
import time
|
||
|
||
# <20><><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>
|
||
texts, labels = load_data()
|
||
labels = np.array(labels)
|
||
|
||
# <20><><EFBFBD><EFBFBD>ѵ<EFBFBD><D1B5><EFBFBD><EFBFBD>/<2F><><EFBFBD>Լ<EFBFBD>
|
||
np.random.seed(42)
|
||
indices = np.random.permutation(len(texts))
|
||
split = int(0.8 * len(texts))
|
||
train_idx, test_idx = indices[:split], indices[split:]
|
||
train_texts = [texts[i] for i in train_idx]
|
||
test_texts = [texts[i] for i in test_idx]
|
||
y_train, y_test = labels[train_idx], labels[test_idx]
|
||
|
||
# <20><><EFBFBD><EFBFBD><EFBFBD><EFBFBD>
|
||
if cfg.VECTORIZER_TYPE == "bow":
|
||
vec = BoWVectorizer(cfg.MAX_FEATURES)
|
||
else:
|
||
vec = TFIDFVectorizer(cfg.MAX_FEATURES)
|
||
|
||
vec.fit(train_texts)
|
||
X_train = np.array([vec.transform(t) for t in train_texts])
|
||
X_test = np.array([vec.transform(t) for t in test_texts])
|
||
|
||
# ѵ<><D1B5>
|
||
print("="*50)
|
||
print(f"ѵ<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>:\n ģ<><C4A3>: {cfg.MODEL_TYPE}\n <20><><EFBFBD><EFBFBD>: {cfg.VECTORIZER_TYPE}\n ѧϰ<D1A7><CFB0>: {cfg.LEARNING_RATE}")
|
||
print("="*50)
|
||
|
||
model, t = train(
|
||
X_train, y_train, X_test, y_test,
|
||
model_type=cfg.MODEL_TYPE,
|
||
lr=cfg.LEARNING_RATE,
|
||
epochs=cfg.NUM_EPOCHS,
|
||
use_weight=cfg.USE_CLASS_WEIGHT
|
||
)
|
||
|
||
# <20><><EFBFBD><EFBFBD>
|
||
ts = time.strftime("%m%d_%H%M%S")
|
||
name = f"model_{cfg.MODEL_TYPE}_{cfg.VECTORIZER_TYPE}_{'weighted' if cfg.USE_CLASS_WEIGHT else 'raw'}_{ts}"
|
||
|
||
if cfg.MODEL_TYPE == "lr":
|
||
np.save(f"{name}_W.npy", model.W)
|
||
np.save(f"{name}_b.npy", model.b)
|
||
else:
|
||
np.save(f"{name}_W1.npy", model.W1)
|
||
np.save(f"{name}_b1.npy", model.b1)
|
||
np.save(f"{name}_W2.npy", model.W2)
|
||
np.save(f"{name}_b2.npy", model.b2)
|
||
|
||
with open(f"{name}_vec.pkl", "wb") as f:
|
||
pickle.dump(vec, f)
|
||
|
||
print(f"\nģ<EFBFBD><EFBFBD><EFBFBD>ѱ<EFBFBD><EFBFBD><EFBFBD>: {name}_*.npy/*.pkl") |