Spaces:
Sleeping
Sleeping
from huggingface_hub import upload_file, create_repo | |
import gradio as gr | |
import os | |
import requests | |
import tempfile | |
import requests | |
import re | |
article = """ | |
Some things to note: | |
* To obtain the download link of the state dict hosted on CivitAI, right click on the "Download" button, visible on the model page. | |
* If the creator of the state dict requires the users to login to CivitAI first, that means downloading the state dict will require the CivitAI API keys. | |
* To obtain the API key of your CivitAI account, head to https://civitai.com/user/account and scroll to "API Keys". | |
* For such state dicts, it is mandatory to pass `civitai_api_key`. | |
* If you are getting "429 Client Error: Too Many Requests for url" error, retry passing your CivitAI API key. | |
""" | |
def download_locally_and_upload_to_hub(civit_url, repo_id, hf_token=None, civitai_api_key=None): | |
if not civitai_api_key: | |
civitai_api_key = None | |
if civitai_api_key: | |
headers = { | |
"Authorization": f"Bearer {civitai_api_key}", | |
"Accept": "application/json" | |
} | |
else: | |
headers = None | |
response = requests.get(civit_url, headers=headers, stream=True) | |
response.raise_for_status() | |
cd = response.headers.get("Content-Disposition") | |
if cd: | |
# This regular expression will try to find a filename attribute in the header | |
fname = re.findall('filename="?([^"]+)"?', cd) | |
if fname: | |
filename = fname[0] | |
else: | |
filename = civit_url.split("/")[-1] | |
else: | |
filename = civit_url.split("/")[-1] | |
with tempfile.TemporaryDirectory() as local_path: | |
local_path = os.path.join(local_path, filename) | |
with open(local_path, "wb") as file: | |
for chunk in response.iter_content(chunk_size=8192): | |
if chunk: # filter out keep-alive new chunks | |
file.write(chunk) | |
if repo_id: | |
repo_successfully_created = False | |
if not hf_token: | |
hf_token = None | |
try: | |
repo_id = create_repo(repo_id=repo_id, exist_ok=True, token=hf_token).repo_id | |
repo_successfully_created = True | |
except Exception as e: | |
error_message_on_repo_creation = e | |
file_successfully_committed = False | |
try: | |
if repo_successfully_created: | |
commit_info = upload_file(repo_id=repo_id, path_or_fileobj=local_path, path_in_repo=filename, token=hf_token) | |
file_successfully_committed = True | |
except Exception as e: | |
error_message_on_file_commit = e | |
if repo_successfully_created and file_successfully_committed: | |
return f"Pushed the checkpoint here: [{commit_info._url}]({commit_info._url})" | |
elif not repo_successfully_created: | |
return f"Error happened during repo creation: {error_message_on_repo_creation}" | |
elif not file_successfully_committed: | |
return f"Error happened during committing the file: {error_message_on_file_commit}" | |
def get_gradio_demo(): | |
demo = gr.Interface( | |
title="Upload CivitAI checkpoints to the HF Hub 🤗", | |
article=article, | |
description="**See instructions below the form.**", | |
fn=download_locally_and_upload_to_hub, | |
inputs=[ | |
gr.Textbox(lines=1, info="Download URL of the CivitAI checkpoint."), | |
gr.Textbox(lines=1, info="Repo ID for the checkpoint to upload on the Hub."), | |
gr.TextArea(lines=1, info="Your HF token. Generate one from https://huggingface.co/settings/tokens."), | |
gr.TextArea(lines=1, info="Civit API key to download the checkpoint if needed. Should be left otherwise.") | |
], | |
outputs="markdown", | |
examples=[ | |
['https://civitai.com/api/download/models/1432115?type=Model&format=SafeTensor', 'sayakpaul/civitai-test', '', ''], | |
], | |
cache_examples="lazy", | |
allow_flagging="never" | |
) | |
return demo | |
if __name__ == "__main__": | |
demo = get_gradio_demo() | |
demo.launch(show_error=True) |