File size: 2,260 Bytes
4363820
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b14ef5c
4363820
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
# src/utils/query_parser.py
import pandas as pd
import json
from typing import Dict, Any
from src.generation.llm import LLM

class QueryParser:
    def __init__(self, df: pd.DataFrame):
        """
        Initialize the query parser with restaurant data.
        
        Args:
            df (pd.DataFrame): DataFrame containing restaurant data.
        """
        self.llm = LLM()
        self.df = df
        self.valid_cuisines = sorted(self.df['cuisine'].unique().tolist())
        self.valid_price_ranges = sorted(self.df['price_range'].unique().tolist())
        self.valid_dishes = sorted(set([dish for dishes in self.df['dishes'] for dish in dishes]))
    
    def parse_query(self, query: str) -> Dict[str, Any]:
        """
        Parse the query to extract features.
        
        Args:
            query (str): User query.
        
        Returns:
            Dict[str, Any]: Parsed features.
        """
        # Format prompt using LLM's prompt template
        prompt = self.llm.format_query_prompt(
            query=query,
            cuisines=self.valid_cuisines,
            dishes=self.valid_dishes,
            price_ranges=self.valid_price_ranges
        )
        
        # Generate response
        response = self.llm.generate(prompt, max_length = 1024)
        
        # Parse JSON response
        try:
            json_start = response.find("{")
            json_end = response.rfind("}") + 1
            parsed = json.loads(response[json_start:json_end])
            return parsed
        except json.JSONDecodeError:
            return {
                "cuisine": None,
                "menu": [],
                "price_range": None,
                "distance": None,
                "rating": None,
                "description": query
            }


# Quick test block for QueryParser
if __name__ == "__main__":
    import pandas as pd

    sample_data = {
        "cuisine": ["Italian", "Japanese"],
        "price_range": ["$", "$$"],
        "dishes": [["pizza", "pasta"], ["sushi", "ramen"]]
    }
    df = pd.DataFrame(sample_data)
    parser = QueryParser(df)

    user_query = "I want cheap sushi"
    result = parser.parse_query(user_query)

    print("Parsed Query Result:")
    print(result)