163 lines
3.4 KiB
Go
163 lines
3.4 KiB
Go
|
package eventstream
|
||
|
|
||
|
import (
|
||
|
"bytes"
|
||
|
"encoding/binary"
|
||
|
"encoding/hex"
|
||
|
"encoding/json"
|
||
|
"fmt"
|
||
|
"hash"
|
||
|
"hash/crc32"
|
||
|
"io"
|
||
|
|
||
|
"github.com/aws/aws-sdk-go/aws"
|
||
|
)
|
||
|
|
||
|
// Encoder provides EventStream message encoding.
|
||
|
type Encoder struct {
|
||
|
w io.Writer
|
||
|
logger aws.Logger
|
||
|
|
||
|
headersBuf *bytes.Buffer
|
||
|
}
|
||
|
|
||
|
// NewEncoder initializes and returns an Encoder to encode Event Stream
|
||
|
// messages to an io.Writer.
|
||
|
func NewEncoder(w io.Writer, opts ...func(*Encoder)) *Encoder {
|
||
|
e := &Encoder{
|
||
|
w: w,
|
||
|
headersBuf: bytes.NewBuffer(nil),
|
||
|
}
|
||
|
|
||
|
for _, opt := range opts {
|
||
|
opt(e)
|
||
|
}
|
||
|
|
||
|
return e
|
||
|
}
|
||
|
|
||
|
// EncodeWithLogger adds a logger to be used by the encode when decoding
|
||
|
// stream events.
|
||
|
func EncodeWithLogger(logger aws.Logger) func(*Encoder) {
|
||
|
return func(d *Encoder) {
|
||
|
d.logger = logger
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Encode encodes a single EventStream message to the io.Writer the Encoder
|
||
|
// was created with. An error is returned if writing the message fails.
|
||
|
func (e *Encoder) Encode(msg Message) (err error) {
|
||
|
e.headersBuf.Reset()
|
||
|
|
||
|
writer := e.w
|
||
|
if e.logger != nil {
|
||
|
encodeMsgBuf := bytes.NewBuffer(nil)
|
||
|
writer = io.MultiWriter(writer, encodeMsgBuf)
|
||
|
defer func() {
|
||
|
logMessageEncode(e.logger, encodeMsgBuf, msg, err)
|
||
|
}()
|
||
|
}
|
||
|
|
||
|
if err = EncodeHeaders(e.headersBuf, msg.Headers); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
crc := crc32.New(crc32IEEETable)
|
||
|
hashWriter := io.MultiWriter(writer, crc)
|
||
|
|
||
|
headersLen := uint32(e.headersBuf.Len())
|
||
|
payloadLen := uint32(len(msg.Payload))
|
||
|
|
||
|
if err = encodePrelude(hashWriter, crc, headersLen, payloadLen); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
if headersLen > 0 {
|
||
|
if _, err = io.Copy(hashWriter, e.headersBuf); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if payloadLen > 0 {
|
||
|
if _, err = hashWriter.Write(msg.Payload); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
}
|
||
|
|
||
|
msgCRC := crc.Sum32()
|
||
|
return binary.Write(writer, binary.BigEndian, msgCRC)
|
||
|
}
|
||
|
|
||
|
func logMessageEncode(logger aws.Logger, msgBuf *bytes.Buffer, msg Message, encodeErr error) {
|
||
|
w := bytes.NewBuffer(nil)
|
||
|
defer func() { logger.Log(w.String()) }()
|
||
|
|
||
|
fmt.Fprintf(w, "Message to encode:\n")
|
||
|
encoder := json.NewEncoder(w)
|
||
|
if err := encoder.Encode(msg); err != nil {
|
||
|
fmt.Fprintf(w, "Failed to get encoded message, %v\n", err)
|
||
|
}
|
||
|
|
||
|
if encodeErr != nil {
|
||
|
fmt.Fprintf(w, "Encode error: %v\n", encodeErr)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
fmt.Fprintf(w, "Raw message:\n%s\n", hex.Dump(msgBuf.Bytes()))
|
||
|
}
|
||
|
|
||
|
func encodePrelude(w io.Writer, crc hash.Hash32, headersLen, payloadLen uint32) error {
|
||
|
p := messagePrelude{
|
||
|
Length: minMsgLen + headersLen + payloadLen,
|
||
|
HeadersLen: headersLen,
|
||
|
}
|
||
|
if err := p.ValidateLens(); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
err := binaryWriteFields(w, binary.BigEndian,
|
||
|
p.Length,
|
||
|
p.HeadersLen,
|
||
|
)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
p.PreludeCRC = crc.Sum32()
|
||
|
err = binary.Write(w, binary.BigEndian, p.PreludeCRC)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// EncodeHeaders writes the header values to the writer encoded in the event
|
||
|
// stream format. Returns an error if a header fails to encode.
|
||
|
func EncodeHeaders(w io.Writer, headers Headers) error {
|
||
|
for _, h := range headers {
|
||
|
hn := headerName{
|
||
|
Len: uint8(len(h.Name)),
|
||
|
}
|
||
|
copy(hn.Name[:hn.Len], h.Name)
|
||
|
if err := hn.encode(w); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
if err := h.Value.encode(w); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func binaryWriteFields(w io.Writer, order binary.ByteOrder, vs ...interface{}) error {
|
||
|
for _, v := range vs {
|
||
|
if err := binary.Write(w, order, v); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
}
|
||
|
return nil
|
||
|
}
|