FewRel训练流程
介绍FewRel数据集的格式,模型的训练流程、输入输出、loss定义。
FewRel数据集为json格式,用python的json包载入后,结构如下:
1 |
|
数据集处理
FewRelDataset
中一个样本为:
1 |
|
data_loader.get_loader
方法返回iter(data_loader)
next(data_loader)
得到的数据为:tuple(batch_support, batch_query, batch_label)
1 |
|
模型训练
1 |
|
指定数据集
参数 | 默认值 | 备注 |
---|---|---|
—train | train_wiki | |
—val | val_wiki | |
—test | test_wiki |
任务设置
参数 | 默认值 | 备注 |
---|---|---|
—trainN | 10 | 训练时的N |
—N | 5 | N-way |
—K | 5 | K-shot |
—Q | 5 | 查询集中每个类别样本数量 |
训练设置
参数 | 默认值 | 备注 |
---|---|---|
—batch_size | 4 | |
—train_iter | 30000 | 训练迭代次数 |
—val_iter | 1000 | 验证迭代次数 |
—test_iter | 10000 | 测试迭代次数 |
—val_step | 2000 | 训练val_step步之后,进行一次验证 |
—lr | -1 | 学习速率 |
—weight_decay | 1e-5 | weight decay |
—only_test | false | 只进行测试 |
每次迭代从data_loader获取一个batch,而每个batch使用的类别与样本是随机选取的。
模型设置
参数 | 默认值 | 备注 |
---|---|---|
—model | proto | 模型名字 |
—encoder | cnn | 编码器:cnn或bert |
—max_length | 128 | 句子最大长度 |
—hidden_size | 230 | 隐藏层单元数量 |
—dropout | 0.0 | |
—grad_iter | 1 | 累积grad_iter次迭代的梯度 |
—optim | sgd | sgd / adam / adamw |
保存/载入断点
参数 | 默认值 | 备注 |
---|---|---|
—load_ckpt | None | 断点文件路径, |
—save_ckpt | None | 若不为None,将替换默认生成的名字 |
—ckpt_name | ‘’ | 断点名字,加在默认生成的名字后 |
本博客所有文章除特别声明外,均采用 CC BY-SA 4.0 协议 ,转载请注明出处!