Skip to content

Commit

Permalink
modify dpo prompt
Browse files Browse the repository at this point in the history
  • Loading branch information
Yu Wang committed Nov 12, 2024
1 parent 3318c92 commit faa2398
Show file tree
Hide file tree
Showing 13 changed files with 3,240 additions and 3,273 deletions.
12 changes: 7 additions & 5 deletions cli/rm_data_trans.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import random

if __name__ == "__main__":
df = pd.read_json("data/gan/test_data_sft.json").to_dict(orient="records")
df = pd.read_json("data/dpo/test_data_sft.json").to_dict(orient="records")
random.shuffle(df)
res = []
for item in df:
Expand All @@ -12,15 +12,17 @@
"conversations": [
{
"from": "human",
"value": f"""你的任务是根据医学实体:{json.dumps(item['units'],ensure_ascii=False)},生成一份‘肺癌’‘{item['report_type']}’.\
医学实体到材料的规则如下:1、医学实体中的键是医学实体类型,值是该实体类型对应的具体值,日期值表示在该日期下做了检查或用药;
2、生成的材料中必须包含医学实体中所有非空、非NA的键值"""
"value": f"""你的任务是根据医学实体字典:{json.dumps(item['units'],ensure_ascii=False)},生成一份‘肺癌’的‘{item['report_type']}’.\
从医学实体到{item['report_type']}的生成规则如下:\
1、医学实体字典中的键是医学实体类型,值是该实体类型对应的具体值;\
2、生成的材料中必须包含医学实体中所有的值\
3、值为NA或者为空的键值忽略""",
}
],
"chosen": {"from": "gpt", "value": item["ocr"]},
"rejected": {"from": "gpt", "value": item["sft_ocr"]},
}
)

