aboutsummaryrefslogblamecommitdiff
path: root/ioaead.go
blob: 773f6415c31addf46d8ffe26d9ffccbb6c7f1f94 (plain) (tree)

















































































































































                                                                                                             




















                                                  




































































































                                                                          
package rabaead

import (
	"crypto/cipher"
	"errors"
	"io"

	"snix.ir/poly1305"
	"snix.ir/rabbitio"
)

type streamReader struct {
	ie        *ioaead
	cip       cipher.Stream
	firstRead bool
	nwr       int
	read      io.Reader
	buff      []byte
	temp      []byte
	tagc      []byte
}

type ioaead struct {
	key            []byte // rabbit cipher key
	nonce          []byte
	poly           *poly1305.MAC
	adlen          int
	additionalData AdditionalFunc
}

type streamWriter struct {
	ie          *ioaead
	writer      io.Writer
	plainWriter io.Writer
	nwr         int
	firstWrite  bool
}

var errunderio = errors.New("underlying io reader returns wrong read value, which is not supposed to happen")

func makeioaead(key, iv []byte, adfunc AdditionalFunc) *ioaead {
	if adfunc == nil {
		adfunc = func() []byte { return nil }
	}
	str := &ioaead{
		key:   make([]byte, rabbitio.KeyLen),
		nonce: make([]byte, len(iv)),
	}

	str.additionalData = adfunc
	copy(str.key, key)
	copy(str.nonce, iv)
	var poly [polykeylen]byte
	cph, _ := rabbitio.NewCipher(str.key, str.nonce)
	cph.XORKeyStream(poly[:], poly[:])
	str.poly = poly1305.New(&poly)
	return str
}

func (s *ioaead) execAdFunc() {
	additionalData := s.additionalData()
	s.adlen = len(additionalData)
	writePadding(s.poly, additionalData)
}

func newCipherReader(r io.Reader, key, nonce []byte, f AdditionalFunc) (*streamReader, error) {
	if len(key) != rabbitio.KeyLen {
		return nil, rabbitio.ErrInvalidKey
	}

	if len(nonce) != rabbitio.IVXLen && len(nonce) != 0 {
		return nil, rabbitio.ErrInvalidIVX
	}

	v := &streamReader{
		ie:   makeioaead(key, nonce, f),
		read: r,
		buff: []byte{},
		tagc: make([]byte, 16),
		temp: make([]byte, 16),
	}

	v.cip, _ = rabbitio.NewCipher(v.ie.key, v.ie.nonce)
	return v, nil
}

// NewStreamReader returns streamReader data type, this reader open() and read aead
// ciphertext which have 16-byte poly1305 tag overhead.
// read data cannot be authenticated until underlying reader returns EOF
// so you should use this reader only if you can undo your read.
// AdFunc will be triggered at first call to read method
func NewStreamReader(r io.Reader, key, nonce []byte, f AdditionalFunc) (*streamReader, error) {
	return newCipherReader(r, key, nonce, f)
}

// NewStreamWriter returns streamWriter data type, this writer sale() and write aead
// plaintext which have 16-byte poly1305 tag overhead, running Close() is necessary
// in order to calculate and write tag at the end of the write.
// AdFunc will be triggered at first call to write method
func NewStreamWriter(w io.Writer, key, nonce []byte, f AdditionalFunc) (*streamWriter, error) {
	return newChipherWriter(w, key, nonce, f)
}

func newChipherWriter(w io.Writer, key, nonce []byte, f AdditionalFunc) (*streamWriter, error) {
	if len(key) != rabbitio.KeyLen {
		return nil, rabbitio.ErrInvalidKey
	}
	if len(nonce) != rabbitio.IVXLen && len(nonce) != 0 {
		return nil, rabbitio.ErrInvalidIVX
	}
	v := &streamWriter{
		ie:          makeioaead(key, nonce, f),
		plainWriter: w,
	}

	v.writer, _ = rabbitio.NewWriterCipher(
		v.ie.key, v.ie.nonce,
		io.MultiWriter(w, v.ie.poly),
	)

	return v, nil
}

