stable-diffusion-tpu / components /main /hooks /useInputGeneration.ts
Esteves Enzo
unused log
d2a707e
raw
history blame
2.85 kB
import { useId, useState } from "react"
import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query"
import { useLocalStorage } from 'react-use';
import { Collection } from "@/utils/type"
import list_styles from "@/assets/list_styles.json"
import { useCollection } from "@/components/modal/useCollection";
export const useInputGeneration = () => {
const { setOpen } = useCollection();
const [myGenerationsId, setGenerationsId] = useLocalStorage<any>('my-own-generations', []);
const [style, setStyle] = useState<string>(list_styles[0].name)
const { data: prompt } = useQuery(["prompt"], () => {
return ''
}, {
refetchOnWindowFocus: false,
refetchOnMount: false,
refetchOnReconnect: false,
initialData: ''
})
const setPrompt = (str:string) => client.setQueryData(["prompt"], () => str)
const client = useQueryClient()
const { mutate: submit, isLoading: loading } = useMutation(
["generation"],
async () => {
// generate string random ID
const id = Math.random().toString(36).substring(2, 15) + Math.random().toString(36).substring(2, 15)
if (!hasMadeFirstGeneration) setFirstGenerationDone()
client.setQueryData(["collections"], (old) => {
return [{
id,
loading: true,
blob: {
type: "image/png",
data: new ArrayBuffer(0),
},
prompt
}, ...old as Collection[]]
})
const findStyle = list_styles.find((item) => item.name === style)
const response = await fetch("/api", {
method: "POST",
body: JSON.stringify({
inputs: findStyle?.prompt.replace("{prompt}", prompt) ?? prompt,
negative_prompt: findStyle?.negative_prompt ?? "",
}),
})
const data = await response.json()
client.setQueryData(["collections"], (old) => {
const newArray = [...old as Collection[]]
const index = newArray.findIndex((item: Collection) => item.id === id)
newArray[index] = !data.ok ? {
...newArray[index],
error: data.message
} : data?.blob as Collection
return newArray
})
if (!data.ok) return null
setGenerationsId(myGenerationsId?.length ? [...myGenerationsId, data?.blob?.id] : [data?.blob?.id])
setOpen(data?.blob?.id)
return data ?? {}
}
)
const { data: hasMadeFirstGeneration } = useQuery(["firstGenerationDone"], () => {
return false
}, {
refetchOnWindowFocus: false,
refetchOnMount: false,
refetchOnReconnect: false,
initialData: false
})
const setFirstGenerationDone = () => client.setQueryData(["firstGenerationDone"], () => true)
return {
prompt,
setPrompt,
loading,
submit,
hasMadeFirstGeneration,
list_styles,
style,
setStyle
}
}