From 31ef58a82b8f80fe0b29260f9170f10220c73714 Mon Sep 17 00:00:00 2001 From: Joel Natividad <1980690+jqnatividad@users.noreply.github.com> Date: Thu, 31 Oct 2024 05:40:30 -0400 Subject: [PATCH] `luau`: refactor stage management - easier to use as we abstract stage operations - better maintainability --- src/cmd/luau.rs | 70 ++++++++++++++++++++++++++++++------------------- 1 file changed, 43 insertions(+), 27 deletions(-) diff --git a/src/cmd/luau.rs b/src/cmd/luau.rs index 788230f68..9c48f1272 100644 --- a/src/cmd/luau.rs +++ b/src/cmd/luau.rs @@ -252,7 +252,6 @@ use log::{debug, info, log_enabled}; use mlua::{Lua, LuaSerdeExt, Value}; use serde::Deserialize; use simple_expand_tilde::expand_tilde; -use strum_macros::IntoStaticStr; use crate::{ config::{Config, Delimiter, DEFAULT_WTR_BUFFER_CAPACITY}, @@ -305,23 +304,40 @@ static QSV_V_LASTROW: &str = "_LASTROW"; static QSV_V_INDEX: &str = "_INDEX"; // there are 3 stages: 1-BEGIN, 2-MAIN, 3-END -#[repr(i8)] -#[derive(IntoStaticStr)] +#[derive(Copy, Clone, Debug, PartialEq)] enum Stage { - Begin = 1, - Main = 2, - End = 3, + Begin = 0, + Main = 1, + End = 2, } impl TryFrom for Stage { - type Error = (); - - fn try_from(v: i8) -> Result { - match v { - x if x == Stage::Begin as i8 => Ok(Stage::Begin), - x if x == Stage::Main as i8 => Ok(Stage::Main), - x if x == Stage::End as i8 => Ok(Stage::End), - _ => Err(()), + type Error = &'static str; + + fn try_from(value: i8) -> Result { + match value { + 0 => Ok(Stage::Begin), + 1 => Ok(Stage::Main), + 2 => Ok(Stage::End), + _ => Err("Invalid stage value"), + } + } +} + +impl Stage { + fn set_current(self) { + LUAU_STAGE.store(self as i8, Ordering::Relaxed); + } + + fn current() -> Option { + Stage::try_from(LUAU_STAGE.load(Ordering::Relaxed)).ok() + } + + const fn as_str(self) -> &'static str { + match self { + Stage::Begin => "BEGIN", + Stage::Main => "MAIN", + Stage::End => "END", } } } @@ -670,7 +686,7 @@ fn sequential_mode( globals.raw_set(QSV_V_ROWCOUNT, 0)?; if !begin_script.is_empty() { info!("Compiling and executing BEGIN script. _IDX: 0 _ROWCOUNT: 0"); - LUAU_STAGE.store(Stage::Begin as i8, Ordering::Relaxed); + Stage::Begin.set_current(); if let Err(e) = luau.load(begin_script).exec() { return fail_clierror!("BEGIN error: Failed to execute \"{begin_script}\".\n{e}"); @@ -721,7 +737,7 @@ fn sequential_mode( let mut idx = 0_u64; let mut error_count = 0_usize; - LUAU_STAGE.store(Stage::Main as i8, Ordering::Relaxed); + Stage::Main.set_current(); info!("Executing MAIN script."); let mut computed_value; @@ -864,7 +880,7 @@ fn sequential_mode( // should make for more readable END scripts. // Also, _ROWCOUNT is zero during the main script, and only set // to _IDX during the END script. - LUAU_STAGE.store(Stage::End as i8, Ordering::Relaxed); + Stage::End.set_current(); globals.raw_set(QSV_V_ROWCOUNT, idx)?; if !idx_used { // for perf reasons, we only updated _IDX in the main @@ -1000,7 +1016,7 @@ fn random_access_mode( "Compiling and executing BEGIN script. _ROWCOUNT: {row_count} _LASTROW: {}", row_count - 1 ); - LUAU_STAGE.store(Stage::Begin as i8, Ordering::Relaxed); + Stage::Begin.set_current(); if let Err(e) = luau.load(begin_script).exec() { return fail_clierror!("BEGIN error: Failed to execute \"{begin_script}\".\n{e}"); @@ -1067,7 +1083,7 @@ fn random_access_mode( progress.set_draw_target(ProgressDrawTarget::hidden()); } - LUAU_STAGE.store(Stage::Main as i8, Ordering::Relaxed); + Stage::Main.set_current(); info!( "Executing MAIN script. _INDEX: {curr_record} _ROWCOUNT: {row_count} _LASTROW: {}", row_count - 1 @@ -1216,7 +1232,7 @@ fn random_access_mode( if !end_script.is_empty() { info!("Compiling and executing END script. _ROWCOUNT: {row_count}"); - LUAU_STAGE.store(Stage::End as i8, Ordering::Relaxed); + Stage::End.set_current(); let end_value: Value = match luau.load(end_script).eval() { Ok(computed) => computed, @@ -1461,9 +1477,9 @@ fn setup_helpers( let mut log_msg = { // at which stage are we logging? // safety: this is safe to unwrap because we only set LUAU_STAGE using the Stage enum - let stage: Stage = LUAU_STAGE.load(Ordering::Relaxed).try_into().unwrap(); - let stage_str: &'static str = stage.into(); - format!("{}: ", stage_str.to_ascii_uppercase()) + let stage = Stage::current().unwrap_or(Stage::Main); + let stage_str = stage.as_str(); + format!("{stage_str}: ") }; let mut idx = 0_u8; let mut log_level = String::new(); @@ -1525,7 +1541,7 @@ fn setup_helpers( // Luau runtime error if called from END script // let qsv_break = luau.create_function(|luau, mut args: mlua::MultiValue| { - if LUAU_STAGE.load(Ordering::Relaxed) == Stage::End as i8 { + if Stage::current() == Some(Stage::End) { return helper_err!( "qsv_break", "qsv_break() can only be called from the BEGIN and MAIN scripts." @@ -1576,7 +1592,7 @@ fn setup_helpers( // or Luau runtime error if called from BEGIN or END scripts // let qsv_skip = luau.create_function(|_, ()| { - if LUAU_STAGE.load(Ordering::Relaxed) != Stage::Main as i8 { + if Stage::current() != Some(Stage::Main) { return helper_err!( "qsv_skip", "qsv_skip() can only be called from the MAIN script." @@ -1603,7 +1619,7 @@ fn setup_helpers( // A Luau runtime error is also returned if called from MAIN or END. // let qsv_autoindex = luau.create_function(|_, ()| { - if LUAU_STAGE.load(Ordering::Relaxed) != Stage::Begin as i8 { + if Stage::current() != Some(Stage::Begin) { return helper_err!( "qsv_autoindex", "qsv_autoindex() can only be called from the BEGIN script." @@ -2075,7 +2091,7 @@ fn setup_helpers( let qsv_register_lookup = luau.create_function(move |luau, (lookup_name, mut lookup_table_uri, cache_age_secs): (String, String, i64)| { const MSG_PREFIX: &str = "qsv_register_lookup() - "; - if LUAU_STAGE.load(Ordering::Relaxed) != Stage::Begin as i8 { + if Stage::current() != Some(Stage::Begin) { return helper_err!("qsv_register_lookup", "can only be called from the BEGIN script."); }