package occ
import (
"encoding/binary"
"fmt"
"hash"
"hash/fnv"
"sync"
"sync/atomic"
"github.com/tidwall/btree"
)
// Store - a data store w.o uncertain semantics, to be wrapped in an OCC manager...
type Store interface {
Read(id uint64) (interface{}, error)
Write(writeSet []interface{}) error
Merge(old, new interface{}) (interface{}, error)
}
// Manager - wraps a user's `Store` impl., handles long-running, concurrent ops
// against the store w.o. taking locks outside of a _short_ critical section.
type Manager struct {
id atomic.Uint64
mut sync.Mutex
txnLog []uint64
hasher hash.Hash64
store Store
approxBits, retention uint64
}
func NewManager(store Store, approxBits, retention uint64) *Manager {
return &Manager{
mut: sync.Mutex{},
txnLog: make([]uint64, approxBits*retention*1/64),
hasher: fnv.New64(),
store: store,
approxBits: approxBits,
retention: retention,
}
}
func (cc *Manager) CurrentTxnID() uint64 {
return cc.id.Load()
}
func (cc *Manager) NewConnection() *Conn {
return &Conn{mgr: cc}
}
type transaction struct {
tstart, tend uint64
w, r btree.Map[uint64, interface{}]
lastRd uint64
hashWr, hashRd []uint64
}
// Conn - corresponds to an active session w. the `Store`, in practice, a `Manager`
// hands out `Conn` on each user connection / tcp connection recv'd...
type Conn struct {
txn *transaction
mgr *Manager
}
// Begin - sets a transaction on `Conn` if DNE...
func (c *Conn) Begin() error {
if c.txn != nil {
return fmt.Errorf("transaction in progress")
}
c.txn = &transaction{
tstart: c.mgr.CurrentTxnID(),
hashWr: make([]uint64, c.mgr.approxBits/64),
hashRd: make([]uint64, c.mgr.approxBits/64),
}
return nil
}
// Read - wraps the `Read` method from a user's `Store` impl.
func (c *Conn) Read(id uint64) (interface{}, error) {
if c.txn == nil {
return nil, fmt.Errorf("no transaction in progress")
}
// if `id` already modified in txn.W enforce read-your-writes by getting the key
// from transaction's local space
if val, ok := c.txn.w.Get(id); ok {
return val, nil
}
// UPDATE: To enforce repeatable-read, need to have stable objects on read; w.o
// this line, T1's reads of an object not in the writeSet are unstable...
if val, ok := c.txn.r.Get(id); ok {
return val, nil
}
// hash `id` into the approx. set of read txns if txn reads off persistent store...
idx, pos := c.hashIndex(id)
c.txn.hashRd[idx] |= (1 << pos)
obj, err := c.mgr.store.Read(id)
if err != nil {
return nil, c.rollback(err)
}
c.txn.r.Set(id, obj)
// NOTE: If set `txn.lastRd` on each read and only check up to `txn.lastR` during
// `validation`. This is always less than or eq. to c.mgr.CurrentTxnID()
// and gives us a shorter set to validate. Must verify this optimization is valid!
c.txn.lastRd = c.mgr.CurrentTxnID()
return obj, nil
}
// Write - wraps the `Write` method from a user's `Store` impl.
func (c *Conn) Write(id uint64, obj interface{}) error {
if c.txn == nil {
return fmt.Errorf("no transaction in progress")
}
// hash `id` into the approx. set of read txns. small optimization
idx, pos := c.hashIndex(id)
c.txn.hashWr[idx] |= (1 << pos)
// reconcile previous & incoming state and write into local space
objPrev, err := c.Read(id)
if err != nil {
return c.rollback(err)
}
obj, err = c.mgr.store.Merge(objPrev, obj)
if err != nil {
return c.rollback(err)
}
c.txn.w.Set(id, obj)
return nil
}
// Commit - `validation` and `write` stage of OCC; validates transaction doesn't
// depend on any updated writes && grabs a latch for _short_ critical section
func (c *Conn) Commit() error {
if c.txn == nil {
return fmt.Errorf("no transaction in progress")
}
// critical section; hold a latch for the duration of the call...
c.mgr.mut.Lock()
defer c.mgr.mut.Unlock()
// TODO: must test that using `lastRd` instead of `CurrentTxnID` t.is valid
tstart, tend := c.txn.tstart, c.txn.lastRd
if err := c.validateTransaction(tstart, tend); err != nil {
return c.rollback(err)
}
// NOTE: we leave it to the user to implement an atomic `store.Write` method
// else we _do not_ have snapshot isolation!
vs := c.txn.w.Values()
if err := c.mgr.store.Write(vs); err != nil {
return c.rollback(err)
}
tend = c.mgr.CurrentTxnID()
c.txn.tend = tend
if (tend - tstart) > c.mgr.retention {
ancientErr := fmt.Errorf("transaction older than max retention")
return c.rollback(ancientErr)
}
// write committed txn into log by overwriting corresponding slot...
slotSz := c.mgr.approxBits / 64
posStart := (tend % c.mgr.retention) * slotSz
for i := uint64(0); i < slotSz; i++ {
c.mgr.txnLog[posStart+i] = c.txn.hashWr[i]
}
c.mgr.id.Add(1)
c.txn = nil
return nil
}
// validateTransaction - checks no collision between the read set of the current txns
// between on positions (tstart, tend] in log...
func (c *Conn) validateTransaction(tstart, tend uint64) error {
slotSz := c.mgr.approxBits / 64
posStart := (tstart % c.mgr.retention) * slotSz
posEnd := (tend % c.mgr.retention) * slotSz
if posEnd < posStart {
for i := uint64(0); i < posEnd; i += slotSz {
if setOverlap(c.mgr.txnLog[i:i+slotSz], c.txn.hashRd) {
return fmt.Errorf("transaction failed validation phase")
}
}
posEnd = uint64(len(c.mgr.txnLog))
}
for i := posStart; i < posEnd; i += slotSz {
if setOverlap(c.mgr.txnLog[i:i+slotSz], c.txn.hashRd) {
return fmt.Errorf("transaction failed validation phase")
}
}
return nil
}
// Rollback...
func (c *Conn) Rollback() error {
return c.rollback(nil)
}
func (c *Conn) rollback(err error) error {
if c.txn == nil {
return fmt.Errorf("no transaction in progress")
}
c.txn = nil
return err
}
// hashIndex - hashes an object into one of `approxBits` equiv. classes. For performance
// reasons, returns the position as index and bit index in an slice of uint64 w. len
// `approxBits/64` instead of bool w. len `approxBits`
func (c *Conn) hashIndex(o uint64) (uint64, uint64) {
b := make([]byte, 8)
binary.LittleEndian.PutUint64(b, o)
u := binary.LittleEndian.Uint64(c.mgr.hasher.Sum(b))
u = u % c.mgr.approxBits
return u / 64, u % 64
}
// setOverlap - check if any item written into `a` also (possibly) in `b`, this
// is more-or-less a bloom filter with a single hash function. This is a rather
// slow impl. w.o SIMD, but we get pretty good performance by using uint64...
func setOverlap(a, b []uint64) bool {
for i := 0; i < len(a); i++ {
if (a[i] & b[i]) != 0 {
return true
}
}
return false
}