Python多进程/多线程导入CSV大文件到数据库

创建数据表

#--
DROP TABLE IF EXISTS `t1`;
CREATE TABLE `t1` (
  `id` bigint(20) NOT NULL AUTO_INCREMENT COMMENT 'ID',
  `c1` varchar(100) DEFAULT NULL COMMENT 'c1',
  `c2` varchar(100) DEFAULT NULL COMMENT 'c2',
  `c3` varchar(100) DEFAULT NULL COMMENT 'c3',
  `c4` varchar(100) DEFAULT NULL COMMENT 'c4',
  `c5` varchar(100) DEFAULT NULL COMMENT 'c5',
  `c6` varchar(100) DEFAULT NULL COMMENT 'c6',
  `create_time` datetime DEFAULT NULL COMMENT '创建时间',
  PRIMARY KEY (`id`) USING BTREE
)ENGINE = InnoDB DEFAULT CHARACTER SET = utf8mb4 COMMENT='test';

SELECT count(*) from t1;

一、多进程导入文件

import pandas as pd
import multiprocessing
import pymysql
import time

print(f'start>>>')
host = 'localhost'
user = 'root'
password = 'root'
# which database to use.
db = 'test'
filename_ = './test.csv'

# 进程数(根据cpu数量来设定)
w = 4
# 批量插入的记录数量
BATCH = 5000

# 数据库连接信息
connection = pymysql.connect(host=host, user=user, password=password, database=db)

def insert_many(l, cursor):
    c_time = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())
    
    s = ','.join(map(str, l))

    sql = f"insert into t1(id, c1, c2, c3, create_time) values {s}"
    #print(f'{sql}')
    cursor.execute(sql)
    connection.commit()
def insert_worker(queue):
    rows = []
    # 每个子进程创建自己的 游标对象
    cursor = connection.cursor()  # 生成游标对象
    while True:
        row = queue.get()
        if row is None:
            if rows:
                # 批量插入方法 或者 sql
                print(f"插入{len(rows)}条")
                insert_many(rows, cursor)
            break

        rows.append(row)
        if len(rows) == BATCH:
            # 批量插入方法 或者 sql
            print("插入5000条")
            insert_many(rows, cursor)
            rows = []


def csv_test():
    # 创建数据队列,主进程读文件并往里写数据,worker 进程从队列读数据
    # 注意一下控制队列的大小,避免消费太慢导致堆积太多数据,占用过多内存
    queue_ = multiprocessing.Queue(maxsize=100)
    workers = []
    for i in range(w):
        # 增加进程任务
        p = multiprocessing.Process(target=insert_worker, args=(queue_,))
        p.start()
        workers.append(p)
        print('starting # %s worker process, pid: %s...', i + 1, p.pid)

    dirty_data_file = r'./error.csv'
    # 表字段
    cols = ["id", "c1", "c2", "c3", "c4", "c5", "c6", "create_time"]
    # 错误数据列表
    dirty_data_list = []

    reader = pd.read_csv(filename_, sep=',', iterator=True) //指定列, usecols=[2, 5, 46, 57]
    #reader['create_time'] = pd.to_datetime(reader['create_time'], format='%Y-%m-%d %H:%M:%S')
    loop = True
    while loop:
        try:
            data = reader.get_chunk(BATCH)  # 返回N行数据块  DataFrame类型
            data_list = data.values.tolist()
            for line in data_list:

                #print(line)
                # 记录并跳过脏数据: 键值数量不一致 或者 其他检验数据错误逻辑
                # cols 表字段
                #if len(line) != (len(cols) + 1):
                #    dirty_data_list.append(line[1:])
                #    continue
                # 把 None 值替换为 'NULL'
                clean_line = [None if x == 'NULL' else x for x in line]
                #clean_line = ['nan' if x == 'NULL' else x for x in line]
                #print(clean_line)

                list_tmp = []
                #list_tmp.append(None)
                list_tmp.append(clean_line[0])
                list_tmp.append(clean_line[1])
                list_tmp.append(clean_line[3])
                list_tmp.append(clean_line[4])

                c_time = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())
                #c_time = time.strftime('%Y-%m-%d %H:%M:%S', clean_line[7])
                list_tmp.append(c_time)
                # 往队列里写数据
                queue_.put(tuple(list_tmp))
        except StopIteration:
            # 错误数据写入文件
            dirty_data_frame = pd.DataFrame(dirty_data_list, columns=cols)
            dirty_data_frame.to_csv(dirty_data_file)
            loop = False

    # 给每个 worker 发送任务结束的信号
    print('send close signal to worker processes')
    for i in range(w):
        queue_.put(None)

    for p in workers:
        p.join()


