diff --git a/internal/configs/config_bootloaders.go b/internal/configs/config_bootloaders.go index 32287d8..9b6ad72 100644 --- a/internal/configs/config_bootloaders.go +++ b/internal/configs/config_bootloaders.go @@ -83,7 +83,7 @@ func Set_KernelStub(isRoot bool) { // Run and log, check for errors common.ErrorCheck(command.ExecAndLogSudo(isRoot, true, - "kernelstub -a "+fmt.Sprintf("\"%s\"", kernel_args), + fmt.Sprintf("kernelstub -a \"%s\"", kernel_args), ), "Error, kernelstub command returned exit code 1", ) diff --git a/pkg/command/command.go b/pkg/command/command.go index 65702d5..005b1b8 100644 --- a/pkg/command/command.go +++ b/pkg/command/command.go @@ -112,6 +112,28 @@ func Clear() { _ = c.Run() } +func processCmdString(cmd string) (string, []string) { + // handle quoted arguments + args := strings.Fields(cmd) + cmdBin := args[0] + args = args[1:] + for i, arg := range args { + if !strings.HasPrefix(arg, "\"") { + continue + } + // find the end of the quoted argument + for j, a := range args[i:] { + if strings.HasSuffix(a, "\"") { + args[i] = strings.Join(args[i:i+j+1], " ") + args = append(args[:i+1], args[i+j+1:]...) + break + } + } + } + + return cmdBin, args +} + // ExecAndLogSudo executes an elevated command and logs the output. // // * if we're root, the command is executed directly @@ -138,8 +160,8 @@ func ExecAndLogSudo(isRoot, noisy bool, cmd string) error { return err } - cs := strings.Fields(cmd) - r := exec.Command(cs[0], cs[1:]...) + cmdBin, args := processCmdString(cmd) + r := exec.Command(cmdBin, args...) r.Dir = wd cmdCombinedOut, err := r.CombinedOutput() diff --git a/pkg/command/command_test.go b/pkg/command/command_test.go new file mode 100644 index 0000000..16e94b6 --- /dev/null +++ b/pkg/command/command_test.go @@ -0,0 +1,92 @@ +package command + +import ( + "fmt" + "reflect" + "testing" +) + +func Test_processCmdString(t *testing.T) { + type args struct { + cmd string + } + + kernel_args := "intel_iommu=on iommu=pt vfio-pci.ids=10de:1c03,10de:10f1" + + tests := []struct { + name string + args args + want string + want1 []string + }{ + { + name: "ls -l", + args: args{ + cmd: "ls -l", + }, + want: "ls", + want1: []string{ + "-l", + }, + }, + { + name: "ls -l -a", + args: args{ + cmd: "ls -l -a", + }, + want: "ls", + want1: []string{ + "-l", + "-a", + }, + }, + { + name: "rm -v \"file.txt\"", + args: args{ + cmd: "rm -v \"file.txt\"", + }, + want: "rm", + want1: []string{ + "-v", + "\"file.txt\"", + }, + }, + { + name: "rm -v \"file.txt\" -f", + args: args{ + cmd: "rm -v \"file.txt\" -f", + }, + want: "rm", + want1: []string{ + "-v", + "\"file.txt\"", + "-f", + }, + }, + { + name: fmt.Sprintf("kernelstub -a \"%s\"", kernel_args), + args: args{ + cmd: fmt.Sprintf("kernelstub -a \"%s\"", kernel_args), + }, + want: "kernelstub", + want1: []string{ + "-a", + fmt.Sprintf("\"%s\"", kernel_args), + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, got1 := processCmdString(tt.args.cmd) + if got != tt.want { + t.Errorf("processCmdString() got = %v, want %v", got, tt.want) + } + t.Logf("got: %v", got) + t.Logf("got1: %v", got1) + if !reflect.DeepEqual(got1, tt.want1) { + t.Errorf("processCmdString() got1 = %v, want %v", got1, tt.want1) + } + }) + } +}