add e621 training data scraper
This commit is contained in:
133
get_training_data.py
Normal file
133
get_training_data.py
Normal file
@@ -0,0 +1,133 @@
|
|||||||
|
"""
|
||||||
|
get_training_data.py — scrape e621 into kohya LoRA training folder structure
|
||||||
|
|
||||||
|
usage:
|
||||||
|
python get_training_data.py --lora my_character --tags "character_name solo"
|
||||||
|
python get_training_data.py --lora my_character --tags "tag1 tag2" --limit 500 --min_score 20 --rating s
|
||||||
|
python get_training_data.py --lora my_character --tags "..." --rating e # nsfw
|
||||||
|
|
||||||
|
output goes to:
|
||||||
|
training_data/<lora>/img/<repeats>_<lora>/<files>
|
||||||
|
|
||||||
|
which is exactly what kohya expects.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
import requests
|
||||||
|
from tqdm import tqdm
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
|
||||||
|
# e621 requires a real user agent with your username
|
||||||
|
E621_USER_AGENT = "animepics-downloader/1.0 (by YOUR_E621_USERNAME)"
|
||||||
|
E621_API = "https://e621.net/posts.json"
|
||||||
|
|
||||||
|
ALLOWED_EXTS = {".jpg", ".jpeg", ".png", ".webp"}
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
p = argparse.ArgumentParser(description="scrape e621 into kohya training data structure")
|
||||||
|
p.add_argument("--lora", required=True, help="lora name, used as folder name and trigger word")
|
||||||
|
p.add_argument("--tags", required=True, help="e621 tags to search (space separated, quote the whole thing)")
|
||||||
|
p.add_argument("--limit", type=int, default=1000, help="max images to download (default: 1000)")
|
||||||
|
p.add_argument("--repeats", type=int, default=10, help="kohya repeat count in folder name (default: 10)")
|
||||||
|
p.add_argument("--min_score", type=int, default=10, help="minimum post score filter (default: 10)")
|
||||||
|
p.add_argument("--rating", default=None,
|
||||||
|
choices=["s", "q", "e"],
|
||||||
|
help="filter by rating: s=safe, q=questionable, e=explicit (default: all)")
|
||||||
|
p.add_argument("--threads", type=int, default=8, help="parallel download threads (default: 8)")
|
||||||
|
p.add_argument("--output_dir", default="training_data", help="base output dir (default: training_data)")
|
||||||
|
return p.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_all_posts(tags, limit, min_score, rating, user_agent):
|
||||||
|
headers = {"User-Agent": user_agent}
|
||||||
|
per_page = 320
|
||||||
|
all_posts = []
|
||||||
|
page = 1
|
||||||
|
|
||||||
|
# build tag string
|
||||||
|
tag_str = tags
|
||||||
|
if rating:
|
||||||
|
tag_str += f" rating:{rating}"
|
||||||
|
if min_score:
|
||||||
|
tag_str += f" score:>={min_score}"
|
||||||
|
|
||||||
|
print(f"searching e621 for: {tag_str!r}")
|
||||||
|
|
||||||
|
while len(all_posts) < limit:
|
||||||
|
params = {"tags": tag_str, "limit": per_page, "page": page}
|
||||||
|
resp = requests.get(E621_API, headers=headers, params=params)
|
||||||
|
resp.raise_for_status()
|
||||||
|
posts = resp.json().get("posts", [])
|
||||||
|
|
||||||
|
if not posts:
|
||||||
|
print(f" no more posts at page {page}")
|
||||||
|
break
|
||||||
|
|
||||||
|
all_posts.extend(posts)
|
||||||
|
print(f" page {page}: +{len(posts)} posts (total {len(all_posts)})")
|
||||||
|
page += 1
|
||||||
|
|
||||||
|
if len(posts) < per_page:
|
||||||
|
break
|
||||||
|
|
||||||
|
return all_posts[:limit]
|
||||||
|
|
||||||
|
|
||||||
|
def download_post(post, save_dir, headers):
|
||||||
|
file_info = post.get("file", {})
|
||||||
|
url = file_info.get("url")
|
||||||
|
if not url:
|
||||||
|
return f"skip: no url for post {post.get('id')}"
|
||||||
|
|
||||||
|
ext = os.path.splitext(url)[1].lower()
|
||||||
|
if ext not in ALLOWED_EXTS:
|
||||||
|
return f"skip: unsupported ext {ext} ({url})"
|
||||||
|
|
||||||
|
file_name = os.path.basename(url)
|
||||||
|
file_path = os.path.join(save_dir, file_name)
|
||||||
|
|
||||||
|
if os.path.exists(file_path):
|
||||||
|
return f"skip: already exists {file_name}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
r = requests.get(url, headers=headers, timeout=30)
|
||||||
|
r.raise_for_status()
|
||||||
|
with open(file_path, "wb") as f:
|
||||||
|
f.write(r.content)
|
||||||
|
return f"ok: {file_name}"
|
||||||
|
except Exception as e:
|
||||||
|
return f"fail: {url} — {e}"
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
# build kohya folder: training_data/<lora>/img/<repeats>_<lora>/
|
||||||
|
save_dir = os.path.join(args.output_dir, args.lora, "img", f"{args.repeats}_{args.lora}")
|
||||||
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
|
print(f"saving to: {save_dir}")
|
||||||
|
|
||||||
|
headers = {"User-Agent": E621_USER_AGENT}
|
||||||
|
|
||||||
|
posts = fetch_all_posts(args.tags, args.limit, args.min_score, args.rating, E621_USER_AGENT)
|
||||||
|
print(f"\n{len(posts)} posts queued for download\n")
|
||||||
|
|
||||||
|
failed = []
|
||||||
|
with ThreadPoolExecutor(max_workers=args.threads) as executor:
|
||||||
|
futures = {executor.submit(download_post, p, save_dir, headers): p for p in posts}
|
||||||
|
for future in tqdm(as_completed(futures), total=len(futures), desc="downloading"):
|
||||||
|
result = future.result()
|
||||||
|
if result.startswith("fail"):
|
||||||
|
failed.append(result)
|
||||||
|
|
||||||
|
print(f"\ndone! {len(posts) - len(failed)} downloaded, {len(failed)} failed")
|
||||||
|
if failed:
|
||||||
|
print("failures:")
|
||||||
|
for f in failed:
|
||||||
|
print(f" {f}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Reference in New Issue
Block a user