diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..d406517 --- /dev/null +++ b/Makefile @@ -0,0 +1,15 @@ +.PHONY: build test unit integration clean + +build: + go build -o shout . + +test: unit integration + +unit: + go test ./... + +integration: build + ./shout test test/ + +clean: + rm -f shout diff --git a/cmd.go b/cmd.go new file mode 100644 index 0000000..6cc9f7f --- /dev/null +++ b/cmd.go @@ -0,0 +1,361 @@ +package main + +import ( + "fmt" + "io/fs" + "os" + "path/filepath" + "sort" + "strconv" + "strings" + "sync" + "time" + + "github.com/spf13/cobra" +) + +var version = "dev" + +var rootCmd = &cobra.Command{ + Use: "shout", + Short: "$ shell output tester", +} + +func init() { + rootCmd.AddCommand(testCmd()) + rootCmd.AddCommand(versionCmd()) + rootCmd.AddCommand(exampleCmd()) +} + +func testCmd() *cobra.Command { + var ( + update bool + keep bool + cleanEnv bool + pathDirs []string + timeout string + verbose bool + portFrom int + parallel bool + ) + + cmd := &cobra.Command{ + Use: "test [files...]", + Short: "Run .shout test files", + Args: cobra.ArbitraryArgs, + RunE: func(cmd *cobra.Command, args []string) error { + timeoutDur, err := parseDuration(timeout) + if err != nil { + return err + } + + paths := args + if len(paths) == 0 { + paths = []string{"."} + } + + files, err := findShoutFiles(paths) + if err != nil { + return err + } + if len(files) == 0 { + fmt.Fprintln(os.Stderr, "No .shout files found") + os.Exit(1) + } + + start := time.Now() + var results []TestResult + cwd, _ := os.Getwd() + nextPort := portFrom + + runOne := func(filePath string, port int) TestResult { + content, err := os.ReadFile(filePath) + if err != nil { + relPath, _ := filepath.Rel(cwd, filePath) + return TestResult{Path: relPath, Error: err.Error()} + } + + relPath, _ := filepath.Rel(cwd, filePath) + parsed, err := parse(relPath, string(content)) + if err != nil { + return TestResult{Path: relPath, Error: err.Error()} + } + + // Resolve directives + envVars := make(map[string]string) + setupEnvVars := make(map[string]string) + userEnvVars := make(map[string]string) + var setupCommands []Command + + for _, d := range parsed.Directives { + switch d.Type { + case "setup": + setupPath := filepath.Join(filepath.Dir(filePath), d.Path) + setupContent, err := os.ReadFile(setupPath) + if err != nil { + return TestResult{Path: parsed.Path, Error: err.Error()} + } + setupRelPath, _ := filepath.Rel(cwd, setupPath) + setupParsed, err := parse(setupRelPath, string(setupContent)) + if err != nil { + return TestResult{Path: parsed.Path, Error: err.Error()} + } + for _, sd := range setupParsed.Directives { + if sd.Type == "setup" { + return TestResult{ + Path: parsed.Path, + Error: fmt.Sprintf("%s: @setup not allowed in setup files", setupRelPath), + } + } + if sd.Type == "env" { + setupEnvVars[sd.Key] = sd.Value + } + } + setupCommands = append(setupCommands, setupParsed.Commands...) + + case "env": + userEnvVars[d.Key] = d.Value + } + } + + // Setup env < user env + for k, v := range setupEnvVars { + envVars[k] = v + } + for k, v := range userEnvVars { + envVars[k] = v + } + if port > 0 { + if _, ok := userEnvVars["PORT"]; !ok { + if _, ok := setupEnvVars["PORT"]; !ok { + envVars["PORT"] = strconv.Itoa(port) + } + } + } + + merged := ShoutFile{ + Path: parsed.Path, + Commands: append(setupCommands, parsed.Commands...), + Directives: parsed.Directives, + } + + var onCommand func(Command) + if verbose { + onCommand = func(c Command) { + fmt.Fprintf(os.Stderr, dim(" $ %s\n"), c.Cmd) + } + } + + fileResult := runFile(merged, RunOptions{ + CleanEnv: cleanEnv, + PathDirs: pathDirs, + EnvVars: envVars, + Timeout: timeoutDur, + Verbose: verbose, + OnCommand: onCommand, + }) + + // Check setup commands for failures + for i := 0; i < len(setupCommands) && i < len(fileResult.Results); i++ { + r := fileResult.Results[i] + sc := setupCommands[i] + ok := false + switch sc.ExitCodeType { + case ExitCodeNone: + ok = r.ExitCode == 0 + case ExitCodeWildcard: + ok = r.ExitCode != 0 + case ExitCodeExact: + ok = r.ExitCode == sc.ExitCodeValue + } + if !ok { + if keep { + fmt.Fprintln(os.Stderr, fileResult.TmpDir) + } else { + cleanupTmpDir(fileResult.TmpDir) + } + return evaluateFile( + parsed.Path, + nil, + fmt.Sprintf("setup command failed (exit %d): $ %s", r.ExitCode, sc.Cmd), + ) + } + } + + fileOwnResults := fileResult.Results + if len(setupCommands) > 0 && len(fileResult.Results) >= len(setupCommands) { + fileOwnResults = fileResult.Results[len(setupCommands):] + } + + testResult := evaluateFile(parsed.Path, fileOwnResults, fileResult.Error) + + if update && len(fileOwnResults) > 0 { + updated := rewriteFile(parsed, fileOwnResults, string(content)) + if updated != string(content) { + _ = os.WriteFile(filePath, []byte(updated), 0o644) + } + } + + if keep { + fmt.Fprintln(os.Stderr, fileResult.TmpDir) + } else { + cleanupTmpDir(fileResult.TmpDir) + } + + return testResult + } + + printDots := func(r TestResult) { + if r.Error != "" { + fmt.Print(red("F")) + return + } + passed := r.CommandCount - len(r.Failures) + for i := 0; i < passed; i++ { + fmt.Print(green(".")) + } + for i := 0; i < len(r.Failures); i++ { + fmt.Print(red("F")) + } + } + + if parallel { + allResults := make([]TestResult, len(files)) + var wg sync.WaitGroup + for idx, f := range files { + wg.Add(1) + go func(i int, filePath string, port int) { + defer wg.Done() + allResults[i] = runOne(filePath, port) + }(idx, f, nextPort) + if nextPort > 0 { + nextPort++ + } + } + wg.Wait() + for _, r := range allResults { + printDots(r) + results = append(results, r) + } + fmt.Println() + } else { + for _, filePath := range files { + port := 0 + if nextPort > 0 { + port = nextPort + nextPort++ + } + r := runOne(filePath, port) + printDots(r) + results = append(results, r) + } + fmt.Println() + } + + // Print failures + var failures []TestResult + for _, r := range results { + if !r.Passed { + failures = append(failures, r) + } + } + if len(failures) > 0 { + fmt.Println() + for _, f := range failures { + fmt.Println(formatFailure(f)) + fmt.Println() + } + } + + elapsed := time.Since(start) + fmt.Println(formatSummary(results, elapsed)) + + if len(failures) > 0 { + os.Exit(1) + } + return nil + }, + } + + cmd.Flags().BoolVarP(&update, "update", "u", false, "Rewrite expected output in-place with actual output") + cmd.Flags().BoolVarP(&keep, "keep", "k", false, "Keep temp directories after run") + cmd.Flags().BoolVar(&cleanEnv, "clean-env", false, "Start with empty environment") + cmd.Flags().StringArrayVar(&pathDirs, "path", nil, "Prepend to PATH (repeatable)") + cmd.Flags().StringVar(&timeout, "timeout", "10s", "Per-command timeout") + cmd.Flags().BoolVarP(&verbose, "verbose", "v", false, "Print each command as it runs") + cmd.Flags().IntVar(&portFrom, "port-from", 0, "Auto-assign $PORT starting from ") + cmd.Flags().BoolVar(¶llel, "parallel", false, "Run files in parallel") + + return cmd +} + +func versionCmd() *cobra.Command { + return &cobra.Command{ + Use: "version", + Short: "Print the version", + Run: func(cmd *cobra.Command, args []string) { + fmt.Println(version) + }, + } +} + +func exampleCmd() *cobra.Command { + return &cobra.Command{ + Use: "example", + Short: "Print an example .shout file", + Run: func(cmd *cobra.Command, args []string) { + fmt.Print(`# Example .shout file +$ echo hello +hello + +$ echo "one"; echo "two"; echo "three" +one +... +three + +$ cat nonexistent +cat: nonexistent: ... +[1] + +$ true +[0] +`) + }, + } +} + +func findShoutFiles(paths []string) ([]string, error) { + var files []string + + for _, p := range paths { + abs, err := filepath.Abs(p) + if err != nil { + continue + } + + info, err := os.Stat(abs) + if err != nil { + if strings.HasSuffix(abs, ".shout") { + files = append(files, abs) + } + continue + } + + if info.IsDir() { + _ = filepath.WalkDir(abs, func(path string, d fs.DirEntry, err error) error { + if err != nil { + return nil + } + if !d.IsDir() && strings.HasSuffix(path, ".shout") { + files = append(files, path) + } + return nil + }) + } else if strings.HasSuffix(abs, ".shout") { + files = append(files, abs) + } + } + + sort.Strings(files) + return files, nil +} diff --git a/color.go b/color.go new file mode 100644 index 0000000..806422a --- /dev/null +++ b/color.go @@ -0,0 +1,40 @@ +package main + +import "os" + +var colorEnabled = detectColor() + +func detectColor() bool { + if os.Getenv("NO_COLOR") != "" { + return false + } + if os.Getenv("TERM") == "dumb" { + return false + } + fi, err := os.Stdout.Stat() + if err != nil { + return false + } + return fi.Mode()&os.ModeCharDevice != 0 +} + +func red(s string) string { + if !colorEnabled { + return s + } + return "\033[31m" + s + "\033[0m" +} + +func green(s string) string { + if !colorEnabled { + return s + } + return "\033[32m" + s + "\033[0m" +} + +func dim(s string) string { + if !colorEnabled { + return s + } + return "\033[2m" + s + "\033[0m" +} diff --git a/duration.go b/duration.go new file mode 100644 index 0000000..9afbaf5 --- /dev/null +++ b/duration.go @@ -0,0 +1,29 @@ +package main + +import ( + "fmt" + "regexp" + "strconv" + "time" +) + +var durationRe = regexp.MustCompile(`^(\d+(?:\.\d+)?)(ms|s|m)$`) + +func parseDuration(s string) (time.Duration, error) { + m := durationRe.FindStringSubmatch(s) + if m == nil { + return 0, fmt.Errorf("invalid duration: %s", s) + } + + value, _ := strconv.ParseFloat(m[1], 64) + switch m[2] { + case "ms": + return time.Duration(value * float64(time.Millisecond)), nil + case "s": + return time.Duration(value * float64(time.Second)), nil + case "m": + return time.Duration(value * float64(time.Minute)), nil + default: + return 0, fmt.Errorf("unknown unit: %s", m[2]) + } +} diff --git a/duration_test.go b/duration_test.go new file mode 100644 index 0000000..914e830 --- /dev/null +++ b/duration_test.go @@ -0,0 +1,60 @@ +package main + +import ( + "testing" + "time" +) + +func TestParseDurationMs(t *testing.T) { + d, err := parseDuration("500ms") + if err != nil { + t.Fatal(err) + } + if d != 500*time.Millisecond { + t.Errorf("got %v, want 500ms", d) + } +} + +func TestParseDurationSeconds(t *testing.T) { + d, err := parseDuration("10s") + if err != nil { + t.Fatal(err) + } + if d != 10*time.Second { + t.Errorf("got %v, want 10s", d) + } +} + +func TestParseDurationDecimal(t *testing.T) { + d, err := parseDuration("1.5s") + if err != nil { + t.Fatal(err) + } + if d != 1500*time.Millisecond { + t.Errorf("got %v, want 1.5s", d) + } +} + +func TestParseDurationMinutes(t *testing.T) { + d, err := parseDuration("1m") + if err != nil { + t.Fatal(err) + } + if d != time.Minute { + t.Errorf("got %v, want 1m", d) + } +} + +func TestParseDurationInvalid(t *testing.T) { + _, err := parseDuration("abc") + if err == nil { + t.Error("expected error for invalid duration") + } +} + +func TestParseDurationNoUnit(t *testing.T) { + _, err := parseDuration("10") + if err == nil { + t.Error("expected error for missing unit") + } +} diff --git a/format.go b/format.go new file mode 100644 index 0000000..4c984c1 --- /dev/null +++ b/format.go @@ -0,0 +1,139 @@ +package main + +import ( + "fmt" + "strings" + "time" +) + +func evaluateFile(path string, results []CommandResult, fileError string) TestResult { + if fileError != "" { + return TestResult{ + Path: path, + Passed: false, + CommandCount: len(results), + Error: fileError, + } + } + + var failures []FailedCommand + + for _, r := range results { + cmd := r.Command + outputMatches := matchOutput(cmd.Expected, r.Actual) + + var exitCodeMismatch bool + switch cmd.ExitCodeType { + case ExitCodeNone: + exitCodeMismatch = r.ExitCode != 0 + case ExitCodeWildcard: + exitCodeMismatch = r.ExitCode == 0 + case ExitCodeExact: + exitCodeMismatch = r.ExitCode != cmd.ExitCodeValue + } + + if !outputMatches || exitCodeMismatch { + var diffLines []DiffLine + if !outputMatches { + diffLines = diff(cmd.Expected, r.Actual) + } + failures = append(failures, FailedCommand{ + Result: r, + DiffLines: diffLines, + ExitCodeMismatch: exitCodeMismatch, + }) + } + } + + return TestResult{ + Path: path, + Passed: len(failures) == 0, + CommandCount: len(results), + Failures: failures, + } +} + +func formatFailure(t TestResult) string { + var lines []string + + lines = append(lines, red("FAIL "+t.Path)) + + if t.Error != "" { + lines = append(lines, " "+red(t.Error)) + return strings.Join(lines, "\n") + } + + for _, f := range t.Failures { + lines = append(lines, "") + lines = append(lines, " "+dim("$")+" "+f.Result.Command.Cmd) + + if len(f.DiffLines) > 0 { + lines = append(lines, red(" expected:")) + for _, dl := range f.DiffLines { + switch dl.Kind { + case "expected": + lines = append(lines, red(" > ")+dl.Text) + case "equal": + lines = append(lines, " "+dl.Text) + case "context": + lines = append(lines, " "+dim(dl.Text)) + } + } + lines = append(lines, green(" actual:")) + for _, dl := range f.DiffLines { + switch dl.Kind { + case "actual": + lines = append(lines, green(" > ")+dl.Text) + case "equal": + lines = append(lines, " "+dl.Text) + case "context": + lines = append(lines, " "+dim(dl.Text)) + } + } + } + + if f.ExitCodeMismatch { + cmd := f.Result.Command + var expectedStr string + switch cmd.ExitCodeType { + case ExitCodeNone: + expectedStr = "0" + case ExitCodeWildcard: + expectedStr = "non-zero" + case ExitCodeExact: + expectedStr = fmt.Sprintf("%d", cmd.ExitCodeValue) + } + lines = append(lines, red(fmt.Sprintf(" expected exit code: %s", expectedStr))) + lines = append(lines, green(fmt.Sprintf(" actual exit code: %d", f.Result.ExitCode))) + } + } + + return strings.Join(lines, "\n") +} + +func formatSummary(results []TestResult, elapsed time.Duration) string { + totalCommands := 0 + failedCommands := 0 + for _, r := range results { + totalCommands += r.CommandCount + failedCommands += len(r.Failures) + } + passedCommands := totalCommands - failedCommands + + var parts []string + if passedCommands > 0 { + parts = append(parts, green(fmt.Sprintf("%d passed", passedCommands))) + } + if failedCommands > 0 { + parts = append(parts, red(fmt.Sprintf("%d failed", failedCommands))) + } + + var timeStr string + if elapsed < time.Second { + timeStr = fmt.Sprintf("%dms", elapsed.Milliseconds()) + } else { + timeStr = fmt.Sprintf("%.1fs", elapsed.Seconds()) + } + + return fmt.Sprintf("%s in %s", strings.Join(parts, ", "), timeStr) +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..437126a --- /dev/null +++ b/go.mod @@ -0,0 +1,10 @@ +module shout + +go 1.24.1 + +require github.com/spf13/cobra v1.10.2 + +require ( + github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/spf13/pflag v1.0.9 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..a6ee3e0 --- /dev/null +++ b/go.sum @@ -0,0 +1,10 @@ +github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= +github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= +github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU= +github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4= +github.com/spf13/pflag v1.0.9 h1:9exaQaMOCwffKiiiYk6/BndUBv+iRViNW+4lEMi0PvY= +github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/main.go b/main.go new file mode 100644 index 0000000..a5a50d9 --- /dev/null +++ b/main.go @@ -0,0 +1,13 @@ +package main + +import ( + "fmt" + "os" +) + +func main() { + if err := rootCmd.Execute(); err != nil { + fmt.Fprintln(os.Stderr, err) + os.Exit(1) + } +} diff --git a/match.go b/match.go new file mode 100644 index 0000000..e00d996 --- /dev/null +++ b/match.go @@ -0,0 +1,103 @@ +package main + +import ( + "regexp" + "strings" +) + +func matchLine(pattern, actual string) bool { + if !strings.Contains(pattern, "...") { + return pattern == actual + } + + parts := strings.Split(pattern, "...") + escaped := make([]string, len(parts)) + for i, p := range parts { + escaped[i] = regexp.QuoteMeta(p) + } + re := regexp.MustCompile("^" + strings.Join(escaped, ".*") + "$") + return re.MatchString(actual) +} + +func matchOutput(expected, actual []string) bool { + return doMatch(expected, 0, actual, 0) +} + +func doMatch(expected []string, ei int, actual []string, ai int) bool { + if ei == len(expected) && ai == len(actual) { + return true + } + if ei == len(expected) { + return false + } + + exp := expected[ei] + + // Multi-line wildcard + if exp == "..." { + for skip := ai; skip <= len(actual); skip++ { + if doMatch(expected, ei+1, actual, skip) { + return true + } + } + return false + } + + if ai == len(actual) { + return false + } + + if matchLine(exp, actual[ai]) { + return doMatch(expected, ei+1, actual, ai+1) + } + + return false +} + +func diff(expected, actual []string) []DiffLine { + var result []DiffLine + ei, ai := 0, 0 + + for ei < len(expected) || ai < len(actual) { + if ei < len(expected) && expected[ei] == "..." { + nextExp := "" + hasNext := ei+1 < len(expected) + if hasNext { + nextExp = expected[ei+1] + } + + if !hasNext { + result = append(result, DiffLine{Kind: "context", Text: "..."}) + break + } + + result = append(result, DiffLine{Kind: "context", Text: "..."}) + ei++ + for ai < len(actual) && !matchLine(nextExp, actual[ai]) { + ai++ + } + continue + } + + if ei < len(expected) && ai < len(actual) { + if matchLine(expected[ei], actual[ai]) { + result = append(result, DiffLine{Kind: "equal", Text: actual[ai]}) + ei++ + ai++ + } else { + result = append(result, DiffLine{Kind: "expected", Text: expected[ei]}) + result = append(result, DiffLine{Kind: "actual", Text: actual[ai]}) + ei++ + ai++ + } + } else if ei < len(expected) { + result = append(result, DiffLine{Kind: "expected", Text: expected[ei]}) + ei++ + } else { + result = append(result, DiffLine{Kind: "actual", Text: actual[ai]}) + ai++ + } + } + + return result +} diff --git a/match_test.go b/match_test.go new file mode 100644 index 0000000..26ae213 --- /dev/null +++ b/match_test.go @@ -0,0 +1,127 @@ +package main + +import "testing" + +func TestMatchLineExact(t *testing.T) { + if !matchLine("hello", "hello") { + t.Error("exact match should pass") + } + if matchLine("hello", "world") { + t.Error("exact mismatch should fail") + } +} + +func TestMatchLineInlineWildcard(t *testing.T) { + if !matchLine("Homebrew 5...", "Homebrew 5.1.0") { + t.Error("trailing wildcard should match") + } + if !matchLine("...world", "hello world") { + t.Error("leading wildcard should match") + } + if !matchLine("a...b...c", "aXXbYYc") { + t.Error("multiple wildcards should match") + } + if matchLine("a...c", "aXXd") { + t.Error("should not match when suffix differs") + } +} + +func TestMatchLinePreservesLiteralDots(t *testing.T) { + if !matchLine("match ...", "match ...") { + t.Error("literal ... should match itself") + } +} + +func TestMatchOutputEmpty(t *testing.T) { + if !matchOutput(nil, nil) { + t.Error("both empty should match") + } + if !matchOutput([]string{}, []string{}) { + t.Error("both empty slices should match") + } +} + +func TestMatchOutputExact(t *testing.T) { + if !matchOutput([]string{"hello", "world"}, []string{"hello", "world"}) { + t.Error("exact match should pass") + } + if matchOutput([]string{"hello"}, []string{"world"}) { + t.Error("mismatch should fail") + } +} + +func TestMatchOutputExtraActual(t *testing.T) { + if matchOutput([]string{"hello"}, []string{"hello", "world"}) { + t.Error("extra actual lines should fail") + } +} + +func TestMatchOutputExtraExpected(t *testing.T) { + if matchOutput([]string{"hello", "world"}, []string{"hello"}) { + t.Error("extra expected lines should fail") + } +} + +func TestMatchOutputMultilineWildcard(t *testing.T) { + if !matchOutput([]string{"first", "...", "last"}, []string{"first", "a", "b", "c", "last"}) { + t.Error("multiline wildcard should match multiple lines") + } + if !matchOutput([]string{"first", "...", "last"}, []string{"first", "last"}) { + t.Error("multiline wildcard should match zero lines") + } + if !matchOutput([]string{"..."}, []string{"a", "b", "c"}) { + t.Error("standalone wildcard should match anything") + } + if !matchOutput([]string{"..."}, nil) { + t.Error("standalone wildcard should match empty") + } +} + +func TestMatchOutputMixed(t *testing.T) { + expected := []string{"one", "...", "three"} + actual := []string{"one", "two", "three"} + if !matchOutput(expected, actual) { + t.Error("mixed wildcard should match") + } +} + +func TestMatchOutputWildcardAtEnd(t *testing.T) { + if !matchOutput([]string{"start", "..."}, []string{"start", "a", "b"}) { + t.Error("trailing wildcard should match remaining lines") + } +} + +func TestMatchOutputInlineAndMultiline(t *testing.T) { + expected := []string{"Homebrew 5...", "..."} + actual := []string{"Homebrew 5.1.0", "extra line"} + if !matchOutput(expected, actual) { + t.Error("inline + multiline wildcards should work together") + } +} + +func TestDiffBasic(t *testing.T) { + d := diff([]string{"hello"}, []string{"world"}) + if len(d) != 2 { + t.Fatalf("expected 2 diff lines, got %d", len(d)) + } + if d[0].Kind != "expected" || d[0].Text != "hello" { + t.Errorf("d[0] = %+v", d[0]) + } + if d[1].Kind != "actual" || d[1].Text != "world" { + t.Errorf("d[1] = %+v", d[1]) + } +} + +func TestDiffEqual(t *testing.T) { + d := diff([]string{"hello"}, []string{"hello"}) + if len(d) != 1 || d[0].Kind != "equal" { + t.Errorf("expected equal, got %+v", d) + } +} + +func TestDiffContext(t *testing.T) { + d := diff([]string{"..."}, []string{"a", "b"}) + if len(d) != 1 || d[0].Kind != "context" { + t.Errorf("expected context, got %+v", d) + } +} diff --git a/parse.go b/parse.go new file mode 100644 index 0000000..623211b --- /dev/null +++ b/parse.go @@ -0,0 +1,128 @@ +package main + +import ( + "fmt" + "strconv" + "strings" +) + +func stripComment(line string) string { + inSingle := false + inDouble := false + for i := 0; i < len(line); i++ { + ch := line[i] + switch { + case ch == '\'' && !inDouble: + inSingle = !inSingle + case ch == '"' && !inSingle: + inDouble = !inDouble + case ch == '#' && !inSingle && !inDouble: + return strings.TrimRight(line[:i], " \t") + } + } + return line +} + +func parseExitCode(lines []string) ([]string, ExitCodeType, int) { + if len(lines) == 0 { + return lines, ExitCodeNone, 0 + } + + last := lines[len(lines)-1] + if len(last) < 3 || last[0] != '[' || last[len(last)-1] != ']' { + return lines, ExitCodeNone, 0 + } + + inner := last[1 : len(last)-1] + if inner == "*" { + return lines[:len(lines)-1], ExitCodeWildcard, 0 + } + + code, err := strconv.Atoi(inner) + if err != nil { + return lines, ExitCodeNone, 0 + } + return lines[:len(lines)-1], ExitCodeExact, code +} + +func trimTrailingEmpty(lines []string) []string { + end := len(lines) + for end > 0 && lines[end-1] == "" { + end-- + } + return lines[:end] +} + +func parse(path, content string) (ShoutFile, error) { + rawLines := strings.Split(content, "\n") + + // Remove trailing newline + if len(rawLines) > 0 && rawLines[len(rawLines)-1] == "" { + rawLines = rawLines[:len(rawLines)-1] + } + + var commands []Command + var directives []Directive + var current *Command + seenCommand := false + + for i, line := range rawLines { + lineNum := i + 1 + + if !seenCommand && strings.HasPrefix(line, "@") { + if strings.HasPrefix(line, "@setup ") { + setupPath := strings.TrimSpace(line[7:]) + if setupPath == "" { + return ShoutFile{}, fmt.Errorf("%s:%d: @setup requires a file path", path, lineNum) + } + directives = append(directives, Directive{ + Type: "setup", Path: setupPath, Line: lineNum, + }) + } else if strings.HasPrefix(line, "@env ") { + rest := strings.TrimSpace(line[5:]) + eq := strings.Index(rest, "=") + if eq <= 0 { + return ShoutFile{}, fmt.Errorf("%s:%d: malformed @env directive (expected KEY=VALUE): %s", path, lineNum, line) + } + directives = append(directives, Directive{ + Type: "env", Key: rest[:eq], Value: rest[eq+1:], Line: lineNum, + }) + } else { + return ShoutFile{}, fmt.Errorf("%s:%d: unknown directive: %s", path, lineNum, line) + } + continue + } + + if strings.HasPrefix(line, "$ ") { + seenCommand = true + if current != nil { + trimmed := trimTrailingEmpty(current.Expected) + remaining, ecType, ecVal := parseExitCode(trimmed) + current.Expected = trimTrailingEmpty(remaining) + current.ExitCodeType = ecType + current.ExitCodeValue = ecVal + commands = append(commands, *current) + } + + current = &Command{ + Line: lineNum, + Raw: line, + Cmd: stripComment(line[2:]), + Expected: nil, + } + } else if current != nil { + current.Expected = append(current.Expected, line) + } + } + + if current != nil { + trimmed := trimTrailingEmpty(current.Expected) + remaining, ecType, ecVal := parseExitCode(trimmed) + current.Expected = trimTrailingEmpty(remaining) + current.ExitCodeType = ecType + current.ExitCodeValue = ecVal + commands = append(commands, *current) + } + + return ShoutFile{Path: path, Commands: commands, Directives: directives}, nil +} diff --git a/parse_test.go b/parse_test.go new file mode 100644 index 0000000..6eb7ede --- /dev/null +++ b/parse_test.go @@ -0,0 +1,167 @@ +package main + +import ( + "testing" +) + +func TestParseBasicCommand(t *testing.T) { + sf, err := parse("test.shout", "$ echo hello\nhello\n") + if err != nil { + t.Fatal(err) + } + if len(sf.Commands) != 1 { + t.Fatalf("expected 1 command, got %d", len(sf.Commands)) + } + c := sf.Commands[0] + if c.Cmd != "echo hello" { + t.Errorf("cmd = %q, want %q", c.Cmd, "echo hello") + } + if len(c.Expected) != 1 || c.Expected[0] != "hello" { + t.Errorf("expected = %v, want [hello]", c.Expected) + } + if c.ExitCodeType != ExitCodeNone { + t.Errorf("exit code type = %d, want ExitCodeNone", c.ExitCodeType) + } +} + +func TestParseMultipleCommands(t *testing.T) { + sf, err := parse("test.shout", "$ echo one\none\n\n$ echo two\ntwo\n") + if err != nil { + t.Fatal(err) + } + if len(sf.Commands) != 2 { + t.Fatalf("expected 2 commands, got %d", len(sf.Commands)) + } + if sf.Commands[0].Cmd != "echo one" { + t.Errorf("cmd[0] = %q", sf.Commands[0].Cmd) + } + if sf.Commands[1].Cmd != "echo two" { + t.Errorf("cmd[1] = %q", sf.Commands[1].Cmd) + } +} + +func TestParseExitCode(t *testing.T) { + sf, err := parse("test.shout", "$ false\n[1]\n") + if err != nil { + t.Fatal(err) + } + c := sf.Commands[0] + if c.ExitCodeType != ExitCodeExact || c.ExitCodeValue != 1 { + t.Errorf("exit code = (%d, %d), want (Exact, 1)", c.ExitCodeType, c.ExitCodeValue) + } + if len(c.Expected) != 0 { + t.Errorf("expected = %v, want []", c.Expected) + } +} + +func TestParseExitCodeWildcard(t *testing.T) { + sf, err := parse("test.shout", "$ false\n[*]\n") + if err != nil { + t.Fatal(err) + } + c := sf.Commands[0] + if c.ExitCodeType != ExitCodeWildcard { + t.Errorf("exit code type = %d, want Wildcard", c.ExitCodeType) + } +} + +func TestParseExitCodeWithOutput(t *testing.T) { + sf, err := parse("test.shout", "$ sh -c \"echo oops && exit 1\"\noops\n[*]\n") + if err != nil { + t.Fatal(err) + } + c := sf.Commands[0] + if len(c.Expected) != 1 || c.Expected[0] != "oops" { + t.Errorf("expected = %v, want [oops]", c.Expected) + } + if c.ExitCodeType != ExitCodeWildcard { + t.Errorf("exit code type = %d, want Wildcard", c.ExitCodeType) + } +} + +func TestParseComment(t *testing.T) { + sf, err := parse("test.shout", "$ echo hello # this is a comment\nhello\n") + if err != nil { + t.Fatal(err) + } + if sf.Commands[0].Cmd != "echo hello" { + t.Errorf("cmd = %q, want %q", sf.Commands[0].Cmd, "echo hello") + } +} + +func TestParseCommentInQuotes(t *testing.T) { + sf, err := parse("test.shout", "$ echo \"keep # this\"\nkeep # this\n") + if err != nil { + t.Fatal(err) + } + if sf.Commands[0].Cmd != `echo "keep # this"` { + t.Errorf("cmd = %q", sf.Commands[0].Cmd) + } +} + +func TestParseEnvDirective(t *testing.T) { + sf, err := parse("test.shout", "@env GREETING=hello\n\n$ echo $GREETING\nhello\n") + if err != nil { + t.Fatal(err) + } + if len(sf.Directives) != 1 { + t.Fatalf("expected 1 directive, got %d", len(sf.Directives)) + } + d := sf.Directives[0] + if d.Type != "env" || d.Key != "GREETING" || d.Value != "hello" { + t.Errorf("directive = %+v", d) + } +} + +func TestParseSetupDirective(t *testing.T) { + sf, err := parse("test.shout", "@setup shared.shout\n\n$ echo hi\nhi\n") + if err != nil { + t.Fatal(err) + } + if len(sf.Directives) != 1 { + t.Fatalf("expected 1 directive, got %d", len(sf.Directives)) + } + d := sf.Directives[0] + if d.Type != "setup" || d.Path != "shared.shout" { + t.Errorf("directive = %+v", d) + } +} + +func TestParseNoExpectedOutput(t *testing.T) { + sf, err := parse("test.shout", "$ export FOO=bar\n$ echo $FOO\nbar\n") + if err != nil { + t.Fatal(err) + } + if len(sf.Commands) != 2 { + t.Fatalf("expected 2 commands, got %d", len(sf.Commands)) + } + if len(sf.Commands[0].Expected) != 0 { + t.Errorf("first command expected = %v, want []", sf.Commands[0].Expected) + } +} + +func TestParseTrimsTrailingBlanks(t *testing.T) { + sf, err := parse("test.shout", "$ echo hello\nhello\n\n\n$ echo world\nworld\n") + if err != nil { + t.Fatal(err) + } + if len(sf.Commands[0].Expected) != 1 || sf.Commands[0].Expected[0] != "hello" { + t.Errorf("expected = %v", sf.Commands[0].Expected) + } +} + +func TestParseLineNumbers(t *testing.T) { + sf, err := parse("test.shout", "@env X=1\n\n$ echo hello\nhello\n\n$ echo world\nworld\n") + if err != nil { + t.Fatal(err) + } + if sf.Directives[0].Line != 1 { + t.Errorf("directive line = %d, want 1", sf.Directives[0].Line) + } + if sf.Commands[0].Line != 3 { + t.Errorf("command[0] line = %d, want 3", sf.Commands[0].Line) + } + if sf.Commands[1].Line != 6 { + t.Errorf("command[1] line = %d, want 6", sf.Commands[1].Line) + } +} diff --git a/run.go b/run.go new file mode 100644 index 0000000..887c61c --- /dev/null +++ b/run.go @@ -0,0 +1,203 @@ +package main + +import ( + "io" + "os" + "os/exec" + "regexp" + "strconv" + "strings" + "syscall" + "time" +) + +const sentinelPrefix = "__SHOUT_SENTINEL_" + +type RunOptions struct { + CleanEnv bool + PathDirs []string + EnvVars map[string]string + Timeout time.Duration + Verbose bool + OnCommand func(Command) +} + +func buildScript(commands []Command) string { + var b strings.Builder + b.WriteString("exec 2>&1\n") + for i, cmd := range commands { + b.WriteString(cmd.Cmd) + b.WriteByte('\n') + b.WriteString("printf '\\n" + sentinelPrefix + "%s_" + strconv.Itoa(i) + "__\\n' \"$?\"\n") + } + return b.String() +} + +func parseSentinelOutput(raw string, commandCount int) (outputs [][]string, exitCodes []int) { + re := regexp.MustCompile(regexp.QuoteMeta(sentinelPrefix) + `(\d+)_(\d+)__`) + + remaining := raw + for i := 0; i < commandCount; i++ { + loc := re.FindStringSubmatchIndex(remaining) + if loc == nil { + lines := strings.Split(remaining, "\n") + if len(lines) > 0 && lines[0] == "" { + lines = lines[1:] + } + lines = trimTrailingEmpty(lines) + outputs = append(outputs, lines) + exitCodes = append(exitCodes, 1) + break + } + + before := remaining[:loc[0]] + exitCodeStr := remaining[loc[2]:loc[3]] + ec, _ := strconv.Atoi(exitCodeStr) + + lines := strings.Split(before, "\n") + if len(lines) > 0 && lines[0] == "" { + lines = lines[1:] + } + lines = trimTrailingEmpty(lines) + if len(lines) == 1 && lines[0] == "" { + lines = nil + } + + outputs = append(outputs, lines) + exitCodes = append(exitCodes, ec) + + afterSentinel := remaining[loc[1]:] + if strings.HasPrefix(afterSentinel, "\n") { + afterSentinel = afterSentinel[1:] + } + remaining = afterSentinel + } + + for len(outputs) < commandCount { + outputs = append(outputs, nil) + exitCodes = append(exitCodes, 1) + } + + return outputs, exitCodes +} + +func runFile(file ShoutFile, opts RunOptions) FileResult { + tmpDir, err := os.MkdirTemp("", "shout-") + if err != nil { + return FileResult{File: file, TmpDir: "", Error: err.Error()} + } + + if len(file.Commands) == 0 { + return FileResult{File: file, TmpDir: tmpDir} + } + + script := buildScript(file.Commands) + + // Build environment + var envMap map[string]string + if opts.CleanEnv { + envMap = make(map[string]string) + } else { + envMap = make(map[string]string) + for _, e := range os.Environ() { + if k, v, ok := strings.Cut(e, "="); ok { + envMap[k] = v + } + } + } + + envMap["HOME"] = tmpDir + envMap["SHOUT_DIR"] = tmpDir + + for k, v := range opts.EnvVars { + envMap[k] = v + } + + if len(opts.PathDirs) > 0 { + existing := envMap["PATH"] + envMap["PATH"] = strings.Join(opts.PathDirs, ":") + ":" + existing + } + + envSlice := make([]string, 0, len(envMap)) + for k, v := range envMap { + envSlice = append(envSlice, k+"="+v) + } + + cmd := exec.Command("/bin/sh") + cmd.Dir = tmpDir + cmd.Env = envSlice + cmd.Stderr = nil + cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} + + stdin, err := cmd.StdinPipe() + if err != nil { + return FileResult{File: file, TmpDir: tmpDir, Error: err.Error()} + } + + stdoutPipe, err := cmd.StdoutPipe() + if err != nil { + return FileResult{File: file, TmpDir: tmpDir, Error: err.Error()} + } + + if err := cmd.Start(); err != nil { + return FileResult{File: file, TmpDir: tmpDir, Error: err.Error()} + } + + defer func() { + if cmd.Process != nil { + _ = syscall.Kill(-cmd.Process.Pid, syscall.SIGKILL) + } + }() + + _, _ = io.WriteString(stdin, script) + stdin.Close() + + // Read stdout with timeout + totalTimeout := opts.Timeout * time.Duration(len(file.Commands)) + type readResult struct { + data []byte + err error + } + ch := make(chan readResult, 1) + go func() { + data, err := io.ReadAll(stdoutPipe) + ch <- readResult{data, err} + }() + + var output []byte + select { + case r := <-ch: + if r.err != nil { + return FileResult{File: file, TmpDir: tmpDir, Error: r.err.Error()} + } + output = r.data + case <-time.After(totalTimeout): + return FileResult{File: file, TmpDir: tmpDir, Error: "Timeout reading output"} + } + + _ = cmd.Wait() + + outputs, exitCodesList := parseSentinelOutput(string(output), len(file.Commands)) + + results := make([]CommandResult, len(file.Commands)) + for i, c := range file.Commands { + if opts.Verbose && opts.OnCommand != nil { + opts.OnCommand(c) + } + actual := outputs[i] + if actual == nil { + actual = []string{} + } + results[i] = CommandResult{ + Command: c, + Actual: actual, + ExitCode: exitCodesList[i], + } + } + + return FileResult{File: file, Results: results, TmpDir: tmpDir} +} + +func cleanupTmpDir(dir string) { + os.RemoveAll(dir) +} diff --git a/shout b/shout new file mode 100755 index 0000000..440a587 Binary files /dev/null and b/shout differ diff --git a/test/basic.shout b/test/basic.shout new file mode 100644 index 0000000..2d759ed --- /dev/null +++ b/test/basic.shout @@ -0,0 +1,9 @@ +$ echo hello +hello + +$ echo one && echo two +one +two + +$ echo "working directory: $(basename $PWD)" +working directory: ... diff --git a/test/comments.shout b/test/comments.shout new file mode 100644 index 0000000..54fe5eb --- /dev/null +++ b/test/comments.shout @@ -0,0 +1,5 @@ +$ echo hello # this is a comment +hello + +$ echo "keep # this" +keep # this diff --git a/test/env.shout b/test/env.shout new file mode 100644 index 0000000..245d02c --- /dev/null +++ b/test/env.shout @@ -0,0 +1,5 @@ +@env GREETING=hello +@env TARGET=world + +$ echo "$GREETING $TARGET" +hello world diff --git a/test/features.shout b/test/features.shout new file mode 100644 index 0000000..92eed91 --- /dev/null +++ b/test/features.shout @@ -0,0 +1,28 @@ +$ echo "test exit codes" +test exit codes + +$ false +[1] + +$ sh -c "exit 42" +[42] + +$ sh -c "echo oops && exit 1" +oops +[*] + +$ export MY_VAR=hello +$ echo $MY_VAR +hello + +$ cd /tmp +$ pwd +/tmp + +$ echo "line 1" && echo "" && echo "line 3" +line 1 + +line 3 + +$ echo "match ..." +match ... diff --git a/test/setup-shared.shout b/test/setup-shared.shout new file mode 100644 index 0000000..7ad3eba --- /dev/null +++ b/test/setup-shared.shout @@ -0,0 +1 @@ +$ export READY=yes diff --git a/test/setup-user.shout b/test/setup-user.shout new file mode 100644 index 0000000..023c5c2 --- /dev/null +++ b/test/setup-user.shout @@ -0,0 +1,4 @@ +@setup setup-shared.shout + +$ echo $READY +yes diff --git a/types.go b/types.go new file mode 100644 index 0000000..2e6f295 --- /dev/null +++ b/types.go @@ -0,0 +1,65 @@ +package main + +// ExitCodeType describes how to match a command's exit code. +type ExitCodeType int + +const ( + ExitCodeNone ExitCodeType = iota // not specified — expect 0 + ExitCodeExact // expect specific code + ExitCodeWildcard // [*] — expect any non-zero +) + +type Command struct { + Line int + Raw string + Cmd string + Expected []string + ExitCodeType ExitCodeType + ExitCodeValue int +} + +type Directive struct { + Type string // "setup" or "env" + Path string // for setup + Key string // for env + Value string // for env + Line int +} + +type ShoutFile struct { + Path string + Commands []Command + Directives []Directive +} + +type CommandResult struct { + Command Command + Actual []string + ExitCode int +} + +type FileResult struct { + File ShoutFile + Results []CommandResult + TmpDir string + Error string +} + +type DiffLine struct { + Kind string // "equal", "expected", "actual", "context" + Text string +} + +type TestResult struct { + Path string + Passed bool + CommandCount int + Failures []FailedCommand + Error string +} + +type FailedCommand struct { + Result CommandResult + DiffLines []DiffLine + ExitCodeMismatch bool +} diff --git a/update.go b/update.go new file mode 100644 index 0000000..aa812b0 --- /dev/null +++ b/update.go @@ -0,0 +1,79 @@ +package main + +import ( + "regexp" + "strings" +) + +var exitCodeMarkerRe = regexp.MustCompile(`^\[(\d+|\*)\]$`) + +func rewriteFile(file ShoutFile, results []CommandResult, originalContent string) string { + lines := strings.Split(originalContent, "\n") + var output []string + + cmdIdx := 0 + + for i := 0; i < len(lines); i++ { + line := lines[i] + + if strings.HasPrefix(line, "$ ") { + output = append(output, line) + + if cmdIdx >= len(file.Commands) || cmdIdx >= len(results) { + cmdIdx++ + continue + } + + cmd := file.Commands[cmdIdx] + result := results[cmdIdx] + + // Skip past old expected output lines + j := i + 1 + for j < len(lines) && !strings.HasPrefix(lines[j], "$ ") { + j++ + } + + oldExpectedRaw := lines[i+1 : j] + + // Check for exit code marker + oldTrimmed := trimTrailingEmpty(oldExpectedRaw) + var oldExitMarker string + if len(oldTrimmed) > 0 { + last := oldTrimmed[len(oldTrimmed)-1] + if exitCodeMarkerRe.MatchString(last) { + oldExitMarker = last + } + } + + // Count trailing blank lines + trailingBlanks := 0 + for k := len(oldExpectedRaw) - 1; k >= 0; k-- { + if oldExpectedRaw[k] == "" { + trailingBlanks++ + } else { + break + } + } + + // If wildcards match, keep original + if matchOutput(cmd.Expected, result.Actual) { + output = append(output, oldExpectedRaw...) + } else { + output = append(output, result.Actual...) + if oldExitMarker != "" { + output = append(output, oldExitMarker) + } + for k := 0; k < trailingBlanks; k++ { + output = append(output, "") + } + } + + i = j - 1 + cmdIdx++ + } else if cmdIdx == 0 { + output = append(output, line) + } + } + + return strings.Join(output, "\n") +}