Skip to content

Commit

Permalink
[feat]: modify document structure (#226)
Browse files Browse the repository at this point in the history
  • Loading branch information
yangxudong authored Jul 1, 2022
1 parent d314a30 commit 5cd039b
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 89 deletions.
1 change: 0 additions & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@ Welcome to easy_rec's documentation!
:maxdepth: 2
:caption: TRAIN & EVAL & EXPORT

loss
train
incremental_train
eval
Expand Down
81 changes: 0 additions & 81 deletions docs/source/loss.md

This file was deleted.

92 changes: 87 additions & 5 deletions docs/source/train.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# 训练

### train_config
## train_config

- log_step_count_steps: 200 # 每200轮打印一行log

Expand Down Expand Up @@ -74,9 +74,91 @@

- 更多参数请参考[easy_rec/python/protos/train.proto](./reference.md)

### 训练命令
### 损失函数

#### Local
EasyRec支持两种损失函数配置方式:1)使用单个损失函数;2)使用多个损失函数。

#### 使用单个损失函数

| 损失函数 | 说明 |
| ------------------------------------------ | ---------------------------------------------------------- |
| CLASSIFICATION | 分类Loss,二分类为sigmoid_cross_entropy;多分类为softmax_cross_entropy |
| L2_LOSS | 平方损失 |
| SIGMOID_L2_LOSS | 对sigmoid函数的结果计算平方损失 |
| CROSS_ENTROPY_LOSS | log loss 负对数损失 |
| CIRCLE_LOSS | CoMetricLearningI2I模型专用 |
| MULTI_SIMILARITY_LOSS | CoMetricLearningI2I模型专用 |
| SOFTMAX_CROSS_ENTROPY_WITH_NEGATIVE_MINING | 自动负采样版本的多分类为softmax_cross_entropy,用在二分类任务中 |
| PAIR_WISE_LOSS | 以优化全局AUC为目标的rank loss |
| F1_REWEIGHTED_LOSS | 可以调整二分类召回率和准确率相对权重的损失函数,可有效对抗正负样本不平衡问题 |

##### 配置

通过`loss_type`配置项指定使用哪个具体的损失函数,默认值为`CLASSIFICATION`

```protobuf
{
loss_type: L2_LOSS
}
```

- F1_REWEIGHTED_LOSS 的参数配置

可以调节二分类模型recall/precision相对权重的损失函数,配置如下:

```
{
loss_type: F1_REWEIGHTED_LOSS
f1_reweight_loss {
f1_beta_square: 0.5625
}
}
```

- f1_beta_square: 大于1的值会导致模型更关注recall,小于1的值会导致模型更关注precision

- F1 分数,又称平衡F分数(balanced F Score),它被定义为精确率和召回率的调和平均数。

![f1 score](../images/other/f1_score.svg)

- 更一般的,我们定义 F_beta 分数为:

![f_beta score](../images/other/f_beta_score.svg)

- f1_beta_square 即为 上述公式中的 beta 系数的平方。

* SOFTMAX_CROSS_ENTROPY_WITH_NEGATIVE_MINING
- 支持参数配置,升级为 [support vector guided softmax loss](https://128.84.21.199/abs/1812.11317)
- 目前只在DropoutNet模型中可用,可参考《 [冷启动推荐模型DropoutNet深度解析与改进](https://zhuanlan.zhihu.com/p/475117993) 》。

#### 使用多个损失函数

目前所有排序模型和部分召回模型(如DropoutNet)支持同时使用多个损失函数,并且可以为每个损失函数配置不同的权重。

##### 配置

下面的配置可以同时使用`F1_REWEIGHTED_LOSS``PAIR_WISE_LOSS`,总的loss为这两个损失函数的加权求和。

```
losses {
loss_type: F1_REWEIGHTED_LOSS
weight: 1.0
}
losses {
loss_type: PAIR_WISE_LOSS
weight: 1.0
}
f1_reweight_loss {
f1_beta_square: 0.5625
}
```

排序模型同时使用多个损失函数的完整示例:
[cmbf_with_multi_loss.config](https://github.com/alibaba/EasyRec/blob/master/samples/model_config/cmbf_with_multi_loss.config)

## 训练命令

### Local

```bash
python -m easy_rec.python.train_eval --pipeline_config_path dwd_avazu_ctr_deepmodel.config
Expand All @@ -90,7 +172,7 @@ python -m easy_rec.python.train_eval --pipeline_config_path dwd_avazu_ctr_deepmo
--edit_config_json='{"train_config.fine_tune_checkpoint": "oss://easyrec/model.ckpt-50"}'
```

#### On PAI
### On PAI

```sql
pai -name easy_rec_ext -project algo_public
Expand Down Expand Up @@ -126,7 +208,7 @@ pai -name easy_rec_ext -project algo_public
- 如果是pai内部版,则不需要指定arn和ossHost, arn和ossHost放在-Dbuckets里面
- -Dbuckets=oss://easyrec/?role_arn=acs:ram::xxx:role/ev-ext-test-oss&host=oss-cn-beijing-internal.aliyuncs.com

#### On EMR
### On EMR

单机单卡模式:

Expand Down
4 changes: 2 additions & 2 deletions docs/source/vector_retrieve.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ pai -name easy_rec_ext -project algo_public_dev
```sql
create table doc_table(pk BIGINT,vector string) partitioned by (pt string);

INSERT OVERWRITE TABLE query_table PARTITION(pt='20190410')
INSERT OVERWRITE TABLE doc_table PARTITION(pt='20190410')
VALUES
(1, '0.1,0.2,-0.4,0.5'),
(2, '-0.1,0.8,0.4,0.5'),
Expand All @@ -58,7 +58,7 @@ VALUES
```sql
create table query_table(pk BIGINT,vector string) partitioned by (pt string);

INSERT OVERWRITE TABLE doc_table PARTITION(pt='20190410')
INSERT OVERWRITE TABLE query_table PARTITION(pt='20190410')
VALUES
(1, '0.1,0.2,0.4,0.5'),
(2, '-0.1,0.2,0.4,0.5'),
Expand Down

0 comments on commit 5cd039b

Please sign in to comment.