Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 1 addition & 19 deletions api/apps/document_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from api.db import FileType
from api.db.db_models import Task
from api.db.services import duplicate_name
from api.db.services.document_service import DocumentService, doc_upload_and_parse
from api.db.services.document_service import DocumentService
Comment thread
coderabbitai[bot] marked this conversation as resolved.
from api.db.services.file2document_service import File2DocumentService
from api.db.services.file_service import FileService
from api.db.services.knowledgebase_service import KnowledgebaseService
Expand Down Expand Up @@ -481,24 +481,6 @@ async def get_artifact(filename):
return server_error_response(e)


@manager.route("/upload_and_parse", methods=["POST"]) # noqa: F821
@login_required
@validate_request("conversation_id")
async def upload_and_parse():
files = await request.files
if "file" not in files:
return get_json_result(data=False, message="No file part!", code=RetCode.ARGUMENT_ERROR)

file_objs = files.getlist("file")
for file_obj in file_objs:
if file_obj.filename == "":
return get_json_result(data=False, message="No file selected!", code=RetCode.ARGUMENT_ERROR)

form = await request.form
doc_ids = doc_upload_and_parse(form.get("conversation_id"), file_objs, current_user.id)
return get_json_result(data=doc_ids)


@manager.route("/parse", methods=["POST"]) # noqa: F821
@login_required
async def parse():
Expand Down
145 changes: 2 additions & 143 deletions api/db/services/document_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
import asyncio
import json
import logging
import random
import re
from concurrent.futures import ThreadPoolExecutor
from copy import deepcopy
from datetime import datetime
from io import BytesIO

import xxhash
from peewee import fn, Case, JOIN
Expand All @@ -35,8 +29,8 @@
from api.db.services.doc_metadata_service import DocMetadataService
from common.misc_utils import get_uuid
from common.time_utils import current_timestamp, get_format_time
from common.constants import LLMType, ParserType, StatusEnum, TaskStatus, SVR_CONSUMER_GROUP_NAME
from rag.nlp import rag_tokenizer, search
from common.constants import ParserType, StatusEnum, TaskStatus, SVR_CONSUMER_GROUP_NAME
from rag.nlp import search
from rag.utils.redis_conn import REDIS_CONN
from common.doc_store.doc_store_base import OrderByExpr
from common import settings
Expand Down Expand Up @@ -1025,138 +1019,3 @@ def get_queue_length(priority):
if not group_info:
return 0
return int(group_info.get("lag", 0) or 0)


def doc_upload_and_parse(conversation_id, file_objs, user_id):
from api.db.services.api_service import API4ConversationService
from api.db.services.conversation_service import ConversationService
from api.db.services.dialog_service import DialogService
from api.db.services.file_service import FileService
from api.db.services.llm_service import LLMBundle
from api.db.services.user_service import TenantService
from api.db.joint_services.tenant_model_service import get_model_config_by_id, get_model_config_by_type_and_name, get_tenant_default_model_by_type
from rag.app import audio, email, naive, picture, presentation

e, conv = ConversationService.get_by_id(conversation_id)
if not e:
e, conv = API4ConversationService.get_by_id(conversation_id)
assert e, "Conversation not found!"

e, dia = DialogService.get_by_id(conv.dialog_id)
if not dia.kb_ids:
raise LookupError("No dataset associated with this conversation. Please add a dataset before uploading documents")
kb_id = dia.kb_ids[0]
e, kb = KnowledgebaseService.get_by_id(kb_id)
if not e:
raise LookupError("Can't find this dataset!")
if kb.tenant_embd_id:
embd_model_config = get_model_config_by_id(kb.tenant_embd_id)
else:
embd_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.EMBEDDING, kb.embd_id)
embd_mdl = LLMBundle(kb.tenant_id, embd_model_config, lang=kb.language)

err, files = FileService.upload_document(kb, file_objs, user_id)
assert not err, "\n".join(err)

def dummy(prog=None, msg=""):
pass

FACTORY = {ParserType.PRESENTATION.value: presentation, ParserType.PICTURE.value: picture, ParserType.AUDIO.value: audio, ParserType.EMAIL.value: email}
parser_config = {"chunk_token_num": 4096, "delimiter": "\n!?;。;!?", "layout_recognize": "Plain Text", "table_context_size": 0, "image_context_size": 0}
exe = ThreadPoolExecutor(max_workers=12)
threads = []
doc_nm = {}
for d, blob in files:
doc_nm[d["id"]] = d["name"]
for d, blob in files:
kwargs = {"callback": dummy, "parser_config": parser_config, "from_page": 0, "to_page": 100000, "tenant_id": kb.tenant_id, "lang": kb.language}
threads.append(exe.submit(FACTORY.get(d["parser_id"], naive).chunk, d["name"], blob, **kwargs))

