File size: 4,553 Bytes
e9f1d21
 
2851004
 
e9f1d21
 
 
 
 
 
2851004
e9f1d21
2851004
24b8349
2851004
 
 
e9f1d21
 
2851004
e9f1d21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2851004
 
 
 
e9f1d21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2851004
e9f1d21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import { pipeline, env } from '@huggingface/transformers';

env.cacheDir = './.cache';

class TeapotAI {
  /**
  * Initializes the TeapotAI class.
  * @param {object} options - Configuration options.
  * @param {string} [options.modelId='teapotai/teapotllm'] - The Hugging Face model ID.
  * @param {boolean} [options.verbose=true] - Whether to print status messages.
  * @param {object} [options.pipelineOptions={}] - Additional pipeline options passed to the transformer.
  */
  constructor({
    modelId = 'tomasmcm/teapotai-teapotllm-onnx',
    verbose = false,
    pipelineOptions = {}
  } = {}) {
    this.modelId = modelId;
    this.verbose = verbose;
    this.pipelineOptions = pipelineOptions;
    this.generator = null;
    this.isInitialized = false;
    
    if (this.verbose) {
      console.log(`TeapotAI instance created for model: ${this.modelId}`);
    }
  }
  
  /**
  * Asynchronously initializes the text generation pipeline.
  * Must be called before using generate, query, or chat.
  */
  async initialize() {
    if (this.isInitialized) {
      if (this.verbose) console.log("Pipeline already initialized.");
      return;
    }
    try {
      if (this.verbose) console.log(`Initializing generator pipeline for model: ${this.modelId}...`);
      const pipelineOptions = {
        model: this.modelId,
        ...this.pipelineOptions
      };
      
      this.generator = await pipeline('text2text-generation', pipelineOptions.model, pipelineOptions);
      this.isInitialized = true;
      if (this.verbose) console.log("Pipeline initialized successfully.");
    } catch (error) {
      console.error("Failed to initialize pipeline:", error);
      throw error;
    }
  }
  
  /**
  * Ensures the pipeline is initialized before proceeding.
  * @private
  */
  _ensureInitialized() {
    if (!this.isInitialized || !this.generator) {
      throw new Error("Pipeline not initialized. Call initialize() before using query(), generate(), or chat().");
    }
  }
  
  /**
  * Generates text based on the input string.
  * (Internal method similar to the Python version's generate)
  * @param {string} inputText - The text prompt to generate a response for.
  * @returns {Promise<string>} The generated output from the model.
  */
  async generate(inputText) {
    this._ensureInitialized();
    try {
      if (this.verbose) console.log("Generating text...");
      
      const output = await this.generator(inputText, {
        max_new_tokens: 512,
      });
      
      const generatedText = output[0]?.generated_text?.trim() ?? "Error: Could not generate text.";
      if (this.verbose) console.log("Text generation complete.");
      return generatedText;
      
    } catch (error) {
      if (this.verbose) console.error("Error during text generation:", error);
      return "Error: Generation failed.";
    }
  }
  
  /**
  * Handles a query and context to generate a response.
  * (Focuses on the case where context is provided, skipping RAG)
  * @param {string} query - The query string to be answered.
  * @param {string} context - The context to guide the response.
  * @returns {Promise<string>} The generated response based on the query and context.
  */
  async query(query, context) {
    this._ensureInitialized();
    
    let inputText;
    if (!context) {
      if (this.verbose) console.warn("Context is empty. Proceeding without context enhancement.");
      inputText = `Query: ${query}`;
    } else {
      inputText = `Context: ${context}\nQuery: ${query}`;
      if (this.verbose) console.log("\nFormatted Input for Query:\n", inputText);
    }
    return this.generate(inputText);
  }
  
  /**
  * Engages in a chat by taking a list of previous messages and generating a response.
  * @param {Array<object>} conversationHistory - An array of message objects, each expected to have a 'content' property. E.g., [{ content: 'User: Hi' }, { content: 'Agent: Hello!' }]
  * @returns {Promise<string>} The generated agent response based on the conversation history.
  */
  async chat(conversationHistory) {
    this._ensureInitialized();
    
    if (!Array.isArray(conversationHistory)) {
      throw new Error("conversationHistory must be an array of message objects.");
    }
    
    let chatHistoryString = conversationHistory
      .map(message => message.content)
      .join("\n");
    
    const inputText = chatHistoryString + "\n" + "agent:";
    
    if (this.verbose) console.log("\nFormatted Input for Chat:\n", inputText);
    
    return this.generate(inputText);
  }
}

export default TeapotAI;