Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor split function with tests #811

Merged
merged 18 commits into from
Jun 5, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from pathlib import Path

import pytest

pytest_plugins = ("jupyter_server.pytest_plugin",)
Expand All @@ -6,3 +8,22 @@
@pytest.fixture
def jp_server_config(jp_server_config):
return {"ServerApp": {"jpserver_extensions": {"jupyter_ai": True}}}


@pytest.fixture(scope="session")
def static_test_files_dir() -> Path:
return (
Path(__file__).parent.resolve()
/ "packages"
/ "jupyter-ai"
/ "jupyter_ai"
/ "tests"
/ "static"
)


@pytest.fixture
def jp_ai_staging_dir(jp_data_dir: Path) -> Path:
staging_area = jp_data_dir / "scheduler_staging_area"
staging_area.mkdir()
return staging_area
25 changes: 16 additions & 9 deletions packages/jupyter-ai/jupyter_ai/document_loaders/directory.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,13 @@ def flatten(*chunk_lists):
return list(itertools.chain(*chunk_lists))


def split(path, all_files: bool, splitter):
chunks = []

def collect_filepaths(path, all_files: bool):
"""Selects eligible files, i.e.,
1. Files not in excluded directories, and
2. Files that are in the valid file extensions list
Called from the `split` function.
Returns all the filepaths to eligible files.
"""
# Check if the path points to a single file
if os.path.isfile(path):
andrii-i marked this conversation as resolved.
Show resolved Hide resolved
filepaths = [Path(path)]
andrii-i marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -125,17 +129,20 @@ def split(path, all_files: bool, splitter):
d for d in subdirs if not (d[0] == "." or d in EXCLUDE_DIRS)
]
filenames = [f for f in filenames if not f[0] == "."]
filepaths += [Path(os.path.join(dir, filename)) for filename in filenames]
filepaths.extend([Path(dir) / filename for filename in filenames])
VALID_EXTS = {j.lower() for j in SUPPORTED_EXTS}
andrii-i marked this conversation as resolved.
Show resolved Hide resolved
filepaths = [fp for fp in filepaths if fp.suffix.lower() in VALID_EXTS]
return filepaths

for filepath in filepaths:
# Lower case everything to make sure file extension comparisons are not case sensitive
if filepath.suffix.lower() not in {j.lower() for j in SUPPORTED_EXTS}:
continue

def split(path, all_files: bool, splitter):
"""Splits files into chunks for vector db in RAG"""
chunks = []
filepaths = collect_filepaths(path, all_files)
for filepath in filepaths:
document = dask.delayed(path_to_doc)(filepath)
chunk = dask.delayed(split_document)(document, splitter)
chunks.append(chunk)

flattened_chunks = dask.delayed(flatten)(*chunks)
return flattened_chunks

Expand Down
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Hidden temp text file.
10 changes: 10 additions & 0 deletions packages/jupyter-ai/jupyter_ai/tests/static/file0.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
<!DOCTYPE html>
<html>
<head><meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Notebook</title>
</head>
<body>
<div>This is the notebook content</div>
</body>
</html>
1 change: 1 addition & 0 deletions packages/jupyter-ai/jupyter_ai/tests/static/file1.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
This is a temp test text file.
3 changes: 3 additions & 0 deletions packages/jupyter-ai/jupyter_ai/tests/static/file2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
import os

print("Hello World")
2 changes: 2 additions & 0 deletions packages/jupyter-ai/jupyter_ai/tests/static/file3.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Column1, Column2
Test1, test2
Empty file.
Binary file not shown.
56 changes: 56 additions & 0 deletions packages/jupyter-ai/jupyter_ai/tests/test_directory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import os
import shutil
from pathlib import Path
from typing import Tuple

import pytest
from jupyter_ai.document_loaders.directory import collect_filepaths


@pytest.fixture
def staging_dir(static_test_files_dir, jp_ai_staging_dir) -> Path:
file1_path = static_test_files_dir / ".hidden_file.pdf"
file2_path = static_test_files_dir / ".hidden_file.txt"
file3_path = static_test_files_dir / "file0.html"
file4_path = static_test_files_dir / "file1.txt"
file5_path = static_test_files_dir / "file2.py"
file6_path = static_test_files_dir / "file3.csv"
file7_path = static_test_files_dir / "file3.xyz"
file8_path = static_test_files_dir / "file4.pdf"

job_staging_dir = jp_ai_staging_dir / "TestDir"
job_staging_dir.mkdir()
job_staging_subdir = job_staging_dir / "subdir"
job_staging_subdir.mkdir()
job_staging_hiddendir = job_staging_dir / ".hidden_dir"
job_staging_hiddendir.mkdir()

shutil.copy2(file1_path, job_staging_dir)
shutil.copy2(file2_path, job_staging_subdir)
shutil.copy2(file3_path, job_staging_dir)
shutil.copy2(file4_path, job_staging_subdir)
shutil.copy2(file5_path, job_staging_subdir)
shutil.copy2(file6_path, job_staging_hiddendir)
shutil.copy2(file7_path, job_staging_subdir)
shutil.copy2(file8_path, job_staging_hiddendir)

return job_staging_dir


def test_collect_filepaths(staging_dir):
"""
Test that the number of valid files for `/learn` is correct.
i.e., the `collect_filepaths` function only selects files that are
1. Not in the the excluded directories and
2. Are in the valid file extensions list.
"""
all_files = False
staging_dir_filepath = staging_dir
# Call the function we want to test
result = collect_filepaths(staging_dir_filepath, all_files)

assert len(result) == 3 # Test number of valid files

filenames = [fp.name for fp in result]
assert "file0.html" in filenames # Check that valid file is included
assert "file3.xyz" not in filenames # Check that invalid file is excluded
Loading