mdz/pytorch/bert_cls/1_scripts/0_infer.py

120 lines
4.3 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- coding: UTF-8 -*-
# !/usr/bin/env python
# -*- coding: UTF-8 -*-
import sys
sys.path.append(R'../0_Bert-Chinese-Text-Classification-Pytorch/')
import time
import torch
import numpy as np
from train_eval import train, init_network, test
from importlib import import_module
import argparse
from utils import build_dataset, build_iterator, get_time_dif
from transformers import BertForSequenceClassification, BertTokenizer
class Config(object):
"""配置参数"""
def __init__(self, dataset):
self.model_name = 'bert'
self.train_path = dataset + '/data/train.txt' # 训练集
self.dev_path = dataset + '/data/dev.txt' # 验证集
self.test_path = dataset + '/data/test.txt' # 测试集
self.class_list = [x.strip() for x in open(
dataset + '/data/class.txt').readlines()] # 类别名单
self.save_path = dataset + '/saved_dict/' + self.model_name + '.ckpt' # 模型训练结果
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 设备
self.require_improvement = 1000 # 若超过1000batch效果还没提升则提前结束训练
self.num_classes = len(self.class_list) # 类别数
self.num_epochs = 3 # epoch数
self.batch_size = 1 # mini-batch大小
self.pad_size = 32 # 每句话处理成的长度(短填长切)
self.learning_rate = 5e-5 # 学习率
self.bert_path = '../weights/bert_pretrain'
self.tokenizer = BertTokenizer.from_pretrained(self.bert_path)
self.hidden_size = 768
PAD, CLS = '[PAD]', '[CLS]' # padding符号, bert中综合信息符号
key = {
0: 'finance',
1: 'realty',
2: 'stocks',
3: 'education',
4: 'science',
5: 'society',
6: 'politics',
7: 'sports',
8: 'game',
9: 'entertainment'
}
def tokenize_text(text, config, pad_size=32):
"""将单个文本样本转换为模型输入格式"""
token = config.tokenizer.tokenize(text)
token = [CLS] + token
seq_len = len(token)
mask = []
token_ids = config.tokenizer.convert_tokens_to_ids(token)
if pad_size:
if len(token) < pad_size:
mask = [1] * len(token_ids) + [0] * (pad_size - len(token))
token_ids += ([0] * (pad_size - len(token)))
else:
mask = [1] * pad_size
token_ids = token_ids[:pad_size]
seq_len = pad_size
return {
'input_ids': token_ids,
'attention_mask': mask,
'seq_len': seq_len
}
def single_inference(config, model, text):
# 假设 text 是一个字符串
tokenized_text = tokenize_text(text, config)
# 将 tokenized_text 转换为模型输入格式
input_ids = torch.tensor([tokenized_text['input_ids']]).to(config.device)
attention_mask = torch.tensor([tokenized_text['attention_mask']]).to(config.device)
# 模型推理
model.eval() # 设置模型为评估模式
with torch.no_grad():
outputs = model(input_ids, attention_mask)
predicted_class = torch.argmax(outputs[0], dim=1).item()
return predicted_class
if __name__ == '__main__':
dataset = '../0_Bert-Chinese-Text-Classification-Pytorch/THUCNews/' # 数据集
model_name = 'bert' # bert
x = import_module('models.' + model_name)
config = Config(dataset)
np.random.seed(1)
torch.manual_seed(1)
torch.cuda.manual_seed_all(1)
torch.backends.cudnn.deterministic = True # 保证每次结果一样
# train
model = x.Model(config).to(config.device)
# 加载预训练模型权重
infer_path = '../weights/bert.ckpt'
model.load_state_dict(torch.load(infer_path))
# 单张推理示例
text = "一起去学习啊"
predicted_class = single_inference(config, model, text)
print(f"Predicted class: {key[predicted_class]}")