for (docinfo, _), th in zip(files, threads):
docs = []
doc = {"doc_id": docinfo["id"], "kb_id": [kb.id]}
for ck in th.result():
d = deepcopy(doc)
d.update(ck)
d["id"] = xxhash.xxh64((ck["content_with_weight"] + str(d["doc_id"])).encode("utf-8")).hexdigest()
d["create_time"] = str(datetime.now()).replace("T", " ")[:19]
d["create_timestamp_flt"] = datetime.now().timestamp()
if not d.get("image"):
docs.append(d)
continue

output_buffer = BytesIO()
if isinstance(d["image"], bytes):
output_buffer = BytesIO(d["image"])
else:
d["image"].save(output_buffer, format="JPEG")

settings.STORAGE_IMPL.put(kb.id, d["id"], output_buffer.getvalue())
d["img_id"] = "{}-{}".format(kb.id, d["id"])
d.pop("image", None)
docs.append(d)

parser_ids = {d["id"]: d["parser_id"] for d, _ in files}
docids = [d["id"] for d, _ in files]
chunk_counts = {id: 0 for id in docids}
token_counts = {id: 0 for id in docids}
es_bulk_size = 64

def embedding(doc_id, cnts, batch_size=16):
nonlocal embd_mdl, chunk_counts, token_counts
vectors = []
for i in range(0, len(cnts), batch_size):
vts, c = embd_mdl.encode(cnts[i : i + batch_size])
vectors.extend(vts.tolist())
chunk_counts[doc_id] += len(cnts[i : i + batch_size])
token_counts[doc_id] += c
return vectors

idxnm = search.index_name(kb.tenant_id)
try_create_idx = True

_, tenant = TenantService.get_by_id(kb.tenant_id)
tenant_llm_config = get_tenant_default_model_by_type(kb.tenant_id, LLMType.CHAT)
llm_bdl = LLMBundle(kb.tenant_id, tenant_llm_config)
for doc_id in docids:
cks = [c for c in docs if c["doc_id"] == doc_id]

if parser_ids[doc_id] != ParserType.PICTURE.value:
from rag.graphrag.general.mind_map_extractor import MindMapExtractor

mindmap = MindMapExtractor(llm_bdl)
try:
mind_map = asyncio.run(mindmap([c["content_with_weight"] for c in docs if c["doc_id"] == doc_id]))
mind_map = json.dumps(mind_map.output, ensure_ascii=False, indent=2)
if len(mind_map) < 32:
raise Exception("Few content: " + mind_map)
cks.append(
{
"id": get_uuid(),
"doc_id": doc_id,
"kb_id": [kb.id],
"docnm_kwd": doc_nm[doc_id],
"title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", doc_nm[doc_id])),
"content_ltks": rag_tokenizer.tokenize("summary summarize 总结 概况 file 文件 概括"),
"content_with_weight": mind_map,
"knowledge_graph_kwd": "mind_map",
}
)
except Exception:
logging.exception("Mind map generation error")

vectors = embedding(doc_id, [c["content_with_weight"] for c in cks])
assert len(cks) == len(vectors)
for i, d in enumerate(cks):
v = vectors[i]
d["q_%d_vec" % len(v)] = v
for b in range(0, len(cks), es_bulk_size):
if try_create_idx:
if not settings.docStoreConn.index_exist(idxnm, kb_id):
settings.docStoreConn.create_idx(idxnm, kb_id, len(vectors[0]), kb.parser_id)
try_create_idx = False
settings.docStoreConn.insert(cks[b : b + es_bulk_size], idxnm, kb_id)

DocumentService.increment_chunk_num(doc_id, kb.id, token_counts[doc_id], chunk_counts[doc_id], 0)

return [d["id"] for d, _ in files]
Original file line number Diff line number Diff line change
Expand Up @@ -315,19 +315,6 @@ def test_empty_upload_result(self, WebApiAuth, add_dataset_func, tmp_path):
# Just verify we get a response
assert "code" in res

def test_upload_and_parse_matrix_unit(self, document_app_module, monkeypatch):
module = document_app_module
monkeypatch.setattr(module, "request", _DummyRequest(form={"conversation_id": "conv-1"}, files=_DummyFiles({"file": [_DummyFile("")]})))
res = _run(module.upload_and_parse.__wrapped__())
assert res["code"] == module.RetCode.ARGUMENT_ERROR
assert res["message"] == "No file selected!"

files = _DummyFiles({"file": [_DummyFile("note.txt")]})
monkeypatch.setattr(module, "request", _DummyRequest(form={"conversation_id": "conv-1"}, files=files))
monkeypatch.setattr(module, "doc_upload_and_parse", lambda _conv_id, _files, _uid: ["doc-1"])
res = _run(module.upload_and_parse.__wrapped__())
assert res["code"] == 0
assert res["data"] == ["doc-1"]

def test_parse_url_and_multipart_matrix_unit(self, document_app_module, monkeypatch, tmp_path):
module = document_app_module
Expand Down
Loading