Spaces:
Runtime error
Runtime error
File size: 4,348 Bytes
48511d8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
package proxy
import (
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"io"
"log"
"net/http"
"github.com/arpinfidel/p2p-llm/config"
"github.com/arpinfidel/p2p-llm/db"
)
type ProxyHandler struct {
cfg *config.Config
peerRepo db.PeerRepository
queue chan Request
maxParallelRequests int
maxParallelPeerRequests int
}
type Request struct {
W http.ResponseWriter
R *http.Request
}
func NewProxyHandler(cfg *config.Config, peerRepo db.PeerRepository) *ProxyHandler {
return &ProxyHandler{
cfg: cfg,
peerRepo: peerRepo,
queue: make(chan Request, 100), // Hardcoded queue size
maxParallelRequests: cfg.MaxParallelRequests,
maxParallelPeerRequests: 5, // Hardcoded peer requests limit
}
}
func (h *ProxyHandler) Handle(w http.ResponseWriter, r *http.Request) {
req := Request{W: w, R: r}
h.queue <- req
}
func (h *ProxyHandler) Run() {
// Get peers from config
peers := h.cfg.TrustedPeers
// Get peers from database
dbPeers, err := h.peerRepo.ListTrustedPeers()
if err != nil {
log.Printf("Error getting peers from database: %v", err)
} else {
for _, p := range dbPeers {
block, _ := pem.Decode([]byte(p.PublicKey))
if block == nil {
continue
}
pubKey, err := x509.ParsePKIXPublicKey(block.Bytes)
if err != nil {
continue
}
peers = append(peers, config.Peer{
URL: p.URL,
PublicKey: pubKey.(*rsa.PublicKey),
})
}
}
// Start workers for target URL
for range h.maxParallelRequests {
go func() {
for req := range h.queue {
h.Forward(req, h.cfg.TargetURL)
}
}()
}
// Start workers for each peer
for _, peer := range peers {
go func(p config.Peer) {
for req := range h.queue {
h.Forward(req, p.URL)
}
}(peer)
}
}
func (h *ProxyHandler) Forward(req Request, url string) {
// Create a new request to the target server
targetReq, err := http.NewRequest(req.R.Method, url+req.R.URL.Path, req.R.Body)
if err != nil {
log.Printf("Error creating request: %v", err)
http.Error(req.W, "Error creating request", http.StatusInternalServerError)
return
}
// Copy headers from original request
for name, values := range req.R.Header {
for _, value := range values {
targetReq.Header.Add(name, value)
}
}
// Create HTTP client
client := &http.Client{}
// Send the request to the target server
resp, err := client.Do(targetReq)
if err != nil {
log.Printf("Error forwarding request: %v", err)
http.Error(req.W, "Error forwarding request", http.StatusBadGateway)
return
}
defer resp.Body.Close()
// Check if this is an SSE response
isSSE := false
for name, values := range resp.Header {
for _, value := range values {
req.W.Header().Add(name, value)
if name == "Content-Type" && value == "text/event-stream" {
isSSE = true
}
}
}
// Set response status code
req.W.WriteHeader(resp.StatusCode)
// Handle SSE responses differently
if isSSE {
// Set necessary headers for SSE
req.W.Header().Set("Content-Type", "text/event-stream")
req.W.Header().Set("Cache-Control", "no-cache")
req.W.Header().Set("Connection", "keep-alive")
req.W.Header().Set("Transfer-Encoding", "chunked")
// Create a flusher if the ResponseWriter supports it
flusher, ok := req.W.(http.Flusher)
if !ok {
log.Printf("ResponseWriter does not support flushing")
http.Error(req.W, "Streaming unsupported", http.StatusInternalServerError)
return
}
// Buffer for reading from response body
buf := make([]byte, 1024)
for {
n, err := resp.Body.Read(buf)
if n > 0 {
// Write data to client
if _, writeErr := req.W.Write(buf[:n]); writeErr != nil {
log.Printf("Error writing to client: %v", writeErr)
break
}
// Flush data immediately to client
flusher.Flush()
}
if err != nil {
if err != io.EOF {
log.Printf("Error reading from response body: %v", err)
}
break
}
}
} else {
// For non-SSE responses, just copy the body
if _, err := io.Copy(req.W, resp.Body); err != nil {
log.Printf("Error copying response body: %v", err)
}
}
}
|