-
Notifications
You must be signed in to change notification settings - Fork 648
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add dsin 【论文复现赛第六期】 #750
Merged
Merged
Add dsin 【论文复现赛第六期】 #750
Changes from 10 commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
931416a
add rank model DSIN's codes
Li-fAngyU 38387cb
add sample data of dsin
Li-fAngyU 0ea3ebc
Merge branch 'PaddlePaddle:master' into master
Li-fAngyU a3be402
updata dsin config file name
Li-fAngyU 979f289
add dsin TIPC config file paddle_infer.py
Li-fAngyU 854e073
Merge branch 'master' of /~https://github.com/Li-fAngyU/PaddleRec_DSIN
Li-fAngyU 0dd564e
delete redundant config file of rank model dsin.
Li-fAngyU 89598c3
update test_tipc/configs/dsin/train_infer_python.txt file
Li-fAngyU dc7d19b
Merge branch 'PaddlePaddle:master' into master
Li-fAngyU 56cb7bb
modify file test_tipc/configs/dsin/paddle_infer.py, in order to fix m…
Li-fAngyU 2da219e
Merge branch 'master' into master
Li-fAngyU 85e5b7f
updata readme file and fix config.yaml of dsin model.
Li-fAngyU File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
mkdir raw_data | ||
cd raw_data | ||
wget https://paddlerec.bj.bcebos.com/datasets/dmr/user_profile.csv.tar.gz | ||
tar -zxvf user_profile.csv.tar.gz | ||
wget https://paddlerec.bj.bcebos.com/datasets/dmr/raw_sample.csv.tar.gz | ||
tar -zxvf raw_sample.csv.tar.gz | ||
wget https://paddlerec.bj.bcebos.com/datasets/dmr/behavior_log.csv.tar.gz | ||
tar -zxvf behavior_log.csv.tar.gz | ||
wget https://paddlerec.bj.bcebos.com/datasets/dmr/ad_feature.csv.tar.gz | ||
tar -zxvf ad_feature.csv.tar.gz |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
# Ali_Display_Ad_Click数据集 | ||
[Ali_Display_Ad_Click](https://tianchi.aliyun.com/dataset/dataDetail?dataId=56)是阿里巴巴提供的一个淘宝展示广告点击率预估数据集 | ||
|
||
## 原始数据集介绍 | ||
- 原始样本骨架raw_sample:淘宝网站中随机抽样了114万用户8天内的广告展示/点击日志(2600万条记录),构成原始的样本骨架 | ||
1. user:脱敏过的用户ID; | ||
2. adgroup_id:脱敏过的广告单元ID; | ||
3. time_stamp:时间戳; | ||
4. pid:资源位; | ||
5. nonclk:为1代表没有点击;为0代表点击; | ||
6. clk:为0代表没有点击;为1代表点击; | ||
|
||
``` | ||
user,time_stamp,adgroup_id,pid,nonclk,clk | ||
581738,1494137644,1,430548_1007,1,0 | ||
``` | ||
|
||
- 广告基本信息表ad_feature:本数据集涵盖了raw_sample中全部广告的基本信息 | ||
1. adgroup_id:脱敏过的广告ID; | ||
2. cate_id:脱敏过的商品类目ID; | ||
3. campaign_id:脱敏过的广告计划ID; | ||
4. customer: 脱敏过的广告主ID; | ||
5. brand:脱敏过的品牌ID; | ||
6. price: 宝贝的价格 | ||
``` | ||
adgroup_id,cate_id,campaign_id,customer,brand,price | ||
63133,6406,83237,1,95471,170.0 | ||
``` | ||
|
||
- 用户基本信息表user_profile:本数据集涵盖了raw_sample中全部用户的基本信息 | ||
1. userid:脱敏过的用户ID; | ||
2. cms_segid:微群ID; | ||
3. cms_group_id:cms_group_id; | ||
4. final_gender_code:性别 1:男,2:女; | ||
5. age_level:年龄层次; 1234 | ||
6. pvalue_level:消费档次,1:低档,2:中档,3:高档; | ||
7. shopping_level:购物深度,1:浅层用户,2:中度用户,3:深度用户 | ||
8. occupation:是否大学生 ,1:是,0:否 | ||
9. new_user_class_level:城市层级 | ||
``` | ||
userid,cms_segid,cms_group_id,final_gender_code,age_level,pvalue_level,shopping_level,occupation,new_user_class_level | ||
234,0,5,2,5,,3,0,3 | ||
``` | ||
|
||
- 用户的行为日志behavior_log:本数据集涵盖了raw_sample中全部用户22天内的购物行为 | ||
1. user:脱敏过的用户ID; | ||
2. time_stamp:时间戳; | ||
3. btag:行为类型, 包括以下四种:(pv:浏览),(cart:加入购物车),(fav:喜欢),(buy:购买) | ||
4. cate:脱敏过的商品类目id; | ||
5. brand: 脱敏过的品牌id; | ||
``` | ||
user,time_stamp,btag,cate,brand | ||
558157,1493741625,pv,6250,91286 | ||
``` | ||
|
||
## 预处理数据集介绍 | ||
对原始数据集中的四个文件,参考[原论文的数据预处理过程](/~https://github.com/shenweichen/DSIN/tree/master/code)对数据进行处理,形成满足DSIN论文条件且可以被reader直接读取的数据集。 | ||
数据集共有八个pkl文件,训练集和测试集各自拥有四个,以训练集为例,这四个文件为train_feat_input.pkl、train_sess_input、train_sess_length和train_label.pkl。各自存储了按0.25的采样比进行采样后的user及item特征输入,用户会话特征输入、用户会话长度和标签数据。 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
mkdir big_train | ||
mkdir big_test | ||
wget -O model_input.tar.gz https://bj.bcebos.com/v1/ai-studio-online/53e61a9bcfc54e0581044883d0f876d9841cb4d0a68848f1a1d568a84591da6f?responseContentDisposition=attachment%3B%20filename%3Dmodel_input.tar.gz&authorization=bce-auth-v1%2F0ef6765c1e494918bc0d4c3ca3e5c6d1%2F2022-04-21T01%3A43%3A00Z%2F-1%2F%2F665a728726f0569e1ef9dd423adfa40a2a5e798f86a8d5d68804a2f21cc03624 | ||
tar -zxvf model_input.tar.gz | ||
mv model_input/test_feat_input.pkl big_test/ | ||
mv model_input/test_label.pkl big_test/ | ||
mv model_input/test_sess_input.pkl big_test/ | ||
mv model_input/test_session_length.pkl big_test/ | ||
mv model_input/train_feat_input.pkl big_train/ | ||
mv model_input/train_label.pkl big_train/ | ||
mv model_input/train_sess_input.pkl big_train/ | ||
mv model_input/train_session_length.pkl big_train/ |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
runner: | ||
train_data_dir: "data/sample_data" | ||
train_reader_path: "dsin_reader" # importlib format | ||
use_gpu: True | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. demo数据下应关闭gpu |
||
use_auc: True | ||
train_batch_size: 64 | ||
epochs: 1 | ||
print_interval: 10 | ||
# model_init_path: "output_model_dmr/0" # init model | ||
model_save_path: "output_model_dsin" | ||
test_data_dir: "data/sample_data" | ||
infer_reader_path: "dsin_reader" # importlib format | ||
infer_batch_size: 64 | ||
infer_load_path: "output_model_dsin" | ||
infer_start_epoch: 0 | ||
infer_end_epoch: 1 | ||
|
||
# hyper parameters of user-defined network | ||
hyper_parameters: | ||
# optimizer config | ||
optimizer: | ||
class: Adam | ||
learning_rate: 0.002 | ||
# user feature size | ||
user_size: 265442 | ||
cms_segid_size: 97 | ||
cms_group_size: 13 | ||
final_gender_size: 2 | ||
age_level_size: 7 | ||
pvalue_level_size: 4 | ||
shopping_level_size: 3 | ||
occupation_size: 2 | ||
new_user_class_level_size: 5 | ||
|
||
# item feature size | ||
adgroup_size: 512431 | ||
cate_size: 12974 #max value + 1 | ||
campaign_size: 309448 | ||
customer_size: 195841 | ||
brand_size: 461499 #max value + 1 | ||
|
||
# context feature size | ||
pid_size: 2 | ||
|
||
# embedding size | ||
feat_embed_size: 4 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
runner: | ||
train_data_dir: "../../../datasets/Ali_Display_Ad_Click_DSIN/big_train" | ||
train_reader_path: "dsin_reader" # importlib format | ||
use_gpu: True | ||
use_auc: True | ||
train_batch_size: 4096 | ||
epochs: 1 | ||
print_interval: 50 | ||
|
||
model_save_path: "output_model_all_dsin" | ||
test_data_dir: "../../../datasets/Ali_Display_Ad_Click_DSIN/big_test" | ||
infer_reader_path: "dsin_reader" # importlib format | ||
infer_batch_size: 16384 # 2**14 | ||
infer_load_path: "output_model_all_dsin" | ||
infer_start_epoch: 0 | ||
infer_end_epoch: 1 | ||
|
||
# hyper parameters of user-defined network | ||
hyper_parameters: | ||
# optimizer config | ||
optimizer: | ||
class: Adam | ||
learning_rate: 0.00235 | ||
# user feature size | ||
user_size: 265442 | ||
cms_segid_size: 97 | ||
cms_group_size: 13 | ||
final_gender_size: 2 | ||
age_level_size: 7 | ||
pvalue_level_size: 4 | ||
shopping_level_size: 3 | ||
occupation_size: 2 | ||
new_user_class_level_size: 5 | ||
|
||
# item feature size | ||
adgroup_size: 512431 | ||
cate_size: 11859 #max value + 1 | ||
campaign_size: 309448 | ||
customer_size: 195841 | ||
brand_size: 362855 #max value + 1 | ||
|
||
# context feature size | ||
pid_size: 2 | ||
|
||
# embedding size | ||
feat_embed_size: 4 |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from __future__ import print_function | ||
import numpy as np | ||
|
||
from paddle.io import IterableDataset | ||
import pandas as pd | ||
|
||
sparse_features = [ | ||
'userid', 'adgroup_id', 'pid', 'cms_segid', 'cms_group_id', | ||
'final_gender_code', 'age_level', 'pvalue_level', 'shopping_level', | ||
'occupation', 'new_user_class_level ', 'campaign_id', 'customer', | ||
'cate_id', 'brand' | ||
] | ||
|
||
dense_features = ['price'] | ||
|
||
|
||
class RecDataset(IterableDataset): | ||
def __init__(self, file_list, config): | ||
super().__init__() | ||
self.file_list = file_list | ||
data_file = [f.split('/')[-1] for f in file_list] | ||
mode = data_file[0].split('_')[0] | ||
data_dir = file_list[0].split(data_file[0])[0] | ||
assert (mode == 'train' or mode == 'test' or mode == 'sample' | ||
), f"mode must be 'train' or 'test', but get '{mode}'" | ||
feat_input = pd.read_pickle(data_dir + mode + '_feat_input.pkl') | ||
self.sess_input = pd.read_pickle(data_dir + mode + '_sess_input.pkl') | ||
self.sess_length = pd.read_pickle(data_dir + mode + | ||
'_session_length.pkl') | ||
self.label = pd.read_pickle(data_dir + mode + '_label.pkl') | ||
if str(type(self.label)).split("'")[1] != 'numpy.ndarray': | ||
self.label = self.label.to_numpy() | ||
self.label = self.label.astype('int64') | ||
self.num_samples = self.label.shape[0] | ||
self.sparse_input = feat_input[sparse_features].to_numpy().astype( | ||
'int64') | ||
self.dense_input = feat_input[dense_features].to_numpy().reshape( | ||
-1).astype('float32') | ||
|
||
def __iter__(self): | ||
for i in range(self.num_samples): | ||
yield [ | ||
self.sparse_input[i, :], self.dense_input[i], | ||
self.sess_input[i, :, :], self.sess_length[i], self.label[i] | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import paddle | ||
import paddle.nn as nn | ||
import paddle.nn.functional as F | ||
import math | ||
|
||
import net | ||
|
||
|
||
class DygraphModel(): | ||
# define model | ||
def create_model(self, config): | ||
user_size = config.get("hyper_parameters.user_size") | ||
cms_segid_size = config.get("hyper_parameters.cms_segid_size") | ||
cms_group_size = config.get("hyper_parameters.cms_group_size") | ||
final_gender_size = config.get("hyper_parameters.final_gender_size") | ||
age_level_size = config.get("hyper_parameters.age_level_size") | ||
pvalue_level_size = config.get("hyper_parameters.pvalue_level_size") | ||
shopping_level_size = config.get( | ||
"hyper_parameters.shopping_level_size") | ||
occupation_size = config.get("hyper_parameters.occupation_size") | ||
new_user_class_level_size = config.get( | ||
"hyper_parameters.new_user_class_level_size") | ||
adgroup_size = config.get("hyper_parameters.adgroup_size") | ||
cate_size = config.get("hyper_parameters.cate_size") | ||
campaign_size = config.get("hyper_parameters.campaign_size") | ||
customer_size = config.get("hyper_parameters.customer_size") | ||
brand_size = config.get("hyper_parameters.brand_size") | ||
pid_size = config.get("hyper_parameters.pid_size") | ||
feat_embed_size = config.get("hyper_parameters.feat_embed_size") | ||
|
||
dsin_model = net.DSIN_layer( | ||
user_size, | ||
adgroup_size, | ||
pid_size, | ||
cms_segid_size, | ||
cms_group_size, | ||
final_gender_size, | ||
age_level_size, | ||
pvalue_level_size, | ||
shopping_level_size, | ||
occupation_size, | ||
new_user_class_level_size, | ||
campaign_size, | ||
customer_size, | ||
cate_size, | ||
brand_size, | ||
sparse_embed_size=feat_embed_size, | ||
l2_reg_embedding=1e-6) | ||
|
||
return dsin_model | ||
|
||
# define loss function by predicts and label | ||
def create_loss(self, pred, label): | ||
return paddle.nn.BCELoss()(pred, label) | ||
|
||
# define feeds which convert numpy of batch data to paddle.tensor | ||
def create_feeds(self, batch_data, config): | ||
data, label = (batch_data[0], batch_data[1], batch_data[2], | ||
batch_data[3]), batch_data[-1] | ||
#data, label = batch_data[0], batch_data[1] | ||
label = label.reshape([-1, 1]) | ||
return label, data | ||
|
||
# define optimizer | ||
def create_optimizer(self, dy_model, config): | ||
lr = config.get("hyper_parameters.optimizer.learning_rate", 0.001) | ||
optimizer = paddle.optimizer.Adam( | ||
learning_rate=lr, parameters=dy_model.parameters()) | ||
return optimizer | ||
|
||
# define metrics such as auc/acc | ||
# multi-task need to define multi metric | ||
def create_metrics(self): | ||
metrics_list_name = ["auc"] | ||
auc_metric = paddle.metric.Auc("ROC") | ||
metrics_list = [auc_metric] | ||
return metrics_list, metrics_list_name | ||
|
||
# construct train forward phase | ||
def train_forward(self, dy_model, metrics_list, batch_data, config): | ||
label, input_tensor = self.create_feeds(batch_data, config) | ||
|
||
pred = dy_model.forward(input_tensor) | ||
# update metrics | ||
predict_2d = paddle.concat(x=[1 - pred, pred], axis=1) | ||
metrics_list[0].update(preds=predict_2d.numpy(), labels=label.numpy()) | ||
loss = self.create_loss(pred, paddle.cast(label, "float32")) | ||
print_dict = {'loss': loss} | ||
# print_dict = None | ||
return loss, metrics_list, print_dict | ||
|
||
def infer_forward(self, dy_model, metrics_list, batch_data, config): | ||
label, input_tensor = self.create_feeds(batch_data, config) | ||
|
||
pred = dy_model.forward(input_tensor) | ||
# update metrics | ||
predict_2d = paddle.concat(x=[1 - pred, pred], axis=1) | ||
metrics_list[0].update(preds=predict_2d.numpy(), labels=label.numpy()) | ||
|
||
return metrics_list, None |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
文件可以改名为run.sh,和其他数据集保持一致。同时记得修改readme中的运行方式
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里我是和DMR model 的dataset:Ali_Display_Ad_Click对齐的,因为数据集一致。