-
Notifications
You must be signed in to change notification settings - Fork 2.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #439 from will-am/chinese_poetry
Add preprocessor for generating Chinese poetry.
- Loading branch information
Showing
6 changed files
with
217 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,111 @@ | ||
[TBD] | ||
# 中国古诗生成 | ||
|
||
## 简介 | ||
基于编码器-解码器(encoder-decoder)神经网络模型,利用全唐诗进行诗句-诗句(sequence to sequence)训练,实现给定诗句后,生成下一诗句。 | ||
|
||
模型中的编码器、解码器均使用堆叠双向LSTM (stacked bi-directional LSTM),默认均为3层,带有注意力单元(attention)。 | ||
|
||
以下是本例的简要目录结构及说明: | ||
|
||
```text | ||
. | ||
├── data # 存储训练数据及字典 | ||
│ ├── download.sh # 下载原始数据 | ||
├── README.md # 文档 | ||
├── index.html # 文档(html格式) | ||
├── preprocess.py # 原始数据预处理 | ||
├── generate.py # 生成诗句脚本 | ||
├── network_conf.py # 模型定义 | ||
├── reader.py # 数据读取接口 | ||
├── train.py # 训练脚本 | ||
└── utils.py # 定义实用工具函数 | ||
``` | ||
|
||
## 数据处理 | ||
### 原始数据来源 | ||
本例使用[中华古诗词数据库](/~https://github.com/chinese-poetry/chinese-poetry)中收集的全唐诗作为训练数据,共有约5.4万首唐诗。 | ||
|
||
### 原始数据下载 | ||
```bash | ||
cd data && ./download.sh && cd .. | ||
``` | ||
### 数据预处理 | ||
```bash | ||
python preprocess.py --datadir data/raw --outfile data/poems.txt --dictfile data/dict.txt | ||
``` | ||
|
||
上述脚本执行完后将生成处理好的训练数据poems.txt和字典dict.txt。字典的构建以字为单位,使用出现频数至少为10的字构建字典。 | ||
|
||
poems.txt中每行为一首唐诗的信息,分为三列,分别为题目、作者、诗内容。在诗内容中,诗句之间用`.`分隔。 | ||
|
||
训练数据示例: | ||
```text | ||
登鸛雀樓 王之渙 白日依山盡.黃河入海流.欲窮千里目.更上一層樓 | ||
觀獵 李白 太守耀清威.乘閑弄晚暉.江沙橫獵騎.山火遶行圍.箭逐雲鴻落.鷹隨月兔飛.不知白日暮.歡賞夜方歸 | ||
晦日重宴 陳嘉言 高門引冠蓋.下客抱支離.綺席珍羞滿.文場翰藻摛.蓂華彫上月.柳色藹春池.日斜歸戚里.連騎勒金羈 | ||
``` | ||
|
||
模型训练时,使用每一诗句作为模型输入,下一诗句作为预测目标。 | ||
|
||
|
||
## 模型训练 | ||
训练脚本[train.py](./train.py)中的命令行参数可以通过`python train.py --help`查看。主要参数说明如下: | ||
- `num_passes`: 训练pass数 | ||
- `batch_size`: batch大小 | ||
- `use_gpu`: 是否使用GPU | ||
- `trainer_count`: trainer数目,默认为1 | ||
- `save_dir_path`: 模型存储路径,默认为当前目录下models目录 | ||
- `encoder_depth`: 模型中编码器LSTM深度,默认为3 | ||
- `decoder_depth`: 模型中解码器LSTM深度,默认为3 | ||
- `train_data_path`: 训练数据路径 | ||
- `word_dict_path`: 数据字典路径 | ||
- `init_model_path`: 初始模型路径,从头训练时无需指定 | ||
|
||
### 训练执行 | ||
```bash | ||
python train.py \ | ||
--num_passes 50 \ | ||
--batch_size 256 \ | ||
--use_gpu True \ | ||
--trainer_count 1 \ | ||
--save_dir_path models \ | ||
--train_data_path data/poems.txt \ | ||
--word_dict_path data/dict.txt \ | ||
2>&1 | tee train.log | ||
``` | ||
每个pass训练结束后,模型参数将保存在models目录下。训练日志保存在train.log中。 | ||
|
||
### 最优模型参数 | ||
寻找cost最小的pass,使用该pass对应的模型参数用于后续预测。 | ||
```bash | ||
python -c 'import utils; utils.find_optiaml_pass("./train.log")' | ||
``` | ||
|
||
## 生成诗句 | ||
使用[generate.py](./generate.py)脚本对输入诗句生成下一诗句,命令行参数可通过`python generate.py --help`查看。 | ||
主要参数说明如下: | ||
- `model_path`: 训练好的模型参数文件 | ||
- `word_dict_path`: 数据字典路径 | ||
- `test_data_path`: 输入数据路径 | ||
- `batch_size`: batch大小,默认为1 | ||
- `beam_size`: beam search中搜索范围大小,默认为5 | ||
- `save_file`: 输出保存路径 | ||
- `use_gpu`: 是否使用GPU | ||
|
||
### 执行生成 | ||
例如将诗句 `孤帆遠影碧空盡` 保存在文件 `input.txt` 中作为预测下句诗的输入,执行命令: | ||
```bash | ||
python generate.py \ | ||
--model_path models/pass_00049.tar.gz \ | ||
--word_dict_path data/dict.txt \ | ||
--test_data_path input.txt \ | ||
--save_file output.txt | ||
``` | ||
生成结果将保存在文件 `output.txt` 中。对于上述示例输入,生成的诗句如下: | ||
```text | ||
-9.6987 萬 壑 清 風 黃 葉 多 | ||
-10.0737 萬 里 遠 山 紅 葉 深 | ||
-10.4233 萬 壑 清 波 紅 一 流 | ||
-10.4802 萬 壑 清 風 黃 葉 深 | ||
-10.9060 萬 壑 清 風 紅 葉 多 | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
#!/bin/bash | ||
|
||
git clone /~https://github.com/chinese-poetry/chinese-poetry.git | ||
|
||
if [ ! -d raw ] | ||
then | ||
mkdir raw | ||
fi | ||
|
||
mv chinese-poetry/json/poet.tang.* raw/ | ||
rm -rf chinese-poetry |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
# -*- coding: utf-8 -*- | ||
import os | ||
import io | ||
import re | ||
import json | ||
import click | ||
import collections | ||
|
||
|
||
def build_vocabulary(dataset, cutoff=0): | ||
dictionary = collections.defaultdict(int) | ||
for data in dataset: | ||
for sent in data[2]: | ||
for char in sent: | ||
dictionary[char] += 1 | ||
dictionary = filter(lambda x: x[1] >= cutoff, dictionary.items()) | ||
dictionary = sorted(dictionary, key=lambda x: (-x[1], x[0])) | ||
vocab, _ = list(zip(*dictionary)) | ||
return (u"<s>", u"<e>", u"<unk>") + vocab | ||
|
||
|
||
@click.command("preprocess") | ||
@click.option("--datadir", type=str, help="Path to raw data") | ||
@click.option("--outfile", type=str, help="Path to save the training data") | ||
@click.option("--dictfile", type=str, help="Path to save the dictionary file") | ||
def preprocess(datadir, outfile, dictfile): | ||
dataset = [] | ||
note_pattern1 = re.compile(u"(.*?)", re.U) | ||
note_pattern2 = re.compile(u"〖.*?〗", re.U) | ||
note_pattern3 = re.compile(u"-.*?-。?", re.U) | ||
note_pattern4 = re.compile(u"(.*$", re.U) | ||
note_pattern5 = re.compile(u"。。.*)$", re.U) | ||
note_pattern6 = re.compile(u"。。", re.U) | ||
note_pattern7 = re.compile(u"[《》「」\[\]]", re.U) | ||
print("Load raw data...") | ||
for fn in os.listdir(datadir): | ||
with io.open(os.path.join(datadir, fn), "r", encoding="utf8") as f: | ||
for data in json.load(f): | ||
title = data['title'] | ||
author = data['author'] | ||
p = "".join(data['paragraphs']) | ||
p = "".join(p.split()) | ||
p = note_pattern1.sub(u"", p) | ||
p = note_pattern2.sub(u"", p) | ||
p = note_pattern3.sub(u"", p) | ||
p = note_pattern4.sub(u"", p) | ||
p = note_pattern5.sub(u"。", p) | ||
p = note_pattern6.sub(u"。", p) | ||
p = note_pattern7.sub(u"", p) | ||
if (p == u"" or u"{" in p or u"}" in p or u"{" in p or | ||
u"}" in p or u"、" in p or u":" in p or u";" in p or | ||
u"!" in p or u"?" in p or u"●" in p or u"□" in p or | ||
u"囗" in p or u")" in p): | ||
continue | ||
paragraphs = re.split(u"。|,", p) | ||
paragraphs = filter(lambda x: len(x), paragraphs) | ||
if len(paragraphs) > 1: | ||
dataset.append((title, author, paragraphs)) | ||
|
||
print("Construct vocabularies...") | ||
vocab = build_vocabulary(dataset, cutoff=10) | ||
with io.open(dictfile, "w", encoding="utf8") as f: | ||
for v in vocab: | ||
f.write(v + "\n") | ||
|
||
print("Write processed data...") | ||
with io.open(outfile, "w", encoding="utf8") as f: | ||
for data in dataset: | ||
title = data[0] | ||
author = data[1] | ||
paragraphs = ".".join(data[2]) | ||
f.write("\t".join((title, author, paragraphs)) + "\n") | ||
|
||
|
||
if __name__ == "__main__": | ||
preprocess() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters