|
package tache |
|
|
|
import ( |
|
"context" |
|
"encoding/json" |
|
"fmt" |
|
"log/slog" |
|
"os" |
|
"runtime" |
|
"sync/atomic" |
|
|
|
"github.com/jaevor/go-nanoid" |
|
|
|
"github.com/xhofe/gsync" |
|
) |
|
|
|
|
|
type Manager[T Task] struct { |
|
tasks gsync.MapOf[string, T] |
|
queue gsync.QueueOf[T] |
|
workers *WorkerPool[T] |
|
opts *Options |
|
debouncePersist func() |
|
running atomic.Bool |
|
|
|
idGenerator func() string |
|
logger *slog.Logger |
|
} |
|
|
|
|
|
func NewManager[T Task](opts ...Option) *Manager[T] { |
|
options := DefaultOptions() |
|
for _, opt := range opts { |
|
opt(options) |
|
} |
|
nanoID, err := nanoid.Standard(21) |
|
if err != nil { |
|
panic(err) |
|
} |
|
m := &Manager[T]{ |
|
workers: NewWorkerPool[T](options.Works), |
|
opts: options, |
|
idGenerator: nanoID, |
|
logger: options.Logger, |
|
} |
|
m.running.Store(options.Running) |
|
if m.opts.PersistPath != "" || (m.opts.PersistReadFunction != nil && m.opts.PersistWriteFunction != nil) { |
|
m.debouncePersist = func() { |
|
_ = m.persist() |
|
} |
|
if m.opts.PersistDebounce != nil { |
|
m.debouncePersist = newDebounce(func() { |
|
_ = m.persist() |
|
}, *m.opts.PersistDebounce) |
|
} |
|
err := m.recover() |
|
if err != nil { |
|
m.logger.Error("recover error", "error", err) |
|
} |
|
} else { |
|
m.debouncePersist = func() {} |
|
} |
|
return m |
|
} |
|
|
|
|
|
func (m *Manager[T]) Add(task T) { |
|
ctx, cancel := context.WithCancel(context.Background()) |
|
task.SetCtx(ctx) |
|
task.SetCancelFunc(cancel) |
|
task.SetPersist(m.debouncePersist) |
|
if task.GetID() == "" { |
|
task.SetID(m.idGenerator()) |
|
} |
|
if _, maxRetry := task.GetRetry(); maxRetry == 0 { |
|
task.SetRetry(0, m.opts.MaxRetry) |
|
} |
|
if sliceContains([]State{StateRunning}, task.GetState()) { |
|
task.SetState(StatePending) |
|
} |
|
if sliceContains([]State{StateCanceling}, task.GetState()) { |
|
task.SetState(StateCanceled) |
|
task.SetErr(context.Canceled) |
|
} |
|
if task.GetState() == StateFailing { |
|
task.SetState(StateFailed) |
|
} |
|
m.tasks.Store(task.GetID(), task) |
|
if !sliceContains([]State{StateSucceeded, StateCanceled, StateErrored, StateFailed}, task.GetState()) { |
|
m.queue.Push(task) |
|
} |
|
m.debouncePersist() |
|
m.next() |
|
} |
|
|
|
|
|
func (m *Manager[T]) next() { |
|
|
|
if !m.running.Load() { |
|
return |
|
} |
|
|
|
worker := m.workers.Get() |
|
if worker == nil { |
|
return |
|
} |
|
m.logger.Debug("got worker", "id", worker.ID) |
|
task, err := m.queue.Pop() |
|
|
|
if err != nil { |
|
m.workers.Put(worker) |
|
return |
|
} |
|
m.logger.Debug("got task", "id", task.GetID()) |
|
go func() { |
|
defer func() { |
|
if task.GetState() == StateWaitingRetry { |
|
m.queue.Push(task) |
|
} |
|
m.workers.Put(worker) |
|
m.next() |
|
}() |
|
if task.GetState() == StateCanceling { |
|
task.SetState(StateCanceled) |
|
task.SetErr(context.Canceled) |
|
return |
|
} |
|
if m.opts.Timeout != nil { |
|
ctx, cancel := context.WithTimeout(task.Ctx(), *m.opts.Timeout) |
|
defer cancel() |
|
task.SetCtx(ctx) |
|
} |
|
m.logger.Info("worker execute task", "worker", worker.ID, "task", task.GetID()) |
|
worker.Execute(task) |
|
}() |
|
} |
|
|
|
|
|
func (m *Manager[T]) Wait() { |
|
for { |
|
tasks, running := m.queue.Len(), m.workers.working.Load() |
|
if tasks == 0 && running == 0 { |
|
return |
|
} |
|
runtime.Gosched() |
|
} |
|
} |
|
|
|
|
|
func (m *Manager[T]) persist() error { |
|
if m.opts.PersistPath == "" && m.opts.PersistReadFunction == nil && m.opts.PersistWriteFunction == nil { |
|
return nil |
|
} |
|
|
|
tasks := m.GetAll() |
|
var toPersist []T |
|
for _, task := range tasks { |
|
|
|
if p, ok := Task(task).(Persistable); !ok || p.Persistable() { |
|
toPersist = append(toPersist, task) |
|
} |
|
} |
|
marshal, err := json.Marshal(toPersist) |
|
if err != nil { |
|
return err |
|
} |
|
if m.opts.PersistReadFunction != nil && m.opts.PersistWriteFunction != nil { |
|
err = m.opts.PersistWriteFunction(marshal) |
|
if err != nil { |
|
return err |
|
} |
|
} |
|
if m.opts.PersistPath != "" { |
|
|
|
err = os.WriteFile(m.opts.PersistPath, marshal, 0644) |
|
if err != nil { |
|
return err |
|
} |
|
} |
|
return nil |
|
} |
|
|
|
|
|
func (m *Manager[T]) recover() error { |
|
var data []byte |
|
var err error |
|
if m.opts.PersistPath != "" { |
|
|
|
data, err = os.ReadFile(m.opts.PersistPath) |
|
} else if m.opts.PersistReadFunction != nil && m.opts.PersistWriteFunction != nil { |
|
data, err = m.opts.PersistReadFunction() |
|
} else { |
|
return nil |
|
} |
|
if err != nil { |
|
return err |
|
} |
|
|
|
var tasks []T |
|
err = json.Unmarshal(data, &tasks) |
|
if err != nil { |
|
return err |
|
} |
|
|
|
for _, task := range tasks { |
|
|
|
if r, ok := Task(task).(Recoverable); !ok || r.Recoverable() { |
|
m.Add(task) |
|
} else { |
|
task.SetState(StateFailed) |
|
task.SetErr(fmt.Errorf("the task is interrupted and cannot be recovered")) |
|
m.tasks.Store(task.GetID(), task) |
|
} |
|
} |
|
return nil |
|
} |
|
|
|
|
|
func (m *Manager[T]) Cancel(id string) { |
|
if task, ok := m.tasks.Load(id); ok { |
|
task.Cancel() |
|
m.debouncePersist() |
|
} |
|
} |
|
|
|
|
|
func (m *Manager[T]) CancelAll() { |
|
m.tasks.Range(func(key string, value T) bool { |
|
value.Cancel() |
|
return true |
|
}) |
|
m.debouncePersist() |
|
} |
|
|
|
|
|
func (m *Manager[T]) GetAll() []T { |
|
var tasks []T |
|
m.tasks.Range(func(key string, value T) bool { |
|
tasks = append(tasks, value) |
|
return true |
|
}) |
|
return tasks |
|
} |
|
|
|
|
|
func (m *Manager[T]) GetByID(id string) (T, bool) { |
|
return m.tasks.Load(id) |
|
} |
|
|
|
|
|
func (m *Manager[T]) GetByState(state ...State) []T { |
|
var tasks []T |
|
m.tasks.Range(func(key string, value T) bool { |
|
if sliceContains(state, value.GetState()) { |
|
tasks = append(tasks, value) |
|
} |
|
return true |
|
}) |
|
return tasks |
|
} |
|
|
|
|
|
func (m *Manager[T]) Remove(id string) { |
|
m.tasks.Delete(id) |
|
m.debouncePersist() |
|
} |
|
|
|
|
|
func (m *Manager[T]) RemoveAll() { |
|
tasks := m.GetAll() |
|
for _, task := range tasks { |
|
m.Remove(task.GetID()) |
|
} |
|
} |
|
|
|
|
|
func (m *Manager[T]) RemoveByState(state ...State) { |
|
tasks := m.GetByState(state...) |
|
for _, task := range tasks { |
|
m.Remove(task.GetID()) |
|
} |
|
} |
|
|
|
|
|
func (m *Manager[T]) Retry(id string) { |
|
if task, ok := m.tasks.Load(id); ok { |
|
task.SetState(StateWaitingRetry) |
|
task.SetErr(nil) |
|
task.SetRetry(0, m.opts.MaxRetry) |
|
m.queue.Push(task) |
|
m.next() |
|
m.debouncePersist() |
|
} |
|
} |
|
|
|
|
|
func (m *Manager[T]) RetryAllFailed() { |
|
tasks := m.GetByState(StateFailed) |
|
for _, task := range tasks { |
|
m.Retry(task.GetID()) |
|
} |
|
} |
|
|
|
|
|
func (m *Manager[T]) Start() { |
|
m.running.Store(true) |
|
m.next() |
|
} |
|
|
|
|
|
func (m *Manager[T]) Pause() { |
|
m.running.Store(false) |
|
} |
|
|