"""Template Demo for IBM Granite Hugging Face spaces.""" from collections.abc import Iterator from datetime import datetime from pathlib import Path from threading import Thread import gradio as gr import PIL import spaces import torch from PIL.Image import Image as PILImage from PIL.Image import Resampling from transformers import ( AutoModelForCausalLM, AutoModelForVision2Seq, AutoProcessor, AutoTokenizer, LlavaNextForConditionalGeneration, LlavaNextProcessor, TextIteratorStreamer, ) from themes.research_monochrome import theme dir_ = Path(__file__).parent.parent today_date = datetime.today().strftime("%B %-d, %Y") # noqa: DTZ002 MODEL_ID = "ibm-granite/granite-vision-3.2-2b" MODEL_ID_PREVIEW = "ibm-granite/granite-vision-3.1-2b-preview" # SYS_PROMPT = f"""Knowledge Cutoff Date: April 2024. # Today's Date: {today_date}. # You are Granite, developed by IBM. You are a helpful AI assistant""" TITLE = "IBM Granite VISION 3.1 2b preview" DESCRIPTION = "Try one of the sample prompts below or write your own. Remember, \ AI models can make mistakes." MAX_INPUT_TOKEN_LENGTH = 4096 MAX_NEW_TOKENS = 1024 TEMPERATURE = 0.7 TOP_P = 0.85 TOP_K = 50 REPETITION_PENALTY = 1.05 sample_data = [ [ "https://www.ibm.com/design/language/static/159e89b3d8d6efcb5db43f543df36b23/a5df1/rebusgallery_tshirt.png", ["What are the three symbols on the tshirt?"], ], [ str(dir_ / "data" / "p2-report.png"), [ "What's the difference in rental income between 2020 and 2019?", "Which table entries are less in 2020 than 2019?", ], ], [ "https://www.ibm.com/design/language/static/159e89b3d8d6efcb5db43f543df36b23/a5df1/rebusgallery_tshirt.png", ["What's this?"], ], ] device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu") processor: LlavaNextProcessor = None model: LlavaNextForConditionalGeneration = None selected_image: PILImage = None def image_changed(im: PILImage): global selected_image if im is None: selected_image = None else: selected_image = im.copy() selected_image.thumbnail((800, 800)) # return selected_image def create_single_turn(image: PILImage, text: str) -> dict: if image is None: return { "role": "user", "content": [ {"type": "text", "text": text}, ], } else: return { "role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": text}, ], } @spaces.GPU def generate( image: PILImage, message: str, chat_history: list[dict], temperature: float = TEMPERATURE, repetition_penalty: float = REPETITION_PENALTY, top_p: float = TOP_P, top_k: float = TOP_K, max_new_tokens: int = MAX_NEW_TOKENS, ): """Generate function for chat demo. Args: max_new_tokens: top_k: top_p: repetition_penalty: temperature: image: the image to be talked about... message (str): The latest input message from the user. chat_history (list[dict]): A list of dictionaries representing previous chat history, where each dictionary contains 'role' and 'content'. Yields: str: The generated response, broken down into smaller chunks. """ print(top_p) # Build messages conversation = [] # TODO: maybe add back custom sys prompt # conversation.append({"role": "system", "content": SYS_PROMPT}) conversation += chat_history conversation.append(create_single_turn(image, message)) # Convert messages to prompt format inputs = processor.apply_chat_template( conversation, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" ).to(device) # TODO: This might cut out the image tokens -- find better strategy # if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: # input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] # gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.") generate_kwargs = dict( max_new_tokens=max_new_tokens, do_sample=True, top_p=top_p, top_k=top_k, temperature=temperature, num_beams=1, repetition_penalty=repetition_penalty, ) output = model.generate(**inputs, **generate_kwargs) out = processor.decode(output[0], skip_special_tokens=True) out_s = out.strip().split("<|assistant|>") return [gr.ChatMessage(role="user", content=message), gr.ChatMessage(role="assistant", content=out_s[-1])] def multimodal_generate_v2( msg: str, temperature: float = TEMPERATURE, repetition_penalty: float = REPETITION_PENALTY, top_p: float = TOP_P, top_k: float = TOP_K, max_new_tokens: int = MAX_NEW_TOKENS, ): global model, processor # lazy loading and adding image if model is None: processor = AutoProcessor.from_pretrained(MODEL_ID) model = AutoModelForVision2Seq.from_pretrained(MODEL_ID, device_map="auto").to(device) return generate( selected_image, msg, [], temperature=temperature, repetition_penalty=repetition_penalty, top_p=top_p, top_k=top_k, max_new_tokens=max_new_tokens, ) tb = gr.Textbox(submit_btn=True) # advanced settings (displayed in Accordion) temperature_slider = gr.Slider( minimum=0, maximum=1.0, value=TEMPERATURE, step=0.1, label="Temperature", elem_classes=["gr_accordion_element"], interactive=True, ) top_p_slider = gr.Slider( minimum=0, maximum=1.0, value=TOP_P, step=0.05, label="Top P", elem_classes=["gr_accordion_element"], interactive=True, ) top_k_slider = gr.Slider( minimum=0, maximum=100, value=TOP_K, step=1, label="Top K", elem_classes=["gr_accordion_element"], interactive=True ) repetition_penalty_slider = gr.Slider( minimum=0, maximum=2.0, value=REPETITION_PENALTY, step=0.05, label="Repetition Penalty", elem_classes=["gr_accordion_element"], interactive=True, ) max_new_tokens_slider = gr.Slider( minimum=1, maximum=2000, value=MAX_NEW_TOKENS, step=1, label="Max New Tokens", elem_classes=["gr_accordion_element"], interactive=True, ) chatbot = gr.Chatbot(examples=[{"text": "Hello World!"}], type="messages", label="Q&A about selected document") css_file_path = Path(Path(__file__).parent / "app.css") head_file_path = Path(Path(__file__).parent / "app_head.html") with gr.Blocks(fill_height=True, css_paths=css_file_path, head_paths=head_file_path, theme=theme, title=TITLE) as demo: is_in_edit_mode = gr.State(True) # in block to be reactive gr.Markdown(f"# {TITLE}") gr.Markdown(DESCRIPTION) with gr.Row(): with gr.Column(): # create sample image object for reference, render later image_x = gr.Image( type="pil", label="Example image", render=False, interactive=False, show_label=False, show_fullscreen_button=False, height=800, ) image_x.change(fn=image_changed, inputs=image_x) # Create Dataset object and render it ds = gr.Dataset(label="Select one document", samples=sample_data, components=[gr.Image(render=False)]) def sample_image_selected(d: gr.SelectData, dx): return gr.Image(dx[0]), gr.update(examples=[{"text": x} for x in dx[1]]) ds.select(lambda: [], outputs=[chatbot]) ds.select(sample_image_selected, inputs=[ds], outputs=[image_x, chatbot]) # Render image object after DS image_x.render() with gr.Column(): # Render ChatBot chatbot.render() # Define behavior for example selection def update_user_chat_x(x: gr.SelectData): return [gr.ChatMessage(role="user", content=x.value["text"])] def send_generate_x(x: gr.SelectData, temperature, repetition_penalty, top_p, top_k, max_new_tokens): txt = x.value["text"] return multimodal_generate_v2(txt, temperature, repetition_penalty, top_p, top_k, max_new_tokens) chatbot.example_select(lambda: False, outputs=is_in_edit_mode) chatbot.example_select(update_user_chat_x, outputs=[chatbot]) chatbot.example_select( send_generate_x, inputs=[ temperature_slider, repetition_penalty_slider, top_p_slider, top_k_slider, max_new_tokens_slider, ], outputs=[chatbot], ) # Create User Chat Textbox and Reset Button tbb = gr.Textbox(submit_btn=True, show_label=False) fb = gr.Button("Reset Chat", visible=False) fb.click(lambda: [], outputs=[chatbot]) # Handle toggling betwwen edit and non-edit mode def textbox_switch(emode): # if t.visible: if not emode: return [gr.update(visible=False), gr.update(visible=True)] else: return [gr.update(visible=True), gr.update(visible=False)] tbb.submit(lambda: False, outputs=[is_in_edit_mode]) fb.click(lambda: True, outputs=[is_in_edit_mode]) is_in_edit_mode.change(textbox_switch, inputs=[is_in_edit_mode], outputs=[tbb, fb]) # submit user question tbb.submit(lambda x: [gr.ChatMessage(role="user", content=x)], inputs=tbb, outputs=chatbot) tbb.submit( multimodal_generate_v2, inputs=[ tbb, temperature_slider, repetition_penalty_slider, top_p_slider, top_k_slider, max_new_tokens_slider, ], outputs=[chatbot], ) # extra model parameters with gr.Accordion("Advanced Settings", open=False): max_new_tokens_slider.render() temperature_slider.render() top_k_slider.render() top_p_slider.render() repetition_penalty_slider.render() if __name__ == "__main__": demo.queue(max_size=20).launch()