|
import os |
|
from typing import Any, Dict |
|
from PIL import Image |
|
from huggingface_inference_toolkit.logging import logger |
|
from pymongo.mongo_client import MongoClient |
|
from diffusers.utils import load_image |
|
import numpy as np |
|
import pandas as pd |
|
import time |
|
from dataclasses import dataclass |
|
from pathlib import Path |
|
from typing import Optional |
|
|
|
import numpy as np |
|
import pandas as pd |
|
import timm |
|
import torch |
|
from huggingface_hub import hf_hub_download |
|
from huggingface_hub.utils import HfHubHTTPError |
|
from PIL import Image |
|
from simple_parsing import field |
|
from timm.data import create_transform, resolve_data_config |
|
from torch import Tensor, nn |
|
from torch.nn import functional as F |
|
|
|
HF_TOKEN = os.environ.get("HF_TOKEN", "") |
|
torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
MODEL_REPO_MAP = { |
|
"vit": "SmilingWolf/wd-vit-large-tagger-v3", |
|
} |
|
|
|
|
|
def pil_ensure_rgb(image: Image.Image) -> Image.Image: |
|
|
|
if image.mode not in ["RGB", "RGBA"]: |
|
image = image.convert("RGBA") if "transparency" in image.info else image.convert("RGB") |
|
|
|
if image.mode == "RGBA": |
|
canvas = Image.new("RGBA", image.size, (255, 255, 255)) |
|
canvas.alpha_composite(image) |
|
image = canvas.convert("RGB") |
|
return image |
|
|
|
|
|
def pil_pad_square(image: Image.Image) -> Image.Image: |
|
w, h = image.size |
|
|
|
px = max(image.size) |
|
|
|
canvas = Image.new("RGB", (px, px), (255, 255, 255)) |
|
canvas.paste(image, ((px - w) // 2, (px - h) // 2)) |
|
return canvas |
|
|
|
|
|
@dataclass |
|
class LabelData: |
|
names: list[str] |
|
rating: list[np.int64] |
|
general: list[np.int64] |
|
character: list[np.int64] |
|
|
|
|
|
def load_labels_hf( |
|
repo_id: str, |
|
revision: Optional[str] = None, |
|
token: Optional[str] = None, |
|
) -> LabelData: |
|
try: |
|
csv_path = hf_hub_download( |
|
repo_id=repo_id, filename="selected_tags.csv", revision=revision, token=token |
|
) |
|
csv_path = Path(csv_path).resolve() |
|
except HfHubHTTPError as e: |
|
raise FileNotFoundError(f"selected_tags.csv failed to download from {repo_id}") from e |
|
|
|
df: pd.DataFrame = pd.read_csv(csv_path, usecols=["name", "category"]) |
|
tag_data = LabelData( |
|
names=df["name"].tolist(), |
|
rating=list(np.where(df["category"] == 9)[0]), |
|
general=list(np.where(df["category"] == 0)[0]), |
|
character=list(np.where(df["category"] == 4)[0]), |
|
) |
|
|
|
return tag_data |
|
|
|
|
|
def get_tags( |
|
probs: Tensor, |
|
labels: LabelData, |
|
gen_threshold: float, |
|
char_threshold: float, |
|
): |
|
|
|
probs = list(zip(labels.names, probs.numpy())) |
|
|
|
|
|
rating_labels = dict([probs[i] for i in labels.rating]) |
|
|
|
|
|
gen_labels = [probs[i] for i in labels.general] |
|
gen_labels = dict([x for x in gen_labels if x[1] > gen_threshold]) |
|
gen_labels = dict(sorted(gen_labels.items(), key=lambda item: item[1], reverse=True)) |
|
|
|
|
|
char_labels = [probs[i] for i in labels.character] |
|
char_labels = dict([x for x in char_labels if x[1] > char_threshold]) |
|
char_labels = dict(sorted(char_labels.items(), key=lambda item: item[1], reverse=True)) |
|
|
|
|
|
combined_names = [x for x in gen_labels] |
|
combined_names.extend([x for x in char_labels]) |
|
|
|
|
|
caption = ", ".join(combined_names) |
|
taglist = caption.replace("_", " ").replace("(", "\(").replace(")", "\)") |
|
|
|
return caption, taglist, rating_labels, char_labels, gen_labels |
|
|
|
|
|
@dataclass |
|
class ScriptOptions: |
|
image_file: Path = field(positional=True) |
|
model: str = field(default="vit") |
|
gen_threshold: float = field(default=0.35) |
|
char_threshold: float = field(default=0.75) |
|
|
|
class EndpointHandler: |
|
def __init__(self, path=""): |
|
self.opts = ScriptOptions |
|
repo_id = MODEL_REPO_MAP.get(self.opts.model) |
|
|
|
self.model: nn.Module = timm.create_model("hf-hub:" + repo_id).eval() |
|
state_dict = timm.models.load_state_dict_from_hf(repo_id) |
|
self.model.load_state_dict(state_dict) |
|
|
|
self.labels: LabelData = load_labels_hf(repo_id=repo_id) |
|
|
|
self.transform = create_transform(**resolve_data_config(self.model.pretrained_cfg, model=self.model)) |
|
|
|
|
|
if torch_device.type != "cpu": |
|
self.model = self.model.to(torch_device) |
|
|
|
uri = os.environ.get("MongoDB", "") |
|
self.client = MongoClient(uri) |
|
|
|
self.db = self.client['nomorecopyright'] |
|
self.collection = self.db['imagerequests'] |
|
|
|
self.query = {"keywords": {"$exists": False}} |
|
self.projection = {"_id": 0, "createdImage": 1} |
|
|
|
def __call__(self, data: Dict[str, Any]) -> str: |
|
logger.info(f"Received incoming request with {data=}") |
|
|
|
if "inputs" in data and isinstance(data["inputs"], str): |
|
prompt = data.pop("inputs") |
|
else: |
|
raise ValueError( |
|
"Provided input body must contain either the key `inputs` or `prompt` with the" |
|
" prompt to use for the image generation, and it needs to be a non-empty string." |
|
) |
|
start_index,limit_count=prompt.split(',') |
|
start_index=int(start_index) |
|
limit_count=int(limit_count) |
|
logger.info(f"Start index: {start_index}, Limit count: {limit_count}") |
|
data = list(self.collection.find(self.query).skip(start_index).limit(limit_count)) |
|
start_time=time.time() |
|
for document in data: |
|
try: |
|
image=load_image(document.get('createdImage', 'https://nomorecopyright.com/default.jpg')) |
|
|
|
|
|
img_input = pil_ensure_rgb(image) |
|
|
|
img_input = pil_pad_square(img_input) |
|
|
|
inputs: Tensor = self.transform(img_input).unsqueeze(0) |
|
|
|
inputs = inputs[:, [2, 1, 0]] |
|
with torch.inference_mode(): |
|
|
|
if torch_device.type != "cpu": |
|
inputs = inputs.to(torch_device) |
|
outputs = self.model.forward(inputs) |
|
|
|
outputs = F.sigmoid(outputs) |
|
|
|
if torch_device.type != "cpu": |
|
inputs = inputs.to("cpu") |
|
outputs = outputs.to("cpu") |
|
caption, taglist, ratings, character, general = get_tags( |
|
probs=outputs.squeeze(0), |
|
labels=self.labels, |
|
gen_threshold=self.opts.gen_threshold, |
|
char_threshold=self.opts.char_threshold, |
|
) |
|
|
|
results={**ratings, **character, **general} |
|
results={key: float(value) for key, value in results.items()} |
|
|
|
saveQuery = {"_id": document.get('_id')} |
|
|
|
update_result = self.collection.update_one(saveQuery , {'$set': {'keywords': results}}) |
|
except Exception as e: |
|
logger.error(f"Error processing image: {e}") |
|
end_time=time.time() |
|
print(f"Time taken: {end_time-start_time:.2f} seconds") |
|
return 'OK' |