361 lines
12 KiB
Python
361 lines
12 KiB
Python
# -*- coding:utf-8 -*-
|
||
"""
|
||
@Author : xuxingchen
|
||
@Contact : xuxingchen@sinochem.com
|
||
@Desc : 杂项
|
||
"""
|
||
import base64
|
||
import hashlib
|
||
import os
|
||
import time
|
||
import warnings
|
||
from io import BytesIO
|
||
import re
|
||
import random
|
||
import string
|
||
import threading
|
||
from datetime import datetime, timedelta
|
||
import socket
|
||
import psutil
|
||
from typing import Optional
|
||
|
||
import numpy as np
|
||
import pandas as pd
|
||
import pytz
|
||
import requests
|
||
from PIL import Image
|
||
import paho.mqtt.client as mqtt
|
||
from openpyxl.styles import PatternFill, Font, Border, Side
|
||
from openpyxl.utils.dataframe import dataframe_to_rows
|
||
from openpyxl.workbook import Workbook
|
||
from pydantic import BaseModel
|
||
|
||
from utils import logger
|
||
|
||
|
||
def get_ip_address(interface_name: str) -> tuple[str, str]:
|
||
interfaces = psutil.net_if_addrs()
|
||
interface = interfaces.get(interface_name, None)
|
||
ipv4_address, ipv6_address = "", ""
|
||
if interface:
|
||
for address in interface:
|
||
if address.family == socket.AF_INET:
|
||
ipv4_address = address.address
|
||
elif address.family == socket.AF_INET6 and not address.address.startswith("fe80"):
|
||
ipv6_address = address.address
|
||
return ipv4_address, ipv6_address
|
||
|
||
def extract_fixed_length_number(number_str: str, fixed_length: int = 2) -> str:
|
||
"""提取一串存在数值的字符串中的第一个数值,对其从右往左取值定值,不足补0"""
|
||
pattern = re.compile(r'[0-9]+')
|
||
number = pattern.search(number_str)
|
||
if number is None:
|
||
return "0" * fixed_length
|
||
else:
|
||
if len(number[0]) < fixed_length:
|
||
return "0" * (fixed_length - len(number[0])) + number[0]
|
||
else:
|
||
return number[0][-fixed_length:]
|
||
|
||
|
||
def snake2camel(key: str) -> str:
|
||
"""snake命名风格转成camel命名风格"""
|
||
parts = key.split('_')
|
||
return parts[0] + ''.join(word.capitalize() for word in parts[1:])
|
||
|
||
|
||
def snake2camel_list_dict(snake_list: list[dict]) -> list[dict]:
|
||
"""将list[dict]中所有的snake格式的key转换成camel格式的key命名风格"""
|
||
camel_list = []
|
||
for snake_dict in snake_list:
|
||
camel_dict = {}
|
||
for key, value in snake_dict.items():
|
||
camel_dict[snake2camel(key)] = value
|
||
camel_list.append(camel_dict)
|
||
return camel_list
|
||
|
||
|
||
def clear_log_file(log_path: str, day: int = 7):
|
||
"""每7天清除一次日志文件"""
|
||
creation_time = os.path.getctime(log_path)
|
||
days_since_creation = (time.time() - creation_time) / (60 * 60 * 24)
|
||
if os.path.exists(log_path) and days_since_creation >= day:
|
||
try:
|
||
f0 = open(log_path, "r", encoding="utf8")
|
||
f1 = open(f"{log_path}.old", "w", encoding="utf8")
|
||
f1.write(f0.read())
|
||
f1.close()
|
||
f0.close()
|
||
os.remove(log_path)
|
||
print(f"日志文件 {logger.LOGGER_PATH} 完成重置")
|
||
except Exception as e:
|
||
print(f"日志文件 {logger.LOGGER_PATH} 重置失败: {e}")
|
||
|
||
|
||
def generate_captcha_text(characters: int = 6) -> str:
|
||
"""生成指定长度的随机文本"""
|
||
letters_and_digits = string.ascii_letters + string.digits
|
||
return ''.join(random.choice(letters_and_digits) for _ in range(characters))
|
||
|
||
|
||
def encrypt_number(phone_number: str, key: bytes = b'7A') -> str:
|
||
"""将电话号码加密为字符串"""
|
||
phone_bytes = phone_number.encode('utf-8') # 将字符串转换为字节
|
||
encrypted_bytes = bytes(
|
||
[byte ^ key[i % len(key)] for i, byte in enumerate(phone_bytes)]
|
||
)
|
||
# 使用Base64编码加密后的字节
|
||
encrypted_number = base64.b64encode(encrypted_bytes).decode('utf-8')
|
||
return encrypted_number
|
||
|
||
|
||
def decrypt_number(encrypted_number: str, key: bytes = b'7A') -> str:
|
||
"""将字符串解密为电话号码"""
|
||
# 使用Base64解码加密后的字符串
|
||
encrypted_bytes = base64.b64decode(encrypted_number)
|
||
decrypted_bytes = bytes(
|
||
[byte ^ key[i % len(key)] for i, byte in enumerate(encrypted_bytes)]
|
||
)
|
||
decrypted_number = decrypted_bytes.decode('utf-8')
|
||
return decrypted_number
|
||
|
||
|
||
def now_tz_datetime(days: int = 0) -> str:
|
||
future_time = datetime.now() + timedelta(days=days)
|
||
return future_time.strftime("%Y-%m-%dT%H:%M:%S.") + f"{future_time.microsecond // 1000:03d}Z"
|
||
|
||
|
||
def now_datetime_nanosecond(days: int = 0) -> str:
|
||
future_time = datetime.now() + timedelta(days=days)
|
||
return future_time.strftime("%Y-%m-%d %H:%M:%S.%f")
|
||
|
||
|
||
def now_datetime_second(days: int = 0) -> str:
|
||
future_time = datetime.now() + timedelta(days=days)
|
||
return future_time.strftime("%Y-%m-%d %H:%M:%S")
|
||
|
||
|
||
def millisecond_timestamp2tz(timestamp_13: str):
|
||
timestamp = int(timestamp_13) / 1000
|
||
dt_utc = datetime.fromtimestamp(timestamp, tz=pytz.UTC)
|
||
# 转换为所需的时区,这里以北京时间(China Standard Time)为例
|
||
china_tz = pytz.timezone('Asia/Shanghai')
|
||
return dt_utc.astimezone(china_tz).strftime("%Y-%m-%dT%H:%M:%S.") + f"{dt_utc.microsecond // 1000:03d}Z"
|
||
|
||
|
||
def is_image_url_valid(url: str) -> bool:
|
||
try:
|
||
# 发送请求获取URL内容
|
||
response = requests.get(url)
|
||
response.raise_for_status() # 如果状态码不是200,会抛出异常
|
||
|
||
# 将内容加载为图片
|
||
image = Image.open(BytesIO(response.content))
|
||
image.verify() # 验证图像文件是否可读
|
||
|
||
# 如果上面的代码没有抛出异常,说明图片存在且格式可读
|
||
return True
|
||
except (requests.RequestException, IOError):
|
||
# 如果有任何异常,说明图片不可用或格式不可读
|
||
return False
|
||
|
||
|
||
def is_image_valid(path: str) -> bool:
|
||
try:
|
||
# 将内容加载为图片
|
||
image = Image.open(open(path, 'rb'))
|
||
image.verify() # 验证图像文件是否可读
|
||
|
||
# 如果上面的代码没有抛出异常,说明图片存在且格式可读
|
||
return True
|
||
except (requests.RequestException, IOError):
|
||
# 如果有任何异常,说明图片不可用或格式不可读
|
||
return False
|
||
|
||
|
||
def get_file_md5(file_path):
|
||
content = open(file_path, 'rb')
|
||
md5hash = hashlib.md5(content.read())
|
||
return md5hash.hexdigest()
|
||
|
||
|
||
def sql_export_xls(query,
|
||
db_connection,
|
||
save_file_path,
|
||
sheet_title,
|
||
sheet_header: Optional[list] = None,
|
||
header_background_color: str = "808080",
|
||
header_font_color: str = "ffffff"):
|
||
df = pd.read_sql_query(query, db_connection)
|
||
wb = Workbook()
|
||
ws = wb.active
|
||
ws.title = sheet_title
|
||
for i, r in enumerate(dataframe_to_rows(df, index=False, header=True)):
|
||
if sheet_header is not None and i == 0:
|
||
ws.append(sheet_header)
|
||
continue
|
||
ws.append(r)
|
||
# 表头样式
|
||
header_fill = PatternFill(start_color=header_background_color, fill_type="solid")
|
||
header_font = Font(color=header_font_color, bold=True)
|
||
for cell in ws[1]:
|
||
cell.fill = header_fill
|
||
cell.font = header_font
|
||
# 边框样式 表头无边框-数据无顶实线框
|
||
thin_border = Border(
|
||
left=Side(style='thin'),
|
||
right=Side(style='thin'),
|
||
bottom=Side(style='thin')
|
||
)
|
||
for i, row in enumerate(ws.iter_rows()):
|
||
if i == 0:
|
||
continue
|
||
for cell in row:
|
||
cell.border = thin_border
|
||
# 单元格宽度自适应调整
|
||
for column in ws.columns:
|
||
max_length = 0
|
||
column = list(column)
|
||
for cell in column:
|
||
if cell.value is not None:
|
||
cell_length = len(str(cell.value))
|
||
if re.search(r'[\u4e00-\u9fff]', str(cell.value)):
|
||
cell_length += len(re.findall(r'[\u4e00-\u9fff]', str(cell.value)))
|
||
if cell_length > max_length:
|
||
max_length = cell_length
|
||
adjusted_width = (max_length + 2)
|
||
ws.column_dimensions[column[0].column_letter].width = adjusted_width
|
||
wb.save(save_file_path)
|
||
|
||
|
||
def valid_xls(file: BytesIO, required_columns: Optional[list]) -> tuple[bool, str, Optional[list]]:
|
||
"""如果校验通过返回list结构的数据,如果检验不通过返回None"""
|
||
with warnings.catch_warnings():
|
||
warnings.simplefilter("ignore")
|
||
try:
|
||
df = pd.read_excel(file)
|
||
df = df.replace(np.nan, None)
|
||
if all(col in df.columns for col in required_columns):
|
||
return True, "", df[required_columns].values.tolist()
|
||
else:
|
||
missing_cols = [col for col in required_columns if col not in df.columns]
|
||
return False, f"缺少必要的列: {', '.join(missing_cols)}", None
|
||
except Exception as e:
|
||
return False, f"文件解析失败 {type(e).__name__}, {e}", None
|
||
|
||
|
||
class BasicCallback(BaseModel):
|
||
status: bool
|
||
message: Optional[str]
|
||
|
||
|
||
class InvalidException(Exception):
|
||
def __init__(self, message: str):
|
||
self.message = message
|
||
|
||
|
||
class UserData:
|
||
def __init__(self):
|
||
self.table_handle = None
|
||
self.topic: Optional[str] = None
|
||
self.topics: list = []
|
||
self.table_handler = None
|
||
self.message = None
|
||
self.token = None
|
||
self.status: dict = {}
|
||
self.clients: dict = {}
|
||
self.lock = threading.Lock() # 添加一个锁用于线程同步
|
||
|
||
def set_table_handle(self, value):
|
||
with self.lock:
|
||
self.table_handle = value
|
||
|
||
def set_topic(self, value: str):
|
||
with self.lock:
|
||
self.topic = value
|
||
|
||
def set_topics(self, value: list):
|
||
with self.lock:
|
||
self.topics = value
|
||
|
||
def set_table_handler(self, value):
|
||
with self.lock:
|
||
self.table_handler = value
|
||
|
||
def set_message(self, value):
|
||
with self.lock:
|
||
self.message = value
|
||
|
||
def set_token(self, value):
|
||
with self.lock:
|
||
self.token = value
|
||
|
||
def set_status(self, value: dict):
|
||
with self.lock:
|
||
self.status = value
|
||
|
||
def set_status_add(self, key, value):
|
||
with self.lock:
|
||
self.status[key] = value
|
||
|
||
def set_status_remove(self, key):
|
||
with self.lock:
|
||
if self.status and key in self.status.keys():
|
||
self.status.pop(key)
|
||
|
||
def get_status(self, key):
|
||
if self.status and key in self.status.keys():
|
||
return self.status[key]
|
||
|
||
def set_clients(self, value: dict):
|
||
with self.lock:
|
||
self.clients = value
|
||
|
||
def set_client_add(self, key, value):
|
||
with self.lock:
|
||
self.clients[key] = value
|
||
|
||
|
||
def create_mqtt_client(broker_host,
|
||
broker_port,
|
||
userdata: UserData,
|
||
on_message=None,
|
||
on_publish=None,
|
||
on_connect=None,
|
||
on_disconnect=None,
|
||
client_id: str = "",
|
||
username: str = "",
|
||
password: str = ""):
|
||
if client_id != "":
|
||
client = mqtt.Client(client_id=client_id)
|
||
else:
|
||
client = mqtt.Client()
|
||
client.user_data_set(userdata)
|
||
if on_connect:
|
||
client.on_connect = on_connect
|
||
if on_disconnect:
|
||
client.on_disconnect = on_disconnect
|
||
if on_message:
|
||
client.on_message = on_message
|
||
if on_publish:
|
||
client.on_publish = on_publish
|
||
client.username_pw_set(username, password)
|
||
client.connect(broker_host, broker_port)
|
||
return client
|
||
|
||
|
||
def on_connect(client, userdata, flags, rc):
|
||
logger.Logger.init(logger.new_dc(f"🔗 Mqtt connection! {{rc: {rc}}} 🔗", '[1;32m'))
|
||
if userdata.topics:
|
||
_topics = [(topic, 0) for topic in userdata.topics]
|
||
client.subscribe(_topics)
|
||
logger.Logger.debug(f"subscribe topics: {userdata.topics}")
|
||
|
||
|
||
def on_disconnect(client, userdata, rc):
|
||
logger.Logger.info(logger.new_dc(f"🔌 Break mqtt connection! {{rc: {rc}}} 🔌", "[1m"))
|
||
|
||
|
||
def on_publish(client, userdata, rc):
|
||
logger.Logger.debug(f"{userdata.topic} <- {userdata.message}")
|