Python多进程/多线程导入CSV大文件到数据库

创建数据表

#--
DROP TABLE IF EXISTS `t1`;
CREATE TABLE `t1` (
  `id` bigint(20) NOT NULL AUTO_INCREMENT COMMENT 'ID',
  `c1` varchar(100) DEFAULT NULL COMMENT 'c1',
  `c2` varchar(100) DEFAULT NULL COMMENT 'c2',
  `c3` varchar(100) DEFAULT NULL COMMENT 'c3',
  `c4` varchar(100) DEFAULT NULL COMMENT 'c4',
  `c5` varchar(100) DEFAULT NULL COMMENT 'c5',
  `c6` varchar(100) DEFAULT NULL COMMENT 'c6',
  `create_time` datetime DEFAULT NULL COMMENT '创建时间',
  PRIMARY KEY (`id`) USING BTREE
)ENGINE = InnoDB DEFAULT CHARACTER SET = utf8mb4 COMMENT='test';

SELECT count(*) from t1;

多进程导入文件

import pandas as pd
import multiprocessing
import pymysql
import time

print(f'start>>>')
host = 'localhost'
user = 'root'
password = 'root'
# which database to use.
db = 'test'
filename_ = './test.csv'

# 进程数(根据cpu数量来设定)
w = 4
# 批量插入的记录数量
BATCH = 5000

# 数据库连接信息
connection = pymysql.connect(host=host, user=user, password=password, database=db)

def insert_many(l, cursor):
    c_time = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())
    
    s = ','.join(map(str, l))

    sql = f"insert into t1(id, c1, c2, c3, create_time) values {s}"
    #print(f'{sql}')
    cursor.execute(sql)
    connection.commit()
def insert_worker(queue):
    rows = []
    # 每个子进程创建自己的 游标对象
    cursor = connection.cursor()  # 生成游标对象
    while True:
        row = queue.get()
        if row is None:
            if rows:
                # 批量插入方法 或者 sql
                print(f"插入{len(rows)}条")
                insert_many(rows, cursor)
            break

        rows.append(row)
        if len(rows) == BATCH:
            # 批量插入方法 或者 sql
            print("插入5000条")
            insert_many(rows, cursor)
            rows = []


def csv_test():
    # 创建数据队列,主进程读文件并往里写数据,worker 进程从队列读数据
    # 注意一下控制队列的大小,避免消费太慢导致堆积太多数据,占用过多内存
    queue_ = multiprocessing.Queue(maxsize=100)
    workers = []
    for i in range(w):
        # 增加进程任务
        p = multiprocessing.Process(target=insert_worker, args=(queue_,))
        p.start()
        workers.append(p)
        print('starting # %s worker process, pid: %s...', i + 1, p.pid)

    dirty_data_file = r'./error.csv'
    # 表字段
    cols = ["id", "c1", "c2", "c3", "c4", "c5", "c6", "create_time"]
    # 错误数据列表
    dirty_data_list = []

    reader = pd.read_csv(filename_, sep=',', iterator=True) //指定列, usecols=[2, 5, 46, 57]
    #reader['create_time'] = pd.to_datetime(reader['create_time'], format='%Y-%m-%d %H:%M:%S')
    loop = True
    while loop:
        try:
            data = reader.get_chunk(BATCH)  # 返回N行数据块  DataFrame类型
            data_list = data.values.tolist()
            for line in data_list:

                #print(line)
                # 记录并跳过脏数据: 键值数量不一致 或者 其他检验数据错误逻辑
                # cols 表字段
                #if len(line) != (len(cols) + 1):
                #    dirty_data_list.append(line[1:])
                #    continue
                # 把 None 值替换为 'NULL'
                clean_line = [None if x == 'NULL' else x for x in line]
                #clean_line = ['nan' if x == 'NULL' else x for x in line]
                #print(clean_line)

                list_tmp = []
                #list_tmp.append(None)
                list_tmp.append(clean_line[0])
                list_tmp.append(clean_line[1])
                list_tmp.append(clean_line[3])
                list_tmp.append(clean_line[4])

                c_time = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())
                #c_time = time.strftime('%Y-%m-%d %H:%M:%S', clean_line[7])
                list_tmp.append(c_time)
                # 往队列里写数据
                queue_.put(tuple(list_tmp))
        except StopIteration:
            # 错误数据写入文件
            dirty_data_frame = pd.DataFrame(dirty_data_list, columns=cols)
            dirty_data_frame.to_csv(dirty_data_file)
            loop = False

    # 给每个 worker 发送任务结束的信号
    print('send close signal to worker processes')
    for i in range(w):
        queue_.put(None)

    for p in workers:
        p.join()


