From 1d82c3dfec6947a485e81c7f7ffc0cf097762be1 Mon Sep 17 00:00:00 2001 From: Johannes Date: Fri, 13 Mar 2026 22:26:10 +0100 Subject: [PATCH] add e621 training data scraper --- get_training_data.py | 133 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 133 insertions(+) create mode 100644 get_training_data.py diff --git a/get_training_data.py b/get_training_data.py new file mode 100644 index 0000000..a63f642 --- /dev/null +++ b/get_training_data.py @@ -0,0 +1,133 @@ +""" +get_training_data.py — scrape e621 into kohya LoRA training folder structure + +usage: + python get_training_data.py --lora my_character --tags "character_name solo" + python get_training_data.py --lora my_character --tags "tag1 tag2" --limit 500 --min_score 20 --rating s + python get_training_data.py --lora my_character --tags "..." --rating e # nsfw + +output goes to: + training_data//img/_/ + +which is exactly what kohya expects. +""" + +import os +import argparse +import requests +from tqdm import tqdm +from concurrent.futures import ThreadPoolExecutor, as_completed + +# e621 requires a real user agent with your username +E621_USER_AGENT = "animepics-downloader/1.0 (by YOUR_E621_USERNAME)" +E621_API = "https://e621.net/posts.json" + +ALLOWED_EXTS = {".jpg", ".jpeg", ".png", ".webp"} + + +def parse_args(): + p = argparse.ArgumentParser(description="scrape e621 into kohya training data structure") + p.add_argument("--lora", required=True, help="lora name, used as folder name and trigger word") + p.add_argument("--tags", required=True, help="e621 tags to search (space separated, quote the whole thing)") + p.add_argument("--limit", type=int, default=1000, help="max images to download (default: 1000)") + p.add_argument("--repeats", type=int, default=10, help="kohya repeat count in folder name (default: 10)") + p.add_argument("--min_score", type=int, default=10, help="minimum post score filter (default: 10)") + p.add_argument("--rating", default=None, + choices=["s", "q", "e"], + help="filter by rating: s=safe, q=questionable, e=explicit (default: all)") + p.add_argument("--threads", type=int, default=8, help="parallel download threads (default: 8)") + p.add_argument("--output_dir", default="training_data", help="base output dir (default: training_data)") + return p.parse_args() + + +def fetch_all_posts(tags, limit, min_score, rating, user_agent): + headers = {"User-Agent": user_agent} + per_page = 320 + all_posts = [] + page = 1 + + # build tag string + tag_str = tags + if rating: + tag_str += f" rating:{rating}" + if min_score: + tag_str += f" score:>={min_score}" + + print(f"searching e621 for: {tag_str!r}") + + while len(all_posts) < limit: + params = {"tags": tag_str, "limit": per_page, "page": page} + resp = requests.get(E621_API, headers=headers, params=params) + resp.raise_for_status() + posts = resp.json().get("posts", []) + + if not posts: + print(f" no more posts at page {page}") + break + + all_posts.extend(posts) + print(f" page {page}: +{len(posts)} posts (total {len(all_posts)})") + page += 1 + + if len(posts) < per_page: + break + + return all_posts[:limit] + + +def download_post(post, save_dir, headers): + file_info = post.get("file", {}) + url = file_info.get("url") + if not url: + return f"skip: no url for post {post.get('id')}" + + ext = os.path.splitext(url)[1].lower() + if ext not in ALLOWED_EXTS: + return f"skip: unsupported ext {ext} ({url})" + + file_name = os.path.basename(url) + file_path = os.path.join(save_dir, file_name) + + if os.path.exists(file_path): + return f"skip: already exists {file_name}" + + try: + r = requests.get(url, headers=headers, timeout=30) + r.raise_for_status() + with open(file_path, "wb") as f: + f.write(r.content) + return f"ok: {file_name}" + except Exception as e: + return f"fail: {url} — {e}" + + +def main(): + args = parse_args() + + # build kohya folder: training_data//img/_/ + save_dir = os.path.join(args.output_dir, args.lora, "img", f"{args.repeats}_{args.lora}") + os.makedirs(save_dir, exist_ok=True) + print(f"saving to: {save_dir}") + + headers = {"User-Agent": E621_USER_AGENT} + + posts = fetch_all_posts(args.tags, args.limit, args.min_score, args.rating, E621_USER_AGENT) + print(f"\n{len(posts)} posts queued for download\n") + + failed = [] + with ThreadPoolExecutor(max_workers=args.threads) as executor: + futures = {executor.submit(download_post, p, save_dir, headers): p for p in posts} + for future in tqdm(as_completed(futures), total=len(futures), desc="downloading"): + result = future.result() + if result.startswith("fail"): + failed.append(result) + + print(f"\ndone! {len(posts) - len(failed)} downloaded, {len(failed)} failed") + if failed: + print("failures:") + for f in failed: + print(f" {f}") + + +if __name__ == "__main__": + main()