FewRel训练流程
介绍FewRel数据集的格式,模型的训练流程、输入输出、loss定义。
FewRel数据集为json格式,用python的json包载入后,结构如下:
1 | |
数据集处理
FewRelDataset中一个样本为:
1 | |
data_loader.get_loader方法返回iter(data_loader)
next(data_loader)得到的数据为:tuple(batch_support, batch_query, batch_label)
1 | |
模型训练
1 | |
指定数据集
| 参数 | 默认值 | 备注 |
|---|---|---|
| —train | train_wiki | |
| —val | val_wiki | |
| —test | test_wiki |
任务设置
| 参数 | 默认值 | 备注 |
|---|---|---|
| —trainN | 10 | 训练时的N |
| —N | 5 | N-way |
| —K | 5 | K-shot |
| —Q | 5 | 查询集中每个类别样本数量 |
训练设置
| 参数 | 默认值 | 备注 |
|---|---|---|
| —batch_size | 4 | |
| —train_iter | 30000 | 训练迭代次数 |
| —val_iter | 1000 | 验证迭代次数 |
| —test_iter | 10000 | 测试迭代次数 |
| —val_step | 2000 | 训练val_step步之后,进行一次验证 |
| —lr | -1 | 学习速率 |
| —weight_decay | 1e-5 | weight decay |
| —only_test | false | 只进行测试 |
每次迭代从data_loader获取一个batch,而每个batch使用的类别与样本是随机选取的。
模型设置
| 参数 | 默认值 | 备注 |
|---|---|---|
| —model | proto | 模型名字 |
| —encoder | cnn | 编码器:cnn或bert |
| —max_length | 128 | 句子最大长度 |
| —hidden_size | 230 | 隐藏层单元数量 |
| —dropout | 0.0 | |
| —grad_iter | 1 | 累积grad_iter次迭代的梯度 |
| —optim | sgd | sgd / adam / adamw |
保存/载入断点
| 参数 | 默认值 | 备注 |
|---|---|---|
| —load_ckpt | None | 断点文件路径, |
| —save_ckpt | None | 若不为None,将替换默认生成的名字 |
| —ckpt_name | ‘’ | 断点名字,加在默认生成的名字后 |
本博客所有文章除特别声明外,均采用 CC BY-SA 4.0 协议 ,转载请注明出处!