if __name__ == '__main__':
    csv_test()

多线程导入文件

#!/usr/bin/env python3
# encoding: utf-8

import pandas as pd
import threading
from queue import Queue
import pymysql
import time
import os

print(f'start>>>')
host = 'localhost'
user = 'root'
password = 'root'
# which database to use.
db = 'test'
filename_ = './test.csv'

# 进程数
w = 4
# 批量插入的记录数量
BATCH = 5000

# 数据库连接信息
connection = pymysql.connect(host=host, user=user, password=password, database=db)

def insert_many(l, cursor):
    mutex.acquire()
    c_time = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())
    
    s = ','.join(map(str, l))

    sql = f"insert into t2(id, c1, c2, c3, create_time) values {s}"
    #print(f'{sql}')
    cursor.execute(sql)
    connection.commit()
    mutex.release()


def insert_worker(queue, thread_id):
    rows = []
    # 每个子进程创建自己的 游标对象
    cursor = connection.cursor()  # 生成游标对象
    while True:
        row = queue.get()
        if row is None:
            if rows:
                # 批量插入方法 或者 sql
                print(f"{thread_id}插入{len(rows)}条")
                insert_many(rows, cursor)
            break

        rows.append(row)
        if len(rows) == BATCH:
            # 批量插入方法 或者 sql
            print(f"{thread_id}插入5000条")
            insert_many(rows, cursor)
            rows = []


def csv_test():
    # 创建数据队列,主进程读文件并往里写数据,worker 进程从队列读数据
    # 注意一下控制队列的大小,避免消费太慢导致堆积太多数据,占用过多内存
    queue_ = Queue()
    workers = []
    for i in range(w):
        # 增加进程任务
        p = threading.Thread(target=insert_worker, args=(queue_, 'thread-' + str(i)))
        p.start()
        workers.append(p)
        print('starting # %s worker process, pid: %s...', i + 1)

    dirty_data_file = r'./error.csv'
    # 表字段
    cols = ["id", "c1", "c2", "c3", "c4", "c5", "c6", "create_time"]
    # 错误数据列表
    dirty_data_list = []

    reader = pd.read_csv(filename_, sep=',', iterator=True)
    #reader['create_time'] = pd.to_datetime(reader['create_time'], format='%Y-%m-%d %H:%M:%S')
    loop = True
    while loop:
        try:
            data = reader.get_chunk(BATCH)  # 返回N行数据块  DataFrame类型
            data_list = data.values.tolist()
            for line in data_list:
                # 记录并跳过脏数据: 键值数量不一致 或者 其他检验数据错误逻辑
                # cols 表字段
                #if len(line) != (len(cols) + 1):
                #    dirty_data_list.append(line[1:])
                #    continue
                # 把 None 值替换为 'NULL'
                clean_line = [None if x == 'NULL' else x for x in line]

                list_tmp = []
                #list_tmp.append(None)
                list_tmp.append(clean_line[0])
                list_tmp.append(clean_line[1])
                list_tmp.append(clean_line[3])
                list_tmp.append(clean_line[4])

                c_time = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())
                #c_time = time.strftime('%Y-%m-%d %H:%M:%S', clean_line[7])
                list_tmp.append(c_time)
                # 往队列里写数据
                queue_.put(tuple(list_tmp))
        except StopIteration:
            # 错误数据写入文件
            dirty_data_frame = pd.DataFrame(dirty_data_list, columns=cols)
            dirty_data_frame.to_csv(dirty_data_file)
            loop = False

    # 给每个 worker 发送任务结束的信号
    print('send close signal to worker processes')
    for i in range(w):
        queue_.put(None)

    for p in workers:
        p.join()


if __name__ == '__main__':
    print(f'process id {os.getpid()}')
    mutex = threading.Lock() 
    csv_test()

本次数据模板

原始数量

导入结果

This entry was posted in 应用. Bookmark the permalink.

发表评论