shout/src/run.rs
2026-04-02 15:18:22 -07:00

437 lines
14 KiB
Rust

use std::io::{Read, Write};
use std::process::{Command, Stdio};
use std::sync::mpsc;
use std::time::Duration;
use regex::Regex;
use crate::parse::{self, ShoutFile};
#[derive(Debug, Clone)]
pub struct CommandResult {
pub command: parse::Command,
pub actual: Vec<String>,
pub exit_code: i32,
}
#[derive(Debug)]
#[allow(dead_code)]
pub struct FileResult {
pub file: ShoutFile,
pub results: Vec<CommandResult>,
pub tmp_dir: String,
pub error: Option<String>,
}
pub struct RunOptions {
pub clean_env: bool,
pub path_dirs: Vec<String>,
pub env_vars: Vec<(String, String)>,
pub source_dir: Option<String>,
pub project_dir: Option<String>,
pub timeout_ms: u64,
pub verbose: bool,
}
const SENTINEL_PREFIX: &str = "__SHOUT_SENTINEL_";
const VERBOSE_MARKER: &str = "__SHOUT_CMD_";
fn build_script(commands: &[parse::Command], verbose: bool) -> String {
let mut lines = Vec::new();
if verbose {
lines.push("exec 3>&2 2>&1 9>&1".to_string());
} else {
lines.push("exec 2>&1 9>&1".to_string());
}
for (i, cmd) in commands.iter().enumerate() {
if verbose {
lines.push(format!("printf '{VERBOSE_MARKER}{i}\\n' >&3"));
}
lines.push("__shout_out=$(mktemp)".to_string());
lines.push("exec 1>\"$__shout_out\" 2>&1".to_string());
lines.push(cmd.command.clone());
lines.push("__shout_ec=$?".to_string());
lines.push("exec 1>&9 2>&1".to_string());
lines.push("cat \"$__shout_out\"".to_string());
lines.push("rm -f \"$__shout_out\"".to_string());
lines.push(format!(
"printf '\\n{SENTINEL_PREFIX}%s_{i}__\\n' \"$__shout_ec\""
));
}
lines.join("\n") + "\n"
}
fn strip_ansi(line: &str) -> String {
// Same regex as the TS version
let re = Regex::new(r"[\x1b\x9b][\[()#;?]*(?:[0-9]{1,4}(?:;[0-9]{0,4})*)?[0-9A-ORZcf-nqry=><]").unwrap();
re.replace_all(line, "").to_string()
}
fn parse_sentinel_output(raw: &str, command_count: usize) -> (Vec<Vec<String>>, Vec<i32>) {
let mut outputs = Vec::new();
let mut exit_codes = Vec::new();
let sentinel_re = Regex::new(&format!(r"{}(\d+)_(\d+)__", regex::escape(SENTINEL_PREFIX))).unwrap();
let mut remaining = raw;
for _i in 0..command_count {
if let Some(m) = sentinel_re.find(remaining) {
let caps = sentinel_re.captures(&remaining[m.start()..]).unwrap();
let exit_code: i32 = caps[1].parse().unwrap_or(1);
let before = &remaining[..m.start()];
let mut lines: Vec<String> = before.split('\n').map(|s| s.to_string()).collect();
// Remove leading empty line (from printf \n prefix)
if !lines.is_empty() && lines[0].is_empty() {
lines.remove(0);
}
// Remove trailing empty lines
lines = parse::trim_trailing_empty(&lines);
if lines.len() == 1 && lines[0].is_empty() {
lines.clear();
}
outputs.push(lines);
exit_codes.push(exit_code);
// Skip past sentinel
let after = &remaining[m.end()..];
remaining = if after.starts_with('\n') {
&after[1..]
} else {
after
};
} else {
// No sentinel found — rest is output for this command
let mut lines: Vec<String> = remaining.split('\n').map(|s| s.to_string()).collect();
if !lines.is_empty() && lines[0].is_empty() {
lines.remove(0);
}
lines = parse::trim_trailing_empty(&lines);
outputs.push(lines);
exit_codes.push(1); // assume failure
break;
}
}
// Fill missing entries
while outputs.len() < command_count {
outputs.push(vec![]);
exit_codes.push(1);
}
(outputs, exit_codes)
}
fn make_tmp_dir() -> std::io::Result<String> {
let base = std::env::temp_dir();
// Create a unique temp directory
loop {
let suffix: u64 = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_nanos() as u64;
let dir = base.join(format!("shout-{suffix}"));
if !dir.exists() {
std::fs::create_dir_all(&dir)?;
return Ok(dir.to_string_lossy().to_string());
}
}
}
fn kill_tree(pid: u32) {
// Find processes in the same process group
if let Ok(output) = Command::new("ps")
.args(["-eo", "pid,pgid"])
.output()
{
let text = String::from_utf8_lossy(&output.stdout);
let pgid = pid.to_string();
for line in text.lines() {
let parts: Vec<&str> = line.trim().split_whitespace().collect();
if parts.len() >= 2 && parts[1] == pgid {
if let Ok(p) = parts[0].parse::<i32>() {
if p as u32 != pid && p > 1 {
unsafe { libc::kill(p, libc::SIGKILL); }
}
}
}
}
}
// Kill the process group
unsafe { libc::kill(-(pid as i32), libc::SIGKILL); }
}
pub fn run_file(
file: &ShoutFile,
options: &RunOptions,
on_command: Option<&dyn Fn(&parse::Command)>,
on_command_result: Option<&dyn Fn(usize, &CommandResult)>,
) -> FileResult {
let tmp_dir = match make_tmp_dir() {
Ok(d) => d,
Err(e) => {
return FileResult {
file: file.clone(),
results: vec![],
tmp_dir: String::new(),
error: Some(format!("Failed to create temp dir: {e}")),
};
}
};
if file.commands.is_empty() {
return FileResult {
file: file.clone(),
results: vec![],
tmp_dir,
error: None,
};
}
let verbose = options.verbose && on_command.is_some();
let script = build_script(&file.commands, verbose);
let mut cmd = Command::new("/bin/sh");
cmd.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.current_dir(&tmp_dir);
// Set up process group (detached)
unsafe {
use std::os::unix::process::CommandExt;
cmd.pre_exec(|| {
libc::setpgid(0, 0);
Ok(())
});
}
// Environment
if options.clean_env {
cmd.env_clear();
}
cmd.env("HOME", &tmp_dir);
cmd.env("SHOUT_DIR", &tmp_dir);
if let Some(ref source_dir) = options.source_dir {
cmd.env("SHOUT_SOURCE_DIR", source_dir);
}
if let Some(ref project_dir) = options.project_dir {
cmd.env("SHOUT_PROJECT_DIR", project_dir);
}
for (key, value) in &options.env_vars {
cmd.env(key, value);
}
if !options.path_dirs.is_empty() {
let existing = std::env::var("PATH").unwrap_or_default();
let new_path = format!("{}:{existing}", options.path_dirs.join(":"));
cmd.env("PATH", new_path);
}
let mut child = match cmd.spawn() {
Ok(c) => c,
Err(e) => {
return FileResult {
file: file.clone(),
results: vec![],
tmp_dir,
error: Some(format!("Failed to spawn shell: {e}")),
};
}
};
let pid = child.id();
// Stream verbose markers from stderr in a separate thread
if verbose {
let stderr = child.stderr.take().unwrap();
let commands = file.commands.clone();
let _handle = std::thread::spawn(move || {
let mut reader = std::io::BufReader::new(stderr);
let mut buf = String::new();
let mut byte = [0u8; 1];
while reader.read(&mut byte).unwrap_or(0) > 0 {
if byte[0] == b'\n' {
if buf.starts_with(VERBOSE_MARKER) {
if let Ok(idx) = buf[VERBOSE_MARKER.len()..].parse::<usize>() {
if idx < commands.len() {
let _ = write!(
std::io::stderr(),
"\x1b[2m $ {}\x1b[0m\n",
commands[idx].command
);
}
}
}
buf.clear();
} else {
buf.push(byte[0] as char);
}
}
});
}
// Write script to stdin
if let Some(mut stdin) = child.stdin.take() {
let _ = stdin.write_all(script.as_bytes());
// stdin drops here, closing the pipe
}
// Read stdout with timeout
let total_timeout_ms = options.timeout_ms * file.commands.len() as u64;
let stdout = child.stdout.take().unwrap();
let (tx, rx) = mpsc::channel::<Vec<u8>>();
let last_sentinel_suffix = format!("_{}_", file.commands.len() - 1);
let sentinel_prefix = SENTINEL_PREFIX.to_string();
let reader_thread = std::thread::spawn(move || {
let mut reader = std::io::BufReader::new(stdout);
let mut buf = [0u8; 4096];
loop {
match reader.read(&mut buf) {
Ok(0) => break,
Ok(n) => {
if tx.send(buf[..n].to_vec()).is_err() {
break;
}
}
Err(_) => break,
}
}
});
let mut accumulated = String::new();
let deadline = std::time::Instant::now() + Duration::from_millis(total_timeout_ms);
let mut timed_out = false;
let mut sentinels_reported: usize = 0;
let mut last_sentinel_end: usize = 0;
let sentinel_re = Regex::new(&format!(r"{}(\d+)_(\d+)__", regex::escape(SENTINEL_PREFIX))).unwrap();
loop {
let remaining = deadline.saturating_duration_since(std::time::Instant::now());
if remaining.is_zero() {
timed_out = true;
break;
}
match rx.recv_timeout(remaining) {
Ok(chunk) => {
accumulated.push_str(&String::from_utf8_lossy(&chunk));
// Stream command results as they come in
if let Some(on_result) = on_command_result {
for caps in sentinel_re.captures_iter(&accumulated[last_sentinel_end..]) {
let idx: usize = caps[2].parse().unwrap_or(0);
if idx >= sentinels_reported {
let exit_code: i32 = caps[1].parse().unwrap_or(1);
let sentinel_match = caps.get(0).unwrap();
let abs_start = last_sentinel_end + sentinel_match.start();
let abs_end = last_sentinel_end + sentinel_match.end();
let output_slice = &accumulated[last_sentinel_end..abs_start];
let mut lines: Vec<String> = output_slice.split('\n').map(|s| s.to_string()).collect();
if !lines.is_empty() && lines[0].is_empty() {
lines.remove(0);
}
lines = parse::trim_trailing_empty(&lines);
if lines.len() == 1 && lines[0].is_empty() {
lines.clear();
}
let result = CommandResult {
command: file.commands[idx].clone(),
actual: lines.iter().map(|l| strip_ansi(l)).collect(),
exit_code,
};
on_result(idx, &result);
sentinels_reported = idx + 1;
last_sentinel_end = abs_end;
if accumulated.as_bytes().get(last_sentinel_end) == Some(&b'\n') {
last_sentinel_end += 1;
}
}
}
}
// Check if we've seen the last sentinel
if let Some(prefix_idx) = accumulated.rfind(&sentinel_prefix) {
if accumulated[prefix_idx..].contains(&last_sentinel_suffix) {
break;
}
}
}
Err(mpsc::RecvTimeoutError::Timeout) => {
timed_out = true;
break;
}
Err(mpsc::RecvTimeoutError::Disconnected) => break,
}
}
let _ = reader_thread.join();
// Kill the process tree
kill_tree(pid);
let _ = child.wait();
if timed_out {
return FileResult {
file: file.clone(),
results: vec![],
tmp_dir,
error: Some("Timeout reading output".to_string()),
};
}
let (outputs, exit_codes) = parse_sentinel_output(&accumulated, file.commands.len());
let results: Vec<CommandResult> = file
.commands
.iter()
.enumerate()
.map(|(i, cmd)| CommandResult {
command: cmd.clone(),
actual: outputs
.get(i)
.unwrap_or(&vec![])
.iter()
.map(|l| strip_ansi(l))
.collect(),
exit_code: exit_codes.get(i).copied().unwrap_or(1),
})
.collect();
FileResult {
file: file.clone(),
results,
tmp_dir,
error: None,
}
}
pub fn cleanup_tmp_dir(dir: &str) {
let _ = std::fs::remove_dir_all(dir);
}
/// Check if a command result passes.
pub fn command_passes(result: &CommandResult) -> bool {
use crate::matching::match_output;
let output_matches = match_output(&result.command.expected, &result.actual);
let exit_code_mismatch = match &result.command.exit_code {
parse::ExitCode::Default => result.exit_code != 0,
parse::ExitCode::Any => result.exit_code == 0,
parse::ExitCode::Code(expected) => result.exit_code != *expected,
};
output_matches && !exit_code_mismatch
}