侧边栏壁纸
博主头像
ZHD的小窝博主等级

行动起来,活在当下

  • 累计撰写 79 篇文章
  • 累计创建 53 个标签
  • 累计收到 1 条评论

目 录CONTENT

文章目录

Python将CSV数据导入到MySQL

江南的风
2025-02-10 / 0 评论 / 0 点赞 / 20 阅读 / 6901 字 / 正在检测是否收录...

将上一篇文章中的数据导入到MySQL数据库

合适在不同环境中不同数据库版本之间的数据同步

导入时记录导入的进度,支持断点续导,每一个表一个日志和进度标识存储。可以边导入边通过查看日志文件查询进度。

import json
import os
import subprocess
import sys
import pandas as pd
import mysql.connector
from clickhouse_driver import Client
import logging
import time
from multiprocessing import Process


def setup_logging(log_file):
    # 配置日志记录
    logging.basicConfig(
        filename=log_file,  # 指定日志文件名
        level=logging.INFO,  # 设置日志级别
        format='%(asctime)s - %(levelname)s - %(message)s'  # 设置日志格式
    )

def load_import_progress(progress_file):
    if os.path.exists(progress_file):
        with open(progress_file, 'r') as f:
            return json.load(f)
    return {}

def update_import_progress(progress_file, file_name, row_count):
    progress = load_import_progress(progress_file)
    progress[file_name] = row_count
    with open(progress_file, 'w') as f:
        json.dump(progress, f)

def count_rows_with_wc(csv_file):
    result = subprocess.run(['wc', '-l', csv_file], stdout=subprocess.PIPE)
    total_rows = int(result.stdout.split()[0])
    return total_rows

def import_csv_to_mysql(host, port, user, password, database, csv_file, log_file, start_row=0):
    setup_logging(log_file)
    connection = mysql.connector.connect(
        host=host,
        port=port,
        user=user,
        password=password,
        database=database
    )
    cursor = connection.cursor()

    table_name = os.path.splitext(os.path.basename(csv_file))[0]

    chunksize = 1000
    current_row = 0
    total_rows = count_rows_with_wc(csv_file)
    logging.info(f"Total rows to import: {total_rows}")
    try:
        for chunk in pd.read_csv(csv_file, chunksize=chunksize, encoding='utf-8'):
            chunk.fillna('', inplace=True)

            if current_row < start_row:
                current_row += len(chunk)
                continue

            # 插入当前chunk的数据
            for _, row in chunk.iterrows():
                non_empty_columns = [col for col in chunk.columns if row[col] != '']
                non_empty_values = [
                    int(row[col]) if pd.api.types.is_integer_dtype(row[col]) and str(row[col]).isdigit()
                    else float(row[col]) if pd.api.types.is_float_dtype(row[col])  # 处理 float64 类型
                    else row[col]
                    for col in non_empty_columns
                ]

                placeholders = ', '.join(['%s'] * len(non_empty_columns))
                columns = ', '.join(non_empty_columns)
                sql = f"INSERT IGNORE INTO {table_name} ({columns}) VALUES ({placeholders})"

                cursor.execute(sql, tuple(non_empty_values))
                current_row += 1

                if current_row % chunksize == 0 or current_row == total_rows:
                    connection.commit()
                    update_import_progress(f'import_progress_{table_name}.json', os.path.basename(csv_file), current_row)
                    logging.info(f"Imported {table_name} {current_row}/{total_rows} rows")

        # 不足chunksize的最后一块数据
        if current_row % chunksize != 0:
            connection.commit()
            update_import_progress(f'import_progress_{table_name}.json', os.path.basename(csv_file), current_row)
            logging.info(f"Imported {table_name} {current_row}/{total_rows} rows")
    except Exception as e:
        # 回滚未提交的事务
        connection.rollback()
        logging.error(f"Error occurred: {e}. Rolling back uncommitted changes.")
        raise
    cursor.close()
    connection.close()
    logging.info(f"Imported {csv_file} to MySQL table {table_name}")

def import_csv_to_clickhouse(host, port, user, password, database, csv_file, log_file, start_row=0):
    setup_logging(log_file)
    client = Client(host=host, port=port, user=user, password=password, database=database)

    table_name = os.path.splitext(os.path.basename(csv_file))[0]

    chunksize = 1000
    for chunk in pd.read_csv(csv_file, chunksize=chunksize, encoding='utf-8'):
        chunk.fillna('', inplace=True)

        logging.info(f"Inserting {len(chunk)} rows into ClickHouse table {table_name}")
        client.execute(f"INSERT INTO {table_name} VALUES", chunk.to_dict(orient='records'))

    logging.info(f"Imported {csv_file} to ClickHouse table {table_name}")

def import_csv_file(db_type, host, port, user, password, database, csv_file, log_file):
    start_row = load_import_progress(f'import_progress_{os.path.splitext(os.path.basename(csv_file))[0]}.json').get(os.path.basename(csv_file), 0)
    if db_type == 'mysql':
        import_csv_to_mysql(host, port, user, password, database, csv_file, log_file, start_row)
    elif db_type == 'clickhouse':
        import_csv_to_clickhouse(host, port, user, password, database, csv_file, log_file, start_row)
    else:
        logging.error(f"Unsupported database type: {db_type}")


def import_csv_files(db_type, host, port, user, password, database, input_dir):
    processes = []

    for root, dirs, files in os.walk(input_dir):
        for file in files:
            if file.endswith('.csv'):
                csv_file = os.path.join(root, file)
                log_file = f'import_{os.path.splitext(file)[0]}.log'
                p = Process(target=import_csv_file, args=(db_type, host, port, user, password, database, csv_file, log_file))
                processes.append(p)
                p.start()

    for p in processes:
        p.join()

if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description="Import CSV files to database.")
    parser.add_argument('db_type', choices=['mysql', 'clickhouse'], help="Database type (mysql or clickhouse)")
    parser.add_argument('host', help="Database host")
    parser.add_argument('port', type=int, help="Database port")
    parser.add_argument('user', help="Database user")
    parser.add_argument('password', help="Database password")
    parser.add_argument('database', help="Database name")
    parser.add_argument('input_dir', help="Input directory containing CSV files")

    args = parser.parse_args()
    import_csv_files(args.db_type, args.host, args.port, args.user, args.password, args.database, args.input_dir)

0

评论区