| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194 |
- import re
- import os
- import sys
- import math
- import threading
- import loguru
- from hashlib import md5
- os.chdir(sys.path[0][:-8])
- from langchain.text_splitter import CharacterTextSplitter
- from qdrant_client import QdrantClient
- from qdrant_client.http import models as rest
- from langchain.docstore.document import Document
- from typing import Dict, Iterable, List, Optional, Union
- from langchain.embeddings import HuggingFaceEmbeddings
- import logging, time
- logging.basicConfig()
- logger = logging.getLogger()
- logger.setLevel(logging.ERROR)
- import chardet
- import pdfplumber
- from qdrant import Qdrant
- from common import CounterLock
- from common import settings, error_print, error_helper, success_print
- source_folder = settings.librarys.qdrant.path
- source_folder_path = os.path.join(os.getcwd(), source_folder)
- root_path_list = source_folder_path.split(os.sep)
- docs = []
- texts_count = 0
- MetadataFilter = Dict[str, Union[str, int, bool]]
- COLLECTION_NAME = settings.librarys.qdrant.collection # 向量库名字
- model_path = settings.librarys.qdrant.model_path
- try:
- encode_kwargs = {'batch_size': settings.librarys.qdrant.batch_size}
- model_kwargs = {"device": settings.librarys.qdrant.device}
- embedding = HuggingFaceEmbeddings(model_name=model_path, encode_kwargs=encode_kwargs, model_kwargs=model_kwargs)
- except Exception as e:
- error_helper("embedding加载失败,请下载相应模型",
- r"https://github.com/l15y/wenda#st%E6%A8%A1%E5%BC%8F")
- raise e
- success_print("Embedding model加载完成")
- try:
- client = QdrantClient(path="memory/q")
- client.get_collection(COLLECTION_NAME)
- vectorstore = Qdrant(client, COLLECTION_NAME, embedding)
- except:
- del client
- vectorstore = None
- # vectorstore = None
- embedding_lock = CounterLock()
- vectorstore_lock = threading.Lock()
- def clac_embedding(texts, embedding, metadatas):
- global vectorstore
- with embedding_lock:
- embeddings = embedding.embed_documents(texts)
- with vectorstore_lock:
- ids = gen_ids(metadatas)
- if vectorstore is None:
- # 如需插入大规模数据可以将prefer_grpc参数置为True
- if(settings.librarys.qdrant.qdrant_path):
- vectorstore = Qdrant.from_texts(texts, embedding, embeddings, ids, metadatas=metadatas,
- path=settings.librarys.qdrant.qdrant_path, prefer_grpc=True,
- collection_name=settings.librarys.qdrant.collection, timeout=10)
- elif(settings.librarys.qdrant.qdrant_host):
- vectorstore = Qdrant.from_texts(texts, embedding, embeddings, ids, metadatas=metadatas,
- url=settings.librarys.qdrant.qdrant_host, prefer_grpc=True,
- collection_name=settings.librarys.qdrant.collection, timeout=10)
- else:
- vectorstore.add_texts(texts, embeddings, ids, metadatas)
- # 生成该id的方法仅供参考
- def gen_ids(metadatas):
- ids = []
- same_title_count = 0
- last_text_title = ""
- for metadata in metadatas:
- text_title = md5(metadata["source"].encode("utf-8")).hexdigest()
- if last_text_title != text_title:
- last_text_title = text_title
- same_title_count = 0
- else:
- same_title_count += 1
- origin = text_title[:30] + str(hex(same_title_count))[2:].zfill(3) # 最后三位为十六进制的文章段落数 前二十九位为文章title哈希
- origin = f"{origin[:8]}-{origin[8:12]}-{origin[12:16]}-{origin[16:20]}-{origin[-12:]}"
- ids.append(origin)
- return ids
- def make_index():
- global docs, texts_count
- if hasattr(settings.librarys.qdrant, "size") and hasattr(settings.librarys.qdrant, "overlap"):
- text_splitter = CharacterTextSplitter(
- chunk_size=int(settings.librarys.qdrant.size), chunk_overlap=int(settings.librarys.qdrant.overlap), separator='\n')
- else:
- text_splitter = CharacterTextSplitter(
- chunk_size=20, chunk_overlap=0, separator='\n')
- doc_texts = text_splitter.split_documents(docs)
- docs = []
- texts = [d.page_content for d in doc_texts]
- metadatas = [d.metadata for d in doc_texts]
- texts_count += len(texts)
- thread = threading.Thread(target=clac_embedding, args=(texts, embedding, metadatas))
- thread.start()
- while embedding_lock.get_waiting_threads() > 1:
- time.sleep(0.1)
- all_files = []
- for root, dirs, files in os.walk(source_folder_path):
- for file in files:
- all_files.append([root, file])
- success_print("文件列表生成完成", len(all_files))
- length_of_read = 0
- for i in range(len(all_files)):
- root, file = all_files[i]
- data = ""
- title = ""
- try:
- file_path = os.path.join(root, file)
- _, ext = os.path.splitext(file_path)
- if ext.lower() == '.pdf':
- # pdf
- with pdfplumber.open(file_path) as pdf:
- data_list = []
- for page in pdf.pages:
- print(page.extract_text())
- data_list.append(page.extract_text())
- data = "\n".join(data_list)
- elif ext.lower() == '.txt':
- # txt
- with open(file_path, 'rb') as f:
- print("open:",file_path)
- b = f.read()
- result = chardet.detect(b)
- with open(file_path, 'r', encoding=result['encoding']) as f:
- data = f.read()
- else:
- print("目前还不支持文件格式:", ext)
- except Exception as e:
- print("文件读取失败,当前文件已被跳过:", file, "。错误信息:", e)
- data = re.sub(r'!', "!\n", data)
- data = re.sub(r':', ":\n", data)
- data = re.sub(r'。', "。\n", data)
- data = re.sub(r'\r', "\n", data)
- data = re.sub(r'\n\n', "\n", data)
- data = re.sub(r"\n\s*\n", "\n", data)
- length_of_read += len(data)
- docs.append(Document(page_content=data, metadata={"source": file}))
- if length_of_read > 1e5: # 大于10万字的先处理(即不作为最后的统一处理)
- success_print("处理进度", int(100*i/len(all_files)), f"%\t({i}/{len(all_files)})")
- make_index()
- length_of_read = 0
- length_of_read += len(data)
- docs.append(Document(page_content=data, metadata={"source": file}))
- if length_of_read > 1e5:
- success_print("处理进度", int(100 * i / len(all_files)), f"%\t({i}/{len(all_files)})")
- make_index()
- length_of_read = 0
- if len(all_files) == 0:
- error_print("指定目录{}没有数据".format(settings.librarys.qdrant.path))
- sys.exit(0)
- if len(docs) > 0:
- make_index()
- while embedding_lock.get_waiting_threads() > 0:
- time.sleep(0.1)
- with embedding_lock:
- time.sleep(0.1)
- success_print("数据上装完成")
- with vectorstore_lock:
- print("开始构建索引,需要一定时间")
- vectorstore.client.update_collection(
- collection_name=COLLECTION_NAME,
- optimizer_config=rest.OptimizersConfigDiff(
- indexing_threshold=20000
- )
- )
- success_print("索引处理完成")
|