if __name__ == '__main__':
    csv_test()

二、多线程导入文件

#!/usr/bin/env python3
# encoding: utf-8

import pandas as pd
import threading
from queue import Queue
import pymysql
import time
import os

print(f'start>>>')
host = 'localhost'
user = 'root'
password = 'root'
# which database to use.
db = 'test'
filename_ = './test.csv'

# 进程数
w = 4
# 批量插入的记录数量
BATCH = 5000

# 数据库连接信息
connection = pymysql.connect(host=host, user=user, password=password, database=db)

def insert_many(l, cursor):
    mutex.acquire()
    c_time = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())
    
    s = ','.join(map(str, l))

    sql = f"insert into t2(id, c1, c2, c3, create_time) values {s}"
    #print(f'{sql}')
    cursor.execute(sql)
    connection.commit()
    mutex.release()


def insert_worker(queue, thread_id):
    rows = []
    # 每个子进程创建自己的 游标对象
    cursor = connection.cursor()  # 生成游标对象
    while True:
        row = queue.get()
        if row is None:
            if rows:
                # 批量插入方法 或者 sql
                print(f"{thread_id}插入{len(rows)}条")
                insert_many(rows, cursor)
            break

        rows.append(row)
        if len(rows) == BATCH:
            # 批量插入方法 或者 sql
            print(f"{thread_id}插入5000条")
            insert_many(rows, cursor)
            rows = []


def csv_test():
    # 创建数据队列,主进程读文件并往里写数据,worker 进程从队列读数据
    # 注意一下控制队列的大小,避免消费太慢导致堆积太多数据,占用过多内存
    queue_ = Queue()
    workers = []
    for i in range(w):
        # 增加进程任务
        p = threading.Thread(target=insert_worker, args=(queue_, 'thread-' + str(i)))
        p.start()
        workers.append(p)
        print('starting # %s worker process, pid: %s...', i + 1)

    dirty_data_file = r'./error.csv'
    # 表字段
    cols = ["id", "c1", "c2", "c3", "c4", "c5", "c6", "create_time"]
    # 错误数据列表
    dirty_data_list = []

    reader = pd.read_csv(filename_, sep=',', iterator=True)
    #reader['create_time'] = pd.to_datetime(reader['create_time'], format='%Y-%m-%d %H:%M:%S')
    loop = True
    while loop:
        try:
            data = reader.get_chunk(BATCH)  # 返回N行数据块  DataFrame类型
            data_list = data.values.tolist()
            for line in data_list:
                # 记录并跳过脏数据: 键值数量不一致 或者 其他检验数据错误逻辑
                # cols 表字段
                #if len(line) != (len(cols) + 1):
                #    dirty_data_list.append(line[1:])
                #    continue
                # 把 None 值替换为 'NULL'
                clean_line = [None if x == 'NULL' else x for x in line]

                list_tmp = []
                #list_tmp.append(None)
                list_tmp.append(clean_line[0])
                list_tmp.append(clean_line[1])
                list_tmp.append(clean_line[3])
                list_tmp.append(clean_line[4])

                c_time = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())
                #c_time = time.strftime('%Y-%m-%d %H:%M:%S', clean_line[7])
                list_tmp.append(c_time)
                # 往队列里写数据
                queue_.put(tuple(list_tmp))
        except StopIteration:
            # 错误数据写入文件
            dirty_data_frame = pd.DataFrame(dirty_data_list, columns=cols)
            dirty_data_frame.to_csv(dirty_data_file)
            loop = False

    # 给每个 worker 发送任务结束的信号
    print('send close signal to worker processes')
    for i in range(w):
        queue_.put(None)

    for p in workers:
        p.join()


if __name__ == '__main__':
    print(f'process id {os.getpid()}')
    mutex = threading.Lock() 
    csv_test()

本次数据模板

原始数量

导入结果

三、优化版本

