File size: 3,249 Bytes
651d019
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
package middleware

import (
	"crypto/subtle"
	"errors"

	"github.com/dave-gray101/v2keyauth"
	"github.com/gofiber/fiber/v2"
	"github.com/gofiber/fiber/v2/middleware/keyauth"
	"github.com/microcosm-cc/bluemonday"
	"github.com/mudler/LocalAI/core/config"
)

// This file contains the configuration generators and handler functions that are used along with the fiber/keyauth middleware
// Currently this requires an upstream patch - and feature patches are no longer accepted to v2
// Therefore `dave-gray101/v2keyauth` contains the v2 backport of the middleware until v3 stabilizes and we migrate.

func GetKeyAuthConfig(applicationConfig *config.ApplicationConfig) (*v2keyauth.Config, error) {
	customLookup, err := v2keyauth.MultipleKeySourceLookup([]string{"header:Authorization", "header:x-api-key", "header:xi-api-key"}, keyauth.ConfigDefault.AuthScheme)
	if err != nil {
		return nil, err
	}

	return &v2keyauth.Config{
		CustomKeyLookup: customLookup,
		Next:            getApiKeyRequiredFilterFunction(applicationConfig),
		Validator:       getApiKeyValidationFunction(applicationConfig),
		ErrorHandler:    getApiKeyErrorHandler(applicationConfig),
		AuthScheme:      "Bearer",
	}, nil
}

func getApiKeyErrorHandler(applicationConfig *config.ApplicationConfig) fiber.ErrorHandler {
	return func(ctx *fiber.Ctx, err error) error {
		if errors.Is(err, v2keyauth.ErrMissingOrMalformedAPIKey) {
			if len(applicationConfig.ApiKeys) == 0 {
				return ctx.Next() // if no keys are set up, any error we get here is not an error.
			}
			if applicationConfig.OpaqueErrors {
				return ctx.SendStatus(403)
			}
			return ctx.Status(403).SendString(bluemonday.StrictPolicy().Sanitize(err.Error()))
		}
		if applicationConfig.OpaqueErrors {
			return ctx.SendStatus(500)
		}
		return err
	}
}

func getApiKeyValidationFunction(applicationConfig *config.ApplicationConfig) func(*fiber.Ctx, string) (bool, error) {

	if applicationConfig.UseSubtleKeyComparison {
		return func(ctx *fiber.Ctx, apiKey string) (bool, error) {
			if len(applicationConfig.ApiKeys) == 0 {
				return true, nil // If no keys are setup, accept everything
			}
			for _, validKey := range applicationConfig.ApiKeys {
				if subtle.ConstantTimeCompare([]byte(apiKey), []byte(validKey)) == 1 {
					return true, nil
				}
			}
			return false, v2keyauth.ErrMissingOrMalformedAPIKey
		}
	}

	return func(ctx *fiber.Ctx, apiKey string) (bool, error) {
		if len(applicationConfig.ApiKeys) == 0 {
			return true, nil // If no keys are setup, accept everything
		}
		for _, validKey := range applicationConfig.ApiKeys {
			if apiKey == validKey {
				return true, nil
			}
		}
		return false, v2keyauth.ErrMissingOrMalformedAPIKey
	}
}

func getApiKeyRequiredFilterFunction(applicationConfig *config.ApplicationConfig) func(*fiber.Ctx) bool {
	if applicationConfig.DisableApiKeyRequirementForHttpGet {
		return func(c *fiber.Ctx) bool {
			if c.Method() != "GET" {
				return false
			}
			for _, rx := range applicationConfig.HttpGetExemptedEndpoints {
				if rx.MatchString(c.Path()) {
					return true
				}
			}
			return false
		}
	}
	return func(c *fiber.Ctx) bool { return false }
}