From 6c48a35180237fbf8da3dc03a9eb355c9e11b60e Mon Sep 17 00:00:00 2001 From: kayos Date: Sat, 27 Jul 2024 08:38:39 -0700 Subject: [PATCH] Feat: conditional permissions behavior (#28) * Fix: don't need sudo if we're root + other aesthetics * Heavy refactoring, see PR #28 * Fix: avoid silent fatalities demo: https://tcp.ac/i/JMSUc.gif * Fix: Inverse check on `IsRoot` * D.R.Y: check for permissions error in `common.ErrorCheck` Reduce cognitive complexity. * Fix: Issue with copying * Resolve https://github.com/HikariKnight/quickpassthrough/pull/28#discussion_r1646535918 * Resolve https://github.com/HikariKnight/quickpassthrough/pull/28#discussion_r1646606680 and https://github.com/HikariKnight/quickpassthrough/pull/28#discussion_r1646594105 * Revert "Resolve https://github.com/HikariKnight/quickpassthrough/pull/28#discussion_r1646606680 and https://github.com/HikariKnight/quickpassthrough/pull/28#discussion_r1646594105" This reverts commit ce1521300982a86e799676ad1a630ea95de3b8c2. * Resolve https://github.com/HikariKnight/quickpassthrough/pull/28#discussion_r1646730751 --- internal/common/errors.go | 50 +++++ internal/configs/config_bootloaders.go | 107 +++++----- internal/configs/config_dracut.go | 6 +- internal/configs/config_initramfstools.go | 6 +- internal/configs/config_mkinitcpio.go | 6 +- internal/configs/config_modprobe.go | 4 +- internal/configs/config_vbios_dumper.go | 6 +- internal/configs/config_vfio_video.go | 4 +- internal/configs/configs.go | 150 ++++++++++--- internal/configs/configs_test.go | 33 +++ .../ls_iommu_downloader.go | 27 +-- internal/pages/02_select_gpu.go | 15 +- internal/pages/05_select_usbctrl.go | 5 +- internal/pages/06_finalize.go | 200 ++++++++++-------- internal/pages/06_finalize_test.go | 23 ++ internal/ui_main.go | 11 +- pkg/command/command.go | 91 ++++++-- pkg/command/command_test.go | 61 ++++++ pkg/fileio/fileio.go | 45 ++-- pkg/menu/manual.go | 5 +- pkg/untar/untar.go | 4 +- 21 files changed, 604 insertions(+), 255 deletions(-) create mode 100644 internal/common/errors.go create mode 100644 internal/configs/configs_test.go create mode 100644 internal/pages/06_finalize_test.go create mode 100644 pkg/command/command_test.go diff --git a/internal/common/errors.go b/internal/common/errors.go new file mode 100644 index 0000000..0c83407 --- /dev/null +++ b/internal/common/errors.go @@ -0,0 +1,50 @@ +package common + +import ( + "errors" + "os" + "time" + + "github.com/HikariKnight/ls-iommu/pkg/errorcheck" + "github.com/gookit/color" +) + +const PermissionNotice = ` +Permissions error occured during file operations. + +Hint: + + If you initially ran QuickPassthrough as root or using sudo, + but are now running it as a normal user, this is expected behavior. + + Try running QuickPassthrough as root or using sudo if so. + + If this does not work, double check your filesystem's permissions, + and be sure to check the debug log for more information. +` + +// ErrorCheck serves as a wrapper for HikariKnight/ls-iommu/pkg/common.ErrorCheck that allows for visibile error messages +func ErrorCheck(err error, msg ...string) { + _, _ = os.Stdout.WriteString("\033[H\033[2J") // clear the screen + if err == nil { + return + } + if errors.Is(err, os.ErrPermission) { + color.Printf(PermissionNotice) + } + oneMsg := "" + if len(msg) < 1 { + oneMsg = "" + } else { + for _, v := range msg { + oneMsg += v + "\n" + } + } + color.Printf("\nFATAL: %s\n%s\nAborting", err.Error(), oneMsg) + for i := 0; i < 10; i++ { + time.Sleep(1 * time.Second) + print(".") + } + print("\n") + errorcheck.ErrorCheck(err, msg...) +} diff --git a/internal/configs/config_bootloaders.go b/internal/configs/config_bootloaders.go index 2f8607d..281cc4d 100644 --- a/internal/configs/config_bootloaders.go +++ b/internal/configs/config_bootloaders.go @@ -1,16 +1,19 @@ package configs import ( + "errors" "fmt" "os" + "os/exec" "regexp" "strings" - "github.com/HikariKnight/ls-iommu/pkg/errorcheck" + "github.com/klauspost/cpuid/v2" + + "github.com/HikariKnight/quickpassthrough/internal/common" "github.com/HikariKnight/quickpassthrough/internal/logger" "github.com/HikariKnight/quickpassthrough/pkg/command" "github.com/HikariKnight/quickpassthrough/pkg/fileio" - "github.com/klauspost/cpuid/v2" ) // This function just adds what bootloader the system has to our config.bootloader value @@ -70,39 +73,32 @@ func Set_Cmdline(gpu_IDs []string) { fileio.AppendContent(fmt.Sprintf(" vfio_pci.ids=%s", strings.Join(gpu_IDs, ",")), config.Path.CMDLINE) } -// Configures systemd-boot using kernelstub -func Set_KernelStub() string { +// Set_KernelStub configures systemd-boot using kernelstub. +func Set_KernelStub(isRoot bool) { // Get the config config := GetConfig() // Get the kernel args kernel_args := fileio.ReadFile(config.Path.CMDLINE) - // Write to logger - logger.Printf("Running command:\nsudo kernelstub -a \"%s\"\n", kernel_args) - - // Run the command - _, err := command.Run("sudo", "kernelstub", "-a", kernel_args) - errorcheck.ErrorCheck(err, "Error, kernelstub command returned exit code 1") - - // Return what we did - return fmt.Sprintf("Executed: sudo kernelstub -a \"%s\"", kernel_args) + // Run and log, check for errors + common.ErrorCheck( + command.ExecAndLogSudo(isRoot, true, "kernelstub", "-a", kernel_args), + "Error, kernelstub command returned exit code 1", + ) } -// Configures grub2 and/or systemd-boot using grubby -func Set_Grubby() string { +// Set_Grubby configures grub2 and/or systemd-boot using grubby +func Set_Grubby(isRoot bool) string { // Get the config config := GetConfig() // Get the kernel args kernel_args := fileio.ReadFile(config.Path.CMDLINE) - // Write to logger - logger.Printf("Running command:\nsudo grubby --update-kernel=ALL --args=\"%s\"\n", kernel_args) - - // Run the command - _, err := command.Run("sudo", "grubby", "--update-kernel=ALL", fmt.Sprintf("--args=%s", kernel_args)) - errorcheck.ErrorCheck(err, "Error, grubby command returned exit code 1") + // Run and log, check for errors + err := command.ExecAndLogSudo(isRoot, true, "grubby", "--update-kernel=ALL", fmt.Sprintf("--args=%s", kernel_args)) + common.ErrorCheck(err, "Error, grubby command returned exit code 1") // Return what we did return fmt.Sprintf("Executed: sudo grubby --update-kernel=ALL --args=\"%s\"", kernel_args) @@ -116,8 +112,8 @@ func Configure_Grub2() { conffile := fmt.Sprintf("%s/grub", config.Path.DEFAULT) // Make sure we start from scratch by deleting any old file - if fileio.FileExist(conffile) { - os.Remove(conffile) + if exists, _ := fileio.FileExist(conffile); exists { + _ = os.Remove(conffile) } // Make a regex to get the system path instead of the config path @@ -201,8 +197,8 @@ func clean_Grub2_Args(old_kernel_args []string) []string { return clean_kernel_args } -// This function copies our config to /etc/default/grub and updates grub -func Set_Grub2() ([]string, error) { +// Set_Grub2 copies our config to /etc/default/grub and updates grub +func Set_Grub2(isRoot bool) error { // Get the config config := GetConfig() @@ -213,38 +209,45 @@ func Set_Grub2() ([]string, error) { sysfile_re := regexp.MustCompile(`^config`) sysfile := sysfile_re.ReplaceAllString(conffile, "") - // Write to logger - logger.Printf("Executing command:\nsudo cp -v \"%s\" %s\n", conffile, sysfile) + // [CopyToSystem] will log the operation + // logger.Printf("Executing command:\nsudo cp -v \"%s\" %s\n", conffile, sysfile) - // Make our output slice - var output []string - - // Copy files to system - output = append(output, CopyToSystem(conffile, sysfile)) + // Copy files to system, logging and error checking is done in the function + CopyToSystem(isRoot, conffile, sysfile) // Set a variable for the mkconfig command - mkconfig := "grub-mkconfig" + var mkconfig string + var grubPath = "/boot/grub/grub.cfg" + var lpErr error + // Check for grub-mkconfig - _, err := command.Run("which", "grub-mkconfig") - if err == nil { - // Set binary as grub-mkconfig - mkconfig = "grub-mkconfig" - } else { - mkconfig = "grub2-mkconfig" + mkconfig, lpErr = exec.LookPath("grub-mkconfig") + switch { + case errors.Is(lpErr, exec.ErrNotFound) || mkconfig == "": + // Check for grub2-mkconfig + mkconfig, lpErr = exec.LookPath("grub2-mkconfig") + if lpErr == nil && mkconfig != "" { + grubPath = "/boot/grub2/grub.cfg" + break // skip below, we found grub2-mkconfig + } + if lpErr == nil { + // we know mkconfig is empty despite no error; + // so set an error for [common.ErrorCheck]. + lpErr = errors.New("neither grub-mkconfig or grub2-mkconfig found") + } + common.ErrorCheck(lpErr, lpErr.Error()+"\n") + return lpErr // note: unreachable as [common.ErrorCheck] calls fatal + default: } - // Update grub.cfg - if fileio.FileExist("/boot/grub/grub.cfg") { - output = append(output, fmt.Sprintf("Executed: sudo %s -o /boot/grub/grub.cfg\nSee debug.log for more detailed output", mkconfig)) - _, mklog, err := command.RunErr("sudo", mkconfig, "-o", "/boot/grub/grub.cfg") - logger.Printf(strings.Join(mklog, "\n")) - errorcheck.ErrorCheck(err, "Failed to update /boot/grub/grub.cfg") - } else { - output = append(output, fmt.Sprintf("Executed: sudo %s -o /boot/grub/grub.cfg\nSee debug.log for more detailed output", mkconfig)) - _, mklog, err := command.RunErr("sudo", mkconfig, "-o", "/boot/grub2/grub.cfg") - logger.Printf(strings.Join(mklog, "\n")) - errorcheck.ErrorCheck(err, "Failed to update /boot/grub/grub.cfg") - } + _, mklog, err := command.RunErrSudo(isRoot, mkconfig, "-o", grubPath) - return output, err + // tabulate the output, [command.RunErrSudo] logged the execution. + logger.Printf("\t" + strings.Join(mklog, "\n\t")) + common.ErrorCheck(err, "Failed to update /boot/grub/grub.cfg") + + // always returns nil as [common.ErrorCheck] calls fatal + // keeping the ret signature, as we should consider passing down errors + // but that's a massive rabbit hole to go down for this codebase as a whole + return err } diff --git a/internal/configs/config_dracut.go b/internal/configs/config_dracut.go index f31a2ff..70f837e 100644 --- a/internal/configs/config_dracut.go +++ b/internal/configs/config_dracut.go @@ -9,7 +9,7 @@ import ( "github.com/HikariKnight/quickpassthrough/pkg/fileio" ) -// This function writes a dracut configuration file for /etc/dracut.conf.d/ +// Set_Dracut writes a dracut configuration file for `/etc/dracut.conf.d/`. func Set_Dracut() { config := GetConfig() @@ -17,8 +17,8 @@ func Set_Dracut() { dracutConf := fmt.Sprintf("%s/vfio.conf", config.Path.DRACUT) // If the file already exists then delete it - if fileio.FileExist(dracutConf) { - os.Remove(dracutConf) + if exists, _ := fileio.FileExist(dracutConf); exists { + _ = os.Remove(dracutConf) } // Write to logger diff --git a/internal/configs/config_initramfstools.go b/internal/configs/config_initramfstools.go index f198959..d3a7bd5 100644 --- a/internal/configs/config_initramfstools.go +++ b/internal/configs/config_initramfstools.go @@ -7,7 +7,7 @@ import ( "regexp" "strings" - "github.com/HikariKnight/ls-iommu/pkg/errorcheck" + "github.com/HikariKnight/quickpassthrough/internal/common" "github.com/HikariKnight/quickpassthrough/pkg/fileio" ) @@ -15,7 +15,7 @@ import ( func initramfs_readHeader(lines int, fileName string) string { // Open the file f, err := os.Open(fileName) - errorcheck.ErrorCheck(err, fmt.Sprintf("Error opening %s", fileName)) + common.ErrorCheck(err, fmt.Sprintf("Error opening %s", fileName)) defer f.Close() header_re := regexp.MustCompile(`^#`) @@ -50,7 +50,7 @@ func initramfs_addModules(conffile string) { // Open the system file for reading sysfile, err := os.Open(syspath) - errorcheck.ErrorCheck(err, fmt.Sprintf("Error opening file for reading %s", syspath)) + common.ErrorCheck(err, fmt.Sprintf("Error opening file for reading %s", syspath)) defer sysfile.Close() // Check if user has vendor-reset installed/enabled and make sure that is first diff --git a/internal/configs/config_mkinitcpio.go b/internal/configs/config_mkinitcpio.go index 219c16b..6c95f8f 100644 --- a/internal/configs/config_mkinitcpio.go +++ b/internal/configs/config_mkinitcpio.go @@ -10,14 +10,14 @@ import ( "github.com/HikariKnight/quickpassthrough/pkg/fileio" ) -// This function copies the content of /etc/mkinitcpio.conf to the config folder and does an inline replace/insert on the MODULES=() line +// Set_Mkinitcpio copies the content of /etc/mkinitcpio.conf to the config folder and does an inline replace/insert on the MODULES=() line func Set_Mkinitcpio() { // Get the config struct config := GetConfig() // Make sure we start from scratch by deleting any old file - if fileio.FileExist(config.Path.MKINITCPIO) { - os.Remove(config.Path.MKINITCPIO) + if exists, _ := fileio.FileExist(config.Path.MKINITCPIO); exists { + _ = os.Remove(config.Path.MKINITCPIO) } // Make a regex to get the system path instead of the config path diff --git a/internal/configs/config_modprobe.go b/internal/configs/config_modprobe.go index 63a5dbc..715a267 100644 --- a/internal/configs/config_modprobe.go +++ b/internal/configs/config_modprobe.go @@ -30,9 +30,9 @@ func Set_Modprobe(gpu_IDs []string) { conffile := fmt.Sprintf("%s/vfio.conf", config.Path.MODPROBE) // If the file exists - if fileio.FileExist(conffile) { + if exists, _ := fileio.FileExist(conffile); exists { // Delete the old file - os.Remove(conffile) + _ = os.Remove(conffile) } content := fmt.Sprint( diff --git a/internal/configs/config_vbios_dumper.go b/internal/configs/config_vbios_dumper.go index b075018..775cf7a 100644 --- a/internal/configs/config_vbios_dumper.go +++ b/internal/configs/config_vbios_dumper.go @@ -6,7 +6,7 @@ import ( "path/filepath" "strings" - "github.com/HikariKnight/ls-iommu/pkg/errorcheck" + "github.com/HikariKnight/quickpassthrough/internal/common" "github.com/HikariKnight/quickpassthrough/internal/logger" ) @@ -55,12 +55,12 @@ func GenerateVBIOSDumper(vbios_path string) { // Make the script file scriptfile, err := os.Create("utils/dump_vbios.sh") - errorcheck.ErrorCheck(err, "Cannot create file \"utils/dump_vbios.sh\"") + common.ErrorCheck(err, "Cannot create file \"utils/dump_vbios.sh\"") defer scriptfile.Close() // Make the script executable scriptfile.Chmod(0775) - errorcheck.ErrorCheck(err, "Could not change permissions of \"utils/dump_vbios.sh\"") + common.ErrorCheck(err, "Could not change permissions of \"utils/dump_vbios.sh\"") // Write to logger logger.Printf("Writing utils/dump_vbios.sh\n") diff --git a/internal/configs/config_vfio_video.go b/internal/configs/config_vfio_video.go index 63230b1..ae7dbeb 100644 --- a/internal/configs/config_vfio_video.go +++ b/internal/configs/config_vfio_video.go @@ -5,7 +5,7 @@ import ( "os" "strings" - "github.com/HikariKnight/ls-iommu/pkg/errorcheck" + "github.com/HikariKnight/quickpassthrough/internal/common" "github.com/HikariKnight/quickpassthrough/internal/logger" "github.com/HikariKnight/quickpassthrough/pkg/fileio" ) @@ -26,7 +26,7 @@ func DisableVFIOVideo(i int) { if strings.Contains(kernel_args, "vfio_pci.disable_vga") { // Remove the old file err := os.Remove(config.Path.CMDLINE) - errorcheck.ErrorCheck(err, fmt.Sprintf("Could not rewrite %s", config.Path.CMDLINE)) + common.ErrorCheck(err, fmt.Sprintf("Could not rewrite %s", config.Path.CMDLINE)) // Enable or disable the VGA based on our given value if i == 0 { diff --git a/internal/configs/configs.go b/internal/configs/configs.go index 50d6174..74ecb60 100644 --- a/internal/configs/configs.go +++ b/internal/configs/configs.go @@ -1,16 +1,20 @@ package configs import ( + "errors" "fmt" "os" + "path/filepath" "regexp" + "strings" - "github.com/HikariKnight/ls-iommu/pkg/errorcheck" + "github.com/klauspost/cpuid/v2" + + "github.com/HikariKnight/quickpassthrough/internal/common" "github.com/HikariKnight/quickpassthrough/internal/logger" "github.com/HikariKnight/quickpassthrough/pkg/command" "github.com/HikariKnight/quickpassthrough/pkg/fileio" "github.com/HikariKnight/quickpassthrough/pkg/uname" - "github.com/klauspost/cpuid/v2" ) type Path struct { @@ -30,9 +34,10 @@ type Config struct { Path *Path Gpu_Group string Gpu_IDs []string + IsRoot bool } -// Gets the path to all the config files +// GetConfigPaths retrieves the path to all the config files. func GetConfigPaths() *Path { Paths := &Path{ CMDLINE: "config/kernel_args", @@ -48,7 +53,7 @@ func GetConfigPaths() *Path { return Paths } -// Gets all the configs and returns the struct +// GetConfig retrieves all the configs and returns the struct. func GetConfig() *Config { config := &Config{ Bootloader: "unknown", @@ -64,7 +69,7 @@ func GetConfig() *Config { return config } -// Constructs the empty config files and folders based on what exists on the system +// InitConfigs constructs the empty config files and folders based on what exists on the system func InitConfigs() { config := GetConfig() @@ -77,10 +82,17 @@ func InitConfigs() { } // Remove old config - os.RemoveAll("config") + if err := os.RemoveAll("config"); err != nil && !errors.Is(err, os.ErrNotExist) { + + // won't be called if the error is ErrNotExist + common.ErrorCheck(err, "\nError removing old config") + } // Make the config folder - os.Mkdir("config", os.ModePerm) + if err := os.Mkdir("config", os.ModePerm); err != nil && !errors.Is(err, os.ErrExist) { + // won't be called if the error is ErrExist + common.ErrorCheck(err, "\nError making config folder") + } // Make a regex to get the system path instead of the config path syspath_re := regexp.MustCompile(`^config`) @@ -90,8 +102,16 @@ func InitConfigs() { // Get the system path syspath := syspath_re.ReplaceAllString(confpath, "") + exists, err := fileio.FileExist(syspath) + + // If we received an error that is not ErrNotExist + if err != nil { + common.ErrorCheck(err, "\nError checking for directory: "+syspath) + continue // note: unreachable due to ErrorCheck calling fatal + } + // If the path exists - if fileio.FileExist(syspath) { + if exists { // Write to log logger.Printf( "%s found on the system\n"+ @@ -104,8 +124,10 @@ func InitConfigs() { makeBackupDir(syspath) // Create the directories for our configs - err := os.MkdirAll(confpath, os.ModePerm) - errorcheck.ErrorCheck(err) + if err = os.MkdirAll(confpath, os.ModePerm); err != nil && !errors.Is(err, os.ErrExist) { + common.ErrorCheck(err, "\nError making directory: "+confpath) + return // note: unreachable due to ErrorCheck calling fatal + } } } @@ -128,7 +150,15 @@ func InitConfigs() { sysfile := syspath_re.ReplaceAllString(conffile, "") // If the file exists - if fileio.FileExist(sysfile) { + exists, err := fileio.FileExist(sysfile) + + // If we received an error that is not ErrNotExist + if err != nil { + common.ErrorCheck(err, "\nError checking for file: "+sysfile) + continue // note: unreachable due to ErrorCheck calling fatal + } + + if exists { // Write to log logger.Printf( "%s found on the system\n"+ @@ -139,16 +169,22 @@ func InitConfigs() { // Create the directories for our configs file, err := os.Create(conffile) - errorcheck.ErrorCheck(err) + common.ErrorCheck(err) // Close the file so we can edit it - file.Close() + _ = file.Close() // Backup the sysfile if we do not have a backup backupFile(sysfile) } + exists, err = fileio.FileExist(conffile) + if err != nil { + common.ErrorCheck(err, "\nError checking for file: "+conffile) + continue // note: unreachable + } + // If we now have a config that exists - if fileio.FileExist(conffile) { + if exists { switch conffile { case config.Path.ETCMODULES: // Write to logger @@ -204,14 +240,28 @@ func backupFile(source string) { // Make a destination path dest := fmt.Sprintf("backup%s", source) + configExists, configFileError := fileio.FileExist(fmt.Sprintf("config%s", source)) + sysExists, sysFileError := fileio.FileExist(source) + destExists, destFileError := fileio.FileExist(dest) + + // If we received an error that is not ErrNotExist on any of the files + for _, err := range []error{configFileError, sysFileError, destFileError} { + if err != nil { + common.ErrorCheck(configFileError, "\nError checking for file: "+source) + return // note: unreachable + } + } + + switch { // If the file exists in the config but not on the system it is a file we make - if fileio.FileExist(fmt.Sprintf("config%s", source)) && !fileio.FileExist(source) { + case configExists && !sysExists: // Create the blank file so that a copy of the backup folder to /etc file, err := os.Create(dest) - errorcheck.ErrorCheck(err, "Error creating file %s\n", dest) - file.Close() - } else if !fileio.FileExist(dest) { + common.ErrorCheck(err, "Error creating file %s\n", dest) + _ = file.Close() + // If a backup of the file does not exist + case sysExists && !destExists: // Write to the logger logger.Printf("No first time backup of %s detected.\nCreating a backup at %s\n", source, dest) @@ -223,29 +273,67 @@ func backupFile(source string) { func makeBackupDir(dest string) { // If a backup directory does not exist - if !fileio.FileExist("backup/") { + exists, err := fileio.FileExist("backup/") + if err != nil { + // If we received an error that is not ErrNotExist + common.ErrorCheck(err, "Error checking for backup/ folder") + return // note: unreachable + } + + if !exists { // Write to the logger logger.Printf("Backup directory does not exist!\nCreating backup directory for first run backup") } // Make the empty directories - err := os.MkdirAll(fmt.Sprintf("backup/%s", dest), os.ModePerm) - errorcheck.ErrorCheck(err, "Error making backup/ folder") + if err = os.MkdirAll(fmt.Sprintf("backup/%s", dest), os.ModePerm); errors.Is(err, os.ErrExist) { + // ignore if the directory already exists + err = nil + } + // will return without incident if there's no error + common.ErrorCheck(err, "Error making backup/ folder") } -// Copy a file to the system, make sure you have run command.Elevate() recently -func CopyToSystem(conffile, sysfile string) string { +// CopyToSystem copies a file to the system. +func CopyToSystem(isRoot bool, conffile, sysfile string) { // Since we should be elevated with our sudo token we will copy with cp // (using built in functions will not work as we are running as the normal user) - output, _ := command.Run("sudo", "cp", "-v", conffile, sysfile) - // Clean the output - clean_re := regexp.MustCompile(`\n`) - clean_output := clean_re.ReplaceAllString(output[0], "") + // ExecAndLogSudo will write to the logger, so just print here + fmt.Printf("Copying: %s to %s\n", conffile, sysfile) - // Write output to logger - logger.Printf("%s\n", clean_output) + if isRoot { + logger.Printf("Copying %s to %s\n", conffile, sysfile) + fmt.Printf("Copying %s to %s\n", conffile, sysfile) + fDat, err := os.ReadFile(conffile) + common.ErrorCheck(err, fmt.Sprintf("Failed to read %s", conffile)) + err = os.WriteFile(sysfile, fDat, 0644) + common.ErrorCheck(err, fmt.Sprintf("Failed to write %s", sysfile)) + logger.Printf("Copied %s to %s\n", conffile, sysfile) + return + } - // Return the output - return fmt.Sprintf("Copying: %s", clean_output) + if !filepath.IsAbs(conffile) { + conffile, _ = filepath.Abs(conffile) + } + + conffile = strings.ReplaceAll(conffile, " ", "\\ ") + cmd := fmt.Sprintf("cp -v %s %s", conffile, sysfile) + + err := command.ExecAndLogSudo(isRoot, false, cmd) + + errMsg := "" + if err != nil { + errMsg = err.Error() + } + + // [command.ExecAndLogSudo] will log the command's output + common.ErrorCheck(err, fmt.Sprintf("Failed to copy %s to %s:\n%s", conffile, sysfile, errMsg)) + + // --------------------------------------------------------------------------------- + // note that if we failed the error check, the following will not appear in the log! + // this is because the [common.ErrorCheck] function will call [log.Fatalf] and exit + // --------------------------------------------------------------------------------- + + logger.Printf("Copied %s to %s\n", conffile, sysfile) } diff --git a/internal/configs/configs_test.go b/internal/configs/configs_test.go new file mode 100644 index 0000000..8cc6290 --- /dev/null +++ b/internal/configs/configs_test.go @@ -0,0 +1,33 @@ +package configs + +import ( + "os" + "path/filepath" + "testing" +) + +func TestCopyToSystem(t *testing.T) { + if err := os.Mkdir("testdir", 0755); err != nil { + t.Fatal(err) + } + tFilePath := filepath.Join("testdir", "testfile") + if err := os.WriteFile(tFilePath, []byte("test"), 0644); err != nil { + t.Fatal(err) + } + t.Cleanup(func() { + if err := os.RemoveAll("testdir"); err != nil { + t.Fatal(err) + } + }) + isRoot := os.Getuid() == 0 + switch isRoot { + case true: + t.Run("TestCopyToSystem_AsRoot", func(t *testing.T) { + CopyToSystem(true, tFilePath, "/etc/testfile") + }) + default: + t.Run("TestCopyToSystem_AsUser", func(t *testing.T) { + CopyToSystem(false, tFilePath, "/etc/testfile") + }) + } +} diff --git a/internal/ls_iommu_downloader/ls_iommu_downloader.go b/internal/ls_iommu_downloader/ls_iommu_downloader.go index d8114f0..0c26cfd 100644 --- a/internal/ls_iommu_downloader/ls_iommu_downloader.go +++ b/internal/ls_iommu_downloader/ls_iommu_downloader.go @@ -12,10 +12,11 @@ import ( "strings" "time" - "github.com/HikariKnight/ls-iommu/pkg/errorcheck" + "github.com/cavaliergopher/grab/v3" + + "github.com/HikariKnight/quickpassthrough/internal/common" "github.com/HikariKnight/quickpassthrough/pkg/fileio" "github.com/HikariKnight/quickpassthrough/pkg/untar" - "github.com/cavaliergopher/grab/v3" ) // Generated from github API response using https://mholt.github.io/json-to-go/ @@ -95,14 +96,14 @@ type Response struct { func CheckLsIOMMU() { // Check the API for releases resp, err := http.Get("https://api.github.com/repos/hikariknight/ls-iommu/releases/latest") - errorcheck.ErrorCheck(err) + common.ErrorCheck(err) // Close the response when function ends defer resp.Body.Close() // Get the response body body, err := io.ReadAll(resp.Body) - errorcheck.ErrorCheck(err) + common.ErrorCheck(err) var result Response if err := json.Unmarshal(body, &result); err != nil { @@ -111,9 +112,9 @@ func CheckLsIOMMU() { // Make the directory for ls-iommu if it does not exist path := "utils" - if !fileio.FileExist(path) { + if exists, _ := fileio.FileExist(path); !exists { err := os.Mkdir(path, os.ModePerm) - errorcheck.ErrorCheck(err) + common.ErrorCheck(err) } // Generate the download url @@ -133,30 +134,30 @@ func CheckLsIOMMU() { // Get the checksum data checksums, err := http.Get(checkSumsUrl) - errorcheck.ErrorCheck(err) + common.ErrorCheck(err) defer checksums.Body.Close() checksums_txt, err := io.ReadAll(checksums.Body) - errorcheck.ErrorCheck(err) + common.ErrorCheck(err) // Check if the tar.gz exists - if !fileio.FileExist(fileName) { + if exists, _ := fileio.FileExist(fileName); !exists { downloadNewVersion(path, fileName, downloadUrl) if checkSum(string(checksums_txt), fileName) { err = untar.Untar(fmt.Sprintf("%s/", path), fileName) - errorcheck.ErrorCheck(err) + common.ErrorCheck(err) } } else { if !checkSum(string(checksums_txt), fileName) { downloadNewVersion(path, fileName, downloadUrl) err = untar.Untar(fmt.Sprintf("%s/", path), fileName) - errorcheck.ErrorCheck(err) + common.ErrorCheck(err) } } } func checkSum(checksums string, fileName string) bool { r, err := os.Open(fileName) - errorcheck.ErrorCheck(err) + common.ErrorCheck(err) defer r.Close() hasher := sha256.New() @@ -182,7 +183,7 @@ func downloadNewVersion(path, fileName, downloadUrl string) { // check for errors if err := download.Err(); err != nil { fmt.Fprintf(os.Stderr, "Download failed: %v\n", err) - if !fileio.FileExist("utils/ls-iommu") { + if exists, _ := fileio.FileExist("utils/ls-iommu"); !exists { log.Fatal("If the above error is 404, then we could not communicate with the GitHub API\n Please manually download and extract ls-iommu to: utils/\nYou can download it from: https://github.com/HikariKnight/ls-iommu/releases") } else { fmt.Println("Existing ls-iommu binary detected in \"utils/\", will use that instead as the GitHub API did not respond.") diff --git a/internal/pages/02_select_gpu.go b/internal/pages/02_select_gpu.go index 1b270c5..a035f65 100644 --- a/internal/pages/02_select_gpu.go +++ b/internal/pages/02_select_gpu.go @@ -4,13 +4,14 @@ import ( "fmt" "os" - "github.com/HikariKnight/ls-iommu/pkg/errorcheck" + "github.com/gookit/color" + + "github.com/HikariKnight/quickpassthrough/internal/common" "github.com/HikariKnight/quickpassthrough/internal/configs" - lsiommu "github.com/HikariKnight/quickpassthrough/internal/lsiommu" + "github.com/HikariKnight/quickpassthrough/internal/lsiommu" "github.com/HikariKnight/quickpassthrough/pkg/command" "github.com/HikariKnight/quickpassthrough/pkg/fileio" "github.com/HikariKnight/quickpassthrough/pkg/menu" - "github.com/gookit/color" ) func SelectGPU(config *configs.Config) { @@ -94,10 +95,10 @@ func viewGPU(config *configs.Config, ext ...int) { config.Gpu_IDs = lsiommu.GetIOMMU("-g", mode, "-i", config.Gpu_Group, "--id") // If the kernel_args file already exists - if fileio.FileExist(config.Path.CMDLINE) { + if exists, _ := fileio.FileExist(config.Path.CMDLINE); exists { // Delete it as we will have to make a new one anyway err := os.Remove(config.Path.CMDLINE) - errorcheck.ErrorCheck(err, fmt.Sprintf("Could not remove %s", config.Path.CMDLINE)) + common.ErrorCheck(err, fmt.Sprintf("Could not remove %s", config.Path.CMDLINE)) } // Write initial kernel_arg file @@ -114,10 +115,10 @@ func viewGPU(config *configs.Config, ext ...int) { ) // If the kernel_args file already exists - if fileio.FileExist(config.Path.CMDLINE) { + if exists, _ := fileio.FileExist(config.Path.CMDLINE); exists { // Delete it as we will have to make a new one anyway err := os.Remove(config.Path.CMDLINE) - errorcheck.ErrorCheck(err, fmt.Sprintf("Could not remove %s", config.Path.CMDLINE)) + common.ErrorCheck(err, fmt.Sprintf("Could not remove %s", config.Path.CMDLINE)) } // Write initial kernel_arg file diff --git a/internal/pages/05_select_usbctrl.go b/internal/pages/05_select_usbctrl.go index 3264f7a..2df9583 100644 --- a/internal/pages/05_select_usbctrl.go +++ b/internal/pages/05_select_usbctrl.go @@ -4,11 +4,12 @@ import ( "fmt" "os" + "github.com/gookit/color" + "github.com/HikariKnight/quickpassthrough/internal/configs" - lsiommu "github.com/HikariKnight/quickpassthrough/internal/lsiommu" + "github.com/HikariKnight/quickpassthrough/internal/lsiommu" "github.com/HikariKnight/quickpassthrough/pkg/command" "github.com/HikariKnight/quickpassthrough/pkg/menu" - "github.com/gookit/color" ) func selectUSB(config *configs.Config) { diff --git a/internal/pages/06_finalize.go b/internal/pages/06_finalize.go index 232bd26..a1ab656 100644 --- a/internal/pages/06_finalize.go +++ b/internal/pages/06_finalize.go @@ -6,34 +6,34 @@ import ( "log" "os" "os/user" - "strings" "syscall" + "github.com/gookit/color" + "golang.org/x/term" + "github.com/HikariKnight/quickpassthrough/internal/configs" "github.com/HikariKnight/quickpassthrough/internal/logger" "github.com/HikariKnight/quickpassthrough/pkg/command" "github.com/HikariKnight/quickpassthrough/pkg/fileio" "github.com/HikariKnight/quickpassthrough/pkg/menu" "github.com/HikariKnight/quickpassthrough/pkg/uname" - "github.com/gookit/color" - "golang.org/x/term" ) func prepModules(config *configs.Config) { // If we have files for modprobe - if fileio.FileExist(config.Path.MODPROBE) { + if exists, _ := fileio.FileExist(config.Path.MODPROBE); exists { // Configure modprobe configs.Set_Modprobe(config.Gpu_IDs) } // If we have a folder for dracut - if fileio.FileExist(config.Path.DRACUT) { + if exists, _ := fileio.FileExist(config.Path.DRACUT); exists { // Configure dracut configs.Set_Dracut() } // If we have a mkinitcpio.conf file - if fileio.FileExist(config.Path.MKINITCPIO) { + if exists, _ := fileio.FileExist(config.Path.MKINITCPIO); exists { configs.Set_Mkinitcpio() } @@ -48,6 +48,43 @@ func prepModules(config *configs.Config) { finalize(config) } +func finalizeNotice(isRoot bool) { + color.Print(` +The configuration files have been generated and are located inside the "config" folder + + * The "kernel_args" file contains kernel arguments that your bootloader needs + * The "qemu" folder contains files that may be needed for passthrough + * The files inside the "etc" folder must be copied to your system. + + Verify that these files are correctly formated/edited! + +Once all files have been copied, the following steps must be taken: + + * bootloader configuration must be updated + * initramfs must be rebuilt + +`) + switch isRoot { + case true: + color.Print("This program can do this for you, if desired.\n") + default: + color.Print(`This program can do this for you, however your sudo password is required. +To avoid this: + + * press CTRL+C and perform the steps mentioned above manually. + OR + * run ` + os.Args[0] + ` as root. + +`) + } + + color.Print(` +If you want to go back and change something, choose Back. + +NOTE: A backup of the original files from the first run can be found in the backup folder +`) +} + func finalize(config *configs.Config) { // Clear the screen command.Clear() @@ -56,60 +93,43 @@ func finalize(config *configs.Config) { title := color.New(color.BgHiBlue, color.White, color.Bold) title.Println("Finalizing configuration") - color.Print( - "The configuration files have been generated and are\n", - "located inside the \"config\" folder\n", - "\n", - "* The \"kernel_args\" file contains kernel arguments that your bootloader needs\n", - "* The \"qemu\" folder contains files that might be\n neccessary for passing through the GPU\n", - "* The files inside the \"etc\" folder must be copied to your system.\n", - " NOTE: Verify that these files are correctly formated/edited!\n", - "* Once all files have been copied, you need to update your bootloader and rebuild\n", - " your initramfs using the tools to do so by your system.\n", - "\n", - "This program can do this for you, however the program will have to\n", - "type your password to sudo using STDIN, to avoid using STDIN press CTRL+C\n", - "and copy the files, update your bootloader and rebuild your initramfs manually.\n", - "If you want to go back and change something, choose Back\n", - "\nNOTE: A backup of the original files from the first run can be found in the backup folder\n", - ) + config.IsRoot = os.Getuid() == 0 - // Make a choice of going next or back - choice := menu.Next("Press Next to continue with sudo using STDIN, ESC to exit or Back to go back.") + finalizeNotice(config.IsRoot) - // Parse the choice - switch choice { + // Make a choice of going next or back and parse the choice + switch menu.Next("Press Next to continue with sudo using STDIN, ESC to exit or Back to go back.") { case "next": installPassthrough(config) - case "back": // Go back disableVideo(config) } - } func installPassthrough(config *configs.Config) { // Get the user data - user, err := user.Current() + currentUser, err := user.Current() if err != nil { log.Fatalf(err.Error()) } - // Provide a password prompt - fmt.Printf("[sudo] password for %s: ", user.Username) - bytep, err := term.ReadPassword(int(syscall.Stdin)) - if err != nil { - os.Exit(1) - } - fmt.Print("\n") + if !config.IsRoot { + // Provide a password prompt + fmt.Printf("[sudo] password for %s: ", currentUser.Username) + bytep, err := term.ReadPassword(syscall.Stdin) + if err != nil { + os.Exit(1) + } + fmt.Print("\n") - // Elevate with sudo - command.Elevate( - base64.StdEncoding.EncodeToString( - bytep, - ), - ) + // Elevate with sudo + command.Elevate( + base64.StdEncoding.EncodeToString( + bytep, + ), + ) + } // Make an output string var output string @@ -120,22 +140,24 @@ func installPassthrough(config *configs.Config) { logger.Printf("Configuring systemd-boot using kernelstub\n") // Configure kernelstub - output = configs.Set_KernelStub() - fmt.Printf("%s\n", output) + // callee logs the output and checks for errors + configs.Set_KernelStub(config.IsRoot) } else if config.Bootloader == "grubby" { // Write to logger logger.Printf("Configuring bootloader using grubby\n") // Configure kernelstub - output = configs.Set_Grubby() + output = configs.Set_Grubby(config.IsRoot) fmt.Printf("%s\n", output) } else if config.Bootloader == "grub2" { // Write to logger logger.Printf("Applying grub2 changes\n") - grub_output, _ := configs.Set_Grub2() - fmt.Printf("%s\n", strings.Join(grub_output, "\n")) + _ = configs.Set_Grub2(config.IsRoot) // note: we set config.IsRoot earlier + + // we'll print the output in the [configs.Set_Grub2] method + // fmt.Printf("%s\n", strings.Join(grub_output, "\n")) } else { kernel_args := fileio.ReadFile(config.Path.CMDLINE) @@ -145,68 +167,62 @@ func installPassthrough(config *configs.Config) { // A lot of linux systems support modprobe along with their own module system // So copy the modprobe files if we have them modprobeFile := fmt.Sprintf("%s/vfio.conf", config.Path.MODPROBE) - if fileio.FileExist(modprobeFile) { - // Copy initramfs-tools module to system - output = configs.CopyToSystem(modprobeFile, "/etc/modprobe.d/vfio.conf") - fmt.Printf("%s\n", output) + + // lets hope by now we've already handled any permissions issues... + // TODO: verify that we actually can drop the errors on [fileio.FileExist] call below + + if exists, _ := fileio.FileExist(modprobeFile); exists { + // Copy initramfs-tools module to system, note that CopyToSystem will log the command and output + // as well as check for errors + configs.CopyToSystem(config.IsRoot, modprobeFile, "/etc/modprobe.d/vfio.conf") } // Copy the config files for the system we have initramfsFile := fmt.Sprintf("%s/modules", config.Path.INITRAMFS) dracutFile := fmt.Sprintf("%s/vfio.conf", config.Path.DRACUT) - if fileio.FileExist(initramfsFile) { + + initramFsExists, initramFsErr := fileio.FileExist(initramfsFile) + dracutExists, dracutErr := fileio.FileExist(dracutFile) + mkinitcpioExists, mkinitcpioErr := fileio.FileExist(config.Path.MKINITCPIO) + + for _, err = range []error{initramFsErr, dracutErr, mkinitcpioErr} { + if err == nil { + continue + } + // we know this error isn't ErrNotExist, so we should throw it and exit + log.Fatalf("Failed to stat file: %s", err) + } + + switch { + case initramFsExists: // Copy initramfs-tools module to system - output = configs.CopyToSystem(initramfsFile, "/etc/initramfs-tools/modules") - fmt.Printf("%s\n", output) + configs.CopyToSystem(config.IsRoot, initramfsFile, "/etc/initramfs-tools/modules") // Copy the modules file to /etc/modules - output = configs.CopyToSystem(config.Path.ETCMODULES, "/etc/modules") - fmt.Printf("%s\n", output) + configs.CopyToSystem(config.IsRoot, config.Path.ETCMODULES, "/etc/modules") - // Write to logger - logger.Printf("Executing: sudo update-initramfs -u\n") + if err = command.ExecAndLogSudo(config.IsRoot, true, "update-initramfs", "-u"); err != nil { + log.Fatalf("Failed to update initramfs: %s", err) + } - // Update initramfs - fmt.Println("Executed: sudo update-initramfs -u\nSee debug.log for detailed output") - cmd_out, cmd_err, _ := command.RunErr("sudo", "update-initramfs", "-u") - - cmd_out = append(cmd_out, cmd_err...) - - // Write to logger - logger.Printf(strings.Join(cmd_out, "\n")) - } else if fileio.FileExist(dracutFile) { + case dracutExists: // Copy dracut config to /etc/dracut.conf.d/vfio - output = configs.CopyToSystem(dracutFile, "/etc/dracut.conf.d/vfio") - fmt.Printf("%s\n", output) + configs.CopyToSystem(config.IsRoot, dracutFile, "/etc/dracut.conf.d/vfio") // Get systeminfo sysinfo := uname.New() - // Write to logger - logger.Printf("Executing: sudo dracut -f -v --kver %s\n", sysinfo.Release) + if err = command.ExecAndLogSudo(config.IsRoot, true, "dracut", "-f", "-v", "--kver", sysinfo.Release); err != nil { + log.Fatalf("Failed to update initramfs: %s", err) + } - // Update initramfs - fmt.Printf("Executed: sudo dracut -f -v --kver %s\nSee debug.log for detailed output", sysinfo.Release) - _, cmd_err, _ := command.RunErr("sudo", "dracut", "-f", "-v", "--kver", sysinfo.Release) - - // Write to logger - logger.Printf(strings.Join(cmd_err, "\n")) - } else if fileio.FileExist(config.Path.MKINITCPIO) { + case mkinitcpioExists: // Copy dracut config to /etc/dracut.conf.d/vfio - output = configs.CopyToSystem(config.Path.MKINITCPIO, "/etc/mkinitcpio.conf") - fmt.Printf("%s\n", output) + configs.CopyToSystem(config.IsRoot, config.Path.MKINITCPIO, "/etc/mkinitcpio.conf") - // Write to logger - logger.Printf("Executing: sudo mkinitcpio -P") - - // Update initramfs - fmt.Println("Executed: sudo mkinitcpio -P\nSee debug.log for detailed output") - cmd_out, cmd_err, _ := command.RunErr("sudo", "mkinitcpio", "-P") - - cmd_out = append(cmd_out, cmd_err...) - - // Write to logger - logger.Printf(strings.Join(cmd_out, "\n")) + if err = command.ExecAndLogSudo(config.IsRoot, true, "mkinitcpio", "-P"); err != nil { + log.Fatalf("Failed to update initramfs: %s", err) + } } // Make sure prompt end up on next line diff --git a/internal/pages/06_finalize_test.go b/internal/pages/06_finalize_test.go new file mode 100644 index 0000000..df9a7d0 --- /dev/null +++ b/internal/pages/06_finalize_test.go @@ -0,0 +1,23 @@ +package pages + +import ( + "strings" + "testing" +) + +func TestFinalizeNotice(t *testing.T) { + msg := "\n%s\nprinting the finalize notice for manual review, this test should always pass.\n%s\n\n" + divider := strings.Repeat("-", len(msg)-12) + t.Logf(msg, divider, divider) + t.Log("\n\nWith isRoot == true:\n\n") + + finalizeNotice(true) + + println("\n\n") + + t.Log("\n\nWith isRoot == false:\n\n") + + finalizeNotice(false) + + println("\n\n") +} diff --git a/internal/ui_main.go b/internal/ui_main.go index 9b0db17..5bf650d 100644 --- a/internal/ui_main.go +++ b/internal/ui_main.go @@ -6,17 +6,18 @@ package internal import ( "os" - "github.com/HikariKnight/ls-iommu/pkg/errorcheck" - "github.com/HikariKnight/quickpassthrough/internal/pages" tea "github.com/charmbracelet/bubbletea" + + "github.com/HikariKnight/quickpassthrough/internal/common" + "github.com/HikariKnight/quickpassthrough/internal/pages" ) // This is where we build everything func Tui() { // Log all errors to a new logfile (super useful feature of BubbleTea!) - os.Remove("debug.log") - logfile, err := tea.LogToFile("debug.log", "") - errorcheck.ErrorCheck(err, "Error creating log file") + _ = os.Rename("quickpassthrough_debug.log", "quickpassthrough_debug_old.log") + logfile, err := tea.LogToFile("quickpassthrough_debug.log", "") + common.ErrorCheck(err, "Error creating log file") defer logfile.Close() // New WIP Tui diff --git a/pkg/command/command.go b/pkg/command/command.go index e7c6fca..e6ba915 100644 --- a/pkg/command/command.go +++ b/pkg/command/command.go @@ -3,12 +3,15 @@ package command import ( "bytes" "encoding/base64" + "fmt" "io" "os" "os/exec" + "strings" "time" - "github.com/HikariKnight/ls-iommu/pkg/errorcheck" + "github.com/HikariKnight/quickpassthrough/internal/common" + "github.com/HikariKnight/quickpassthrough/internal/logger" ) // Run a command and return STDOUT @@ -27,14 +30,14 @@ func Run(binary string, args ...string) ([]string, error) { output, _ := io.ReadAll(&stdout) // Get the output - outputs := []string{} + outputs := make([]string, 0, 1) outputs = append(outputs, string(output)) // Return our list of items return outputs, err } -// This function is just like command.Run() but also returns STDERR +// RunErr is just like command.Run() but also returns STDERR func RunErr(binary string, args ...string) ([]string, []string, error) { var stdout, stderr bytes.Buffer @@ -59,8 +62,18 @@ func RunErr(binary string, args ...string) ([]string, []string, error) { return outputs, outerrs, err } -// This functions runs the command "sudo -Sk -- echo", this forces sudo -// to re-authenticate and lets us enter the password to STDIN +func RunErrSudo(isRoot bool, binary string, args ...string) ([]string, []string, error) { + if !isRoot && binary != "sudo" { + args = append([]string{binary}, args...) + binary = "sudo" + } + logger.Printf("Executing (elevated): %s %s\n", binary, strings.Join(args, " ")) + fmt.Printf("Executing (elevated): %s %s\n", binary, strings.Join(args, " ")) + return RunErr(binary, args...) +} + +// Elevate elevates this functions runs the command "sudo -Sk -- echo", +// this forces sudo to re-authenticate and lets us enter the password to STDIN // giving us the ability to run sudo commands func Elevate(password string) { // Do a simple sudo command to just authenticate with sudo @@ -71,29 +84,83 @@ func Elevate(password string) { // Open STDIN stdin, err := cmd.StdinPipe() - errorcheck.ErrorCheck(err, "\nFailed to get sudo STDIN") + common.ErrorCheck(err, "\nFailed to get sudo STDIN") // Start the authentication - cmd.Start() + err = cmd.Start() + common.ErrorCheck(err, "\nFailed to start sudo command") // Get the passed password pw, _ := base64.StdEncoding.DecodeString(password) _, err = stdin.Write([]byte(string(pw) + "\n")) - errorcheck.ErrorCheck(err, "\nFailed at typing to STDIN") + common.ErrorCheck(err, "\nFailed at typing to STDIN") // Clear the password pw = nil password = "" - stdin.Close() + _ = stdin.Close() // Wait for the sudo prompt (If the correct password was given, it will not stay behind) err = cmd.Wait() - errorcheck.ErrorCheck(err, "\nError, password given was wrong") + common.ErrorCheck(err, "\nError, password given was wrong") } -// Function to just clear the terminal +// Clear clears the terminal. func Clear() { c := exec.Command("clear") c.Stdout = os.Stdout - c.Run() + _ = c.Run() +} + +// ExecAndLogSudo executes an elevated command and logs the output. +// +// * if we're root, the command is executed directly +// * if we're not root, the command is prefixed with "sudo" +// +// - noisy determines if we should print the command to the user +// noisy isn't set to true by our copy caller, as it logs differently, +// but other callers set it. +func ExecAndLogSudo(isRoot, noisy bool, exe string, args ...string) error { + if !isRoot && exe != "sudo" { + og := exe + exe = "sudo" + newArgs := make([]string, 0) + newArgs = append(newArgs, og) + newArgs = append(newArgs, args...) + args = newArgs + } + + // Write to logger + logger.Printf("Executing (elevated): %s %s\n", exe, strings.Join(args, " ")) + + if noisy { + // Print to the user + fmt.Printf("Executing (elevated): %s %s\nSee debug.log for detailed output\n", exe, strings.Join(args, " ")) + } + + wd, err := os.Getwd() + if err != nil { + return err + } + + r := exec.Command(exe, args...) + r.Dir = wd + + cmdCombinedOut, err := r.CombinedOutput() + outStr := string(cmdCombinedOut) + + // Write to logger, tabulate output + // tabulation denotes it's hierarchy as a child of the command + outStr = strings.ReplaceAll(outStr, "\n", "\n\t") + logger.Printf("\t" + string(cmdCombinedOut) + "\n") + if noisy { + // Print to the user + fmt.Printf("%s\n", outStr) + } + + if err != nil { + err = fmt.Errorf("failed to execute %s: %w\n%s", exe, err, outStr) + } + + return err } diff --git a/pkg/command/command_test.go b/pkg/command/command_test.go new file mode 100644 index 0000000..f6fa119 --- /dev/null +++ b/pkg/command/command_test.go @@ -0,0 +1,61 @@ +package command + +import ( + "os" + "path/filepath" + "strings" + "testing" +) + +const fakeSudo = `#!/bin/sh +"$@" -qptest` + +const fakeUtil = `#!/bin/sh +echo "$@" +if [ "$4" = "-qptest" ]; then exit 0; else exit 1; fi` + +func setupExecTestEnv(t *testing.T) (string, string) { + t.Helper() + + tmpDir := t.TempDir() + fakeSudoPath := filepath.Join(tmpDir, "sudo") + fakeUtilPath := filepath.Join(tmpDir, "util") + + if err := os.WriteFile(fakeSudoPath, []byte(fakeSudo), 0755); err != nil { + t.Fatalf("failed to write fake sudo stub: %s", err.Error()) + } + if err := os.WriteFile(fakeUtilPath, []byte(fakeUtil), 0755); err != nil { + t.Fatalf("failed to write fake util stub: %s", err.Error()) + } + t.Setenv("PATH", tmpDir+":"+os.Getenv("PATH")) + + return fakeSudoPath, fakeUtilPath +} + +func TestExecAndLogSudo(t *testing.T) { + _, fakeUtilPath := setupExecTestEnv(t) + + args := []string{"i am a string with spaces", "i came to ruin parsers and chew bubble gum", "and I'm all out of bubblegum."} + + t.Run("is_not_root", func(t *testing.T) { + if err := ExecAndLogSudo(false, false, "util", args...); err != nil { + t.Errorf("unexpected error: %s", err.Error()) + } + }) + + t.Run("is_root", func(t *testing.T) { + newFakeUtil := strings.Replace(fakeUtil, "exit 1", "exit 0", 1) + newFakeUtil = strings.Replace(newFakeUtil, "exit 0", "exit 1", 1) + if err := os.WriteFile(fakeUtilPath, []byte(newFakeUtil), 0755); err != nil { + t.Fatalf("failed to overwrite fake util with modified stub: %s", err.Error()) + } + if err := ExecAndLogSudo(false, false, "util", args...); err == nil { + t.Errorf("expected error when using modified util with sudo, got nil") + } + + if err := ExecAndLogSudo(true, true, "util", args...); err != nil { + t.Errorf("unexpected error: %s", err.Error()) + } + }) + +} diff --git a/pkg/fileio/fileio.go b/pkg/fileio/fileio.go index 38c5a7a..b727ede 100644 --- a/pkg/fileio/fileio.go +++ b/pkg/fileio/fileio.go @@ -7,29 +7,31 @@ import ( "io" "os" - "github.com/HikariKnight/ls-iommu/pkg/errorcheck" + "github.com/HikariKnight/quickpassthrough/internal/common" ) /* * This just implements repetetive tasks I have to do with files */ -// Creates a file and appends the content to the file (ending newline must be supplied with content string) +// AppendContent creates a file and appends the content to the file. +// (ending newline must be supplied with content string) func AppendContent(content string, fileName string) { // Open the file f, err := os.OpenFile(fileName, os.O_CREATE|os.O_APPEND|os.O_WRONLY, os.ModePerm) - errorcheck.ErrorCheck(err, fmt.Sprintf("Error opening \"%s\" for writing", fileName)) + + common.ErrorCheck(err, fmt.Sprintf("Error opening \"%s\" for writing", fileName)) defer f.Close() // Write the content _, err = f.WriteString(content) - errorcheck.ErrorCheck(err, fmt.Sprintf("Error writing to %s", fileName)) + common.ErrorCheck(err, fmt.Sprintf("Error writing to %s", fileName)) } -// Reads the file and returns a stringlist with each line +// ReadLines reads the file and returns a stringlist with each line. func ReadLines(fileName string) []string { content, err := os.Open(fileName) - errorcheck.ErrorCheck(err, fmt.Sprintf("Error reading file %s", fileName)) + common.ErrorCheck(err, fmt.Sprintf("Error reading file %s", fileName)) defer content.Close() // Make a list of lines @@ -46,54 +48,55 @@ func ReadLines(fileName string) []string { } -// Reads a file and returns all the content as a string +// ReadFile reads a file and returns all the content as a string. func ReadFile(fileName string) string { // Read the whole file content, err := os.ReadFile(fileName) - errorcheck.ErrorCheck(err, fmt.Sprintf("Failed to ReadFile on %s", fileName)) + common.ErrorCheck(err, fmt.Sprintf("Failed to ReadFile on %s", fileName)) // Return all the lines as one string return string(content) } -// Checks if a file exists and returns a bool -func FileExist(fileName string) bool { +// FileExist checks if a file exists and returns a bool and any error that isn't os.ErrNotExist. +func FileExist(fileName string) (bool, error) { var exist bool // Check if the file exists - if _, err := os.Stat(fileName); !errors.Is(err, os.ErrNotExist) { - // Set the value to true + _, err := os.Stat(fileName) + switch { + case err == nil: exist = true - } else { - // Set the value to false + case errors.Is(err, os.ErrNotExist): + // Set the value to true exist = false + err = nil } - // Return if the file exists - return exist + return exist, err } -// Copies a FILE from source to dest +// FileCopy copies a FILE from source to dest. func FileCopy(sourceFile, destFile string) { // Get the file info filestat, err := os.Stat(sourceFile) - errorcheck.ErrorCheck(err, "Error getting fileinfo of: %s", sourceFile) + common.ErrorCheck(err, "Error getting fileinfo of: %s", sourceFile) // If the file is a regular file if filestat.Mode().IsRegular() { // Open the source file for reading source, err := os.Open(sourceFile) - errorcheck.ErrorCheck(err, "Error opening %s for copying", sourceFile) + common.ErrorCheck(err, "Error opening %s for copying", sourceFile) defer source.Close() // Create the destination file dest, err := os.Create(destFile) - errorcheck.ErrorCheck(err, "Error creating %s", destFile) + common.ErrorCheck(err, "Error creating %s", destFile) defer dest.Close() // Copy the contents of source to dest using io _, err = io.Copy(dest, source) - errorcheck.ErrorCheck(err, "Failed to copy \"%s\" to \"%s\"", sourceFile, destFile) + common.ErrorCheck(err, "Failed to copy \"%s\" to \"%s\"", sourceFile, destFile) } } diff --git a/pkg/menu/manual.go b/pkg/menu/manual.go index 5e0b06c..5db9da7 100644 --- a/pkg/menu/manual.go +++ b/pkg/menu/manual.go @@ -4,8 +4,9 @@ import ( "fmt" "strings" - "github.com/HikariKnight/ls-iommu/pkg/errorcheck" "github.com/gookit/color" + + "github.com/HikariKnight/quickpassthrough/internal/common" ) func ManualInput(msg string, format string) []string { @@ -18,7 +19,7 @@ func ManualInput(msg string, format string) []string { // Get the user input var input string _, err := fmt.Scan(&input) - errorcheck.ErrorCheck(err) + common.ErrorCheck(err) input_list := strings.Split(input, ",") diff --git a/pkg/untar/untar.go b/pkg/untar/untar.go index 96ab454..f94b399 100644 --- a/pkg/untar/untar.go +++ b/pkg/untar/untar.go @@ -8,7 +8,7 @@ import ( "os" "path/filepath" - "github.com/HikariKnight/ls-iommu/pkg/errorcheck" + "github.com/HikariKnight/quickpassthrough/internal/common" ) // Slightly modified from source: https://medium.com/@skdomino/taring-untaring-files-in-go-6b07cf56bc07 @@ -17,7 +17,7 @@ import ( // creating the file structure at 'dst' along the way, and writing any files func Untar(dst string, fileName string) error { r, err := os.Open(fileName) - errorcheck.ErrorCheck(err, fmt.Sprintf("Failed to open: %s", fileName)) + common.ErrorCheck(err, fmt.Sprintf("Failed to open: %s", fileName)) defer r.Close() gzr, err := gzip.NewReader(r)