mdz/pytorch/ERes2Net/1_scripts/configs/eres2net.yml

126 lines
3.9 KiB
YAML
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# 数据集参数
dataset_conf:
# 过滤最短的音频长度
min_duration: 0.5
# 最长的音频长度,大于这个长度会裁剪掉
max_duration: 3
# 是否裁剪静音片段
do_vad: False
# 音频的采样率
sample_rate: 16000
# 是否对音频进行音量归一化
use_dB_normalization: True
# 对音频进行音量归一化的音量分贝值
target_dB: -20
# 训练数据的数据列表路径
train_list: 'dataset/train_list.txt'
# 评估注册的数据列表路径
enroll_list: 'dataset/cn-celeb-test/enroll_list.txt'
# 评估检验的数据列表路径
trials_list: 'dataset/cn-celeb-test/trials_list.txt'
# 是否使用PKSampler该Sampler可以保证每个说话人都有sample_per_id个样本
is_use_pksampler: False
# 使用PKSampler时设置样本数量
sample_per_id: 4
# 评估的数据要特殊处理
eval_conf:
# 评估的批量大小
batch_size: 4
# 最长的音频长度
max_duration: 20
# 数据加载器参数
dataLoader:
# 训练的批量大小
batch_size: 32
# 读取数据的线程数量
num_workers: 4
# 是否丢弃最后一个样本
drop_last: True
# 数据增强参数
aug_conf:
# 是否使用语速扰动增强
speed_perturb: True
# 使用语速增强是否分类大小翻三倍
speed_perturb_3_class: False
# 是否使用音量增强
volume_perturb: False
# 音量增强概率
volume_aug_prob: 0.2
# 噪声增强的噪声文件夹
noise_dir: 'dataset/noise'
# 噪声增强概率
noise_aug_prob: 0.2
# 是否使用SpecAug
use_spec_aug: True
# Spec增强参数
spec_aug_args:
# 随机频谱掩码大小
freq_mask_width: [ 0, 8 ]
# 随机时间掩码大小
time_mask_width: [ 0, 10 ]
# 数据预处理参数
preprocess_conf:
# 是否使用HF上的Wav2Vec2类似模型提取音频特征
use_hf_model: False
# 音频预处理方法,也可以叫特征提取方法
# 当use_hf_model为False时支持MelSpectrogram、Spectrogram、MFCC、Fbank
# 当use_hf_model为True时指定的是HuggingFace的模型或者本地路径比如facebook/w2v-bert-2.0或者./feature_models/w2v-bert-2.0
feature_method: 'Fbank'
# 当use_hf_model为False时设置API参数更参数查看对应API不清楚的可以直接删除该部分直接使用默认值。
# 当use_hf_model为True时可以设置参数use_gpu指定是否使用GPU提取特征
method_args:
sample_frequency: 16000
num_mel_bins: 80
optimizer_conf:
# 优化方法支持Adam、AdamW、SGD
optimizer: 'Adam'
# 初始学习率的大小
learning_rate: 0.0005
weight_decay: !!float 1e-5
# 学习率衰减函数支持WarmupCosineSchedulerLR、CosineAnnealingLR
scheduler: 'WarmupCosineSchedulerLR'
# 学习率衰减函数参数
scheduler_args:
min_lr: !!float 5e-6
max_lr: 0.0005
warmup_epoch: 5
model_conf:
backbone:
# 所使用的池化层支持TAP、TSTP
pooling_type: 'TSTP'
embd_dim: 192
classifier:
# 说话人数量,即分类大小
num_speakers: 2796
num_blocks: 0
loss_conf:
# 所使用的损失函数支持AAMLoss、SphereFace2、AMLoss、ARMLoss、CELoss、SubCenterLoss、TripletAngularMarginLoss
use_loss: 'AAMLoss'
# 损失函数参数
args:
margin: 0.2
scale: 32
easy_margin: False
# 是否使用损失函数margin调度器
use_margin_scheduler: True
# margin调度器参数
margin_scheduler_args:
initial_margin: 0.0
final_margin: 0.3
train_conf:
# 是否开启自动混合精度
enable_amp: False
# 是否使用Pytorch2.0的编译器
use_compile: False
# 训练的轮数
max_epoch: 60
log_interval: 100
# 所使用的模型支持模型ERes2Net、ERes2NetV2
use_model: 'ERes2Net'