Skip to content

Commit

Permalink
Refactor test infra (#53)
Browse files Browse the repository at this point in the history
* Refactor async code of test goldens fetcher

* ci: checkout single branch
  • Loading branch information
ochafik authored Feb 9, 2025
1 parent 7eb5202 commit ff6f9a0
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 32 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ jobs:
- name: Clone
uses: actions/checkout@v4
with:
fetch-depth: 0
fetch-depth: 1
single-branch: true

- name: ccache
uses: hendrikmuhs/ccache-action@v1.2.11
Expand Down
65 changes: 36 additions & 29 deletions scripts/fetch_templates_and_goldens.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,9 @@ def needs_polyfills(self, context):
)) \
or caps.requires_typed_content

def apply(self, context):
def apply(self, context: dict):
assert isinstance(context, dict)
context = json.loads(json.dumps(context))

caps = self.original_caps
has_tools = 'tools' in context
Expand Down Expand Up @@ -349,10 +351,14 @@ def apply(self, context):
logger.info(f" ERROR: {e2} (after first error: {e1})")
return f"ERROR: {e2}"

@dataclass
class Context:
name: str
file: str
bindings: dict



async def handle_chat_template(output_folder, model_id, variant, template_src, context_files):
async def handle_chat_template(output_folder, model_id, variant, template_src, contexts: list[Context]):
if '{% generation %}' in template_src:
print('Removing {% generation %} blocks from template', file=sys.stderr)
template_src = template_src.replace('{% generation %}', '').replace('{% endgeneration %}', '')
Expand All @@ -376,33 +382,32 @@ async def handle_chat_template(output_folder, model_id, variant, template_src, c
})
caps = template.original_caps

if not context_files:
if not contexts:
print(f"{template_file} {caps_file} n/a {template_file}")
return

async with aiofiles.open(caps_file, 'w') as f:
await f.write(caps.to_json())

for context_file in context_files:
context_name = os.path.basename(context_file).replace(".json", "")
async with aiofiles.open(context_file, 'r') as f:
context = json.loads(await f.read())

if not caps.supports_tool_calls and context.get('tools') is not None:
print(f'Skipping {context_name} test as tools seem unsupported by template {template_file}', file=sys.stderr)
assert isinstance(contexts, list)
for context in contexts:
assert isinstance(context, Context)
assert isinstance(context.bindings, dict)
if not caps.supports_tool_calls and context.bindings.get('tools') is not None:
print(f'Skipping {context.name} test as tools seem unsupported by template {template_file}', file=sys.stderr)
continue

needs_tools_in_system = len(context.get('tools', [])) > 0 and not caps.supports_tools
if not caps.supports_system_role and (any(m['role'] == 'system' for m in context['messages']) or needs_tools_in_system):
needs_tools_in_system = len(context.bindings.get('tools', [])) > 0 and not caps.supports_tools
if not caps.supports_system_role and (any(m['role'] == 'system' for m in context.bindings['messages']) or needs_tools_in_system):
continue

output_file = join_cmake_path(output_folder, f'{base_name}-{context_name}.txt')
output_file = join_cmake_path(output_folder, f'{base_name}-{context.name}.txt')

output = template.apply(context)
output = template.apply(context.bindings)
async with aiofiles.open(output_file, 'w') as f:
await f.write(output)

print(f"{template_file} {caps_file} {context_file} {output_file}")
print(f"{template_file} {caps_file} {context.file} {output_file}")

async def async_hf_download(repo_id: str, filename: str) -> str:
headers = build_hf_headers()
Expand All @@ -412,8 +417,9 @@ async def async_hf_download(repo_id: str, filename: str) -> str:
response.raise_for_status()
return await response.text()

async def process_model(output_folder: str, model_id: str, context_files: list):
async def process_model(output_folder: str, model_id: str, contexts: list[Context]):
try:
print(f"Processing model {model_id}...", file=sys.stderr)
config_str = await async_hf_download(model_id, "tokenizer_config.json")

try:
Expand All @@ -424,14 +430,16 @@ async def process_model(output_folder: str, model_id: str, context_files: list):
assert 'chat_template' in config, 'No "chat_template" entry in tokenizer_config.json!'
chat_template = config['chat_template']
if isinstance(chat_template, str):
await handle_chat_template(output_folder, model_id, None, chat_template, context_files)
await handle_chat_template(output_folder, model_id, None, chat_template, contexts)
else:
await asyncio.gather(*[
handle_chat_template(output_folder, model_id, ct['name'], ct['template'], context_files)
handle_chat_template(output_folder, model_id, ct['name'], ct['template'], contexts)
for ct in chat_template
])
except Exception as e:
logger.error(f"Error processing model {model_id}: {e}")
# import traceback
# traceback.print_exc()
await handle_chat_template(output_folder, model_id, None, str(e), [])

async def async_copy_file(src: str, dst: str):
Expand All @@ -445,27 +453,26 @@ async def main():
parser.add_argument("json_context_files_or_model_ids", nargs="+", help="List of context JSON files or HuggingFace model IDs")
args = parser.parse_args()

context_files = []
contexts: list[Context] = []
model_ids = []
for file in args.json_context_files_or_model_ids:
if file.endswith('.json'):
context_files.append(file)
async with aiofiles.open(file, 'r') as f:
contexts.append(Context(
name=os.path.basename(file).replace(".json", ""),
file=file,
bindings=json.loads(await f.read())))
else:
model_ids.append(file)

output_folder = args.output_folder
if not os.path.isdir(output_folder):
os.makedirs(output_folder)

# Copy context files to the output folder asynchronously
await asyncio.gather(*[
async_copy_file(context_file, os.path.join(output_folder, os.path.basename(context_file)))
for context_file in context_files
])

# Process models concurrently
# for model_id in model_ids:
# await process_model(output_folder, model_id, contexts)
await asyncio.gather(*[
process_model(output_folder, model_id, context_files)
process_model(output_folder, model_id, contexts)
for model_id in model_ids
])

Expand Down
4 changes: 2 additions & 2 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -152,8 +152,8 @@ foreach(test_case ${CHAT_TEMPLATE_TEST_CASES})
separate_arguments(test_args UNIX_COMMAND "${test_case}")
list(GET test_args -1 last_arg)
string(REGEX REPLACE "^[^ ]+/([^ /\\]+)\\.[^.]+$" "\\1" test_name "${last_arg}")
add_test(NAME ${test_name} COMMAND $<TARGET_FILE:test-supported-template> ${test_args})
set_tests_properties(${test_name} PROPERTIES SKIP_RETURN_CODE 127)
add_test(NAME test-supported-template-${test_name} COMMAND $<TARGET_FILE:test-supported-template> ${test_args})
set_tests_properties(test-supported-template-${test_name} PROPERTIES SKIP_RETURN_CODE 127)
endforeach()

if (MINJA_FUZZTEST_ENABLED)
Expand Down

0 comments on commit ff6f9a0

Please sign in to comment.