diff --git a/README.md b/README.md index 1f5d6ad..d8e1a0a 100644 --- a/README.md +++ b/README.md @@ -3204,7 +3204,10 @@ mod foo 'PATH' Which loads the module's source file from `PATH`, instead of from the usual locations. A leading `~/` in `PATH` is replaced with the current user's home -directory. +directory. `PATH` may point to the module source file itself, or to a directory +containing the module source file with the name `mod.just`, `justfile`, or +`.justfile`. In the latter two cases, the module file may have any +capitalization. Environment files are only loaded for the root justfile, and loaded environment variables are available in submodules. Settings in submodules that affect diff --git a/src/compiler.rs b/src/compiler.rs index a977be5..119a0af 100644 --- a/src/compiler.rs +++ b/src/compiler.rs @@ -43,17 +43,12 @@ impl Compiler { } => { let parent = current.path.parent().unwrap(); - let import = if let Some(relative) = relative { - let path = parent.join(Self::expand_tilde(&relative.cooked)?); + let relative = relative + .as_ref() + .map(|relative| Self::expand_tilde(&relative.cooked)) + .transpose()?; - if path.is_file() { - Some(path) - } else { - None - } - } else { - Self::find_module_file(parent, *name)? - }; + let import = Self::find_module_file(parent, *name, relative.as_deref())?; if let Some(import) = import { if current.file_path.contains(&import) { @@ -111,19 +106,63 @@ impl Compiler { }) } - fn find_module_file<'src>(parent: &Path, module: Name<'src>) -> RunResult<'src, Option> { - let mut candidates = vec![format!("{module}.just"), format!("{module}/mod.just")] - .into_iter() - .filter(|path| parent.join(path).is_file()) - .collect::>(); + fn find_module_file<'src>( + parent: &Path, + module: Name<'src>, + path: Option<&Path>, + ) -> RunResult<'src, Option> { + let mut candidates = Vec::new(); - let directory = parent.join(module.lexeme()); + if let Some(path) = path { + let full = parent.join(path); - if directory.exists() { - let entries = fs::read_dir(&directory).map_err(|io_error| SearchError::Io { - io_error, - directory: directory.clone(), - })?; + if full.is_file() { + return Ok(Some(full)); + } + + candidates.push((path.join("mod.just"), true)); + + for name in search::JUSTFILE_NAMES { + candidates.push((path.join(name), false)); + } + } else { + candidates.push((format!("{module}.just").into(), true)); + candidates.push((format!("{module}/mod.just").into(), true)); + + for name in search::JUSTFILE_NAMES { + candidates.push((format!("{module}/{name}").into(), false)); + } + } + + let mut grouped = BTreeMap::>::new(); + + for (candidate, case_sensitive) in candidates { + let candidate = parent.join(candidate).lexiclean(); + grouped + .entry(candidate.parent().unwrap().into()) + .or_default() + .push((candidate, case_sensitive)); + } + + let mut found = Vec::new(); + + for (directory, candidates) in grouped { + let entries = match fs::read_dir(&directory) { + Ok(entries) => entries, + Err(io_error) => { + if io_error.kind() == io::ErrorKind::NotFound { + continue; + } + + return Err( + SearchError::Io { + io_error, + directory, + } + .into(), + ); + } + }; for entry in entries { let entry = entry.map_err(|io_error| SearchError::Io { @@ -132,22 +171,34 @@ impl Compiler { })?; if let Some(name) = entry.file_name().to_str() { - for justfile_name in search::JUSTFILE_NAMES { - if name.eq_ignore_ascii_case(justfile_name) { - candidates.push(format!("{module}/{name}")); + for (candidate, case_sensitive) in &candidates { + let candidate_name = candidate.file_name().unwrap().to_str().unwrap(); + + let eq = if *case_sensitive { + name == candidate_name + } else { + name.eq_ignore_ascii_case(candidate_name) + }; + + if eq { + found.push(candidate.parent().unwrap().join(name)); } } } } } - match candidates.as_slice() { - [] => Ok(None), - [file] => Ok(Some(parent.join(file).lexiclean())), - found => Err(Error::AmbiguousModuleFile { - found: found.into(), + if found.len() > 1 { + found.sort(); + Err(Error::AmbiguousModuleFile { + found: found + .into_iter() + .map(|found| found.strip_prefix(parent).unwrap().into()) + .collect(), module, - }), + }) + } else { + Ok(found.into_iter().next()) } } @@ -242,4 +293,84 @@ recipe_b: recipe_c import == tmp.path().join("justfile").lexiclean() ); } + + #[test] + fn find_module_file() { + #[track_caller] + fn case(path: Option<&str>, files: &[&str], expected: Result, &[&str]>) { + let module = Name { + token: Token { + column: 0, + kind: TokenKind::Identifier, + length: 3, + line: 0, + offset: 0, + path: Path::new(""), + src: "foo", + }, + }; + + let tempdir = tempfile::tempdir().unwrap(); + + for file in files { + if let Some(parent) = Path::new(file).parent() { + fs::create_dir_all(tempdir.path().join(parent)).unwrap(); + } + + fs::write(tempdir.path().join(file), "").unwrap(); + } + + let actual = Compiler::find_module_file(tempdir.path(), module, path.map(Path::new)); + + match expected { + Err(expected) => match actual.unwrap_err() { + Error::AmbiguousModuleFile { found, .. } => { + assert_eq!( + found, + expected + .iter() + .map(|expected| expected.replace('/', std::path::MAIN_SEPARATOR_STR).into()) + .collect::>() + ); + } + _ => panic!("unexpected error"), + }, + Ok(Some(expected)) => assert_eq!( + actual.unwrap().unwrap(), + tempdir + .path() + .join(expected.replace('/', std::path::MAIN_SEPARATOR_STR)) + ), + Ok(None) => assert_eq!(actual.unwrap(), None), + } + } + + case(None, &["foo.just"], Ok(Some("foo.just"))); + case(None, &["FOO.just"], Ok(None)); + case(None, &["foo/mod.just"], Ok(Some("foo/mod.just"))); + case(None, &["foo/MOD.just"], Ok(None)); + case(None, &["foo/justfile"], Ok(Some("foo/justfile"))); + case(None, &["foo/JUSTFILE"], Ok(Some("foo/JUSTFILE"))); + case(None, &["foo/.justfile"], Ok(Some("foo/.justfile"))); + case(None, &["foo/.JUSTFILE"], Ok(Some("foo/.JUSTFILE"))); + case( + None, + &["foo/.justfile", "foo/justfile"], + Err(&["foo/.justfile", "foo/justfile"]), + ); + case(None, &["foo/JUSTFILE"], Ok(Some("foo/JUSTFILE"))); + + case(Some("bar"), &["bar"], Ok(Some("bar"))); + case(Some("bar"), &["bar/mod.just"], Ok(Some("bar/mod.just"))); + case(Some("bar"), &["bar/justfile"], Ok(Some("bar/justfile"))); + case(Some("bar"), &["bar/JUSTFILE"], Ok(Some("bar/JUSTFILE"))); + case(Some("bar"), &["bar/.justfile"], Ok(Some("bar/.justfile"))); + case(Some("bar"), &["bar/.JUSTFILE"], Ok(Some("bar/.JUSTFILE"))); + + case( + Some("bar"), + &["bar/justfile", "bar/mod.just"], + Err(&["bar/justfile", "bar/mod.just"]), + ); + } } diff --git a/src/error.rs b/src/error.rs index d2e8704..563f21b 100644 --- a/src/error.rs +++ b/src/error.rs @@ -4,7 +4,7 @@ use super::*; pub(crate) enum Error<'src> { AmbiguousModuleFile { module: Name<'src>, - found: Vec, + found: Vec, }, ArgumentCountMismatch { recipe: &'src str, @@ -262,7 +262,7 @@ impl<'src> ColorDisplay for Error<'src> { AmbiguousModuleFile { module, found } => write!(f, "Found multiple source files for module `{module}`: {}", - List::and_ticked(found), + List::and_ticked(found.iter().map(|path| path.display())), )?, ArgumentCountMismatch { recipe, found, min, max, .. } => { let count = Count("argument", *found); diff --git a/tests/lib.rs b/tests/lib.rs index 932a6c3..072574b 100644 --- a/tests/lib.rs +++ b/tests/lib.rs @@ -19,7 +19,7 @@ pub(crate) use { fs, io::Write, iter, - path::{Path, PathBuf, MAIN_SEPARATOR}, + path::{Path, PathBuf, MAIN_SEPARATOR, MAIN_SEPARATOR_STR}, process::{Command, Stdio}, str, }, diff --git a/tests/modules.rs b/tests/modules.rs index 1783942..46a06bc 100644 --- a/tests/modules.rs +++ b/tests/modules.rs @@ -439,12 +439,13 @@ fn modules_require_unambiguous_file() { .status(EXIT_FAILURE) .stderr( " - error: Found multiple source files for module `foo`: `foo.just` and `foo/justfile` + error: Found multiple source files for module `foo`: `foo/justfile` and `foo.just` ——▶ justfile:1:5 │ 1 │ mod foo │ ^^^ - ", + " + .replace('/', MAIN_SEPARATOR_STR), ) .run(); } @@ -564,6 +565,23 @@ fn modules_may_specify_path() { .run(); } +#[test] +fn modules_may_specify_path_to_directory() { + Test::new() + .write("commands/bar/mod.just", "foo:\n @echo FOO") + .justfile( + " + mod foo 'commands/bar' + ", + ) + .test_round_trip(false) + .arg("--unstable") + .arg("foo") + .arg("foo") + .stdout("FOO\n") + .run(); +} + #[test] fn modules_with_paths_are_dumped_correctly() { Test::new()