with open("data/gan/test_data_rm.json", "w", encoding="utf-8") as f:
with open("data/dpo/test_data_rm.json", "w", encoding="utf-8") as f:
json.dump(res, f, ensure_ascii=False, indent=4)
2 changes: 1 addition & 1 deletion data/dataset_info.json
Original file line number Diff line number Diff line change
Expand Up @@ -678,7 +678,7 @@
"file_name": "gan/train_data.json"
},
"reward": {
"file_name": "gan/test_data_rm.json",
"file_name": "dpo/test_data_rm.json",
"ranking": true,
"formatting": "sharegpt",
"columns": {
Expand Down
File renamed without changes.
6,380 changes: 3,190 additions & 3,190 deletions data/gan/test_data_rm.json → data/dpo/test_data_rm.json

Large diffs are not rendered by default.

File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
### model
model_name_or_path: /mnt/windows/Users/Admin/LLM/models/qwen/qwen-rlhf/sft
model_name_or_path: ../wwyuuuu/Qwen2___5-SFT

### method
stage: dpo
do_train: true
finetuning_type: lora
lora_target: all
lora_target: q_proj,v_proj
pref_beta: 0.1
pref_loss: orpo # [sigmoid (dpo), orpo, simpo]
pref_loss: sigmoid # [sigmoid (dpo), orpo, simpo]

### dataset
dataset: reward
Expand All @@ -18,7 +18,7 @@ overwrite_cache: true
preprocessing_num_workers: 16

### output
output_dir: /mnt/windows/Users/Admin/LLM/models/qwen/qwen-rlhf/dpo
output_dir: ../wwyuuuu/dpo
logging_steps: 100
save_steps: 400
plot_loss: true
Expand All @@ -27,8 +27,8 @@ overwrite_output_dir: true
### train
per_device_train_batch_size: 2
gradient_accumulation_steps: 1
learning_rate: 5.0e-6
num_train_epochs: 5.0
learning_rate: 1.0e-5
num_train_epochs: 4.0
lr_scheduler_type: cosine
warmup_ratio: 0.1
bf16: true
Expand All @@ -40,3 +40,5 @@ val_size: 0.1
per_device_eval_batch_size: 1
eval_strategy: steps
eval_steps: 100
lora_rank: 32
lora_dropout: 0.05
File renamed without changes.
File renamed without changes.
File renamed without changes.
50 changes: 0 additions & 50 deletions examples/train_lora/internlm_full_dose_sft.yaml

This file was deleted.

55 changes: 34 additions & 21 deletions tests/eval/gan.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# 模型加载
from datetime import datetime
from llamafactory.chat import ChatModel
from llamafactory.extras.misc import torch_gc
Expand All @@ -7,48 +6,62 @@
from data_augmentation import sft_prompt
import pandas as pd

# 20B=200亿token 0.01亿字(红楼梦)

print("*****************运行评估测试************************")
print("********************运行评估测试************************")
# internlm模型更好
args = dict(
do_sample=True,
model_name_or_path="/mnt/windows/Users/Admin/LLM/models/qwen/qwen-rlhf/sft",
# adapter_name_or_path="/mnt/windows/Users/Admin/LLM/models/qwen/gan/Qwen2_5-7B-gan", # 加载之前保存的 LoRA 适配器
model_name_or_path="..wwyuuuu/Qwen2___5-SFT",
adapter_name_or_path="../DPO/wwyuuuu/dpo", # 加载之前保存的 LoRA 适配器
template="qwen", # 和训练保持一致
finetuning_type="lora", # 和训练保持一致
# quantization_bit=4,
temperature=0.95,
temperature=0.3,
top_p=0.7,
max_new_tokens=1000,
repetition_penalty=1.2,
)

if __name__ == "__main__":
torch_gc()
chat_model = ChatModel(args)
results = []
path = "/root/LLM/LLaMA-Factory/data/gan/test_data_sft.json"
output_file_path = "/root/LLM/LLaMA-Factory/data/gan/dpo.json"
ori_data = pd.read_json(path).to_dict(orient="records")

results = pd.read_json("/DPO/LLaMA-Factory/data/dpo/test_data_sft.json").to_dict(
orient="records"
)

output_file_path = "/DPO/LLaMA-Factory/data/dpo/test_data_dpo.json"

# 初始化或加载已有的结果
try:
with open(output_file_path, 'r', encoding='utf-8') as file:
saved_results = json.load(file)
except (FileNotFoundError, json.JSONDecodeError):
saved_results = []

i = 0
for item in ori_data:
for item in results:
i += 1
print(i)
# 正确包装成 messages 列表
instruction = f"你的任务是根据给定的医学实体,生成类型为:{item['report_type']}的报告, 报告中必须包含医学实体中的所有值. 医学实体:{json.dumps(item['units'],ensure_ascii=False)}"

# 将字符串 instruction 包装为消息列表,传递给 chat_model
# 正确包装成消息列表
instruction = f"""你的任务是根据医学实体字典:{json.dumps(item['units'],ensure_ascii=False)},生成一份‘肺癌’的‘{item['report_type']}’.\
从医学实体到{item['report_type']}的生成规则如下:\
1、医学实体字典中的键是医学实体类型,值是该实体类型对应的具体值;\
2、生成的材料中必须包含医学实体中所有的值\
3、值为NA或者为空的键值忽略"""
messages_2 = [{"role": "user", "content": instruction}]
response = ""

# 生成模型输出
for new_text in chat_model.stream_chat(messages_2):
print(new_text, end="")
response += new_text
print()
# 将生成的 OCR 结果存入新字段 sft_ocr

# 将生成的结果存入新字段 sft_dpo
item["dpo_ocr"] = response
saved_results.append(item)

results.append(item)
# 实时保存每次迭代后的结果到 JSON 文件
with open(output_file_path, "w", encoding="utf-8") as file:
json.dump(results, file, ensure_ascii=False, indent=4)
with open(output_file_path, 'w', encoding='utf-8') as file:
json.dump(saved_results, file, ensure_ascii=False, indent=4)

print("运行结束,结果已实时保存。")

0 comments on commit faa2398

Please sign in to comment.