occ-code
package occ

import (
    "encoding/binary"
    "fmt"
    "hash"
    "hash/fnv"
    "sync"
    "sync/atomic"
    "github.com/tidwall/btree"
)

// Store - a data store w.o uncertain semantics, to be wrapped in an OCC manager...
type Store interface {
    Read(id uint64) (interface{}, error)
    Write(writeSet []interface{}) error
    Merge(old, new interface{}) (interface{}, error)
}

// Manager - wraps a user's `Store` impl., handles long-running, concurrent ops
// against the store w.o. taking locks outside of a _short_ critical section.
type Manager struct {
    id                    atomic.Uint64
    mut                   sync.Mutex
    txnLog                []uint64
    hasher                hash.Hash64
    store                 Store
    approxBits, retention uint64
}

func NewManager(store Store, approxBits, retention uint64) *Manager {
    return &Manager{
        mut:        sync.Mutex{},
        txnLog:     make([]uint64, approxBits*retention*1/64),
        hasher:     fnv.New64(),
        store:      store,
        approxBits: approxBits,
        retention:  retention,
    }
}

func (cc *Manager) CurrentTxnID() uint64 {
    return cc.id.Load()
}

func (cc *Manager) NewConnection() *Conn {
    return &Conn{mgr: cc}
}

type transaction struct {
    tstart, tend   uint64
    w, r           btree.Map[uint64, interface{}]
    lastRd         uint64
    hashWr, hashRd []uint64
}

// Conn - corresponds to an active session w. the `Store`, in practice, a `Manager`
// hands out `Conn` on each user connection / tcp connection recv'd...
type Conn struct {
    txn *transaction
    mgr *Manager
}

// Begin - sets a transaction on `Conn` if DNE...
func (c *Conn) Begin() error {
    if c.txn != nil {
        return fmt.Errorf("transaction in progress")
    }

    c.txn = &transaction{
        tstart: c.mgr.CurrentTxnID(),
        hashWr: make([]uint64, c.mgr.approxBits/64),
        hashRd: make([]uint64, c.mgr.approxBits/64),
    }
    return nil
}

// Read - wraps the `Read` method from a user's `Store` impl.
func (c *Conn) Read(id uint64) (interface{}, error) {
    if c.txn == nil {
        return nil, fmt.Errorf("no transaction in progress")
    }

    // if `id` already modified in txn.W enforce read-your-writes by getting the key
    // from transaction's local space
    if val, ok := c.txn.w.Get(id); ok {
        return val, nil
    }

    // UPDATE: To enforce repeatable-read, need to have stable objects on read; w.o
    // this line, T1's reads of an object not in the writeSet are unstable...
    if val, ok := c.txn.r.Get(id); ok {
        return val, nil
    }

    // hash `id` into the approx. set of read txns if txn reads off persistent store...
    idx, pos := c.hashIndex(id)
    c.txn.hashRd[idx] |= (1 << pos)

    obj, err := c.mgr.store.Read(id)
    if err != nil {
        return nil, c.rollback(err)
    }

    c.txn.r.Set(id, obj)
    // NOTE: If set `txn.lastRd` on each read and only check up to `txn.lastR` during
    // `validation`. This is always less than or eq. to c.mgr.CurrentTxnID()
    // and gives us a shorter set to validate. Must verify this optimization is valid!
    c.txn.lastRd = c.mgr.CurrentTxnID()
    return obj, nil
}

// Write - wraps the `Write` method from a user's `Store` impl.
func (c *Conn) Write(id uint64, obj interface{}) error {
    if c.txn == nil {
        return fmt.Errorf("no transaction in progress")
    }

    // hash `id` into the approx. set of read txns. small optimization
    idx, pos := c.hashIndex(id)
    c.txn.hashWr[idx] |= (1 << pos)

    // reconcile previous & incoming state and write into local space
    objPrev, err := c.Read(id)
    if err != nil {
        return c.rollback(err)
    }

    obj, err = c.mgr.store.Merge(objPrev, obj)
    if err != nil {
        return c.rollback(err)
    }

    c.txn.w.Set(id, obj)
    return nil
}

