go-shout/run.go

234 lines
5.3 KiB
Go

package main
import (
"crypto/rand"
"fmt"
"io"
"os"
"os/exec"
"regexp"
"strconv"
"strings"
"syscall"
"time"
)
func newSentinelPrefix() string {
b := make([]byte, 8)
_, _ = rand.Read(b)
return fmt.Sprintf("__SHOUT_%x_", b)
}
type RunOptions struct {
CleanEnv bool
PathDirs []string
EnvVars map[string]string
Timeout time.Duration
}
func buildScript(commands []Command, sentinelPrefix string) string {
var b strings.Builder
b.WriteString("exec 2>&1\n")
for i, cmd := range commands {
b.WriteString(cmd.Cmd)
b.WriteByte('\n')
b.WriteString("printf '\\n" + sentinelPrefix + "%s_" + strconv.Itoa(i) + "__\\n' \"$?\"\n")
}
return b.String()
}
func splitSentinelBlock(s string) []string {
lines := strings.Split(s, "\n")
if len(lines) > 0 && lines[0] == "" {
lines = lines[1:]
}
lines = trimTrailingEmpty(lines)
if len(lines) == 1 && lines[0] == "" {
lines = nil
}
return lines
}
func parseSentinelOutput(raw string, commandCount int, sentinelPrefix string) (outputs [][]string, exitCodes []int) {
re := regexp.MustCompile(regexp.QuoteMeta(sentinelPrefix) + `(\d+)_(\d+)__`)
remaining := raw
for i := 0; i < commandCount; i++ {
loc := re.FindStringSubmatchIndex(remaining)
if loc == nil {
// No more sentinels — assign remaining output, mark as failed
outputs = append(outputs, splitSentinelBlock(remaining))
exitCodes = append(exitCodes, 1)
remaining = ""
continue
}
exitCodeStr := remaining[loc[2]:loc[3]]
ec, _ := strconv.Atoi(exitCodeStr)
cmdIdxStr := remaining[loc[4]:loc[5]]
cmdIdx, _ := strconv.Atoi(cmdIdxStr)
before := remaining[:loc[0]]
// If sentinel belongs to a later command, this command's sentinel is missing
if cmdIdx > i {
outputs = append(outputs, splitSentinelBlock(before))
exitCodes = append(exitCodes, 1)
// Keep remaining from the sentinel onwards for the next iteration
remaining = remaining[loc[0]:]
continue
}
outputs = append(outputs, splitSentinelBlock(before))
exitCodes = append(exitCodes, ec)
afterSentinel := remaining[loc[1]:]
if strings.HasPrefix(afterSentinel, "\n") {
afterSentinel = afterSentinel[1:]
}
remaining = afterSentinel
}
for len(outputs) < commandCount {
outputs = append(outputs, nil)
exitCodes = append(exitCodes, 1)
}
return outputs, exitCodes
}
func runFile(file ShoutFile, opts RunOptions) FileResult {
tmpDir, err := os.MkdirTemp("", "shout-")
if err != nil {
return FileResult{File: file, TmpDir: "", Error: err.Error()}
}
if len(file.Commands) == 0 {
return FileResult{File: file, TmpDir: tmpDir}
}
sentinel := newSentinelPrefix()
script := buildScript(file.Commands, sentinel)
// Build environment
envMap := make(map[string]string)
if !opts.CleanEnv {
for _, e := range os.Environ() {
if k, v, ok := strings.Cut(e, "="); ok {
envMap[k] = v
}
}
}
envMap["HOME"] = tmpDir
envMap["SHOUT_DIR"] = tmpDir
for k, v := range opts.EnvVars {
envMap[k] = v
}
if len(opts.PathDirs) > 0 {
existing := envMap["PATH"]
prepend := strings.Join(opts.PathDirs, ":")
if existing != "" {
envMap["PATH"] = prepend + ":" + existing
} else {
envMap["PATH"] = prepend
}
}
envSlice := make([]string, 0, len(envMap))
for k, v := range envMap {
envSlice = append(envSlice, k+"="+v)
}
cmd := exec.Command("/bin/sh")
cmd.Dir = tmpDir
cmd.Env = envSlice
cmd.Stderr = nil
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
stdin, err := cmd.StdinPipe()
if err != nil {
return FileResult{File: file, TmpDir: tmpDir, Error: err.Error()}
}
stdoutPipe, err := cmd.StdoutPipe()
if err != nil {
return FileResult{File: file, TmpDir: tmpDir, Error: err.Error()}
}
if err := cmd.Start(); err != nil {
return FileResult{File: file, TmpDir: tmpDir, Error: err.Error()}
}
var waited bool
defer func() {
if !waited && cmd.Process != nil {
_ = syscall.Kill(-cmd.Process.Pid, syscall.SIGKILL)
}
}()
// Read stdout with timeout — must start reader BEFORE writing to stdin
// to avoid deadlock when pipe buffers fill in both directions.
totalTimeout := opts.Timeout * time.Duration(len(file.Commands))
type readResult struct {
data []byte
err error
}
ch := make(chan readResult, 1)
go func() {
data, err := io.ReadAll(stdoutPipe)
ch <- readResult{data, err}
}()
_, _ = io.WriteString(stdin, script)
stdin.Close()
var output []byte
select {
case r := <-ch:
if r.err != nil {
if cmd.Process != nil {
_ = syscall.Kill(-cmd.Process.Pid, syscall.SIGKILL)
}
_ = cmd.Wait()
waited = true
return FileResult{File: file, TmpDir: tmpDir, Error: r.err.Error()}
}
output = r.data
case <-time.After(totalTimeout):
if cmd.Process != nil {
_ = syscall.Kill(-cmd.Process.Pid, syscall.SIGKILL)
}
<-ch // drain the reader goroutine so pipe FDs are released
_ = cmd.Wait()
waited = true
return FileResult{File: file, TmpDir: tmpDir, Error: "Timeout reading output"}
}
_ = cmd.Wait()
waited = true
outputs, exitCodesList := parseSentinelOutput(string(output), len(file.Commands), sentinel)
results := make([]CommandResult, len(file.Commands))
for i, c := range file.Commands {
actual := outputs[i]
if actual == nil {
actual = []string{}
}
results[i] = CommandResult{
Command: c,
Actual: actual,
ExitCode: exitCodesList[i],
}
}
return FileResult{File: file, Results: results, TmpDir: tmpDir}
}
func cleanupTmpDir(dir string) {
os.RemoveAll(dir)
}