170 lines
5.7 KiB
Python
170 lines
5.7 KiB
Python
import pymysql
|
||
import sys
|
||
import os
|
||
import glob
|
||
import json
|
||
import time
|
||
from urllib.parse import urlparse, parse_qs
|
||
|
||
def parse_jdbc_url(url):
|
||
# jdbc:mysql://host:port/db?params
|
||
# remove jdbc:mysql://
|
||
if url.startswith("jdbc:mysql://"):
|
||
url = url[13:]
|
||
|
||
# split host:port and db
|
||
if "/" in url:
|
||
address, remainder = url.split("/", 1)
|
||
if "?" in remainder:
|
||
db, params = remainder.split("?", 1)
|
||
else:
|
||
db = remainder
|
||
params = ""
|
||
else:
|
||
address = url
|
||
db = ""
|
||
params = ""
|
||
|
||
if ":" in address:
|
||
host, port = address.split(":")
|
||
port = int(port)
|
||
else:
|
||
host = address
|
||
port = 3306
|
||
|
||
return host, port, db
|
||
|
||
def load_csv(jdbc_url, user, password, table, csv_dir, columns=None):
|
||
host, port, db = parse_jdbc_url(jdbc_url)
|
||
|
||
print(f"Connecting to MySQL {host}:{port}/{db} as {user}...")
|
||
|
||
try:
|
||
conn = pymysql.connect(
|
||
host=host,
|
||
port=port,
|
||
user=user,
|
||
password=password,
|
||
database=db,
|
||
local_infile=True,
|
||
charset='utf8mb4',
|
||
cursorclass=pymysql.cursors.DictCursor
|
||
)
|
||
except Exception as e:
|
||
print(f"Connection failed: {e}")
|
||
sys.exit(1)
|
||
|
||
try:
|
||
with conn.cursor() as cursor:
|
||
# Optimization settings
|
||
print("Setting session parameters for speed...")
|
||
cursor.execute("SET NAMES utf8mb4")
|
||
|
||
# 逐个尝试设置优化参数,避免因单个参数(如 SQL_LOG_BIN)权限不足导致整体失败
|
||
opts = [
|
||
("SET FOREIGN_KEY_CHECKS = 0", "Foreign Key Checks Disabled"),
|
||
("SET UNIQUE_CHECKS = 0", "Unique Checks Disabled"),
|
||
("SET SQL_LOG_BIN = 0", "Binary Logging Disabled (Requires SUPER privilege)")
|
||
]
|
||
|
||
for sql, desc in opts:
|
||
try:
|
||
cursor.execute(sql)
|
||
print(f" - {desc}: Success")
|
||
except Exception as e:
|
||
# 如果是权限问题 (1227),打印更友好的信息
|
||
if "1227" in str(e):
|
||
print(f" - {desc}: Skipped (Insufficient privileges, but that's okay)")
|
||
else:
|
||
print(f" - {desc}: Failed ({e})")
|
||
|
||
# Truncate table
|
||
print(f"Truncating table {table}...")
|
||
cursor.execute(f"TRUNCATE TABLE `{table}`")
|
||
|
||
# Find files
|
||
files = glob.glob(os.path.join(csv_dir, "*"))
|
||
if not files:
|
||
print(f"No files found in {csv_dir}")
|
||
return
|
||
|
||
total_rows = 0
|
||
start_time = time.time()
|
||
|
||
for file_path in files:
|
||
file_path = os.path.abspath(file_path).replace('\\', '/')
|
||
print(f"Loading file: {file_path}")
|
||
|
||
# Build SQL
|
||
# Assuming DataX txtfilewriter defaults:
|
||
# separator: ,
|
||
# quoteChar: "
|
||
# escapeChar: \
|
||
# nullFormat: \N
|
||
|
||
col_sql = ""
|
||
if columns:
|
||
col_list = [f"`{c}`" for c in columns]
|
||
col_sql = "(" + ", ".join(col_list) + ")"
|
||
|
||
sql = f"""
|
||
LOAD DATA LOCAL INFILE '{file_path}'
|
||
INTO TABLE `{table}`
|
||
CHARACTER SET utf8mb4
|
||
FIELDS TERMINATED BY ','
|
||
OPTIONALLY ENCLOSED BY '"'
|
||
ESCAPED BY '\\\\'
|
||
LINES TERMINATED BY '\\n'
|
||
{col_sql}
|
||
"""
|
||
|
||
cursor.execute(sql)
|
||
rows = cursor.rowcount
|
||
total_rows += rows
|
||
print(f" -> Loaded {rows} rows")
|
||
|
||
# 显示 MySQL 警告(SHOW WARNINGS)的功能,用于排查导入差异
|
||
try:
|
||
cursor.execute("SHOW WARNINGS")
|
||
warnings = cursor.fetchall()
|
||
if warnings:
|
||
print(f" - MySQL Warnings ({len(warnings)}):")
|
||
# 最多显示前 5 条警告,避免日志过多
|
||
for i, warn in enumerate(warnings[:5]):
|
||
print(f" - {warn.get('Level', 'Warning')}: {warn.get('Message', 'Unknown error')}")
|
||
if len(warnings) > 5:
|
||
print(f" - ... and {len(warnings) - 5} more warnings")
|
||
except Exception as warn_e:
|
||
print(f" - Could not fetch warnings: {warn_e}")
|
||
|
||
conn.commit()
|
||
|
||
duration = time.time() - start_time
|
||
print(f"Total loaded: {total_rows} rows in {duration:.2f}s ({total_rows/duration if duration > 0 else 0:.2f} rows/s)")
|
||
|
||
except Exception as e:
|
||
print(f"Error during load: {e}")
|
||
sys.exit(1)
|
||
finally:
|
||
conn.close()
|
||
|
||
if __name__ == "__main__":
|
||
if len(sys.argv) < 6:
|
||
print("Usage: python LoadCsvToMysql.py <jdbc_url> <user> <password> <table> <csv_dir> [columns_json]")
|
||
sys.exit(1)
|
||
|
||
jdbc_url = sys.argv[1]
|
||
user = sys.argv[2]
|
||
password = sys.argv[3]
|
||
table = sys.argv[4]
|
||
csv_dir = sys.argv[5]
|
||
|
||
columns = None
|
||
if len(sys.argv) > 6:
|
||
try:
|
||
columns = json.loads(sys.argv[6])
|
||
except:
|
||
print("Warning: Could not parse columns JSON")
|
||
|
||
load_csv(jdbc_url, user, password, table, csv_dir, columns)
|