zstd_stream.go 7.6 KB
package zstd

/*
#define ZSTD_STATIC_LINKING_ONLY
#define ZBUFF_DISABLE_DEPRECATE_WARNINGS
#include "zstd.h"
#include "zbuff.h"
*/
import "C"
import (
	"errors"
	"fmt"
	"io"
	"unsafe"
)

var errShortRead = errors.New("short read")

// Writer is an io.WriteCloser that zstd-compresses its input.
type Writer struct {
	CompressionLevel int

	ctx              *C.ZSTD_CCtx
	dict             []byte
	dstBuffer        []byte
	firstError       error
	underlyingWriter io.Writer
}

func resize(in []byte, newSize int) []byte {
	if in == nil {
		return make([]byte, newSize)
	}
	if newSize <= cap(in) {
		return in[:newSize]
	}
	toAdd := newSize - len(in)
	return append(in, make([]byte, toAdd)...)
}

// NewWriter creates a new Writer with default compression options.  Writes to
// the writer will be written in compressed form to w.
func NewWriter(w io.Writer) *Writer {
	return NewWriterLevelDict(w, DefaultCompression, nil)
}

// NewWriterLevel is like NewWriter but specifies the compression level instead
// of assuming default compression.
//
// The level can be DefaultCompression or any integer value between BestSpeed
// and BestCompression inclusive.
func NewWriterLevel(w io.Writer, level int) *Writer {
	return NewWriterLevelDict(w, level, nil)

}

// NewWriterLevelDict is like NewWriterLevel but specifies a dictionary to
// compress with.  If the dictionary is empty or nil it is ignored. The dictionary
// should not be modified until the writer is closed.
func NewWriterLevelDict(w io.Writer, level int, dict []byte) *Writer {
	var err error
	ctx := C.ZSTD_createCCtx()

	if dict == nil {
		err = getError(int(C.ZSTD_compressBegin(ctx,
			C.int(level))))
	} else {
		err = getError(int(C.ZSTD_compressBegin_usingDict(
			ctx,
			unsafe.Pointer(&dict[0]),
			C.size_t(len(dict)),
			C.int(level))))
	}

	return &Writer{
		CompressionLevel: level,
		ctx:              ctx,
		dict:             dict,
		dstBuffer:        make([]byte, CompressBound(1024)),
		firstError:       err,
		underlyingWriter: w,
	}
}

// Write writes a compressed form of p to the underlying io.Writer.
func (w *Writer) Write(p []byte) (int, error) {
	if w.firstError != nil {
		return 0, w.firstError
	}
	if len(p) == 0 {
		return 0, nil
	}
	// Check if dstBuffer is enough
	if len(w.dstBuffer) < CompressBound(len(p)) {
		w.dstBuffer = make([]byte, CompressBound(len(p)))
	}

	retCode := C.ZSTD_compressContinue(
		w.ctx,
		unsafe.Pointer(&w.dstBuffer[0]),
		C.size_t(len(w.dstBuffer)),
		unsafe.Pointer(&p[0]),
		C.size_t(len(p)))

	if err := getError(int(retCode)); err != nil {
		return 0, err
	}
	written := int(retCode)

	// Write to underlying buffer
	_, err := w.underlyingWriter.Write(w.dstBuffer[:written])

	// Same behaviour as zlib, we can't know how much data we wrote, only
	// if there was an error
	if err != nil {
		return 0, err
	}
	return len(p), err
}

// Close closes the Writer, flushing any unwritten data to the underlying
// io.Writer and freeing objects, but does not close the underlying io.Writer.
func (w *Writer) Close() error {
	retCode := C.ZSTD_compressEnd(
		w.ctx,
		unsafe.Pointer(&w.dstBuffer[0]),
		C.size_t(len(w.dstBuffer)),
		unsafe.Pointer(nil),
		C.size_t(0))

	if err := getError(int(retCode)); err != nil {
		return err
	}
	written := int(retCode)
	retCode = C.ZSTD_freeCCtx(w.ctx) // Safely close buffer before writing the end

	if err := getError(int(retCode)); err != nil {
		return err
	}

	_, err := w.underlyingWriter.Write(w.dstBuffer[:written])
	if err != nil {
		return err
	}
	return nil
}

// reader is an io.ReadCloser that decompresses when read from.
type reader struct {
	ctx                 *C.ZBUFF_DCtx
	compressionBuffer   []byte
	compressionLeft     int
	decompressionBuffer []byte
	decompOff           int
	decompSize          int
	dict                []byte
	firstError          error
	recommendedSrcSize  int
	underlyingReader    io.Reader
}

