150 lines
3.5 KiB
Python
150 lines
3.5 KiB
Python
import tushare as ts
|
|
import pymysql
|
|
from datetime import datetime
|
|
import time
|
|
|
|
# ===== 配置 =====
|
|
TUSHARE_TOKEN = "0fad3cf498757089e2630028455d5bbe1637475788bcdaa1f2175e93"
|
|
|
|
DB_CONFIG = {
|
|
"host": "8.159.129.156",
|
|
"port": 10836,
|
|
"user": "yangfan",
|
|
"password": "aA%8023321088",
|
|
"database": "level",
|
|
"charset": "utf8mb4"
|
|
}
|
|
|
|
# ===== 初始化 =====
|
|
ts.set_token(TUSHARE_TOKEN)
|
|
pro = ts.pro_api()
|
|
|
|
conn = pymysql.connect(**DB_CONFIG)
|
|
cursor = conn.cursor()
|
|
|
|
|
|
# ===== 建表 =====
|
|
def create_tables():
|
|
sql = """
|
|
CREATE TABLE IF NOT EXISTS stock_daily (
|
|
id BIGINT AUTO_INCREMENT PRIMARY KEY,
|
|
ts_code VARCHAR(20) NOT NULL,
|
|
symbol VARCHAR(10) NOT NULL,
|
|
trade_date DATE NOT NULL,
|
|
open DECIMAL(16,4),
|
|
high DECIMAL(16,4),
|
|
low DECIMAL(16,4),
|
|
close DECIMAL(16,4),
|
|
pre_close DECIMAL(16,4),
|
|
change_val DECIMAL(16,4),
|
|
pct_chg DECIMAL(10,4),
|
|
vol BIGINT,
|
|
amount BIGINT,
|
|
UNIQUE KEY uniq_code_date (ts_code, trade_date)
|
|
);
|
|
"""
|
|
cursor.execute(sql)
|
|
conn.commit()
|
|
print("表检查/创建完成")
|
|
|
|
|
|
# ===== 日期转换 =====
|
|
def str_to_date(date_str):
|
|
return datetime.strptime(date_str, "%Y%m%d").date()
|
|
|
|
|
|
# ===== 获取交易日 =====
|
|
def get_trade_dates(start_date, end_date):
|
|
sql = """
|
|
SELECT cal_date
|
|
FROM trade_calendar
|
|
WHERE cal_date BETWEEN %s AND %s
|
|
AND is_open = 1
|
|
ORDER BY cal_date
|
|
"""
|
|
cursor.execute(sql, (str_to_date(start_date), str_to_date(end_date)))
|
|
results = cursor.fetchall()
|
|
|
|
return [row[0].strftime("%Y%m%d") for row in results]
|
|
|
|
|
|
# ===== 插入数据 =====
|
|
def insert_data(df):
|
|
if df is None or df.empty:
|
|
return 0, 0
|
|
|
|
sql = """
|
|
INSERT IGNORE INTO stock_daily
|
|
(ts_code, symbol, trade_date, open, high, low, close, pre_close, change_val, pct_chg, vol, amount)
|
|
VALUES (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)
|
|
"""
|
|
|
|
data = []
|
|
for _, row in df.iterrows():
|
|
ts_code = row['ts_code']
|
|
symbol = ts_code.split('.')[0] # ✅ 只改这里
|
|
|
|
data.append((
|
|
ts_code,
|
|
symbol,
|
|
str_to_date(row['trade_date']), # 保持 DATE 类型
|
|
row['open'],
|
|
row['high'],
|
|
row['low'],
|
|
row['close'],
|
|
row['pre_close'],
|
|
row['change'],
|
|
row['pct_chg'],
|
|
row['vol'],
|
|
row['amount']
|
|
))
|
|
|
|
cursor.executemany(sql, data)
|
|
conn.commit()
|
|
|
|
inserted = cursor.rowcount
|
|
skipped = len(data) - inserted
|
|
|
|
return inserted, skipped
|
|
|
|
|
|
# ===== 主逻辑 =====
|
|
def run(start_date, end_date):
|
|
create_tables()
|
|
|
|
trade_dates = get_trade_dates(start_date, end_date)
|
|
|
|
print(f"需要处理 {len(trade_dates)} 个交易日")
|
|
|
|
total_inserted = 0
|
|
total_skipped = 0
|
|
|
|
for trade_date in trade_dates:
|
|
print(f"\n处理交易日: {trade_date}")
|
|
|
|
try:
|
|
df = pro.daily(trade_date=trade_date)
|
|
|
|
inserted, skipped = insert_data(df)
|
|
|
|
total_inserted += inserted
|
|
total_skipped += skipped
|
|
|
|
print(f"新增: {inserted}, 跳过: {skipped}")
|
|
|
|
time.sleep(0.2)
|
|
|
|
except Exception as e:
|
|
print(f"失败: {trade_date}, 错误: {e}")
|
|
|
|
print("\n====== 汇总 ======")
|
|
print(f"总新增: {total_inserted}")
|
|
print(f"总跳过: {total_skipped}")
|
|
|
|
|
|
# ===== 输入参数 =====
|
|
if __name__ == "__main__":
|
|
START_DATE = "20240901"
|
|
END_DATE = "20260401"
|
|
|
|
run(START_DATE, END_DATE) |