diff --git a/packet.go b/packet.go index 3a7c4b3d8..adf91ab46 100644 --- a/packet.go +++ b/packet.go @@ -753,8 +753,8 @@ type ZeroCopyPacketDataSource interface { // This method is the most convenient and easiest to code, but lacks // flexibility. Packets returns a 'chan Packet', then asynchronously writes // packets into that channel. Packets uses a blocking channel, and closes -// it if an io.EOF is returned by the underlying PacketDataSource. All other -// PacketDataSource errors are ignored and discarded. +// it if a non-temporary error is returned by the underlying PacketDataSource. +// Temporary PacketDataSource errors are retried. // for packet := range packetSource.Packets() { // ... // } @@ -821,34 +821,78 @@ func (p *PacketSource) packetsToChannel() { continue } - // Immediately retry for temporary network errors - if nerr, ok := err.(net.Error); ok && nerr.Temporary() { + // Immediately retry for temporary network errors. + if isTemporaryError(err) { continue } - // Immediately retry for EAGAIN - if err == syscall.EAGAIN { + // Immediately retry for EAGAIN. + if errorIs(err, syscall.EAGAIN) { continue } - // Immediately break for known unrecoverable errors - if err == io.EOF || err == io.ErrUnexpectedEOF || - err == io.ErrNoProgress || err == io.ErrClosedPipe || err == io.ErrShortBuffer || - err == syscall.EBADF || + // Immediately break for known unrecoverable errors. + if errorIs(err, io.EOF) || errorIs(err, io.ErrUnexpectedEOF) || + errorIs(err, io.ErrNoProgress) || errorIs(err, io.ErrClosedPipe) || + errorIs(err, io.ErrShortBuffer) || errorIs(err, syscall.EBADF) || strings.Contains(err.Error(), "use of closed file") { break } - // Sleep briefly and try again - time.Sleep(time.Millisecond * time.Duration(5)) + // Unknown non-temporary errors are unrecoverable for this reader; close + // the channel instead of retrying forever and starving the caller. + break + } +} + +type errorUnwrapper interface { + Unwrap() error +} + +func unwrapError(err error) error { + if u, ok := err.(errorUnwrapper); ok { + return u.Unwrap() + } + return nil +} + +func errorIs(err, target error) bool { + if target == nil { + return err == nil + } + for err != nil { + if errorsEqual(err, target) { + return true + } + err = unwrapError(err) + } + return false +} + +func errorsEqual(err, target error) (ok bool) { + defer func() { + if recover() != nil { + ok = false + } + }() + return err == target +} + +func isTemporaryError(err error) bool { + for err != nil { + if nerr, ok := err.(net.Error); ok && nerr.Temporary() { + return true + } + err = unwrapError(err) } + return false } // Packets returns a channel of packets, allowing easy iterating over // packets. Packets will be asynchronously read in from the underlying // PacketDataSource and written to the returned channel. If the underlying -// PacketDataSource returns an io.EOF error, the channel will be closed. -// If any other error is encountered, it is ignored. +// PacketDataSource returns a non-temporary error, including io.EOF, the +// channel will be closed. // // for packet := range packetSource.Packets() { // handlePacket(packet) // Do something with each packet. diff --git a/packet_test.go b/packet_test.go index bd0544888..06fa26e77 100644 --- a/packet_test.go +++ b/packet_test.go @@ -7,9 +7,11 @@ package gopacket import ( + "errors" "io" "reflect" "testing" + "time" ) type embedded struct { @@ -39,6 +41,28 @@ func (s *singlePacketSource) ReadPacketData() ([]byte, CaptureInfo, error) { return out, CaptureInfo{}, nil } +func TestPacketSourcePacketsClosesOnNonTemporaryError(t *testing.T) { + source := errorPacketSource{errors.New("fatal read error")} + packetSource := NewPacketSource(source, DecodePayload) + + select { + case packet, ok := <-packetSource.Packets(): + if ok { + t.Fatalf("Packets returned unexpected packet for non-temporary error: %v", packet) + } + case <-time.After(250 * time.Millisecond): + t.Fatal("Packets did not close after a non-temporary read error") + } +} + +type errorPacketSource struct { + err error +} + +func (s errorPacketSource) ReadPacketData() ([]byte, CaptureInfo, error) { + return nil, CaptureInfo{}, s.err +} + func TestConcatPacketSources(t *testing.T) { sourceA := &singlePacketSource{[]byte{1}} sourceB := &singlePacketSource{[]byte{2}}