forked from PaddlePaddle/models
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request PaddlePaddle#477 from wanghaoshuang/fix_text_class…
…ification fix data reader error of text classification.
- Loading branch information
Showing
10 changed files
with
441 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# DeepFM 基于深度因子分解机的点击率预测模型 | ||
|
||
## 简介 | ||
|
||
[TBD] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
#!/bin/bash | ||
|
||
wget https://s3-eu-west-1.amazonaws.com/criteo-labs/dac.tar.gz | ||
tar zxf dac.tar.gz | ||
rm -f dac.tar.gz |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import os | ||
import gzip | ||
import argparse | ||
import itertools | ||
|
||
import paddle.v2 as paddle | ||
|
||
from network_conf import DeepFM | ||
import reader | ||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser(description="PaddlePaddle DeepFM example") | ||
parser.add_argument( | ||
'--model_gz_path', | ||
type=str, | ||
required=True, | ||
help="path of model parameters gz file") | ||
parser.add_argument( | ||
'--data_path', | ||
type=str, | ||
required=True, | ||
help="path of the dataset to infer") | ||
parser.add_argument( | ||
'--prediction_output_path', | ||
type=str, | ||
required=True, | ||
help="path to output the prediction") | ||
parser.add_argument( | ||
'--factor_size', | ||
type=int, | ||
default=10, | ||
help="the factor size for the factorization machine (default:10)") | ||
|
||
return parser.parse_args() | ||
|
||
|
||
def infer(): | ||
args = parse_args() | ||
|
||
paddle.init(use_gpu=False, trainer_count=1) | ||
|
||
model = DeepFM(args.factor_size, infer=True) | ||
|
||
parameters = paddle.parameters.Parameters.from_tar( | ||
gzip.open(args.model_gz_path, 'r')) | ||
|
||
inferer = paddle.inference.Inference( | ||
output_layer=model, parameters=parameters) | ||
|
||
dataset = reader.Dataset() | ||
|
||
infer_reader = paddle.batch(dataset.infer(args.data_path), batch_size=1000) | ||
|
||
with open(args.prediction_output_path, 'w') as out: | ||
for id, batch in enumerate(infer_reader()): | ||
res = inferer.infer(input=batch) | ||
predictions = [x for x in itertools.chain.from_iterable(res)] | ||
out.write('\n'.join(map(str, predictions)) + '\n') | ||
|
||
|
||
if __name__ == '__main__': | ||
infer() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
import paddle.v2 as paddle | ||
|
||
dense_feature_dim = 13 | ||
sparse_feature_dim = 117568 | ||
|
||
|
||
def fm_layer(input, factor_size, fm_param_attr): | ||
first_order = paddle.layer.fc( | ||
input=input, size=1, act=paddle.activation.Linear()) | ||
second_order = paddle.layer.factorization_machine( | ||
input=input, | ||
factor_size=factor_size, | ||
act=paddle.activation.Linear(), | ||
param_attr=fm_param_attr) | ||
out = paddle.layer.addto( | ||
input=[first_order, second_order], | ||
act=paddle.activation.Sigmoid(), | ||
bias_attr=False) | ||
return out | ||
|
||
|
||
def DeepFM(factor_size, infer=False): | ||
dense_input = paddle.layer.data( | ||
name="dense_input", | ||
type=paddle.data_type.dense_vector(dense_feature_dim)) | ||
sparse_input = paddle.layer.data( | ||
name="sparse_input", | ||
type=paddle.data_type.sparse_binary_vector(sparse_feature_dim)) | ||
sparse_input_ids = [ | ||
paddle.layer.data( | ||
name="C" + str(i), | ||
type=paddle.data_type.integer_value(sparse_feature_dim)) | ||
for i in range(1, 27) | ||
] | ||
|
||
dense_fm = fm_layer( | ||
dense_input, | ||
factor_size, | ||
fm_param_attr=paddle.attr.Param(name="DenseFeatFactors")) | ||
sparse_fm = fm_layer( | ||
sparse_input, | ||
factor_size, | ||
fm_param_attr=paddle.attr.Param(name="SparseFeatFactors")) | ||
|
||
def embedding_layer(input): | ||
return paddle.layer.embedding( | ||
input=input, | ||
size=factor_size, | ||
param_attr=paddle.attr.Param(name="SparseFeatFactors")) | ||
|
||
sparse_embed_seq = map(embedding_layer, sparse_input_ids) | ||
sparse_embed = paddle.layer.concat(sparse_embed_seq) | ||
|
||
fc1 = paddle.layer.fc( | ||
input=[sparse_embed, dense_input], | ||
size=400, | ||
act=paddle.activation.Relu()) | ||
fc2 = paddle.layer.fc(input=fc1, size=400, act=paddle.activation.Relu()) | ||
fc3 = paddle.layer.fc(input=fc2, size=400, act=paddle.activation.Relu()) | ||
|
||
predict = paddle.layer.fc( | ||
input=[dense_fm, sparse_fm, fc3], | ||
size=1, | ||
act=paddle.activation.Sigmoid()) | ||
|
||
if not infer: | ||
label = paddle.layer.data( | ||
name="label", type=paddle.data_type.dense_vector(1)) | ||
cost = paddle.layer.multi_binary_label_cross_entropy_cost( | ||
input=predict, label=label) | ||
return cost | ||
else: | ||
return predict |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
""" | ||
Preprocess Criteo dataset. This dataset was used for the Display Advertising | ||
Challenge (https://www.kaggle.com/c/criteo-display-ad-challenge). | ||
""" | ||
import os | ||
import sys | ||
import click | ||
import collections | ||
|
||
# There are 13 integer features and 26 categorical features | ||
continous_features = range(1, 14) | ||
categorial_features = range(14, 40) | ||
|
||
|
||
class CategoryDictGenerator: | ||
""" | ||
Generate dictionary for each of the categorical features | ||
""" | ||
|
||
def __init__(self, num_feature): | ||
self.dicts = [] | ||
self.num_feature = num_feature | ||
for i in range(0, num_feature): | ||
self.dicts.append(collections.defaultdict(int)) | ||
|
||
def build(self, datafile, categorial_features, cutoff=0): | ||
with open(datafile, 'r') as f: | ||
for line in f: | ||
features = line.rstrip('\n').split('\t') | ||
for i in range(0, self.num_feature): | ||
if features[categorial_features[i]] != '': | ||
self.dicts[i][features[categorial_features[i]]] += 1 | ||
for i in range(0, self.num_feature): | ||
self.dicts[i] = filter(lambda x: x[1] >= cutoff, | ||
self.dicts[i].items()) | ||
self.dicts[i] = sorted(self.dicts[i], key=lambda x: (-x[1], x[0])) | ||
vocabs, _ = list(zip(*self.dicts[i])) | ||
self.dicts[i] = dict(zip(vocabs, range(1, len(vocabs) + 1))) | ||
self.dicts[i]['<unk>'] = 0 | ||
|
||
def gen(self, idx, key): | ||
if key not in self.dicts[idx]: | ||
res = self.dicts[idx]['<unk>'] | ||
else: | ||
res = self.dicts[idx][key] | ||
return res | ||
|
||
def dicts_sizes(self): | ||
return map(len, self.dicts) | ||
|
||
|
||
class ContinuousFeatureGenerator: | ||
""" | ||
Normalize the integer features to [0, 1] by min-max normalization | ||
""" | ||
|
||
def __init__(self, num_feature): | ||
self.num_feature = num_feature | ||
self.min = [sys.maxint] * num_feature | ||
self.max = [-sys.maxint] * num_feature | ||
|
||
def build(self, datafile, continous_features): | ||
with open(datafile, 'r') as f: | ||
for line in f: | ||
features = line.rstrip('\n').split('\t') | ||
for i in range(0, self.num_feature): | ||
val = features[continous_features[i]] | ||
if val != '': | ||
val = int(val) | ||
self.min[i] = min(self.min[i], val) | ||
self.max[i] = max(self.max[i], val) | ||
|
||
def gen(self, idx, val): | ||
if val == '': | ||
return 0 | ||
val = float(val) | ||
return (val - self.min[idx]) / (self.max[idx] - self.min[idx]) | ||
|
||
|
||
@click.command("preprocess") | ||
@click.option("--datadir", type=str, help="Path to raw criteo dataset") | ||
@click.option("--outdir", type=str, help="Path to save the processed data") | ||
def preprocess(datadir, outdir): | ||
""" | ||
All the 13 integer features are normalzied to continous values and these | ||
continous features are combined into one vecotr with dimension 13. | ||
Each of the 26 categorical features are one-hot encoded and all the one-hot | ||
vectors are combined into one sparse binary vector. | ||
""" | ||
dists = ContinuousFeatureGenerator(len(continous_features)) | ||
dists.build(os.path.join(datadir, 'train.txt'), continous_features) | ||
|
||
dicts = CategoryDictGenerator(len(categorial_features)) | ||
dicts.build( | ||
os.path.join(datadir, 'train.txt'), categorial_features, cutoff=200) | ||
|
||
dict_sizes = dicts.dicts_sizes() | ||
categorial_feature_offset = [0] | ||
for i in range(1, len(categorial_features)): | ||
offset = categorial_feature_offset[i - 1] + dict_sizes[i - 1] | ||
categorial_feature_offset.append(offset) | ||
|
||
with open(os.path.join(outdir, 'train.txt'), 'w') as out: | ||
with open(os.path.join(datadir, 'train.txt'), 'r') as f: | ||
for line in f: | ||
features = line.rstrip('\n').split('\t') | ||
|
||
continous_vals = [] | ||
for i in range(0, len(continous_features)): | ||
val = dists.gen(i, features[continous_features[i]]) | ||
continous_vals.append(str(val)) | ||
categorial_vals = [] | ||
for i in range(0, len(categorial_features)): | ||
val = dicts.gen(i, features[categorial_features[ | ||
i]]) + categorial_feature_offset[i] | ||
categorial_vals.append(str(val)) | ||
|
||
continous_vals = ','.join(continous_vals) | ||
categorial_vals = ','.join(categorial_vals) | ||
label = features[0] | ||
out.write('\t'.join([continous_vals, categorial_vals, label]) + | ||
'\n') | ||
|
||
with open(os.path.join(outdir, 'test.txt'), 'w') as out: | ||
with open(os.path.join(datadir, 'test.txt'), 'r') as f: | ||
for line in f: | ||
features = line.rstrip('\n').split('\t') | ||
|
||
continous_vals = [] | ||
for i in range(0, len(continous_features)): | ||
val = dists.gen(i, features[continous_features[i] - 1]) | ||
continous_vals.append(str(val)) | ||
categorial_vals = [] | ||
for i in range(0, len(categorial_features)): | ||
val = dicts.gen(i, | ||
features[categorial_features[i] - | ||
1]) + categorial_feature_offset[i] | ||
categorial_vals.append(str(val)) | ||
|
||
continous_vals = ','.join(continous_vals) | ||
categorial_vals = ','.join(categorial_vals) | ||
out.write('\t'.join([continous_vals, categorial_vals]) + '\n') | ||
|
||
|
||
if __name__ == "__main__": | ||
preprocess() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
class Dataset: | ||
def _reader_creator(self, path, is_infer): | ||
def reader(): | ||
with open(path, 'r') as f: | ||
for line in f: | ||
features = line.rstrip('\n').split('\t') | ||
dense_feature = map(float, features[0].split(',')) | ||
sparse_feature = map(int, features[1].split(',')) | ||
if not is_infer: | ||
label = [float(features[2])] | ||
yield [dense_feature, sparse_feature | ||
] + sparse_feature + [label] | ||
else: | ||
yield [dense_feature, sparse_feature] + sparse_feature | ||
|
||
return reader | ||
|
||
def train(self, path): | ||
return self._reader_creator(path, False) | ||
|
||
def infer(self, path): | ||
return self._reader_creator(path, True) | ||
|
||
|
||
feeding = { | ||
'dense_input': 0, | ||
'sparse_input': 1, | ||
'C1': 2, | ||
'C2': 3, | ||
'C3': 4, | ||
'C4': 5, | ||
'C5': 6, | ||
'C6': 7, | ||
'C7': 8, | ||
'C8': 9, | ||
'C9': 10, | ||
'C10': 11, | ||
'C11': 12, | ||
'C12': 13, | ||
'C13': 14, | ||
'C14': 15, | ||
'C15': 16, | ||
'C16': 17, | ||
'C17': 18, | ||
'C18': 19, | ||
'C19': 20, | ||
'C20': 21, | ||
'C21': 22, | ||
'C22': 23, | ||
'C23': 24, | ||
'C24': 25, | ||
'C25': 26, | ||
'C26': 27, | ||
'label': 28 | ||
} |
Oops, something went wrong.