将上一篇文章中的数据导入到MySQL数据库
合适在不同环境中不同数据库版本之间的数据同步
导入时记录导入的进度,支持断点续导,每一个表一个日志和进度标识存储。可以边导入边通过查看日志文件查询进度。
import json
import os
import subprocess
import sys
import pandas as pd
import mysql.connector
from clickhouse_driver import Client
import logging
import time
from multiprocessing import Process
def setup_logging(log_file):
# 配置日志记录
logging.basicConfig(
filename=log_file, # 指定日志文件名
level=logging.INFO, # 设置日志级别
format='%(asctime)s - %(levelname)s - %(message)s' # 设置日志格式
)
def load_import_progress(progress_file):
if os.path.exists(progress_file):
with open(progress_file, 'r') as f:
return json.load(f)
return {}
def update_import_progress(progress_file, file_name, row_count):
progress = load_import_progress(progress_file)
progress[file_name] = row_count
with open(progress_file, 'w') as f:
json.dump(progress, f)
def count_rows_with_wc(csv_file):
result = subprocess.run(['wc', '-l', csv_file], stdout=subprocess.PIPE)
total_rows = int(result.stdout.split()[0])
return total_rows
def import_csv_to_mysql(host, port, user, password, database, csv_file, log_file, start_row=0):
setup_logging(log_file)
connection = mysql.connector.connect(
host=host,
port=port,
user=user,
password=password,
database=database
)
cursor = connection.cursor()
table_name = os.path.splitext(os.path.basename(csv_file))[0]
chunksize = 1000
current_row = 0
total_rows = count_rows_with_wc(csv_file)
logging.info(f"Total rows to import: {total_rows}")
try:
for chunk in pd.read_csv(csv_file, chunksize=chunksize, encoding='utf-8'):
chunk.fillna('', inplace=True)
if current_row < start_row:
current_row += len(chunk)
continue
# 插入当前chunk的数据
for _, row in chunk.iterrows():
non_empty_columns = [col for col in chunk.columns if row[col] != '']
non_empty_values = [
int(row[col]) if pd.api.types.is_integer_dtype(row[col]) and str(row[col]).isdigit()
else float(row[col]) if pd.api.types.is_float_dtype(row[col]) # 处理 float64 类型
else row[col]
for col in non_empty_columns
]
placeholders = ', '.join(['%s'] * len(non_empty_columns))
columns = ', '.join(non_empty_columns)
sql = f"INSERT IGNORE INTO {table_name} ({columns}) VALUES ({placeholders})"
cursor.execute(sql, tuple(non_empty_values))
current_row += 1
if current_row % chunksize == 0 or current_row == total_rows:
connection.commit()
update_import_progress(f'import_progress_{table_name}.json', os.path.basename(csv_file), current_row)
logging.info(f"Imported {table_name} {current_row}/{total_rows} rows")
# 不足chunksize的最后一块数据
if current_row % chunksize != 0:
connection.commit()
update_import_progress(f'import_progress_{table_name}.json', os.path.basename(csv_file), current_row)
logging.info(f"Imported {table_name} {current_row}/{total_rows} rows")
except Exception as e:
# 回滚未提交的事务
connection.rollback()
logging.error(f"Error occurred: {e}. Rolling back uncommitted changes.")
raise
cursor.close()
connection.close()
logging.info(f"Imported {csv_file} to MySQL table {table_name}")
def import_csv_to_clickhouse(host, port, user, password, database, csv_file, log_file, start_row=0):
setup_logging(log_file)
client = Client(host=host, port=port, user=user, password=password, database=database)
table_name = os.path.splitext(os.path.basename(csv_file))[0]
chunksize = 1000
for chunk in pd.read_csv(csv_file, chunksize=chunksize, encoding='utf-8'):
chunk.fillna('', inplace=True)
logging.info(f"Inserting {len(chunk)} rows into ClickHouse table {table_name}")
client.execute(f"INSERT INTO {table_name} VALUES", chunk.to_dict(orient='records'))
logging.info(f"Imported {csv_file} to ClickHouse table {table_name}")
def import_csv_file(db_type, host, port, user, password, database, csv_file, log_file):
start_row = load_import_progress(f'import_progress_{os.path.splitext(os.path.basename(csv_file))[0]}.json').get(os.path.basename(csv_file), 0)
if db_type == 'mysql':
import_csv_to_mysql(host, port, user, password, database, csv_file, log_file, start_row)
elif db_type == 'clickhouse':
import_csv_to_clickhouse(host, port, user, password, database, csv_file, log_file, start_row)
else:
logging.error(f"Unsupported database type: {db_type}")
def import_csv_files(db_type, host, port, user, password, database, input_dir):
processes = []
for root, dirs, files in os.walk(input_dir):
for file in files:
if file.endswith('.csv'):
csv_file = os.path.join(root, file)
log_file = f'import_{os.path.splitext(file)[0]}.log'
p = Process(target=import_csv_file, args=(db_type, host, port, user, password, database, csv_file, log_file))
processes.append(p)
p.start()
for p in processes:
p.join()
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Import CSV files to database.")
parser.add_argument('db_type', choices=['mysql', 'clickhouse'], help="Database type (mysql or clickhouse)")
parser.add_argument('host', help="Database host")
parser.add_argument('port', type=int, help="Database port")
parser.add_argument('user', help="Database user")
parser.add_argument('password', help="Database password")
parser.add_argument('database', help="Database name")
parser.add_argument('input_dir', help="Input directory containing CSV files")
args = parser.parse_args()
import_csv_files(args.db_type, args.host, args.port, args.user, args.password, args.database, args.input_dir)
评论区