File size: 3,401 Bytes
cecbdb4
 
596d555
5923c3d
 
cecbdb4
 
 
 
 
0d65258
cecbdb4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8f22122
cecbdb4
 
 
 
 
 
 
 
 
 
 
 
 
5923c3d
cecbdb4
 
 
 
 
 
 
 
 
8f22122
5923c3d
 
 
 
 
 
 
 
 
 
 
 
062b1e0
 
5923c3d
cecbdb4
 
 
 
 
 
 
 
 
 
 
 
 
 
6051851
 
cecbdb4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import gradio as gr

from transformers import TextIteratorStreamer

from threading import Thread

from unsloth import FastLanguageModel

load_in_4bit = True

peft_model_id = "ID2223JR/lora_model"

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=peft_model_id,
    load_in_4bit=load_in_4bit,
)
FastLanguageModel.for_inference(model)


# Data storage
ingredients_list = []


# Function to add ingredient
def add_ingredient(ingredient, quantity):
    if ingredient and int(quantity) > 0:
        ingredients_list.append(f"{ingredient}, {quantity} grams")
    return (
        "\n".join(ingredients_list),
        gr.update(value="", interactive=True),
        gr.update(value=None, interactive=True),
    )


# Function to enable/disable add button
def validate_inputs(ingredient, quantity):
    if ingredient and int(quantity) > 0:
        return gr.update(interactive=True)
    return gr.update(interactive=False)


# Function to handle model submission
def submit_to_model():
    if not ingredients_list:
        return "Ingredients list is empty! Please add ingredients first."

    # Join ingredients into a single prompt
    prompt = f"Using the following ingredients, suggest a recipe:\n\n" + "\n".join(
        ingredients_list
    )
    ingredients_list.clear()

    messages = [
        {"role": "user", "content": prompt},
    ]
    inputs = tokenizer.apply_chat_template(
        messages,
        tokenize=True,
        add_generation_prompt=True,  # Must add for generation
        return_tensors="pt",
    ).to("cuda")

    text_streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)

    generation_kwargs = dict(inputs=inputs, streamer=text_streamer, use_cache=True, temperature=0.3, min_p=0.1)
    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()


    content = ""
    for text in text_streamer:
        print(text)
        content += text
        if content.endswith("<|eot_id|>"):
            content = content.replace("<|eot_id|>", "")
        yield content


# App
def app():
    with gr.Blocks() as demo:
        with gr.Row():
            ingredient_input = gr.Textbox(
                label="Ingredient", placeholder="Enter ingredient name"
            )
            quantity_input = gr.Number(label="Quantity (grams)", value=None)

        add_button = gr.Button("Add Ingredient", interactive=False)
        output = gr.Textbox(label="Ingredients List", lines=10, interactive=False)

        submit_button = gr.Button("Give me a meal!")

        with gr.Row():
            model_output = gr.Textbox(
                label="Recipe Suggestion", lines=10, interactive=False
            )

        # Validate inputs
        ingredient_input.change(
            validate_inputs, [ingredient_input, quantity_input], add_button
        )
        quantity_input.change(
            validate_inputs, [ingredient_input, quantity_input], add_button
        )

        # Add ingredient logic
        add_button.click(
            add_ingredient,
            [ingredient_input, quantity_input],
            [output, ingredient_input, quantity_input],
        )

        # Submit to model logic
        submit_button.click(
            submit_to_model,
            inputs=None,  # No inputs required as it uses the global ingredients_list
            outputs=model_output,
        )

    return demo


demo = app()
demo.launch()