// NewReader creates a new io.ReadCloser.  Reads from the returned ReadCloser
// read and decompress data from r.  It is the caller's responsibility to call
// Close on the ReadCloser when done.  If this is not done, underlying objects
// in the zstd library will not be freed.
func NewReader(r io.Reader) io.ReadCloser {
	return NewReaderDict(r, nil)
}

// NewReaderDict is like NewReader but uses a preset dictionary.  NewReaderDict
// ignores the dictionary if it is nil.
func NewReaderDict(r io.Reader, dict []byte) io.ReadCloser {
	var err error
	ctx := C.ZBUFF_createDCtx()
	if len(dict) == 0 {
		err = getError(int(C.ZBUFF_decompressInit(ctx)))
	} else {
		err = getError(int(C.ZBUFF_decompressInitDictionary(
			ctx,
			unsafe.Pointer(&dict[0]),
			C.size_t(len(dict)))))
	}
	cSize := int(C.ZBUFF_recommendedDInSize())
	dSize := int(C.ZBUFF_recommendedDOutSize())
	if cSize <= 0 {
		panic(fmt.Errorf("ZBUFF_recommendedDInSize() returned invalid size: %v", cSize))
	}
	if dSize <= 0 {
		panic(fmt.Errorf("ZBUFF_recommendedDOutSize() returned invalid size: %v", dSize))
	}

	compressionBuffer := make([]byte, cSize)
	decompressionBuffer := make([]byte, dSize)
	return &reader{
		ctx:                 ctx,
		dict:                dict,
		compressionBuffer:   compressionBuffer,
		decompressionBuffer: decompressionBuffer,
		firstError:          err,
		recommendedSrcSize:  cSize,
		underlyingReader:    r,
	}
}

// Close frees the allocated C objects
func (r *reader) Close() error {
	return getError(int(C.ZBUFF_freeDCtx(r.ctx)))
}

func (r *reader) Read(p []byte) (int, error) {

	// If we already have enough bytes, return
	if r.decompSize-r.decompOff >= len(p) {
		copy(p, r.decompressionBuffer[r.decompOff:])
		r.decompOff += len(p)
		return len(p), nil
	}

	copy(p, r.decompressionBuffer[r.decompOff:r.decompSize])
	got := r.decompSize - r.decompOff
	r.decompSize = 0
	r.decompOff = 0

	for got < len(p) {
		// Populate src
		src := r.compressionBuffer
		reader := r.underlyingReader
		n, err := TryReadFull(reader, src[r.compressionLeft:])
		if err != nil && err != errShortRead { // Handle underlying reader errors first
			return 0, fmt.Errorf("failed to read from underlying reader: %s", err)
		} else if n == 0 && r.compressionLeft == 0 {
			return got, io.EOF
		}
		src = src[:r.compressionLeft+n]

		// C code
		cSrcSize := C.size_t(len(src))
		cDstSize := C.size_t(len(r.decompressionBuffer))
		retCode := int(C.ZBUFF_decompressContinue(
			r.ctx,
			unsafe.Pointer(&r.decompressionBuffer[0]),
			&cDstSize,
			unsafe.Pointer(&src[0]),
			&cSrcSize))

		if err = getError(retCode); err != nil {
			return 0, fmt.Errorf("failed to decompress: %s", err)
		}

		// Put everything in buffer
		if int(cSrcSize) < len(src) {
			left := src[int(cSrcSize):]
			copy(r.compressionBuffer, left)
		}
		r.compressionLeft = len(src) - int(cSrcSize)
		r.decompSize = int(cDstSize)
		r.decompOff = copy(p[got:], r.decompressionBuffer[:r.decompSize])
		got += r.decompOff

		// Resize buffers
		nsize := retCode // Hint for next src buffer size
		if nsize <= 0 {
			// Reset to recommended size
			nsize = r.recommendedSrcSize
		}
		if nsize < r.compressionLeft {
			nsize = r.compressionLeft
		}
		r.compressionBuffer = resize(r.compressionBuffer, nsize)
	}
	return got, nil
}

// TryReadFull reads buffer just as ReadFull does
// Here we expect that buffer may end and we do not return ErrUnexpectedEOF as ReadAtLeast does.
// We return errShortRead instead to distinguish short reads and failures.
// We cannot use ReadFull/ReadAtLeast because it masks Reader errors, such as network failures
// and causes panic instead of error.
func TryReadFull(r io.Reader, buf []byte) (n int, err error) {
	for n < len(buf) && err == nil {
		var nn int
		nn, err = r.Read(buf[n:])
		n += nn
	}
	if n == len(buf) && err == io.EOF {
		err = nil // EOF at the end is somewhat expected
	} else if err == io.EOF {
		err = errShortRead
	}
	return
}