diff --git a/src/dawn_node/tools/src/cmd/run-cts/main.go b/src/dawn_node/tools/src/cmd/run-cts/main.go index bca7af9db8..a8dbfa5857 100644 --- a/src/dawn_node/tools/src/cmd/run-cts/main.go +++ b/src/dawn_node/tools/src/cmd/run-cts/main.go @@ -16,6 +16,7 @@ package main import ( + "bytes" "context" "encoding/json" "errors" @@ -78,7 +79,7 @@ func run() error { flag.BoolVar(&verbose, "verbose", false, "print extra information while testing") flag.BoolVar(&build, "build", true, "attempt to build the CTS before running") flag.BoolVar(&colors, "colors", colors, "enable / disable colors") - flag.IntVar(&numRunners, "j", runtime.NumCPU(), "number of concurrent runners") + flag.IntVar(&numRunners, "j", runtime.NumCPU(), "number of concurrent runners. 0 runs serially") flag.StringVar(&logFilename, "log", "", "path to log file of tests run and result") flag.Parse() @@ -109,9 +110,13 @@ func run() error { } // The test query is the optional unnamed argument - queries := []string{"webgpu:*"} - if args := flag.Args(); len(args) > 0 { - queries = args + query := "webgpu:*" + switch len(flag.Args()) { + case 0: + case 1: + query = flag.Args()[0] + default: + return fmt.Errorf("only a single query can be provided") } // Find node @@ -184,14 +189,18 @@ func run() error { } } - // Find all the test cases that match the given queries. - if err := r.gatherTestCases(queries, verbose); err != nil { - return fmt.Errorf("failed to gather test cases: %w", err) + if numRunners > 0 { + // Find all the test cases that match the given queries. + if err := r.gatherTestCases(query, verbose); err != nil { + return fmt.Errorf("failed to gather test cases: %w", err) + } + + fmt.Printf("Testing %d test cases...\n", len(r.testcases)) + return r.runParallel() } - fmt.Printf("Testing %d test cases...\n", len(r.testcases)) - - return r.run() + fmt.Println("Running serially...") + return r.runSerially(query) } type logger struct { @@ -308,7 +317,7 @@ func (r *runner) buildCTS(verbose bool) error { // gatherTestCases() queries the CTS for all test cases that match the given // query. On success, gatherTestCases() populates r.testcases. -func (r *runner) gatherTestCases(queries []string, verbose bool) error { +func (r *runner) gatherTestCases(query string, verbose bool) error { if verbose { start := time.Now() fmt.Println("Gathering test cases...") @@ -325,7 +334,7 @@ func (r *runner) gatherTestCases(queries []string, verbose bool) error { // start at 1, so just inject a dummy argument. "dummy-arg", "--list", - }, queries...) + }, query) cmd := exec.Command(r.node, args...) cmd.Dir = r.cts @@ -339,9 +348,10 @@ func (r *runner) gatherTestCases(queries []string, verbose bool) error { return nil } -// run() calls the CTS test runner to run each testcase in a separate process. +// runParallel() calls the CTS test runner to run each testcase in a separate +// process. // Up to r.numRunners tests will be run concurrently. -func (r *runner) run() error { +func (r *runner) runParallel() error { // Create a chan of test indices. // This will be read by the test runner goroutines. caseIndices := make(chan int, len(r.testcases)) @@ -362,7 +372,9 @@ func (r *runner) run() error { go func() { defer wg.Done() for idx := range caseIndices { - results <- r.runTestcase(idx) + res := r.runTestcase(r.testcases[idx], false) + res.index = idx + results <- res } }() } @@ -429,6 +441,18 @@ timeout: %v (%v) return nil } +// runSerially() calls the CTS test runner to run the test query in a single +// process. +func (r *runner) runSerially(query string) error { + start := time.Now() + result := r.runTestcase(query, true) + timeTaken := time.Since(start) + + fmt.Println("Completed in", timeTaken) + fmt.Println(result) + return nil +} + // status is an enumerator of test result status type status string @@ -448,14 +472,12 @@ type result struct { error error } -// runTestcase() runs the CTS testcase with the given index, returning the test +// runTestcase() runs the CTS testcase with the given query, returning the test // result. -func (r *runner) runTestcase(idx int) result { +func (r *runner) runTestcase(query string, printToStout bool) result { ctx, cancel := context.WithTimeout(context.Background(), testTimeout) defer cancel() - testcase := r.testcases[idx] - eval := r.evalScript args := append([]string{ "-e", eval, // Evaluate 'eval'. @@ -467,23 +489,33 @@ func (r *runner) runTestcase(idx int) result { // Actual arguments begin here "--gpu-provider", r.dawnNode, "--verbose", - }, testcase) + }, query) cmd := exec.CommandContext(ctx, r.node, args...) cmd.Dir = r.cts - out, err := cmd.CombinedOutput() - msg := string(out) + + var buf bytes.Buffer + if printToStout { + cmd.Stdout = io.MultiWriter(&buf, os.Stdout) + cmd.Stderr = io.MultiWriter(&buf, os.Stderr) + } else { + cmd.Stdout = &buf + cmd.Stderr = &buf + } + + err := cmd.Run() + msg := buf.String() switch { case errors.Is(err, context.DeadlineExceeded): - return result{index: idx, testcase: testcase, status: timeout, message: msg} + return result{testcase: query, status: timeout, message: msg} case strings.Contains(msg, "[fail]"): - return result{index: idx, testcase: testcase, status: fail, message: msg} + return result{testcase: query, status: fail, message: msg} case strings.Contains(msg, "[skip]"): - return result{index: idx, testcase: testcase, status: skip, message: msg} + return result{testcase: query, status: skip, message: msg} case strings.Contains(msg, "[pass]"), err == nil: - return result{index: idx, testcase: testcase, status: pass, message: msg} + return result{testcase: query, status: pass, message: msg} } - return result{index: idx, testcase: testcase, status: fail, message: fmt.Sprint(msg, err), error: err} + return result{testcase: query, status: fail, message: fmt.Sprint(msg, err), error: err} } // filterTestcases returns in with empty strings removed