common.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. from fastapi import FastAPI, WebSocket, WebSocketDisconnect
  2. from bottle import route, response, request, static_file, hook
  3. import threading
  4. import webbrowser
  5. import re
  6. import json
  7. from yaml import load, dump
  8. try:
  9. from yaml import CLoader as Loader, CDumper as Dumper
  10. except ImportError:
  11. from yaml import Loader, Dumper
  12. import argparse
  13. parser = argparse.ArgumentParser(description='Wenda config')
  14. parser.add_argument('-c', type=str, dest="Config",
  15. default='config.yml', help="配置文件")
  16. parser.add_argument('-p', type=int, dest="Port", help="使用端口号")
  17. parser.add_argument('-l', type=bool, dest="Logging", help="是否开启日志")
  18. parser.add_argument('-t', type=str, dest="LLM_Type", help="选择使用的大模型")
  19. args = parser.parse_args()
  20. class dotdict(dict):
  21. __getattr__ = dict.get
  22. __setattr__ = dict.__setitem__
  23. __delattr__ = dict.__delitem__
  24. def object_hook(dict1):
  25. for key, value in dict1.items():
  26. if isinstance(value, dict):
  27. dict1[key] = dotdict(value)
  28. else:
  29. dict1[key] = value
  30. return dotdict(dict1)
  31. green = "\033[1;32m"
  32. red = "\033[1;31m"
  33. white = "\033[1;37m"
  34. def error_helper(e, doc_url):
  35. error_print(e)
  36. error_print("查看:", doc_url)
  37. # webbrowser.open_new(doc_url)
  38. def error_print(*s):
  39. print(red, end="")
  40. print(*s)
  41. print(white, end="")
  42. def success_print(*s):
  43. print(green, end="")
  44. print(*s)
  45. print(white, end="")
  46. wenda_Config = args.Config
  47. wenda_Port = str(args.Port)
  48. wenda_Logging = str(args.Logging)
  49. wenda_LLM_Type = str(args.LLM_Type)
  50. print(args)
  51. try:
  52. stream = open(wenda_Config, encoding='utf8')
  53. except:
  54. error_print('加载配置失败,改为加载默认配置')
  55. stream = open('example.config.yml', encoding='utf8')
  56. settings = load(stream, Loader=Loader)
  57. settings = dotdict(settings)
  58. stream.close()
  59. if wenda_Port != 'None':
  60. settings.port = wenda_Port
  61. if wenda_Logging != 'None':
  62. settings.logging = wenda_Logging
  63. if wenda_LLM_Type != 'None':
  64. settings.llm_type = wenda_LLM_Type
  65. try:
  66. settings.llm = settings.llm_models[settings.llm_type]
  67. except:
  68. error_print("没有读取到LLM参数,可能是因为当前模型为API调用。")
  69. del settings.llm_models
  70. settings_str_toprint = dump(dict(settings))
  71. settings_str_toprint = re.sub(r':', ":"+"\033[1;32m", settings_str_toprint)
  72. settings_str_toprint = re.sub(r'\n', "\n\033[1;31m", settings_str_toprint)
  73. print("\033[1;31m", end="")
  74. print(settings_str_toprint, end="")
  75. print("\033[1;37m")
  76. settings_str = json.dumps(settings)
  77. settings = json.loads(settings_str, object_hook=object_hook)
  78. class CounterLock:
  79. def __init__(self):
  80. self.lock = threading.Lock()
  81. self.waiting_threads = 0
  82. self.waiting_threads_lock = threading.Lock()
  83. def acquire(self):
  84. with self.waiting_threads_lock:
  85. self.waiting_threads += 1
  86. acquired = self.lock.acquire()
  87. def release(self):
  88. self.lock.release()
  89. with self.waiting_threads_lock:
  90. self.waiting_threads -= 1
  91. def get_waiting_threads(self):
  92. with self.waiting_threads_lock:
  93. return self.waiting_threads
  94. def __enter__(self): # 实现 __enter__() 方法,用于在 with 语句的开始获取锁
  95. self.acquire()
  96. return self
  97. def __exit__(self, exc_type, exc_val, exc_tb): # 实现 __exit__() 方法,用于在 with 语句的结束释放锁
  98. self.release()
  99. def allowCROS():
  100. response.set_header('Access-Control-Allow-Origin', '*')
  101. response.add_header('Access-Control-Allow-Methods', 'POST,OPTIONS')
  102. response.add_header('Access-Control-Allow-Headers',
  103. 'Origin, Accept, Content-Type, X-Requested-With, X-CSRF-Token')
  104. app = FastAPI(title="Wenda",
  105. description="Wenda API",
  106. version="1.0.0",
  107. # docs_url=None,
  108. # redoc_url=None,
  109. openapi_url="/api/v1/openapi.json",
  110. docs_url="/api/v1/docs",
  111. redoc_url="/api/v1/redoc")
粤ICP备19079148号