Spaces:
Running
Running
File size: 3,357 Bytes
72f0edb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import { useState, useCallback, useEffect } from "react";
import { useMutation } from "@tanstack/react-query";
import { v4 as uuidv4 } from "uuid";
// Generate a conversation ID to help backend keep track of the chat history
const getConversationId = () => {
const storedId = localStorage.getItem("chat_conversation_id");
if (storedId) return storedId;
const newId = uuidv4();
localStorage.setItem("chat_conversation_id", newId);
return newId;
};
export interface ChatMessage {
sender: string;
text: string;
imageUrl?: string;
}
export const useChatApi = () => {
const [conversationId] = useState(getConversationId);
const [isStreaming, setIsStreaming] = useState(false);
// Use mutation for the API call
const chatMutation = useMutation({
mutationFn: async (prompt: string) => {
const controller = new AbortController();
const signal = controller.signal;
const response = await fetch("http://127.0.0.1:8000/prompt", {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
session_id: conversationId,
prompt: prompt,
}),
signal,
});
if (!response.ok) {
throw new Error("Network response was not ok");
}
return { response, controller };
},
});
// Process the streaming response
const streamResponse = useCallback(
async (
response: Response,
onChunk: (chunk: string) => void,
onImage?: (imageUrl: string) => void
) => {
if (!response.body) {
throw new Error("Response body is null");
}
setIsStreaming(true);
const reader = response.body.getReader();
const decoder = new TextDecoder();
let buffer = "";
try {
while (true) {
const { done, value } = await reader.read();
if (done) {
break;
}
// Decode the chunk
const text = decoder.decode(value, { stream: true });
buffer += text;
// Process SSE format: "data: {json}\n\n"
const lines = buffer.split("\n\n");
buffer = lines.pop() || "";
for (const line of lines) {
if (line.startsWith("data: ")) {
const jsonStr = line.slice(6);
try {
const data = JSON.parse(jsonStr);
// Check if there's text content
if (data.text) {
onChunk(data.text);
}
// Check if there's an image URL
if (data.image_url) {
onImage?.(data.image_url);
}
} catch (e) {
console.error("Error parsing SSE JSON:", e);
}
}
}
}
} catch (error) {
console.error("Error reading stream:", error);
} finally {
setIsStreaming(false);
}
},
[]
);
// Cleanup function to abort any pending requests
useEffect(() => {
return () => {
if (chatMutation.data?.controller) {
chatMutation.data.controller.abort();
}
};
}, [chatMutation.data]);
return {
sendMessage: chatMutation.mutate,
streamResponse,
isLoading: chatMutation.isPending || isStreaming,
error: chatMutation.error,
};
};
|