主要优化点
    命令行参数支持:
        新增多个命令行参数,包括:
            --host:数据库地址(默认 localhost)。
            --user:数据库用户名(默认 root)。
            --password:数据库密码(默认 root)。
            --db:数据库名称(默认 test)。
            --port:数据库端口(默认 3306)。
            --table:表名(默认 imported_data)。
            --threads:线程数(默认 4)。
            --batch-size:批量插入大小(默认 5000)。
            --encoding: csv文件编码(默认utf-8)。
    动态配置:
        将所有可变参数(如数据库配置、表名、线程数等)通过命令行传入。
        使用 argparse 模块解析命令行参数。
    代码结构优化:
        将数据库配置封装为字典 db_config,便于传递和管理。
        将表名作为参数传递给相关函数。
    默认值:
        为所有命令行参数提供默认值,简化使用。

源码:

#!/usr/bin/env python3
# encoding: utf-8
import pandas as pd
import threading
from queue import Queue
import pymysql
import time
import os
import argparse
from typing import List, Optional, Dict

# 默认配置
DEFAULT_HOST = 'localhost'
DEFAULT_USER = 'root'
DEFAULT_PASSWORD = 'root'
DEFAULT_DB = 'test'
DEFAULT_PORT = 3306
DEFAULT_TABLE_NAME = 'imported_data'
DEFAULT_NUM_THREADS = 4
DEFAULT_BATCH_SIZE = 5000
DEFAULT_ENCODING = 'utf-8'  # 默认编码

def clean_value(value) -> Optional[str]:
    """清理数据:将 NaN 转换为 None"""
    return None if pd.isna(value) else value

def infer_column_type(series: pd.Series) -> str:
    """
    根据 Pandas Series 推断列的数据类型。
    返回 MySQL 对应的数据类型。
    """
    try:
        # 尝试转换为整数
        pd.to_numeric(series, downcast='integer')
        return "FLOAT"
    except ValueError:
        try:
            # 尝试转换为日期时间
            pd.to_datetime(series)
            return "DATETIME"
        except ValueError:
            # 默认使用 VARCHAR(255)
            return "VARCHAR(255)"


def create_table_from_csv(csv_path: str, connection, table_name: str, encoding: str) -> Optional[List[str]]:
    """
    根据CSV文件的内容生成建表语句并创建表。
    返回表的列名列表。
    """
    try:
        # 读取CSV文件的前100行以推断列类型
        df = pd.read_csv(csv_path, nrows=100, encoding=encoding)
        columns = df.columns.tolist()

        # 推断每列的数据类型
        column_definitions = []
        for col in columns:
            col_type = infer_column_type(df[col])
            column_definitions.append(f"{col} {col_type}")

        # 生成建表语句
        create_table_sql = f"CREATE TABLE IF NOT EXISTS {table_name} ("
        create_table_sql += ", ".join(column_definitions) + ");"

        # 执行建表语句
        with connection.cursor() as cursor:
            cursor.execute(f"DROP TABLE IF EXISTS {table_name};")
            cursor.execute(create_table_sql)
            connection.commit()
            print(f"表 {table_name} 创建成功")
            print(f"建表语句: {create_table_sql}")

        return columns
    except Exception as e:
        print(f"创建表失败: {e}")
        return None


def insert_many(rows: List[tuple], cursor, connection, table_name: str):
    """使用参数化查询进行批量插入"""
    sql = f"""
    INSERT INTO {table_name}
    VALUES ({', '.join(['%s'] * len(rows[0]))})
    """
    try:
        cursor.executemany(sql, rows)
        connection.commit()
    except Exception as e:
        connection.rollback()
        print(f"插入失败: {e}")


def insert_worker(queue: Queue, thread_id: str, db_config: Dict, table_name: str):
    """数据库写入工作线程"""
    try:
        # 每个线程使用独立连接
        conn = pymysql.connect(
            host=db_config['host'],
            user=db_config['user'],
            password=db_config['password'],
            database=db_config['database'],
            port=db_config['port'],
            autocommit=False  # 使用手动提交
        )
        cursor = conn.cursor()
        buffer = []

        while True:
            item = queue.get()
            if item is None:  # 终止信号
                if buffer:
                    print(f"{thread_id} 插入剩余 {len(buffer)} 条")
                    insert_many(buffer, cursor, conn, table_name)
                break

            buffer.append(item)
            if len(buffer) >= db_config['batch_size']:
                print(f"{thread_id} 插入 {len(buffer)} 条")
                insert_many(buffer, cursor, conn, table_name)
                buffer = []

    except Exception as e:
        print(f"线程 {thread_id} 发生异常: {str(e)}")
    finally:
        if 'cursor' in locals():
            cursor.close()
        if 'conn' in locals():
            conn.close()
        print(f"{thread_id} 退出")


