File size: 2,295 Bytes
854f61d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from __future__ import annotations

from hamilton import base, driver

import logging
import sys
import data_module.data_pipeline  as data_pipeline
import data_module.embedding_pipeline as embedding_pipeline
import data_module.vectorstore as vectorstore
import classification_module.semantic_similarity as semantic_similarity
import classification_module.dio_support_detector as dio_support_detector
import click


logger = logging.getLogger(__name__)
logging.basicConfig(stream=sys.stdout)

@click.command()
@click.option(
    "--embedding_service",
    type=click.Choice(["openai", "cohere", "sentence_transformer", "marqo"], case_sensitive=False),
    default="sentence_transformer",
    help="Text embedding service.",
)
@click.option(
    "--embedding_service_api_key",
    default=None,
    help="API Key for embedding service. Needed if using OpenAI or Cohere.",
)
@click.option("--model_name", default=None, help="Text embedding model name.")
@click.option("--user_input", help="Content on which to run radicalization detection")
def main(
    embedding_service: str,
    embedding_service_api_key: str | None,
    model_name: str,
    user_input: str
):  
    if model_name is None:
        if embedding_service == "openai":
            model_name = "text-embedding-ada-002"
        elif embedding_service == "cohere":
            model_name = "embed-english-light-v2.0"
        elif embedding_service == "sentence_transformer":
            model_name = "multi-qa-MiniLM-L6-cos-v1"
    
    config = {"loader": "pd", "embedding_service": embedding_service, "api_key": embedding_service_api_key, "model_name": model_name}  # or "pd"
    
    dr = driver.Driver(
        config,
        data_pipeline,
        embedding_pipeline,
        vectorstore,
        semantic_similarity,
        dio_support_detector
    )
    # The `final_vars` requested are functions with side-effects
    print(dr.execute(
        final_vars=["detect_glorification"],
        inputs={"project_root": ".", "user_input": user_input}  # I specify this because of how I run this example.
    ))
    # dr.visualize_execution(final_vars=["save_vector_store"],
    #     inputs={"project_root": ".", "user_input": user_input}, output_file_path='./my-dag.dot', render_kwargs={})

if __name__ == "__main__":
    main()