mdz/pytorch/TextCNN/1_scripts/2_save_infer.py

90 lines
2.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- coding: UTF-8 -*-
# !/usr/bin/env python
# -*- coding: UTF-8 -*-
import sys
sys.path.append(R'../0_Chinese-Text-Classification-Pytorch/')
import torch
import torch.nn as nn
import pickle as pkl
import numpy as np
from importlib import import_module
key = {
0: 'finance',
1: 'realty',
2: 'stocks',
3: 'education',
4: 'science',
5: 'society',
6: 'politics',
7: 'sports',
8: 'game',
9: 'entertainment'
}
class Predict:
def __init__(self, model_name='TextCNN', dataset='../0_Chinese-Text-Classification-Pytorch/THUCNews/', embedding='embedding_SougouNews.npz', use_word=False):
if use_word:
self.tokenizer = lambda x: x.split(' ') # 以空格隔开word-level
else:
self.tokenizer = lambda x: [y for y in x] # char-level
self.x = import_module('models.' + model_name)
self.config = self.x.Config(dataset, embedding)
self.vocab = pkl.load(open(self.config.vocab_path, 'rb'))
self.pad_size = self.config.pad_size
self.model = self.x.Model(self.config).to('cpu')
self.model.load_state_dict(torch.load(self.config.save_path, map_location='cpu'))
def build_predict_text(self, texts):
words_lines = []
seq_lens = []
for text in texts:
words_line = []
token = self.tokenizer(text)
seq_len = len(token)
if self.pad_size:
if len(token) < self.pad_size:
token.extend(['<PAD>'] * (self.pad_size - len(token)))
else:
token = token[:self.pad_size]
seq_len = self.pad_size
# word to id
for word in token:
words_line.append(self.vocab.get(word, self.vocab.get('<UNK>')))
words_lines.append(words_line)
seq_lens.append(seq_len)
return torch.LongTensor(words_lines), torch.LongTensor(seq_lens)
def predict_trace(self, query):
query = [query]
# 返回预测的索引
data = self.build_predict_text(query)
with torch.no_grad():
embedding_out = self.model.embedding(data[0])
trace_model = torch.jit.load(TRACE_PATH)
outputs_trace = trace_model(embedding_out)
print(outputs_trace)
num_trace = torch.argmax(outputs_trace)
return key[int(num_trace)]
if __name__ == "__main__":
# 参数设置
TRACE_PATH = R'../2_compile/fmodel/TextCNN_traced.pt'
pred = Predict('TextCNN')
# 预测一条
# query = "学费太贵怎么办?"
# query = "金融怎么样"
# query = "今天股票涨了吗?"
query = "昨天游戏通关了吗?"
# query = "明天打球去啊"
print(pred.predict_trace(query))