add e621 training data scraper

This commit is contained in:
Johannes
2026-03-13 22:26:10 +01:00
parent 4c2972e7a2
commit 1d82c3dfec

133
get_training_data.py Normal file
View File

@@ -0,0 +1,133 @@
"""
get_training_data.py — scrape e621 into kohya LoRA training folder structure
usage:
python get_training_data.py --lora my_character --tags "character_name solo"
python get_training_data.py --lora my_character --tags "tag1 tag2" --limit 500 --min_score 20 --rating s
python get_training_data.py --lora my_character --tags "..." --rating e # nsfw
output goes to:
training_data/<lora>/img/<repeats>_<lora>/<files>
which is exactly what kohya expects.
"""
import os
import argparse
import requests
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed
# e621 requires a real user agent with your username
E621_USER_AGENT = "animepics-downloader/1.0 (by YOUR_E621_USERNAME)"
E621_API = "https://e621.net/posts.json"
ALLOWED_EXTS = {".jpg", ".jpeg", ".png", ".webp"}
def parse_args():
p = argparse.ArgumentParser(description="scrape e621 into kohya training data structure")
p.add_argument("--lora", required=True, help="lora name, used as folder name and trigger word")
p.add_argument("--tags", required=True, help="e621 tags to search (space separated, quote the whole thing)")
p.add_argument("--limit", type=int, default=1000, help="max images to download (default: 1000)")
p.add_argument("--repeats", type=int, default=10, help="kohya repeat count in folder name (default: 10)")
p.add_argument("--min_score", type=int, default=10, help="minimum post score filter (default: 10)")
p.add_argument("--rating", default=None,
choices=["s", "q", "e"],
help="filter by rating: s=safe, q=questionable, e=explicit (default: all)")
p.add_argument("--threads", type=int, default=8, help="parallel download threads (default: 8)")
p.add_argument("--output_dir", default="training_data", help="base output dir (default: training_data)")
return p.parse_args()
def fetch_all_posts(tags, limit, min_score, rating, user_agent):
headers = {"User-Agent": user_agent}
per_page = 320
all_posts = []
page = 1
# build tag string
tag_str = tags
if rating:
tag_str += f" rating:{rating}"
if min_score:
tag_str += f" score:>={min_score}"
print(f"searching e621 for: {tag_str!r}")
while len(all_posts) < limit:
params = {"tags": tag_str, "limit": per_page, "page": page}
resp = requests.get(E621_API, headers=headers, params=params)
resp.raise_for_status()
posts = resp.json().get("posts", [])
if not posts:
print(f" no more posts at page {page}")
break
all_posts.extend(posts)
print(f" page {page}: +{len(posts)} posts (total {len(all_posts)})")
page += 1
if len(posts) < per_page:
break
return all_posts[:limit]
def download_post(post, save_dir, headers):
file_info = post.get("file", {})
url = file_info.get("url")
if not url:
return f"skip: no url for post {post.get('id')}"
ext = os.path.splitext(url)[1].lower()
if ext not in ALLOWED_EXTS:
return f"skip: unsupported ext {ext} ({url})"
file_name = os.path.basename(url)
file_path = os.path.join(save_dir, file_name)
if os.path.exists(file_path):
return f"skip: already exists {file_name}"
try:
r = requests.get(url, headers=headers, timeout=30)
r.raise_for_status()
with open(file_path, "wb") as f:
f.write(r.content)
return f"ok: {file_name}"
except Exception as e:
return f"fail: {url}{e}"
def main():
args = parse_args()
# build kohya folder: training_data/<lora>/img/<repeats>_<lora>/
save_dir = os.path.join(args.output_dir, args.lora, "img", f"{args.repeats}_{args.lora}")
os.makedirs(save_dir, exist_ok=True)
print(f"saving to: {save_dir}")
headers = {"User-Agent": E621_USER_AGENT}
posts = fetch_all_posts(args.tags, args.limit, args.min_score, args.rating, E621_USER_AGENT)
print(f"\n{len(posts)} posts queued for download\n")
failed = []
with ThreadPoolExecutor(max_workers=args.threads) as executor:
futures = {executor.submit(download_post, p, save_dir, headers): p for p in posts}
for future in tqdm(as_completed(futures), total=len(futures), desc="downloading"):
result = future.result()
if result.startswith("fail"):
failed.append(result)
print(f"\ndone! {len(posts) - len(failed)} downloaded, {len(failed)} failed")
if failed:
print("failures:")
for f in failed:
print(f" {f}")
if __name__ == "__main__":
main()