Spaces:
Running
Running
import flask_restful | |
from flask import current_app, request | |
from flask_login import current_user | |
from flask_restful import Resource, marshal, marshal_with, reqparse | |
from werkzeug.exceptions import Forbidden, NotFound | |
import services | |
from controllers.console import api | |
from controllers.console.apikey import api_key_fields, api_key_list | |
from controllers.console.app.error import ProviderNotInitializeError | |
from controllers.console.datasets.error import DatasetNameDuplicateError | |
from controllers.console.setup import setup_required | |
from controllers.console.wraps import account_initialization_required | |
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError | |
from core.indexing_runner import IndexingRunner | |
from core.model_runtime.entities.model_entities import ModelType | |
from core.provider_manager import ProviderManager | |
from core.rag.extractor.entity.extract_setting import ExtractSetting | |
from extensions.ext_database import db | |
from fields.app_fields import related_app_list | |
from fields.dataset_fields import dataset_detail_fields, dataset_query_detail_fields | |
from fields.document_fields import document_status_fields | |
from libs.login import login_required | |
from models.dataset import Dataset, Document, DocumentSegment | |
from models.model import ApiToken, UploadFile | |
from services.dataset_service import DatasetService, DocumentService | |
def _validate_name(name): | |
if not name or len(name) < 1 or len(name) > 40: | |
raise ValueError('Name must be between 1 to 40 characters.') | |
return name | |
def _validate_description_length(description): | |
if len(description) > 400: | |
raise ValueError('Description cannot exceed 400 characters.') | |
return description | |
class DatasetListApi(Resource): | |
def get(self): | |
page = request.args.get('page', default=1, type=int) | |
limit = request.args.get('limit', default=20, type=int) | |
ids = request.args.getlist('ids') | |
provider = request.args.get('provider', default="vendor") | |
search = request.args.get('keyword', default=None, type=str) | |
tag_ids = request.args.getlist('tag_ids') | |
if ids: | |
datasets, total = DatasetService.get_datasets_by_ids(ids, current_user.current_tenant_id) | |
else: | |
datasets, total = DatasetService.get_datasets(page, limit, provider, | |
current_user.current_tenant_id, current_user, search, tag_ids) | |
# check embedding setting | |
provider_manager = ProviderManager() | |
configurations = provider_manager.get_configurations( | |
tenant_id=current_user.current_tenant_id | |
) | |
embedding_models = configurations.get_models( | |
model_type=ModelType.TEXT_EMBEDDING, | |
only_active=True | |
) | |
model_names = [] | |
for embedding_model in embedding_models: | |
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}") | |
data = marshal(datasets, dataset_detail_fields) | |
for item in data: | |
if item['indexing_technique'] == 'high_quality': | |
item_model = f"{item['embedding_model']}:{item['embedding_model_provider']}" | |
if item_model in model_names: | |
item['embedding_available'] = True | |
else: | |
item['embedding_available'] = False | |
else: | |
item['embedding_available'] = True | |
response = { | |
'data': data, | |
'has_more': len(datasets) == limit, | |
'limit': limit, | |
'total': total, | |
'page': page | |
} | |
return response, 200 | |
def post(self): | |
parser = reqparse.RequestParser() | |
parser.add_argument('name', nullable=False, required=True, | |
help='type is required. Name must be between 1 to 40 characters.', | |
type=_validate_name) | |
parser.add_argument('indexing_technique', type=str, location='json', | |
choices=Dataset.INDEXING_TECHNIQUE_LIST, | |
nullable=True, | |
help='Invalid indexing technique.') | |
args = parser.parse_args() | |
# The role of the current user in the ta table must be admin or owner | |
if not current_user.is_admin_or_owner: | |
raise Forbidden() | |
try: | |
dataset = DatasetService.create_empty_dataset( | |
tenant_id=current_user.current_tenant_id, | |
name=args['name'], | |
indexing_technique=args['indexing_technique'], | |
account=current_user | |
) | |
except services.errors.dataset.DatasetNameDuplicateError: | |
raise DatasetNameDuplicateError() | |
return marshal(dataset, dataset_detail_fields), 201 | |
class DatasetApi(Resource): | |
def get(self, dataset_id): | |
dataset_id_str = str(dataset_id) | |
dataset = DatasetService.get_dataset(dataset_id_str) | |
if dataset is None: | |
raise NotFound("Dataset not found.") | |
try: | |
DatasetService.check_dataset_permission( | |
dataset, current_user) | |
except services.errors.account.NoPermissionError as e: | |
raise Forbidden(str(e)) | |
data = marshal(dataset, dataset_detail_fields) | |
# check embedding setting | |
provider_manager = ProviderManager() | |
configurations = provider_manager.get_configurations( | |
tenant_id=current_user.current_tenant_id | |
) | |
embedding_models = configurations.get_models( | |
model_type=ModelType.TEXT_EMBEDDING, | |
only_active=True | |
) | |
model_names = [] | |
for embedding_model in embedding_models: | |
model_names.append(f"{embedding_model.model}:{embedding_model.provider.provider}") | |
if data['indexing_technique'] == 'high_quality': | |
item_model = f"{data['embedding_model']}:{data['embedding_model_provider']}" | |
if item_model in model_names: | |
data['embedding_available'] = True | |
else: | |
data['embedding_available'] = False | |
else: | |
data['embedding_available'] = True | |
return data, 200 | |
def patch(self, dataset_id): | |
dataset_id_str = str(dataset_id) | |
dataset = DatasetService.get_dataset(dataset_id_str) | |
if dataset is None: | |
raise NotFound("Dataset not found.") | |
# check user's model setting | |
DatasetService.check_dataset_model_setting(dataset) | |
parser = reqparse.RequestParser() | |
parser.add_argument('name', nullable=False, | |
help='type is required. Name must be between 1 to 40 characters.', | |
type=_validate_name) | |
parser.add_argument('description', | |
location='json', store_missing=False, | |
type=_validate_description_length) | |
parser.add_argument('indexing_technique', type=str, location='json', | |
choices=Dataset.INDEXING_TECHNIQUE_LIST, | |
nullable=True, | |
help='Invalid indexing technique.') | |
parser.add_argument('permission', type=str, location='json', choices=( | |
'only_me', 'all_team_members'), help='Invalid permission.') | |
parser.add_argument('embedding_model', type=str, | |
location='json', help='Invalid embedding model.') | |
parser.add_argument('embedding_model_provider', type=str, | |
location='json', help='Invalid embedding model provider.') | |
parser.add_argument('retrieval_model', type=dict, location='json', help='Invalid retrieval model.') | |
args = parser.parse_args() | |
# The role of the current user in the ta table must be admin or owner | |
if not current_user.is_admin_or_owner: | |
raise Forbidden() | |
dataset = DatasetService.update_dataset( | |
dataset_id_str, args, current_user) | |
if dataset is None: | |
raise NotFound("Dataset not found.") | |
return marshal(dataset, dataset_detail_fields), 200 | |
def delete(self, dataset_id): | |
dataset_id_str = str(dataset_id) | |
# The role of the current user in the ta table must be admin or owner | |
if not current_user.is_admin_or_owner: | |
raise Forbidden() | |
if DatasetService.delete_dataset(dataset_id_str, current_user): | |
return {'result': 'success'}, 204 | |
else: | |
raise NotFound("Dataset not found.") | |
class DatasetQueryApi(Resource): | |
def get(self, dataset_id): | |
dataset_id_str = str(dataset_id) | |
dataset = DatasetService.get_dataset(dataset_id_str) | |
if dataset is None: | |
raise NotFound("Dataset not found.") | |
try: | |
DatasetService.check_dataset_permission(dataset, current_user) | |
except services.errors.account.NoPermissionError as e: | |
raise Forbidden(str(e)) | |
page = request.args.get('page', default=1, type=int) | |
limit = request.args.get('limit', default=20, type=int) | |
dataset_queries, total = DatasetService.get_dataset_queries( | |
dataset_id=dataset.id, | |
page=page, | |
per_page=limit | |
) | |
response = { | |
'data': marshal(dataset_queries, dataset_query_detail_fields), | |
'has_more': len(dataset_queries) == limit, | |
'limit': limit, | |
'total': total, | |
'page': page | |
} | |
return response, 200 | |
class DatasetIndexingEstimateApi(Resource): | |
def post(self): | |
parser = reqparse.RequestParser() | |
parser.add_argument('info_list', type=dict, required=True, nullable=True, location='json') | |
parser.add_argument('process_rule', type=dict, required=True, nullable=True, location='json') | |
parser.add_argument('indexing_technique', type=str, required=True, | |
choices=Dataset.INDEXING_TECHNIQUE_LIST, | |
nullable=True, location='json') | |
parser.add_argument('doc_form', type=str, default='text_model', required=False, nullable=False, location='json') | |
parser.add_argument('dataset_id', type=str, required=False, nullable=False, location='json') | |
parser.add_argument('doc_language', type=str, default='English', required=False, nullable=False, | |
location='json') | |
args = parser.parse_args() | |
# validate args | |
DocumentService.estimate_args_validate(args) | |
extract_settings = [] | |
if args['info_list']['data_source_type'] == 'upload_file': | |
file_ids = args['info_list']['file_info_list']['file_ids'] | |
file_details = db.session.query(UploadFile).filter( | |
UploadFile.tenant_id == current_user.current_tenant_id, | |
UploadFile.id.in_(file_ids) | |
).all() | |
if file_details is None: | |
raise NotFound("File not found.") | |
if file_details: | |
for file_detail in file_details: | |
extract_setting = ExtractSetting( | |
datasource_type="upload_file", | |
upload_file=file_detail, | |
document_model=args['doc_form'] | |
) | |
extract_settings.append(extract_setting) | |
elif args['info_list']['data_source_type'] == 'notion_import': | |
notion_info_list = args['info_list']['notion_info_list'] | |
for notion_info in notion_info_list: | |
workspace_id = notion_info['workspace_id'] | |
for page in notion_info['pages']: | |
extract_setting = ExtractSetting( | |
datasource_type="notion_import", | |
notion_info={ | |
"notion_workspace_id": workspace_id, | |
"notion_obj_id": page['page_id'], | |
"notion_page_type": page['type'], | |
"tenant_id": current_user.current_tenant_id | |
}, | |
document_model=args['doc_form'] | |
) | |
extract_settings.append(extract_setting) | |
else: | |
raise ValueError('Data source type not support') | |
indexing_runner = IndexingRunner() | |
try: | |
response = indexing_runner.indexing_estimate(current_user.current_tenant_id, extract_settings, | |
args['process_rule'], args['doc_form'], | |
args['doc_language'], args['dataset_id'], | |
args['indexing_technique']) | |
except LLMBadRequestError: | |
raise ProviderNotInitializeError( | |
"No Embedding Model available. Please configure a valid provider " | |
"in the Settings -> Model Provider.") | |
except ProviderTokenNotInitError as ex: | |
raise ProviderNotInitializeError(ex.description) | |
return response, 200 | |
class DatasetRelatedAppListApi(Resource): | |
def get(self, dataset_id): | |
dataset_id_str = str(dataset_id) | |
dataset = DatasetService.get_dataset(dataset_id_str) | |
if dataset is None: | |
raise NotFound("Dataset not found.") | |
try: | |
DatasetService.check_dataset_permission(dataset, current_user) | |
except services.errors.account.NoPermissionError as e: | |
raise Forbidden(str(e)) | |
app_dataset_joins = DatasetService.get_related_apps(dataset.id) | |
related_apps = [] | |
for app_dataset_join in app_dataset_joins: | |
app_model = app_dataset_join.app | |
if app_model: | |
related_apps.append(app_model) | |
return { | |
'data': related_apps, | |
'total': len(related_apps) | |
}, 200 | |
class DatasetIndexingStatusApi(Resource): | |
def get(self, dataset_id): | |
dataset_id = str(dataset_id) | |
documents = db.session.query(Document).filter( | |
Document.dataset_id == dataset_id, | |
Document.tenant_id == current_user.current_tenant_id | |
).all() | |
documents_status = [] | |
for document in documents: | |
completed_segments = DocumentSegment.query.filter(DocumentSegment.completed_at.isnot(None), | |
DocumentSegment.document_id == str(document.id), | |
DocumentSegment.status != 're_segment').count() | |
total_segments = DocumentSegment.query.filter(DocumentSegment.document_id == str(document.id), | |
DocumentSegment.status != 're_segment').count() | |
document.completed_segments = completed_segments | |
document.total_segments = total_segments | |
documents_status.append(marshal(document, document_status_fields)) | |
data = { | |
'data': documents_status | |
} | |
return data | |
class DatasetApiKeyApi(Resource): | |
max_keys = 10 | |
token_prefix = 'dataset-' | |
resource_type = 'dataset' | |
def get(self): | |
keys = db.session.query(ApiToken). \ | |
filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id). \ | |
all() | |
return {"items": keys} | |
def post(self): | |
# The role of the current user in the ta table must be admin or owner | |
if not current_user.is_admin_or_owner: | |
raise Forbidden() | |
current_key_count = db.session.query(ApiToken). \ | |
filter(ApiToken.type == self.resource_type, ApiToken.tenant_id == current_user.current_tenant_id). \ | |
count() | |
if current_key_count >= self.max_keys: | |
flask_restful.abort( | |
400, | |
message=f"Cannot create more than {self.max_keys} API keys for this resource type.", | |
code='max_keys_exceeded' | |
) | |
key = ApiToken.generate_api_key(self.token_prefix, 24) | |
api_token = ApiToken() | |
api_token.tenant_id = current_user.current_tenant_id | |
api_token.token = key | |
api_token.type = self.resource_type | |
db.session.add(api_token) | |
db.session.commit() | |
return api_token, 200 | |
class DatasetApiDeleteApi(Resource): | |
resource_type = 'dataset' | |
def delete(self, api_key_id): | |
api_key_id = str(api_key_id) | |
# The role of the current user in the ta table must be admin or owner | |
if not current_user.is_admin_or_owner: | |
raise Forbidden() | |
key = db.session.query(ApiToken). \ | |
filter(ApiToken.tenant_id == current_user.current_tenant_id, ApiToken.type == self.resource_type, | |
ApiToken.id == api_key_id). \ | |
first() | |
if key is None: | |
flask_restful.abort(404, message='API key not found') | |
db.session.query(ApiToken).filter(ApiToken.id == api_key_id).delete() | |
db.session.commit() | |
return {'result': 'success'}, 204 | |
class DatasetApiBaseUrlApi(Resource): | |
def get(self): | |
return { | |
'api_base_url': (current_app.config['SERVICE_API_URL'] if current_app.config['SERVICE_API_URL'] | |
else request.host_url.rstrip('/')) + '/v1' | |
} | |
class DatasetRetrievalSettingApi(Resource): | |
def get(self): | |
vector_type = current_app.config['VECTOR_STORE'] | |
if vector_type in {"milvus", "relyt", "pgvector", "pgvecto_rs"}: | |
return { | |
'retrieval_method': [ | |
'semantic_search' | |
] | |
} | |
elif vector_type in {"qdrant", "weaviate"}: | |
return { | |
'retrieval_method': [ | |
'semantic_search', 'full_text_search', 'hybrid_search' | |
] | |
} | |
else: | |
raise ValueError("Unsupported vector db type.") | |
class DatasetRetrievalSettingMockApi(Resource): | |
def get(self, vector_type): | |
if vector_type in {'milvus', 'relyt', 'pgvector'}: | |
return { | |
'retrieval_method': [ | |
'semantic_search' | |
] | |
} | |
elif vector_type in {'qdrant', 'weaviate'}: | |
return { | |
'retrieval_method': [ | |
'semantic_search', 'full_text_search', 'hybrid_search' | |
] | |
} | |
else: | |
raise ValueError("Unsupported vector db type.") | |
class DatasetErrorDocs(Resource): | |
def get(self, dataset_id): | |
dataset_id_str = str(dataset_id) | |
dataset = DatasetService.get_dataset(dataset_id_str) | |
if dataset is None: | |
raise NotFound("Dataset not found.") | |
results = DocumentService.get_error_documents_by_dataset_id(dataset_id_str) | |
return { | |
'data': [marshal(item, document_status_fields) for item in results], | |
'total': len(results) | |
}, 200 | |
api.add_resource(DatasetListApi, '/datasets') | |
api.add_resource(DatasetApi, '/datasets/<uuid:dataset_id>') | |
api.add_resource(DatasetQueryApi, '/datasets/<uuid:dataset_id>/queries') | |
api.add_resource(DatasetErrorDocs, '/datasets/<uuid:dataset_id>/error-docs') | |
api.add_resource(DatasetIndexingEstimateApi, '/datasets/indexing-estimate') | |
api.add_resource(DatasetRelatedAppListApi, '/datasets/<uuid:dataset_id>/related-apps') | |
api.add_resource(DatasetIndexingStatusApi, '/datasets/<uuid:dataset_id>/indexing-status') | |
api.add_resource(DatasetApiKeyApi, '/datasets/api-keys') | |
api.add_resource(DatasetApiDeleteApi, '/datasets/api-keys/<uuid:api_key_id>') | |
api.add_resource(DatasetApiBaseUrlApi, '/datasets/api-base-info') | |
api.add_resource(DatasetRetrievalSettingApi, '/datasets/retrieval-setting') | |
api.add_resource(DatasetRetrievalSettingMockApi, '/datasets/retrieval-setting/<string:vector_type>') | |