// Commit - `validation` and `write` stage of OCC; validates transaction doesn't
// depend on any updated writes && grabs a latch for _short_ critical section
func (c *Conn) Commit() error {
    if c.txn == nil {
        return fmt.Errorf("no transaction in progress")
    }

    // critical section; hold a latch for the duration of the call...
    c.mgr.mut.Lock()
    defer c.mgr.mut.Unlock()

    // TODO: must test that using `lastRd` instead of `CurrentTxnID` t.is valid
    tstart, tend := c.txn.tstart, c.txn.lastRd
    if err := c.validateTransaction(tstart, tend); err != nil {
        return c.rollback(err)
    }

    // NOTE: we leave it to the user to implement an atomic `store.Write` method
    // else we _do not_ have snapshot isolation!
    vs := c.txn.w.Values()
    if err := c.mgr.store.Write(vs); err != nil {
        return c.rollback(err)
    }

    tend = c.mgr.CurrentTxnID()
    c.txn.tend = tend
    if (tend - tstart) > c.mgr.retention {
        ancientErr := fmt.Errorf("transaction older than max retention")
        return c.rollback(ancientErr)
    }

    // write committed txn into log by overwriting corresponding slot...
    slotSz := c.mgr.approxBits / 64
    posStart := (tend % c.mgr.retention) * slotSz
    for i := uint64(0); i < slotSz; i++ {
        c.mgr.txnLog[posStart+i] = c.txn.hashWr[i]
    }
    c.mgr.id.Add(1)
    c.txn = nil
    return nil
}

// validateTransaction - checks no collision between the read set of the current txns
// between on positions (tstart, tend] in log...
func (c *Conn) validateTransaction(tstart, tend uint64) error {
    slotSz := c.mgr.approxBits / 64
    posStart := (tstart % c.mgr.retention) * slotSz
    posEnd := (tend % c.mgr.retention) * slotSz

    if posEnd < posStart {
        for i := uint64(0); i < posEnd; i += slotSz {
            if setOverlap(c.mgr.txnLog[i:i+slotSz], c.txn.hashRd) {
                return fmt.Errorf("transaction failed validation phase")
            }
        }
        posEnd = uint64(len(c.mgr.txnLog))
    }

    for i := posStart; i < posEnd; i += slotSz {
        if setOverlap(c.mgr.txnLog[i:i+slotSz], c.txn.hashRd) {
            return fmt.Errorf("transaction failed validation phase")
        }
    }
    return nil
}

// Rollback...
func (c *Conn) Rollback() error {
    return c.rollback(nil)
}

func (c *Conn) rollback(err error) error {
    if c.txn == nil {
        return fmt.Errorf("no transaction in progress")
    }
    c.txn = nil
    return err
}

// hashIndex - hashes an object into one of `approxBits` equiv. classes. For performance
// reasons, returns the position as index and bit index in an slice of uint64 w. len
// `approxBits/64` instead of bool w. len `approxBits`
func (c *Conn) hashIndex(o uint64) (uint64, uint64) {
    b := make([]byte, 8)
    binary.LittleEndian.PutUint64(b, o)
    u := binary.LittleEndian.Uint64(c.mgr.hasher.Sum(b))
    u = u % c.mgr.approxBits
    return u / 64, u % 64
}

// setOverlap - check if any item written into `a` also (possibly) in `b`, this
// is more-or-less a bloom filter with a single hash function. This is a rather
// slow impl. w.o SIMD, but we get pretty good performance by using uint64...
func setOverlap(a, b []uint64) bool {
    for i := 0; i < len(a); i++ {
        if (a[i] & b[i]) != 0 {
            return true
        }
    }
    return false
}