Skip to content

Commit

Permalink
upload all
Browse files Browse the repository at this point in the history
  • Loading branch information
jcwang587 committed Feb 18, 2024
1 parent abd98c7 commit ba34a0a
Showing 1 changed file with 12 additions and 20 deletions.
32 changes: 12 additions & 20 deletions xdatbus/fml01_xml2xyz.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,39 +4,31 @@
from ase.io import read, write
from xdatbus.utils import filter_files

# load the vasprun.xml file
train_set = read("run1/vasprun.xml", index="::")
test_set = read("run2/vasprun.xml", index="::")

# write the atoms object to extended xyz file with forces
write("train.xyz", train_set)
write("test.xyz", test_set)


def xml2xyz(xml_dir="./", output_path="./", train_ratio=1.0):
try:
raw_list = os.listdir(xml_dir)
xml_list = filter_files(raw_list, "vasprun.xml")
xml_list = filter_files(raw_list, "vasprun")

if len(xml_list) == 0:
raise ValueError("No vasprun.xml file found in the directory.")
raise ValueError("No vasprun file found in the directory.")

xml_list_sort = sorted(xml_list, key=lambda x: int(re.findall(r"\d+", x)[0]))

data_set = []
for xml_file in xml_list_sort:
print(f"xdatbus-func | xml2xyz: Processing {xml_file}")
xml_path = os.path.join(xml_dir, xml_file)
xml_set = read(xml_path, index="::")
if train_ratio < 1.0:
train_set = xml_set[: int(len(xml_set) * train_ratio)]
test_set = xml_set[int(len(xml_set) * train_ratio) :]
train_path = os.path.join(output_path, f"{xml_file}_train.xyz")
test_path = os.path.join(output_path, f"{xml_file}_test.xyz")
write(train_path, train_set)
write(test_path, test_set)
else:
xml_path = os.path.join(output_path, f"{xml_file}.xyz")
write(xml_path, xml_set)
data_set.append(xml_set)

if train_ratio < 1.0:
train_set = data_set[: int(len(data_set) * train_ratio)]
test_set = data_set[int(len(data_set) * train_ratio) :]
write(os.path.join(output_path, "train.xyz"), train_set)
write(os.path.join(output_path, "test.xyz"), test_set)
else:
write(os.path.join(output_path, "data.xyz"), data_set)

print("xdatbus-func | xml2xyz: Done!")

Expand Down

0 comments on commit ba34a0a

Please sign in to comment.