import webdataset as wds import io import torch ds = wds.WebDataset("output.tar") for s in ds: print(s.keys()) prompt_embeds_bytes = s["prompt_embeds.pt"] prompt_embeddings = torch.load(io.BytesIO(prompt_embeds_bytes)) print(prompt_embeddings.shape)