Blob Blame History Raw
package eventstream

import (
	"bytes"
	"encoding/binary"
	"encoding/hex"
	"encoding/json"
	"fmt"
	"hash"
	"hash/crc32"
	"io"

	"github.com/aws/aws-sdk-go/aws"
)

// Decoder provides decoding of an Event Stream messages.
type Decoder struct {
	r      io.Reader
	logger aws.Logger
}

// NewDecoder initializes and returns a Decoder for decoding event
// stream messages from the reader provided.
func NewDecoder(r io.Reader) *Decoder {
	return &Decoder{
		r: r,
	}
}

// Decode attempts to decode a single message from the event stream reader.
// Will return the event stream message, or error if Decode fails to read
// the message from the stream.
func (d *Decoder) Decode(payloadBuf []byte) (m Message, err error) {
	reader := d.r
	if d.logger != nil {
		debugMsgBuf := bytes.NewBuffer(nil)
		reader = io.TeeReader(reader, debugMsgBuf)
		defer func() {
			logMessageDecode(d.logger, debugMsgBuf, m, err)
		}()
	}

	crc := crc32.New(crc32IEEETable)
	hashReader := io.TeeReader(reader, crc)

	prelude, err := decodePrelude(hashReader, crc)
	if err != nil {
		return Message{}, err
	}

	if prelude.HeadersLen > 0 {
		lr := io.LimitReader(hashReader, int64(prelude.HeadersLen))
		m.Headers, err = decodeHeaders(lr)
		if err != nil {
			return Message{}, err
		}
	}

	if payloadLen := prelude.PayloadLen(); payloadLen > 0 {
		buf, err := decodePayload(payloadBuf, io.LimitReader(hashReader, int64(payloadLen)))
		if err != nil {
			return Message{}, err
		}
		m.Payload = buf
	}

	msgCRC := crc.Sum32()
	if err := validateCRC(reader, msgCRC); err != nil {
		return Message{}, err
	}

	return m, nil
}

// UseLogger specifies the Logger that that the decoder should use to log the
// message decode to.
func (d *Decoder) UseLogger(logger aws.Logger) {
	d.logger = logger
}

func logMessageDecode(logger aws.Logger, msgBuf *bytes.Buffer, msg Message, decodeErr error) {
	w := bytes.NewBuffer(nil)
	defer func() { logger.Log(w.String()) }()

	fmt.Fprintf(w, "Raw message:\n%s\n",
		hex.Dump(msgBuf.Bytes()))

	if decodeErr != nil {
		fmt.Fprintf(w, "Decode error: %v\n", decodeErr)
		return
	}

	rawMsg, err := msg.rawMessage()
	if err != nil {
		fmt.Fprintf(w, "failed to create raw message, %v\n", err)
		return
	}

	decodedMsg := decodedMessage{
		rawMessage: rawMsg,
		Headers:    decodedHeaders(msg.Headers),
	}

	fmt.Fprintf(w, "Decoded message:\n")
	encoder := json.NewEncoder(w)
	if err := encoder.Encode(decodedMsg); err != nil {
		fmt.Fprintf(w, "failed to generate decoded message, %v\n", err)
	}
}

func decodePrelude(r io.Reader, crc hash.Hash32) (messagePrelude, error) {
	var p messagePrelude

	var err error
	p.Length, err = decodeUint32(r)
	if err != nil {
		return messagePrelude{}, err
	}

	p.HeadersLen, err = decodeUint32(r)
	if err != nil {
		return messagePrelude{}, err
	}

	if err := p.ValidateLens(); err != nil {
		return messagePrelude{}, err
	}

	preludeCRC := crc.Sum32()
	if err := validateCRC(r, preludeCRC); err != nil {
		return messagePrelude{}, err
	}

	p.PreludeCRC = preludeCRC

	return p, nil
}

func decodePayload(buf []byte, r io.Reader) ([]byte, error) {
	w := bytes.NewBuffer(buf[0:0])

	_, err := io.Copy(w, r)
	return w.Bytes(), err
}

func decodeUint8(r io.Reader) (uint8, error) {
	type byteReader interface {
		ReadByte() (byte, error)
	}

	if br, ok := r.(byteReader); ok {
		v, err := br.ReadByte()
		return uint8(v), err
	}

	var b [1]byte
	_, err := io.ReadFull(r, b[:])
	return uint8(b[0]), err
}
func decodeUint16(r io.Reader) (uint16, error) {
	var b [2]byte
	bs := b[:]
	_, err := io.ReadFull(r, bs)
	if err != nil {
		return 0, err
	}
	return binary.BigEndian.Uint16(bs), nil
}
func decodeUint32(r io.Reader) (uint32, error) {
	var b [4]byte
	bs := b[:]
	_, err := io.ReadFull(r, bs)
	if err != nil {
		return 0, err
	}
	return binary.BigEndian.Uint32(bs), nil
}
func decodeUint64(r io.Reader) (uint64, error) {
	var b [8]byte
	bs := b[:]
	_, err := io.ReadFull(r, bs)
	if err != nil {
		return 0, err
	}
	return binary.BigEndian.Uint64(bs), nil
}

func validateCRC(r io.Reader, expect uint32) error {
	msgCRC, err := decodeUint32(r)
	if err != nil {
		return err
	}

	if msgCRC != expect {
		return ChecksumError{}
	}

	return nil
}