def csv_importer(csv_path: str, db_config: Dict, table_name: str, encoding: str):
    """CSV导入主逻辑"""
    # 初始化队列和工作线程
    queue = Queue(maxsize=10000)  # 控制队列大小防止内存溢出
    threads = []

    # 启动工作线程
    for i in range(db_config['num_threads']):
        t = threading.Thread(
            target=insert_worker,
            args=(queue, f'Thread-{i}', db_config, table_name)
        )
        t.start()
        threads.append(t)
        print(f'启动工作线程 {i}')

    # 创建数据库连接
    conn = pymysql.connect(
        host=db_config['host'],
        user=db_config['user'],
        password=db_config['password'],
        database=db_config['database'],
        port=db_config['port']
    )

    # 根据CSV文件创建表
    columns = create_table_from_csv(csv_path, conn, table_name, encoding)
    if not columns:
        print("无法创建表,退出程序")
        return

    # 读取CSV并填充队列
    try:
        # 使用更高效的低内存模式读取
        reader = pd.read_csv(csv_path, chunksize=db_config['batch_size'], iterator=True, encoding=encoding)
        dirty_data = []

        for chunk in reader:
            # 处理数据块
            for _, row in chunk.iterrows():
                # 数据校验
                #if row.isnull().sum() > 2:  # 示例校验逻辑
                #    dirty_data.append(row)
                #    continue

                # 构造插入数据
                try:
                    data = tuple(clean_value(row[col]) for col in columns)
                    queue.put(data)
                except Exception as e:
                    print(f"数据格式错误: {row},错误: {str(e)}")
                    dirty_data.append(row)

        # 保存脏数据
        if dirty_data:
            dirty_data_file = os.path.join(os.path.dirname(csv_path), 'error.csv')
            pd.DataFrame(dirty_data, columns=columns).to_csv(dirty_data_file, index=False, encoding=encoding)
            print(f"保存脏数据到 {dirty_data_file}")

    except Exception as e:
        print(f"读取文件失败: {str(e)}")
    finally:
        # 发送终止信号
        for _ in range(db_config['num_threads']):
            queue.put(None)

        # 等待线程结束
        for t in threads:
            t.join()

        # 关闭数据库连接
        if 'conn' in locals():
            conn.close()


def main():
    """主函数"""
    # 解析命令行参数
    parser = argparse.ArgumentParser(description="导入CSV文件到MySQL数据库")
    parser.add_argument('csv_path', type=str, help="CSV文件路径")
    parser.add_argument('--host', type=str, default=DEFAULT_HOST, help="数据库地址")
    parser.add_argument('--user', type=str, default=DEFAULT_USER, help="数据库用户名")
    parser.add_argument('--password', type=str, default=DEFAULT_PASSWORD, help="数据库密码")
    parser.add_argument('--db', type=str, default=DEFAULT_DB, help="数据库名称")
    parser.add_argument('--port', type=int, default=DEFAULT_PORT, help="数据库端口")
    parser.add_argument('--table', type=str, default=DEFAULT_TABLE_NAME, help="表名")
    parser.add_argument('--threads', type=int, default=DEFAULT_NUM_THREADS, help="线程数")
    parser.add_argument('--batch-size', type=int, default=DEFAULT_BATCH_SIZE, help="批量插入大小")
    parser.add_argument('--encoding', type=str, default=DEFAULT_ENCODING, help="CSV文件编码")
    args = parser.parse_args()

    # 检查文件是否存在
    if not os.path.exists(args.csv_path):
        print(f"文件 {args.csv_path} 不存在")
        return

    # 数据库配置
    db_config = {
        'host': args.host,
        'user': args.user,
        'password': args.password,
        'database': args.db,
        'port': args.port,
        'num_threads': args.threads,
        'batch_size': args.batch_size
    }

    print(f'主进程ID: {os.getpid()}')
    csv_importer(args.csv_path, db_config, args.table, args.encoding)
    print('导入完成')


if __name__ == '__main__':
    main()

requirements.txt

pandas
pymysql

操作:

首次:
python -m venv venv
venv\Scripts\activate
pip3 install -r requirements.txt
python i3.py ./test.csv --table test_table_2 --encoding gbk

再次:
venv\Scripts\activate
python i3.py ./test.csv --table test_table_2 --encoding gbk
This entry was posted in 应用. Bookmark the permalink.

发表评论