|
package net |
|
|
|
import ( |
|
"bytes" |
|
"context" |
|
"fmt" |
|
"io" |
|
"math" |
|
"net/http" |
|
"strconv" |
|
"strings" |
|
"sync" |
|
"time" |
|
|
|
"github.com/alist-org/alist/v3/pkg/http_range" |
|
"github.com/aws/aws-sdk-go/aws/awsutil" |
|
log "github.com/sirupsen/logrus" |
|
) |
|
|
|
|
|
|
|
const DefaultDownloadPartSize = 1024 * 1024 * 10 |
|
|
|
|
|
|
|
const DefaultDownloadConcurrency = 2 |
|
|
|
|
|
const DefaultPartBodyMaxRetries = 3 |
|
|
|
type Downloader struct { |
|
PartSize int |
|
|
|
|
|
PartBodyMaxRetries int |
|
|
|
|
|
|
|
|
|
|
|
Concurrency int |
|
|
|
|
|
HttpClient HttpRequestFunc |
|
} |
|
type HttpRequestFunc func(ctx context.Context, params *HttpRequestParams) (*http.Response, error) |
|
|
|
func NewDownloader(options ...func(*Downloader)) *Downloader { |
|
d := &Downloader{ |
|
HttpClient: DefaultHttpRequestFunc, |
|
PartSize: DefaultDownloadPartSize, |
|
PartBodyMaxRetries: DefaultPartBodyMaxRetries, |
|
Concurrency: DefaultDownloadConcurrency, |
|
} |
|
for _, option := range options { |
|
option(d) |
|
} |
|
return d |
|
} |
|
|
|
|
|
|
|
|
|
|
|
func (d Downloader) Download(ctx context.Context, p *HttpRequestParams) (readCloser io.ReadCloser, err error) { |
|
|
|
var finalP HttpRequestParams |
|
awsutil.Copy(&finalP, p) |
|
if finalP.Range.Length == -1 { |
|
finalP.Range.Length = finalP.Size - finalP.Range.Start |
|
} |
|
impl := downloader{params: &finalP, cfg: d, ctx: ctx} |
|
|
|
|
|
|
|
impl.partBodyMaxRetries = d.PartBodyMaxRetries |
|
|
|
if impl.cfg.Concurrency == 0 { |
|
impl.cfg.Concurrency = DefaultDownloadConcurrency |
|
} |
|
|
|
if impl.cfg.PartSize == 0 { |
|
impl.cfg.PartSize = DefaultDownloadPartSize |
|
} |
|
|
|
return impl.download() |
|
} |
|
|
|
|
|
type downloader struct { |
|
ctx context.Context |
|
cancel context.CancelFunc |
|
cfg Downloader |
|
|
|
params *HttpRequestParams |
|
chunkChannel chan chunk |
|
|
|
|
|
m sync.Mutex |
|
|
|
nextChunk int |
|
chunks []chunk |
|
bufs []*Buf |
|
|
|
written int64 |
|
err error |
|
|
|
partBodyMaxRetries int |
|
} |
|
|
|
|
|
func (d *downloader) download() (io.ReadCloser, error) { |
|
d.ctx, d.cancel = context.WithCancel(d.ctx) |
|
|
|
pos := d.params.Range.Start |
|
maxPos := d.params.Range.Start + d.params.Range.Length |
|
id := 0 |
|
for pos < maxPos { |
|
finalSize := int64(d.cfg.PartSize) |
|
|
|
if pos+finalSize > maxPos { |
|
finalSize = maxPos - pos |
|
} |
|
c := chunk{start: pos, size: finalSize, id: id} |
|
d.chunks = append(d.chunks, c) |
|
pos += finalSize |
|
id++ |
|
} |
|
if len(d.chunks) < d.cfg.Concurrency { |
|
d.cfg.Concurrency = len(d.chunks) |
|
} |
|
|
|
if d.cfg.Concurrency == 1 { |
|
resp, err := d.cfg.HttpClient(d.ctx, d.params) |
|
if err != nil { |
|
return nil, err |
|
} |
|
return resp.Body, nil |
|
} |
|
|
|
|
|
d.chunkChannel = make(chan chunk, d.cfg.Concurrency) |
|
|
|
for i := 0; i < d.cfg.Concurrency; i++ { |
|
buf := NewBuf(d.ctx, d.cfg.PartSize, i) |
|
d.bufs = append(d.bufs, buf) |
|
go d.downloadPart() |
|
} |
|
|
|
for i := 0; i < d.cfg.Concurrency; i++ { |
|
d.sendChunkTask() |
|
} |
|
|
|
var rc io.ReadCloser = NewMultiReadCloser(d.chunks[0].buf, d.interrupt, d.finishBuf) |
|
|
|
|
|
return rc, d.err |
|
} |
|
func (d *downloader) sendChunkTask() *chunk { |
|
ch := &d.chunks[d.nextChunk] |
|
ch.buf = d.getBuf(d.nextChunk) |
|
ch.buf.Reset(int(ch.size)) |
|
d.chunkChannel <- *ch |
|
d.nextChunk++ |
|
return ch |
|
} |
|
|
|
|
|
func (d *downloader) interrupt() error { |
|
d.cancel() |
|
if d.written != d.params.Range.Length { |
|
log.Debugf("Downloader interrupt before finish") |
|
if d.getErr() == nil { |
|
d.setErr(fmt.Errorf("interrupted")) |
|
} |
|
} |
|
defer func() { |
|
close(d.chunkChannel) |
|
for _, buf := range d.bufs { |
|
buf.Close() |
|
} |
|
}() |
|
return d.err |
|
} |
|
func (d *downloader) getBuf(id int) (b *Buf) { |
|
|
|
return d.bufs[id%d.cfg.Concurrency] |
|
} |
|
func (d *downloader) finishBuf(id int) (isLast bool, buf *Buf) { |
|
if id >= len(d.chunks)-1 { |
|
return true, nil |
|
} |
|
if d.nextChunk > id+1 { |
|
return false, d.getBuf(id + 1) |
|
} |
|
ch := d.sendChunkTask() |
|
return false, ch.buf |
|
} |
|
|
|
|
|
|
|
func (d *downloader) downloadPart() { |
|
|
|
for { |
|
c, ok := <-d.chunkChannel |
|
if !ok { |
|
break |
|
} |
|
if d.getErr() != nil { |
|
|
|
|
|
continue |
|
} |
|
log.Debugf("downloadPart tried to get chunk") |
|
if err := d.downloadChunk(&c); err != nil { |
|
d.setErr(err) |
|
} |
|
} |
|
} |
|
|
|
|
|
func (d *downloader) downloadChunk(ch *chunk) error { |
|
log.Debugf("start new chunk %+v buffer_id =%d", ch, ch.id) |
|
var n int64 |
|
var err error |
|
params := d.getParamsFromChunk(ch) |
|
for retry := 0; retry <= d.partBodyMaxRetries; retry++ { |
|
if d.getErr() != nil { |
|
return d.getErr() |
|
} |
|
n, err = d.tryDownloadChunk(params, ch) |
|
if err == nil { |
|
break |
|
} |
|
|
|
|
|
|
|
|
|
|
|
if bodyErr, ok := err.(*errReadingBody); ok { |
|
err = bodyErr.Unwrap() |
|
} else { |
|
return err |
|
} |
|
|
|
|
|
|
|
log.Debugf("object part body download interrupted %s, err, %v, retrying attempt %d", |
|
params.URL, err, retry) |
|
} |
|
|
|
d.incrWritten(n) |
|
log.Debugf("down_%d downloaded chunk", ch.id) |
|
|
|
|
|
return err |
|
} |
|
|
|
func (d *downloader) tryDownloadChunk(params *HttpRequestParams, ch *chunk) (int64, error) { |
|
|
|
resp, err := d.cfg.HttpClient(d.ctx, params) |
|
if err != nil { |
|
return 0, err |
|
} |
|
defer resp.Body.Close() |
|
|
|
if ch.id == 0 { |
|
err = d.checkTotalBytes(resp) |
|
if err != nil { |
|
return 0, err |
|
} |
|
} |
|
|
|
n, err := io.Copy(ch.buf, resp.Body) |
|
|
|
if err != nil { |
|
return n, &errReadingBody{err: err} |
|
} |
|
if n != ch.size { |
|
err = fmt.Errorf("chunk download size incorrect, expected=%d, got=%d", ch.size, n) |
|
return n, &errReadingBody{err: err} |
|
} |
|
|
|
return n, nil |
|
} |
|
func (d *downloader) getParamsFromChunk(ch *chunk) *HttpRequestParams { |
|
var params HttpRequestParams |
|
awsutil.Copy(¶ms, d.params) |
|
|
|
|
|
params.Range = http_range.Range{Start: ch.start, Length: ch.size} |
|
return ¶ms |
|
} |
|
|
|
func (d *downloader) checkTotalBytes(resp *http.Response) error { |
|
var err error |
|
var totalBytes int64 = math.MinInt64 |
|
contentRange := resp.Header.Get("Content-Range") |
|
if len(contentRange) == 0 { |
|
|
|
|
|
if resp.ContentLength > 0 { |
|
totalBytes = resp.ContentLength |
|
} |
|
} else { |
|
parts := strings.Split(contentRange, "/") |
|
|
|
total := int64(-1) |
|
|
|
|
|
|
|
|
|
totalStr := parts[len(parts)-1] |
|
if totalStr != "*" { |
|
total, err = strconv.ParseInt(totalStr, 10, 64) |
|
if err != nil { |
|
err = fmt.Errorf("failed extracting file size") |
|
} |
|
} else { |
|
err = fmt.Errorf("file size unknown") |
|
} |
|
|
|
totalBytes = total |
|
} |
|
if totalBytes != d.params.Size && err == nil { |
|
err = fmt.Errorf("expect file size=%d unmatch remote report size=%d, need refresh cache", d.params.Size, totalBytes) |
|
} |
|
if err != nil { |
|
_ = d.interrupt() |
|
d.setErr(err) |
|
} |
|
return err |
|
|
|
} |
|
|
|
func (d *downloader) incrWritten(n int64) { |
|
d.m.Lock() |
|
defer d.m.Unlock() |
|
|
|
d.written += n |
|
} |
|
|
|
|
|
func (d *downloader) getErr() error { |
|
d.m.Lock() |
|
defer d.m.Unlock() |
|
|
|
return d.err |
|
} |
|
|
|
|
|
func (d *downloader) setErr(e error) { |
|
d.m.Lock() |
|
defer d.m.Unlock() |
|
|
|
d.err = e |
|
} |
|
|
|
|
|
|
|
|
|
|
|
type chunk struct { |
|
start int64 |
|
size int64 |
|
buf *Buf |
|
id int |
|
|
|
|
|
|
|
|
|
} |
|
|
|
func DefaultHttpRequestFunc(ctx context.Context, params *HttpRequestParams) (*http.Response, error) { |
|
header := http_range.ApplyRangeToHttpHeader(params.Range, params.HeaderRef) |
|
|
|
res, err := RequestHttp(ctx, "GET", header, params.URL) |
|
if err != nil { |
|
return nil, err |
|
} |
|
return res, nil |
|
} |
|
|
|
type HttpRequestParams struct { |
|
URL string |
|
|
|
Range http_range.Range |
|
HeaderRef http.Header |
|
|
|
Size int64 |
|
} |
|
type errReadingBody struct { |
|
err error |
|
} |
|
|
|
func (e *errReadingBody) Error() string { |
|
return fmt.Sprintf("failed to read part body: %v", e.err) |
|
} |
|
|
|
func (e *errReadingBody) Unwrap() error { |
|
return e.err |
|
} |
|
|
|
type MultiReadCloser struct { |
|
cfg *cfg |
|
closer closerFunc |
|
finish finishBufFUnc |
|
} |
|
|
|
type cfg struct { |
|
rPos int |
|
curBuf *Buf |
|
} |
|
|
|
type closerFunc func() error |
|
type finishBufFUnc func(id int) (isLast bool, buf *Buf) |
|
|
|
|
|
func NewMultiReadCloser(buf *Buf, c closerFunc, fb finishBufFUnc) *MultiReadCloser { |
|
return &MultiReadCloser{closer: c, finish: fb, cfg: &cfg{curBuf: buf}} |
|
} |
|
|
|
func (mr MultiReadCloser) Read(p []byte) (n int, err error) { |
|
if mr.cfg.curBuf == nil { |
|
return 0, io.EOF |
|
} |
|
n, err = mr.cfg.curBuf.Read(p) |
|
|
|
if err == io.EOF { |
|
log.Debugf("read_%d finished current buffer", mr.cfg.rPos) |
|
|
|
isLast, next := mr.finish(mr.cfg.rPos) |
|
if isLast { |
|
return n, io.EOF |
|
} |
|
mr.cfg.curBuf = next |
|
mr.cfg.rPos++ |
|
|
|
return n, nil |
|
} |
|
return n, err |
|
} |
|
func (mr MultiReadCloser) Close() error { |
|
return mr.closer() |
|
} |
|
|
|
type Buf struct { |
|
buffer *bytes.Buffer |
|
size int |
|
ctx context.Context |
|
off int |
|
rw sync.Mutex |
|
|
|
} |
|
|
|
|
|
|
|
func NewBuf(ctx context.Context, maxSize int, id int) *Buf { |
|
d := make([]byte, 0, maxSize) |
|
return &Buf{ |
|
ctx: ctx, |
|
buffer: bytes.NewBuffer(d), |
|
size: maxSize, |
|
|
|
} |
|
} |
|
func (br *Buf) Reset(size int) { |
|
br.buffer.Reset() |
|
br.size = size |
|
br.off = 0 |
|
} |
|
|
|
func (br *Buf) Read(p []byte) (n int, err error) { |
|
if err := br.ctx.Err(); err != nil { |
|
return 0, err |
|
} |
|
if len(p) == 0 { |
|
return 0, nil |
|
} |
|
if br.off >= br.size { |
|
return 0, io.EOF |
|
} |
|
br.rw.Lock() |
|
n, err = br.buffer.Read(p) |
|
br.rw.Unlock() |
|
if err == nil { |
|
br.off += n |
|
return n, err |
|
} |
|
if err != io.EOF { |
|
return n, err |
|
} |
|
if n != 0 { |
|
br.off += n |
|
return n, nil |
|
} |
|
|
|
|
|
select { |
|
case <-br.ctx.Done(): |
|
return 0, br.ctx.Err() |
|
|
|
|
|
case <-time.After(time.Millisecond * 200): |
|
return 0, nil |
|
} |
|
} |
|
|
|
func (br *Buf) Write(p []byte) (n int, err error) { |
|
if err := br.ctx.Err(); err != nil { |
|
return 0, err |
|
} |
|
br.rw.Lock() |
|
defer br.rw.Unlock() |
|
n, err = br.buffer.Write(p) |
|
select { |
|
|
|
default: |
|
} |
|
return |
|
} |
|
|
|
func (br *Buf) Close() { |
|
|
|
} |
|
|