package rabaead import ( "bytes" "crypto/cipher" "encoding/binary" "errors" "io" "snix.ir/rabbitio" ) const cmrs = 0x02 // chunk size indicator, // without this reader cannot calculate actual size of plaintext // additional data func, return value is used as AD in Seal and Open // nil AdFunc is harmless and equal to func()[]byte{return nil} type AdditionalFunc func() []byte type chunkReader struct { aead cipher.AEAD csize int rader io.Reader buff []byte nonce []byte adexe AdditionalFunc } type chunkWriter struct { aead cipher.AEAD csize int writer io.Writer buff []byte nonce []byte adexe AdditionalFunc } // NewChunkReader returns a chunkReader data type, this reader reads and open() aead // ciphertext, each chunk has its own tag and cmrk value. // this reader has a chunk size in-memory buffer, large chunk size can make application to runs // out of memory, thus is most suitable for sliced data, like network data transmit and so.. func NewChunkReader(r io.Reader, chnk int, a cipher.AEAD, nonce []byte, f AdditionalFunc) (*chunkReader, error) { if len(nonce) != rabbitio.IVXLen && len(nonce) != 0 { return nil, rabbitio.ErrInvalidIVX } if chnk > int(^uint16(0)) || chnk <= 0 { return nil, errors.New("rabaead: bad chunk size") } s := &chunkReader{ aead: a, buff: []byte{}, nonce: make([]byte, len(nonce)), csize: chnk, rader: r, adexe: f, } if s.adexe == nil { s.adexe = func() []byte { return nil } } copy(s.nonce, nonce) return s, nil } // NewChunkWriter returns a chunkWriter data type, this writer sale() and write aead // plaintext, each chunk has its own tag and cmrk value. // this writer has a chunk size in-memory buffer, large chunk size can make application to // runs out of memory, thus is most suitable for sliced data, like network data transmit and so.. func NewChunkWriter(w io.Writer, chnk int, a cipher.AEAD, nonce []byte, f AdditionalFunc) (*chunkWriter, error) { if len(nonce) != rabbitio.IVXLen && len(nonce) != 0 { return nil, rabbitio.ErrInvalidIVX } if chnk > int(^uint16(0)) || chnk <= 0 { return nil, errors.New("rabaead: bad chunk size") } s := &chunkWriter{ aead: a, buff: []byte{}, nonce: make([]byte, len(nonce)), csize: chnk, writer: w, adexe: f, } if s.adexe == nil { s.adexe = func() []byte { return nil } } copy(s.nonce, nonce) return s, nil } // Close method, if there is any func (w *chunkWriter) Close() error { if c, ok := w.writer.(io.Closer); ok { return c.Close() } return nil } // Write writes plaintext chunk into the sale() and underlying writer // write would not report overhead data (chunk size marker and poly1305 tag) in // written return value. for each chunk there is 2+16 byte overhead data. // AdFunc will be triggered for each chunk of data func (w *chunkWriter) Write(b []byte) (n int, err error) { w.buff = b for len(w.buff) > 0 { s, err := w.write() if err != nil { return n, err } n += s } return } func (w *chunkWriter) write() (int, error) { size := cmrs + w.csize + w.aead.Overhead() chnk := make([]byte, size) var n int var err error if len(w.buff) > 0 { s := copy(chnk[cmrs:len(chnk)-w.aead.Overhead()], w.buff) w.buff = w.buff[s:] copy(chnk[0:cmrs], uint16Little(uint16(s))) w.aead.Seal(chnk[:0], w.nonce, chnk[:cmrs+w.csize], w.adexe()) _, err = w.writer.Write(chnk) if err != nil { return n, err } n += s } return n, err } // Read reads and open() ciphertext chunk from underlying reader // read would not report overhead data (chunk size marker and poly1305 tag) in its // return value. if the read data from underlying reader is corrupted, ErrAuthMsg // error will be returned. for each chunk there is 2+16 byte overhead data. // AdFunc will be triggered for each chunk of data func (r *chunkReader) Read(b []byte) (int, error) { if len(b) <= r.csize { return r.readTo(b) } n := 0 for { if n+r.csize > len(b) { sr, err := r.readTo(b[n:]) n += sr if err != nil { return n, err } break } sr, err := r.readTo(b[n : n+r.csize]) n += sr if err != nil { return n, err } } return n, nil } func (r *chunkReader) readTo(b []byte) (int, error) { var n int if len(r.buff) > 0 { n = copy(b, r.buff) r.buff = r.buff[n:] return n, nil } sr, err := r.read() n = copy(b, r.buff[:sr]) r.buff = r.buff[n:] return n, err } func (r *chunkReader) read() (int, error) { var n int size := cmrs + r.csize + r.aead.Overhead() chnk := make([]byte, size) chLE := uint16Little(uint16(r.csize)) si, err := io.ReadFull(r.rader, chnk) if err != nil { return n, err } if si > 0 { _, err = r.aead.Open(chnk[:0], r.nonce, chnk, r.adexe()) if err != nil { return n, err } if bytes.Equal(chnk[0:cmrs], chLE) { n += r.csize r.buff = append(r.buff, chnk[cmrs:cmrs+r.csize]...) } else { f := binary.LittleEndian.Uint16(chnk[0:cmrs]) n += int(f) r.buff = append(r.buff, chnk[cmrs:cmrs+f]...) } } return n, err } func uint16Little(n uint16) []byte { b := make([]byte, cmrs) binary.LittleEndian.PutUint16(b, n) return b }