221 lines
10 KiB
Python
221 lines
10 KiB
Python
import torch
|
||
from transformers import BertTokenizer, BertForMaskedLM, BertModel
|
||
from typing import List, Optional, Tuple, Union
|
||
from torch.nn import CrossEntropyLoss
|
||
import numpy as np
|
||
|
||
|
||
def LM_forward(
|
||
self,
|
||
embedding_output: Optional[torch.Tensor] = None,
|
||
input_ids: Optional[torch.Tensor] = None,
|
||
attention_mask: Optional[torch.Tensor] = None,
|
||
token_type_ids: Optional[torch.Tensor] = None,
|
||
position_ids: Optional[torch.Tensor] = None,
|
||
head_mask: Optional[torch.Tensor] = None,
|
||
inputs_embeds: Optional[torch.Tensor] = None,
|
||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||
labels: Optional[torch.Tensor] = None,
|
||
output_attentions: Optional[bool] = None,
|
||
output_hidden_states: Optional[bool] = None,
|
||
return_dict: Optional[bool] = None,
|
||
):
|
||
r"""
|
||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
|
||
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
|
||
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
|
||
"""
|
||
|
||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||
|
||
outputs = self.bert(
|
||
embedding_output,
|
||
input_ids,
|
||
attention_mask=attention_mask,
|
||
token_type_ids=token_type_ids,
|
||
position_ids=position_ids,
|
||
head_mask=head_mask,
|
||
inputs_embeds=inputs_embeds,
|
||
encoder_hidden_states=encoder_hidden_states,
|
||
encoder_attention_mask=encoder_attention_mask,
|
||
output_attentions=output_attentions,
|
||
output_hidden_states=output_hidden_states,
|
||
return_dict=return_dict,
|
||
)
|
||
|
||
sequence_output = outputs[0]
|
||
prediction_scores = self.cls(sequence_output)
|
||
|
||
masked_lm_loss = None
|
||
if labels is not None:
|
||
loss_fct = CrossEntropyLoss() # -100 index = padding token
|
||
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
||
|
||
if not return_dict:
|
||
output = (prediction_scores,) + outputs[2:]
|
||
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
|
||
|
||
|
||
def BERT_forward(
|
||
self,
|
||
embedding_output: Optional[torch.Tensor] = None,
|
||
input_ids: Optional[torch.Tensor] = None,
|
||
attention_mask: Optional[torch.Tensor] = None,
|
||
token_type_ids: Optional[torch.Tensor] = None,
|
||
position_ids: Optional[torch.Tensor] = None,
|
||
head_mask: Optional[torch.Tensor] = None,
|
||
inputs_embeds: Optional[torch.Tensor] = None,
|
||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||
encoder_attention_mask: Optional[torch.Tensor] = None,
|
||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||
use_cache: Optional[bool] = None,
|
||
output_attentions: Optional[bool] = None,
|
||
output_hidden_states: Optional[bool] = None,
|
||
return_dict: Optional[bool] = None,
|
||
):
|
||
r"""
|
||
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
||
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
||
the model is configured as a decoder.
|
||
encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
||
the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
|
||
|
||
- 1 for tokens that are **not masked**,
|
||
- 0 for tokens that are **masked**.
|
||
past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
||
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
||
|
||
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
|
||
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
|
||
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
|
||
use_cache (`bool`, *optional*):
|
||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
|
||
`past_key_values`).
|
||
"""
|
||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||
output_hidden_states = (
|
||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||
)
|
||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||
|
||
if self.config.is_decoder:
|
||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||
else:
|
||
use_cache = False
|
||
|
||
if input_ids is not None and inputs_embeds is not None:
|
||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||
elif input_ids is not None:
|
||
input_shape = input_ids.size()
|
||
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
|
||
elif inputs_embeds is not None:
|
||
input_shape = inputs_embeds.size()[:-1]
|
||
else:
|
||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||
|
||
batch_size, seq_length = input_shape
|
||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||
|
||
# past_key_values_length
|
||
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
||
|
||
if attention_mask is None:
|
||
attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
|
||
|
||
if token_type_ids is None:
|
||
if hasattr(self.embeddings, "token_type_ids"):
|
||
buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
|
||
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
|
||
token_type_ids = buffered_token_type_ids_expanded
|
||
else:
|
||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
||
|
||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
|
||
|
||
# If a 2D or 3D attention mask is provided for the cross-attention
|
||
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||
if self.config.is_decoder and encoder_hidden_states is not None:
|
||
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
||
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
||
if encoder_attention_mask is None:
|
||
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
||
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
||
else:
|
||
encoder_extended_attention_mask = None
|
||
|
||
# Prepare head mask if needed
|
||
# 1.0 in head_mask indicate we keep the head
|
||
# attention_probs has shape bsz x n_heads x N x N
|
||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||
|
||
# embedding_output = self.embeddings(
|
||
# input_ids=input_ids,
|
||
# position_ids=position_ids,
|
||
# token_type_ids=token_type_ids,
|
||
# inputs_embeds=inputs_embeds,
|
||
# past_key_values_length=past_key_values_length,
|
||
# )
|
||
encoder_outputs = self.encoder(
|
||
embedding_output,
|
||
attention_mask=extended_attention_mask,
|
||
head_mask=head_mask,
|
||
encoder_hidden_states=encoder_hidden_states,
|
||
encoder_attention_mask=encoder_extended_attention_mask,
|
||
past_key_values=past_key_values,
|
||
use_cache=use_cache,
|
||
output_attentions=output_attentions,
|
||
output_hidden_states=output_hidden_states,
|
||
return_dict=return_dict,
|
||
)
|
||
sequence_output = encoder_outputs[0]
|
||
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
||
|
||
if not return_dict:
|
||
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
||
|
||
|
||
|
||
# 加载BERT模型和Tokenizer
|
||
model_path = R'../weights/bert-base-cased'
|
||
tokenizer_path = R'../weight/bert-base-cased'
|
||
# 消除embedding
|
||
BertModel.forward = BERT_forward
|
||
BertForMaskedLM.forward = LM_forward
|
||
# 导入去掉embedding后的模型
|
||
tokenizer = BertTokenizer.from_pretrained(tokenizer_path)
|
||
model = BertForMaskedLM.from_pretrained(model_path, return_dict=False)
|
||
model.eval()
|
||
|
||
# 示例文本
|
||
text = "I want to [MASK] a new car."
|
||
|
||
# 对文本进行截断或填充,使编码长度固定为128
|
||
encoded_inputs = tokenizer.encode_plus(text, add_special_tokens=True, max_length=128, padding='max_length', truncation=True, return_tensors='pt')
|
||
|
||
# 获取输入的token张量和相应的位置标记
|
||
input_ids = encoded_inputs['input_ids']
|
||
mask_token_index = torch.where(input_ids == tokenizer.mask_token_id)[1]
|
||
|
||
# 前向传播获取模型输出
|
||
attention_mask = torch.ones(((1, 128)))
|
||
|
||
embedding_out = model.bert.embeddings(input_ids)
|
||
|
||
# 模型导出
|
||
embedding_out.detach().numpy().astype(np.float32).tofile("../2_compile/qtset/bert/embedding_out.ftmp")
|
||
input_ids.detach().numpy().astype(np.float32).tofile("../2_compile/qtset/bert/input_ids.ftmp")
|
||
attention_mask.detach().numpy().astype(np.float32).tofile("../2_compile/qtset/bert/attention_mask.ftmp")
|
||
traced_model = torch.jit.trace(model, [embedding_out, input_ids, attention_mask], strict=False)
|
||
traced_model.save('../2_compile/fmodel/bert_mask_traced.pt')
|
||
print("trace model success !!!")
|
||
|
||
|
||
|