func (r *streamReader) readTo(b []byte) (int, error) {
	var n int
	if len(r.buff) > 0 {
		return r.copyBuff(b), nil
	}

	sr, err := r.readBuff()
	if err != nil {
		if err == io.EOF {
			n = r.copyUntil(b, sr)
			return n, r.verify()
		}
		return n, err
	}
	return r.copyUntil(b, sr), err
}

// Read reads and open ciphertext.
// read data is unreliable until underlying reader returns EOF
// after that Read return EOF or ErrAuthMsg if integrity of data has been compromised.
// in such a case, you need to unread data. a simple demonstration would be to delete
// or truncate the file if ErrAuthMsg is returned
func (r *streamReader) Read(b []byte) (int, error) {
	if len(b) <= 16 {
		return r.readTo(b)
	}
	n := 0
	for {
		if n+16 > len(b) {
			sr, err := r.readTo(b[n:])
			n += sr
			if err != nil {
				return n, err
			}
			break
		}

		sr, err := r.readTo(b[n : n+16])
		n += sr
		if err != nil {
			return n, err
		}
	}
	return n, nil
}

func (r *streamReader) verify() error {
	r.ie.ioPaddingTo(r.nwr)
	if r.ie.poly.Verify(r.tagc) {
		return io.EOF
	}
	return ErrAuthMsg
}

func (r *streamReader) copyUntil(b []byte, sr int) int {
	n := copy(b, r.buff[:sr])
	r.buff = r.buff[n:]
	r.nwr += n
	return n
}

func (r *streamReader) copyBuff(b []byte) int {
	n := copy(b, r.buff)
	r.buff = r.buff[n:]
	r.nwr += n
	return n
}

func (r *streamReader) readBuff() (int, error) {
	if !r.firstRead {
		r.ie.execAdFunc()
		_, err := io.ReadFull(r.read, r.temp)
		if err != nil {
			return 0, err
		}
		r.firstRead = true
	}

	var buff = make([]byte, 16)
	n, err := r.read.Read(buff)
	if err != nil {
		return 0, err
	}
	if n > len(buff) {
		return 0, errunderio
	}

	copy(r.tagc, append(r.temp[n:], buff[:n]...))
	r.buff = append(r.buff, r.temp[:n]...)
	r.buffAndXor()

	if n < 16 {
		return n, err
	}

	copy(r.temp, buff)
	return n, err
}

func (r *streamReader) buffAndXor() {
	r.ie.poly.Write(r.buff)
	r.cip.XORKeyStream(r.buff, r.buff)
}

// Write writes plaintext data, in order to calculate and write tag
// at the end of the write, running Close() is necessary
func (w *streamWriter) Write(b []byte) (int, error) {
	if !w.firstWrite {
		w.ie.execAdFunc()
		w.firstWrite = true
	}
	n, err := w.writer.Write(b)
	if err != nil {
		return n, err
	}

	w.nwr += n
	return n, err
}

func (p *ioaead) ioPaddingTo(nb int) {
	if rem := nb % 16; rem != 0 {
		var buf [16]byte
		padLen := 16 - rem
		p.poly.Write(buf[:padLen])
	}

	writeUint64(p.poly, p.adlen)
	writeUint64(p.poly, nb)
}

// Close calculate and write poly1305 tag before closing the writer
// if underlying writer does not have a Close() method, Close only
// calculate and write poly1305 tag
func (w *streamWriter) Close() error {
	w.ie.ioPaddingTo(w.nwr)
	if _, err := w.plainWriter.Write(w.ie.poly.Sum(nil)); err != nil {
		return err
	}

	if c, ok := w.plainWriter.(io.Closer); ok {
		return c.Close()
	}
	return nil
}

Snix LLC Git Repository Holder Copyright(C) 2022 All Rights Reserved Email To Snix.IR