|
package net |
|
|
|
import ( |
|
"fmt" |
|
"io" |
|
"math" |
|
"mime/multipart" |
|
"net/http" |
|
"net/textproto" |
|
"strings" |
|
"time" |
|
|
|
"github.com/alist-org/alist/v3/pkg/http_range" |
|
log "github.com/sirupsen/logrus" |
|
) |
|
|
|
|
|
|
|
|
|
func scanETag(s string) (etag string, remain string) { |
|
s = textproto.TrimString(s) |
|
start := 0 |
|
if strings.HasPrefix(s, "W/") { |
|
start = 2 |
|
} |
|
if len(s[start:]) < 2 || s[start] != '"' { |
|
return "", "" |
|
} |
|
|
|
|
|
for i := start + 1; i < len(s); i++ { |
|
c := s[i] |
|
switch { |
|
|
|
case c == 0x21 || c >= 0x23 && c <= 0x7E || c >= 0x80: |
|
case c == '"': |
|
return s[:i+1], s[i+1:] |
|
default: |
|
return "", "" |
|
} |
|
} |
|
return "", "" |
|
} |
|
|
|
|
|
|
|
func etagStrongMatch(a, b string) bool { |
|
return a == b && a != "" && a[0] == '"' |
|
} |
|
|
|
|
|
|
|
func etagWeakMatch(a, b string) bool { |
|
return strings.TrimPrefix(a, "W/") == strings.TrimPrefix(b, "W/") |
|
} |
|
|
|
|
|
|
|
type condResult int |
|
|
|
const ( |
|
condNone condResult = iota |
|
condTrue |
|
condFalse |
|
) |
|
|
|
func checkIfMatch(w http.ResponseWriter, r *http.Request) condResult { |
|
im := r.Header.Get("If-Match") |
|
if im == "" { |
|
return condNone |
|
} |
|
for { |
|
im = textproto.TrimString(im) |
|
if len(im) == 0 { |
|
break |
|
} |
|
if im[0] == ',' { |
|
im = im[1:] |
|
continue |
|
} |
|
if im[0] == '*' { |
|
return condTrue |
|
} |
|
etag, remain := scanETag(im) |
|
if etag == "" { |
|
break |
|
} |
|
if etagStrongMatch(etag, w.Header().Get("Etag")) { |
|
return condTrue |
|
} |
|
im = remain |
|
} |
|
|
|
return condFalse |
|
} |
|
|
|
func checkIfUnmodifiedSince(r *http.Request, modtime time.Time) condResult { |
|
ius := r.Header.Get("If-Unmodified-Since") |
|
if ius == "" || isZeroTime(modtime) { |
|
return condNone |
|
} |
|
t, err := http.ParseTime(ius) |
|
if err != nil { |
|
return condNone |
|
} |
|
|
|
|
|
|
|
modtime = modtime.Truncate(time.Second) |
|
if ret := modtime.Compare(t); ret <= 0 { |
|
return condTrue |
|
} |
|
return condFalse |
|
} |
|
|
|
func checkIfNoneMatch(w http.ResponseWriter, r *http.Request) condResult { |
|
inm := r.Header.Get("If-None-Match") |
|
if inm == "" { |
|
return condNone |
|
} |
|
buf := inm |
|
for { |
|
buf = textproto.TrimString(buf) |
|
if len(buf) == 0 { |
|
break |
|
} |
|
if buf[0] == ',' { |
|
buf = buf[1:] |
|
continue |
|
} |
|
if buf[0] == '*' { |
|
return condFalse |
|
} |
|
etag, remain := scanETag(buf) |
|
if etag == "" { |
|
break |
|
} |
|
if etagWeakMatch(etag, w.Header().Get("Etag")) { |
|
return condFalse |
|
} |
|
buf = remain |
|
} |
|
return condTrue |
|
} |
|
|
|
func checkIfModifiedSince(r *http.Request, modtime time.Time) condResult { |
|
if r.Method != "GET" && r.Method != "HEAD" { |
|
return condNone |
|
} |
|
ims := r.Header.Get("If-Modified-Since") |
|
if ims == "" || isZeroTime(modtime) { |
|
return condNone |
|
} |
|
t, err := http.ParseTime(ims) |
|
if err != nil { |
|
return condNone |
|
} |
|
|
|
|
|
modtime = modtime.Truncate(time.Second) |
|
if ret := modtime.Compare(t); ret <= 0 { |
|
return condFalse |
|
} |
|
return condTrue |
|
} |
|
|
|
func checkIfRange(w http.ResponseWriter, r *http.Request, modtime time.Time) condResult { |
|
if r.Method != "GET" && r.Method != "HEAD" { |
|
return condNone |
|
} |
|
ir := r.Header.Get("If-Range") |
|
if ir == "" { |
|
return condNone |
|
} |
|
etag, _ := scanETag(ir) |
|
if etag != "" { |
|
if etagStrongMatch(etag, w.Header().Get("Etag")) { |
|
return condTrue |
|
} |
|
return condFalse |
|
} |
|
|
|
|
|
if modtime.IsZero() { |
|
return condFalse |
|
} |
|
t, err := http.ParseTime(ir) |
|
if err != nil { |
|
return condFalse |
|
} |
|
if t.Unix() == modtime.Unix() { |
|
return condTrue |
|
} |
|
return condFalse |
|
} |
|
|
|
var unixEpochTime = time.Unix(0, 0) |
|
|
|
|
|
func isZeroTime(t time.Time) bool { |
|
return t.IsZero() || t.Equal(unixEpochTime) |
|
} |
|
|
|
func setLastModified(w http.ResponseWriter, modtime time.Time) { |
|
if !isZeroTime(modtime) { |
|
w.Header().Set("Last-Modified", modtime.UTC().Format(http.TimeFormat)) |
|
} |
|
} |
|
|
|
func writeNotModified(w http.ResponseWriter) { |
|
|
|
|
|
|
|
|
|
|
|
h := w.Header() |
|
delete(h, "Content-Type") |
|
delete(h, "Content-Length") |
|
delete(h, "Content-Encoding") |
|
if h.Get("Etag") != "" { |
|
delete(h, "Last-Modified") |
|
} |
|
w.WriteHeader(http.StatusNotModified) |
|
} |
|
|
|
|
|
|
|
func checkPreconditions(w http.ResponseWriter, r *http.Request, modtime time.Time) (done bool, rangeHeader string) { |
|
|
|
ch := checkIfMatch(w, r) |
|
if ch == condNone { |
|
ch = checkIfUnmodifiedSince(r, modtime) |
|
} |
|
if ch == condFalse { |
|
w.WriteHeader(http.StatusPreconditionFailed) |
|
return true, "" |
|
} |
|
switch checkIfNoneMatch(w, r) { |
|
case condFalse: |
|
if r.Method == "GET" || r.Method == "HEAD" { |
|
writeNotModified(w) |
|
return true, "" |
|
} |
|
w.WriteHeader(http.StatusPreconditionFailed) |
|
return true, "" |
|
case condNone: |
|
if checkIfModifiedSince(r, modtime) == condFalse { |
|
writeNotModified(w) |
|
return true, "" |
|
} |
|
} |
|
|
|
rangeHeader = r.Header.Get("Range") |
|
if rangeHeader != "" && checkIfRange(w, r, modtime) == condFalse { |
|
rangeHeader = "" |
|
} |
|
return false, rangeHeader |
|
} |
|
|
|
func sumRangesSize(ranges []http_range.Range) (size int64) { |
|
for _, ra := range ranges { |
|
size += ra.Length |
|
} |
|
return |
|
} |
|
|
|
|
|
type countingWriter int64 |
|
|
|
func (w *countingWriter) Write(p []byte) (n int, err error) { |
|
*w += countingWriter(len(p)) |
|
return len(p), nil |
|
} |
|
|
|
|
|
|
|
func rangesMIMESize(ranges []http_range.Range, contentType string, contentSize int64) (encSize int64, err error) { |
|
var w countingWriter |
|
mw := multipart.NewWriter(&w) |
|
for _, ra := range ranges { |
|
_, err := mw.CreatePart(ra.MimeHeader(contentType, contentSize)) |
|
if err != nil { |
|
return 0, err |
|
} |
|
encSize += ra.Length |
|
} |
|
err = mw.Close() |
|
if err != nil { |
|
return 0, err |
|
} |
|
encSize += int64(w) |
|
return encSize, nil |
|
} |
|
|
|
|
|
type LimitedReadCloser struct { |
|
rc io.ReadCloser |
|
remaining int |
|
} |
|
|
|
func (l *LimitedReadCloser) Read(buf []byte) (int, error) { |
|
if l.remaining <= 0 { |
|
return 0, io.EOF |
|
} |
|
|
|
if len(buf) > l.remaining { |
|
buf = buf[0:l.remaining] |
|
} |
|
|
|
n, err := l.rc.Read(buf) |
|
l.remaining -= n |
|
|
|
return n, err |
|
} |
|
|
|
func (l *LimitedReadCloser) Close() error { |
|
return l.rc.Close() |
|
} |
|
|
|
|
|
|
|
func GetRangedHttpReader(readCloser io.ReadCloser, offset, length int64) (io.ReadCloser, error) { |
|
var length_int int |
|
if length > math.MaxInt { |
|
return nil, fmt.Errorf("doesnot support length bigger than int32 max ") |
|
} |
|
length_int = int(length) |
|
|
|
if offset > 100*1024*1024 { |
|
log.Warnf("offset is more than 100MB, if loading data from internet, high-latency and wasting of bandwidth is expected") |
|
} |
|
|
|
if _, err := io.Copy(io.Discard, io.LimitReader(readCloser, offset)); err != nil { |
|
return nil, err |
|
} |
|
|
|
|
|
return &LimitedReadCloser{readCloser, length_int}, nil |
|
} |
|
|