diff --git a/main.go b/main.go index 4f40262..347cb81 100644 --- a/main.go +++ b/main.go @@ -34,7 +34,7 @@ var ( ) func main() { - if trace := os.Getenv(backend.EnvTrace); trace == "true" { + if trace := os.Getenv(backend.TraceCommandEnv); trace == "true" { flag.Parse() if flag.NArg() < 2 { @@ -50,15 +50,17 @@ func main() { if err != nil { // GoPacket doesn't export the permission error, so we need to compare error strings if strings.HasSuffix(err.Error(), "(socket: Operation not permitted)") { - fmt.Println(errors.Join(errors.New("could not start capturing, capture permission denied"), err)) - - os.Exit(backend.PermissionDeniedExitCode) + fmt.Print(backend.TraceCommandHandshakeHandlePermissionDenied) } else { - panic(err) + fmt.Print(backend.TraceCommandHandshakeHandleUnexpectedError) } + + panic(err) } defer handle.Close() + fmt.Print(backend.TraceCommandHandshakeHandleAcquired) + source := gopacket.NewPacketSource(handle, handle.LinkType()) encoder := json.NewEncoder(os.Stdout) diff --git a/pkg/backend/server.go b/pkg/backend/server.go index 334837b..0179ebf 100644 --- a/pkg/backend/server.go +++ b/pkg/backend/server.go @@ -35,15 +35,20 @@ import ( ) var ( - ErrDatabaseNotInArchive = errors.New("database not in archive") + ErrDatabaseNotInArchive = errors.New("database not in archive") + ErrUnexpectedErrorWhileStartingTraceCommand = errors.New("unexpected error while starting trace command") ) const ( flatpakSpawnCmd = "flatpak-spawn" - EnvTrace = "CONNMAPPER_TRACE" + TraceCommandEnv = "CONNMAPPER_TRACE" - PermissionDeniedExitCode = 13 + traceCommandHandshakeLen = 2 + + TraceCommandHandshakeHandleAcquired = "AQ" + TraceCommandHandshakeHandlePermissionDenied = "PD" + TraceCommandHandshakeHandleUnexpectedError = "UE" ) func lookupLocation(db *geoip2.Reader, ip net.IP) ( @@ -385,7 +390,7 @@ func (l *local) TraceDevice(ctx context.Context, device Device) error { } cmd := exec.Command(os.Args[0], device.PcapName, fmt.Sprintf("%v", device.MTU)) - cmd.Env = append(cmd.Env, EnvTrace+"=true") + cmd.Env = append(cmd.Env, TraceCommandEnv+"=true") stdout, err := cmd.StdoutPipe() if err != nil { @@ -396,10 +401,46 @@ func (l *local) TraceDevice(ctx context.Context, device Device) error { return err } + handshake := make([]byte, traceCommandHandshakeLen) + if _, err := stdout.Read(handshake); err != nil { + if cmd.Process != nil { + _ = cmd.Process.Kill() + + _ = cmd.Wait() + } + + return err + } + + switch string(handshake) { + case TraceCommandHandshakeHandleAcquired: + break + + case TraceCommandHandshakeHandlePermissionDenied: + if cmd.Process != nil { + _ = cmd.Process.Kill() + + _ = cmd.Wait() + } + + return l.RestartApp(ctx, true) + + default: + if cmd.Process != nil { + _ = cmd.Process.Kill() + + _ = cmd.Wait() + } + + return errors.Join(ErrUnexpectedErrorWhileStartingTraceCommand, errors.New(string(handshake))) + } + go func() { defer func() { if cmd.Process != nil { _ = cmd.Process.Kill() + + _ = cmd.Wait() } _ = db.Close() @@ -412,21 +453,8 @@ func (l *local) TraceDevice(ctx context.Context, device Device) error { for { var rawPacket Packet if err := decoder.Decode(&rawPacket); err != nil { - if cmd.Process != nil { - _ = cmd.Process.Kill() - - _ = cmd.Wait() - } - - if cmd.ProcessState.ExitCode() == PermissionDeniedExitCode { - if err := l.RestartApp(ctx, true); err != nil { - log.Println("Could not restart app:", err) - - return - } - } - log.Println("Could not capture:", err) + log.Println("Could not continue capturing:", err) return }