File size: 4,143 Bytes
d77c5aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23f0311
 
d77c5aa
 
 
 
 
 
23f0311
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from huggingface_hub import upload_file, create_repo
import gradio as gr
import os
import requests
import tempfile
import requests
import re


article = """
Some things to note:

* To obtain the download link of the state dict hosted on CivitAI, right click on the "Download" button, visible on the model page.
* If the creator of the state dict requires the users to login to CivitAI first, that means downloading the state dict will require the CivitAI API keys.
* To obtain the API key of your CivitAI account, head to https://civitai.com/user/account and scroll to "API Keys".
* For such state dicts, it is mandatory to pass `civitai_api_key`.
* If you are getting "429 Client Error: Too Many Requests for url" error, retry passing your CivitAI API key.

"""

def download_locally_and_upload_to_hub(civit_url, repo_id, hf_token=None, civitai_api_key=None):
    if not civitai_api_key:
        civitai_api_key = None
    if civitai_api_key:
        headers = {
            "Authorization": f"Bearer {civitai_api_key}",
            "Accept": "application/json"
        }
    else:
        headers = None
    
    response = requests.get(civit_url, headers=headers, stream=True)
    response.raise_for_status()

    cd = response.headers.get("Content-Disposition")
    if cd:
        # This regular expression will try to find a filename attribute in the header
        fname = re.findall('filename="?([^"]+)"?', cd)
        if fname:
            filename = fname[0]
        else:
            filename = civit_url.split("/")[-1]
    else:
        filename = civit_url.split("/")[-1]

    with tempfile.TemporaryDirectory() as local_path:
        local_path = os.path.join(local_path, filename)
        with open(local_path, "wb") as file:
            for chunk in response.iter_content(chunk_size=8192):
                if chunk:  # filter out keep-alive new chunks
                    file.write(chunk)

        if repo_id:
            repo_successfully_created = False
            if not hf_token:
                hf_token = None
            try:
                repo_id = create_repo(repo_id=repo_id, exist_ok=True, token=hf_token).repo_id
                repo_successfully_created = True
            except Exception as e:
                error_message_on_repo_creation = e 
            
            file_successfully_committed = False
            try:
                if repo_successfully_created:
                    commit_info = upload_file(repo_id=repo_id, path_or_fileobj=local_path, path_in_repo=filename, token=hf_token)
                    file_successfully_committed = True
            except Exception as e:
                error_message_on_file_commit = e
            
    
    if repo_successfully_created and file_successfully_committed:
        return f"Pushed the checkpoint here: [{commit_info._url}]({commit_info._url})"
    elif not repo_successfully_created:
        return f"Error happened during repo creation: {error_message_on_repo_creation}"
    elif not file_successfully_committed:
        return f"Error happened during committing the file: {error_message_on_file_commit}"


def get_gradio_demo():
    demo = gr.Interface(
        title="Upload CivitAI checkpoints to the HF Hub 🤗",
        article=article,
        description="**See instructions below the form.**",
        fn=download_locally_and_upload_to_hub,
        inputs=[
            gr.Textbox(lines=1, info="Download URL of the CivitAI checkpoint."),
            gr.Textbox(lines=1, info="Repo ID for the checkpoint to upload on the Hub."),
            gr.TextArea(lines=1, info="Your HF token. Generate one from https://huggingface.co/settings/tokens."),
            gr.TextArea(lines=1, info="Civit API key to download the checkpoint if needed. Should be left otherwise.")
        ],
        outputs="markdown",
        examples=[
            ['https://civitai.com/api/download/models/1432115?type=Model&format=SafeTensor', 'sayakpaul/civitai-test', '', ''],
        ],
        cache_examples="lazy",
        allow_flagging="never"
    )
    return demo

if __name__ == "__main__":
    demo = get_gradio_demo()
    demo.launch(show_error=True)