File size: 6,654 Bytes
4363820
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2e37701
4363820
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2e37701
 
64c47c6
 
2e37701
 
4363820
 
 
4af5fd9
aedb14e
 
4363820
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
from transformers import AutoModelForCausalLM, AutoTokenizer
from langchain_core.prompts import PromptTemplate
import os
from typing import List

class LLM:
    def __init__(self, model_repo: str = "Qwen/Qwen2-1.5B-Instruct", 
                 local_path: str = "models"):
        """
        Initialize the LLM with Qwen2-1.5B-Instruct using Hugging Face Transformers.

        Args:
            model_repo (str): Hugging Face repository ID for the model.
            local_path (str): Local directory to store the model.
        """
        os.makedirs(local_path, exist_ok=True)
        
        try:
            # Load the model
            self.llm = AutoModelForCausalLM.from_pretrained(
                model_repo,
                device_map="auto",  # Automatically map to CPU
                cache_dir=local_path,
                trust_remote_code=True
            )
            
            # Load the tokenizer
            self.tokenizer = AutoTokenizer.from_pretrained(
                model_repo,
                cache_dir=local_path,
                trust_remote_code=True
            )
            print(f"Model successfully loaded from {model_repo}")
        except Exception as e:
            raise RuntimeError(
                f"Failed to initialize model from {model_repo}. "
                f"Please ensure the model is available at https://huggingface.co/{model_repo}. "
                f"Error: {str(e)}"
            )
        
        # Define prompt template for query parsing (used in query_parser.py)
        self.prompt_template = PromptTemplate(
            template="""Bạn là một trợ lý phân tích truy vấn nhà hàng. Phân tích truy vấn sau và trích xuất các đặc trưng: cuisine, menu, price_range, distance, rating, và description. Chỉ trích xuất các giá trị khớp chính xác với danh sách giá trị hợp lệ. Nếu không tìm thấy giá trị khớp, trả về null (hoặc [] cho menu). Loại bỏ các từ khóa đã trích xuất khỏi description. Trả về kết quả dưới dạng JSON.

                    **Danh sách giá trị hợp lệ**:
                    - cuisine: {cuisines}
                    - menu: {dishes}
                    - price_range: {price_ranges}

                    **Hướng dẫn**:
                    - cuisine: Chỉ chọn giá trị từ danh sách cuisine. Ví dụ, "Viet" → "Vietnamese".
                    - menu: Chỉ chọn các món khớp chính xác với danh sách menu. Ví dụ, "phở bò" → "phở", "sushi" → [].
                    - price_range: Chỉ chọn {price_ranges}. Ví dụ, "cheap" → "low".
                    - distance: Trích xuất số km (e.g., "2 km" → 2.0) hoặc từ khóa ["nearby", "close" → 2.0, "far" → 10.0]. Nếu không rõ, trả về null.
                    - rating: Trích xuất số (e.g., "4 stars" → 4.0). Nếu không rõ, trả về null.
                    - description: Phần còn lại sau khi loại bỏ các từ khóa đã trích xuất. Nếu rỗng, trả về truy vấn gốc.

                    **Truy vấn**: {query}

                    **Định dạng đầu ra**:
                    {{
                    "cuisine": null | "tên loại ẩm thực",
                    "menu": [],
                    "price_range": null | "low" | "medium" | "high",
                    "distance": null | số km | "nearby" | "close" | "far",
                    "rating": null | số,
                    "description": "phần mô tả còn lại"
                    }}
                    """,
            input_variables=["cuisines", "dishes", "price_ranges", "query"]
        )
    
    def generate(self, prompt: str, max_length: int = 100) -> str:
        """
        Generate text using the LLM.
        
        Args:
            prompt (str): Input prompt.
            max_length (int): Maximum length of the generated text.
        
        Returns:
            str: Generated text.
        """
        try:
            # Apply chat template for instruction-tuned Qwen model
            messages = [{"role": "user", "content": prompt}]
            prompt_with_template = self.tokenizer.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True
            )
            # Tokenize input prompt
            inputs = self.tokenizer(prompt_with_template, return_tensors="pt").to(self.llm.device)
            # Generate text
            outputs = self.llm.generate(
                        **inputs,
                        max_new_tokens=max_length,
                        temperature=0.7,
                        do_sample=True,
                        pad_token_id=self.tokenizer.eos_token_id,
                    )
            # Decode the generated tokens
            response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            print("Response generated successfully!")
            return response.split('assistant')[2]
        except Exception as e:
            raise RuntimeError(f"Failed to generate response: {str(e)}")
    
    def format_query_prompt(self, query: str, cuisines: List[str], dishes: List[str], price_ranges: List[str]) -> str:
        """
        Format the prompt for query parsing using the prompt template.
        
        Args:
            query (str): User query.
            cuisines (list): List of valid cuisines.
            dishes (list): List of valid dishes.
            price_ranges (list): List of valid price ranges.
        
        Returns:
            str: Formatted prompt.
        """
        return self.prompt_template.format(
            cuisines=cuisines,
            dishes=dishes,
            price_ranges=price_ranges,
            query=query
        )

if __name__ == "__main__":
    # Khởi tạo đối tượng LLM với model_repo và local_path
    local_path = 'models'
    
    try:
        # Khởi tạo đối tượng LLM
        llm = LLM(local_path=local_path)

        # Định nghĩa một truy vấn và các tham số cần thiết
        query = "Tìm quán ăn Việt Nam gần đây, giá rẻ với món phở và cơm tấm"
        cuisines = ["Vietnamese", "Chinese", "Italian"]
        dishes = ["phở", "sushi", "pasta", "cơm tấm"]
        price_ranges = ["low", "medium", "high"]

        # Sử dụng hàm generate để tạo câu trả lời từ truy vấn
        generated_text = llm.generate(query, max_length=300)

        # In kết quả ra màn hình
        print("Generated text:")
        print(generated_text)

    except Exception as e:
        print(f"Error: {str(e)}")