// Copyright The OpenTelemetry Authors
// SPDX-License-Identifier: Apache-2.0

package internal // import "go.opentelemetry.io/collector/exporter/exporterhelper/internal"

import (
	"bytes"
	"context"
	"encoding/binary"
	"errors"

	"go.uber.org/zap"

	"go.opentelemetry.io/collector/extension/experimental/storage"
)

var errItemIndexArrInvalidDataType = errors.New("invalid data type, expected []itemIndex")

// batchStruct provides convenience capabilities for creating and processing storage extension batches
type batchStruct struct {
	logger *zap.Logger
	pcs    *persistentContiguousStorage

	operations    []storage.Operation
	getOperations map[string]storage.Operation
}

func newBatch(pcs *persistentContiguousStorage) *batchStruct {
	return &batchStruct{
		logger:        pcs.logger,
		pcs:           pcs,
		operations:    []storage.Operation{},
		getOperations: map[string]storage.Operation{},
	}
}

// execute runs the provided operations in order
func (bof *batchStruct) execute(ctx context.Context) (*batchStruct, error) {
	err := bof.pcs.client.Batch(ctx, bof.operations...)
	if err != nil {
		return nil, err
	}

	return bof, nil
}

// set adds a Set operation to the batch
func (bof *batchStruct) set(key string, value any, marshal func(any) ([]byte, error)) *batchStruct {
	valueBytes, err := marshal(value)
	if err != nil {
		bof.logger.Debug("Failed marshaling item, skipping it", zap.String(zapKey, key), zap.Error(err))
	} else {
		bof.operations = append(bof.operations, storage.SetOperation(key, valueBytes))
	}

	return bof
}

// get adds a Get operation to the batch. After executing, its result will be available through getResult
func (bof *batchStruct) get(keys ...string) *batchStruct {
	for _, key := range keys {
		op := storage.GetOperation(key)
		bof.getOperations[key] = op
		bof.operations = append(bof.operations, op)
	}

	return bof
}

// delete adds a Delete operation to the batch
func (bof *batchStruct) delete(keys ...string) *batchStruct {
	for _, key := range keys {
		bof.operations = append(bof.operations, storage.DeleteOperation(key))
	}

	return bof
}

// getResult returns the result of a Get operation for a given key using the provided unmarshal function.
// It should be called after execute. It may return nil value
func (bof *batchStruct) getResult(key string, unmarshal func([]byte) (any, error)) (any, error) {
	op := bof.getOperations[key]
	if op == nil {
		return nil, errKeyNotPresentInBatch
	}

	if op.Value == nil {
		return nil, nil
	}

	return unmarshal(op.Value)
}

// getRequestResult returns the result of a Get operation as a request
// If the value cannot be retrieved, it returns an error
func (bof *batchStruct) getRequestResult(key string) (Request, error) {
	reqIf, err := bof.getResult(key, bof.bytesToRequest)
	if err != nil {
		return nil, err
	}
	if reqIf == nil {
		return nil, errValueNotSet
	}

	return reqIf.(Request), nil
}

// getItemIndexResult returns the result of a Get operation as an itemIndex
// If the value cannot be retrieved, it returns an error
func (bof *batchStruct) getItemIndexResult(key string) (itemIndex, error) {
	itemIndexIf, err := bof.getResult(key, bytesToItemIndex)
	if err != nil {
		return itemIndex(0), err
	}

	if itemIndexIf == nil {
		return itemIndex(0), errValueNotSet
	}

	return itemIndexIf.(itemIndex), nil
}

// getItemIndexArrayResult returns the result of a Get operation as a itemIndexArray
// It may return nil value
func (bof *batchStruct) getItemIndexArrayResult(key string) ([]itemIndex, error) {
	itemIndexArrIf, err := bof.getResult(key, bytesToItemIndexArray)
	if err != nil {
		return nil, err
	}

	if itemIndexArrIf == nil {
		return nil, nil
	}

	return itemIndexArrIf.([]itemIndex), nil
}

// setRequest adds Set operation over a given request to the batch
func (bof *batchStruct) setRequest(key string, value Request) *batchStruct {
	return bof.set(key, value, bof.requestToBytes)
}

// setItemIndex adds Set operation over a given itemIndex to the batch
func (bof *batchStruct) setItemIndex(key string, value itemIndex) *batchStruct {
	return bof.set(key, value, itemIndexToBytes)
}

// setItemIndexArray adds Set operation over a given itemIndex array to the batch
func (bof *batchStruct) setItemIndexArray(key string, value []itemIndex) *batchStruct {
	return bof.set(key, value, itemIndexArrayToBytes)
}

func itemIndexToBytes(val any) ([]byte, error) {
	var buf bytes.Buffer
	err := binary.Write(&buf, binary.LittleEndian, val)
	if err != nil {
		return nil, err
	}
	return buf.Bytes(), err
}

func bytesToItemIndex(b []byte) (any, error) {
	var val itemIndex
	err := binary.Read(bytes.NewReader(b), binary.LittleEndian, &val)
	if err != nil {
		return val, err
	}
	return val, nil
}

func itemIndexArrayToBytes(arr any) ([]byte, error) {
	var buf bytes.Buffer
	size := 0

	if arr != nil {
		arrItemIndex, ok := arr.([]itemIndex)
		if ok {
			size = len(arrItemIndex)
		} else {
			return nil, errItemIndexArrInvalidDataType
		}
	}

	err := binary.Write(&buf, binary.LittleEndian, uint32(size))
	if err != nil {
		return nil, err
	}

	err = binary.Write(&buf, binary.LittleEndian, arr)
	if err != nil {
		return nil, err
	}
	return buf.Bytes(), err
}

func bytesToItemIndexArray(b []byte) (any, error) {
	var size uint32
	reader := bytes.NewReader(b)
	err := binary.Read(reader, binary.LittleEndian, &size)
	if err != nil {
		return nil, err
	}

	val := make([]itemIndex, size)
	err = binary.Read(reader, binary.LittleEndian, &val)
	return val, err
}

func (bof *batchStruct) requestToBytes(req any) ([]byte, error) {
	return bof.pcs.marshaler(req.(Request))
}

func (bof *batchStruct) bytesToRequest(b []byte) (any, error) {
	return bof.pcs.unmarshaler(b)
}
