Coverage for background_sync.py: 70%
168 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-21 23:06 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-21 23:06 +0000
1# Copyright 2026 venim1103
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
15import os
16import time
17import argparse
18import zipfile
19from huggingface_hub import HfApi
21CHECKPOINT_DIR = "checkpoints"
22DEFAULT_POLL_INTERVAL = 15
23DEFAULT_SLEEP_BEFORE_UPLOAD = 30
24DEFAULT_MIN_FILE_AGE = 60
25DEFAULT_STABILIZE_WINDOW = 10
26DEFAULT_VALIDATE_ZIP = True
29def _log(message: str):
30 """Emit timestamped log lines with immediate flush for notebook log files."""
31 ts = time.strftime("%Y-%m-%d %H:%M:%S")
32 print(f"[{ts}] {message}", flush=True)
35def iter_new_checkpoint_files(checkpoint_dir: str, uploaded_files: set[str]):
36 """Yield unseen .pt checkpoint file paths under checkpoint_dir."""
37 if not os.path.exists(checkpoint_dir): 37 ↛ 38line 37 didn't jump to line 38 because the condition on line 37 was never true
38 return
39 for root, _, files in os.walk(checkpoint_dir):
40 for file in files:
41 if not file.endswith(".pt"):
42 continue
43 filepath = os.path.join(root, file)
44 if filepath not in uploaded_files:
45 yield filepath
48def _is_old_enough(filepath: str, min_file_age_seconds: int) -> bool:
49 """Return True once file mtime is at least min_file_age_seconds in the past."""
50 if min_file_age_seconds <= 0:
51 return True
52 return (time.time() - os.path.getmtime(filepath)) >= min_file_age_seconds
55def _is_stable(filepath: str, stabilize_window: int, *, sleep_fn=time.sleep) -> bool:
56 """Return True when file size and mtime remain unchanged across stabilize_window."""
57 if stabilize_window <= 0:
58 st = os.stat(filepath)
59 return st.st_size > 0
60 first = os.stat(filepath)
61 sleep_fn(stabilize_window)
62 second = os.stat(filepath)
63 return (
64 first.st_size > 0
65 and first.st_size == second.st_size
66 and first.st_mtime == second.st_mtime
67 )
70def _looks_like_valid_torch_zip(filepath: str) -> bool:
71 """Basic structural check for torch checkpoint zip container integrity."""
72 try:
73 with zipfile.ZipFile(filepath, "r") as zf:
74 return len(zf.infolist()) > 0
75 except Exception:
76 return False
79def upload_checkpoint(
80 api: HfApi,
81 filepath: str,
82 repo_id: str,
83 uploaded_files: set[str],
84 *,
85 sleep_before_upload: int = DEFAULT_SLEEP_BEFORE_UPLOAD,
86 min_file_age_seconds: int = DEFAULT_MIN_FILE_AGE,
87 stabilize_window_seconds: int = DEFAULT_STABILIZE_WINDOW,
88 validate_zip: bool = DEFAULT_VALIDATE_ZIP,
89 sleep_fn=time.sleep,
90 logger=print,
91) -> bool:
92 """Upload a checkpoint file and record it in uploaded_files on success."""
93 sleep_fn(sleep_before_upload)
94 if not os.path.exists(filepath): 94 ↛ 95line 94 didn't jump to line 95 because the condition on line 94 was never true
95 logger(f"Deferred {os.path.basename(filepath)}: file missing")
96 return False
97 try:
98 if not _is_old_enough(filepath, min_file_age_seconds):
99 logger(f"Deferred {os.path.basename(filepath)}: file is too new")
100 return False
101 if not _is_stable(filepath, stabilize_window_seconds, sleep_fn=sleep_fn):
102 logger(f"Deferred {os.path.basename(filepath)}: file still changing")
103 return False
104 if validate_zip and not _looks_like_valid_torch_zip(filepath):
105 logger(f"Deferred {os.path.basename(filepath)}: checkpoint zip is not valid yet")
106 return False
107 except OSError as exc:
108 logger(f"Deferred {os.path.basename(filepath)}: OS error during preflight ({exc})")
109 return False
110 try:
111 path_in_repo = os.path.relpath(filepath, start=".")
112 api.upload_file(
113 path_or_fileobj=filepath,
114 path_in_repo=path_in_repo,
115 repo_id=repo_id,
116 repo_type="model",
117 )
118 uploaded_files.add(filepath)
119 logger(f"Backed up {os.path.basename(filepath)}")
120 return True
121 except Exception as exc:
122 logger(f"Upload failed {os.path.basename(filepath)}: {exc}")
123 return False
126def sync_once(
127 api: HfApi,
128 repo_id: str,
129 checkpoint_dir: str,
130 uploaded_files: set[str],
131 *,
132 sleep_before_upload: int = DEFAULT_SLEEP_BEFORE_UPLOAD,
133 min_file_age_seconds: int = DEFAULT_MIN_FILE_AGE,
134 stabilize_window_seconds: int = DEFAULT_STABILIZE_WINDOW,
135 validate_zip: bool = DEFAULT_VALIDATE_ZIP,
136 sleep_fn=time.sleep,
137 logger=print,
138) -> int:
139 """Perform one scan/upload cycle and return number of successful uploads."""
140 uploaded_count = 0
141 for filepath in iter_new_checkpoint_files(checkpoint_dir, uploaded_files):
142 if upload_checkpoint(
143 api,
144 filepath,
145 repo_id,
146 uploaded_files,
147 sleep_before_upload=sleep_before_upload,
148 min_file_age_seconds=min_file_age_seconds,
149 stabilize_window_seconds=stabilize_window_seconds,
150 validate_zip=validate_zip,
151 sleep_fn=sleep_fn,
152 logger=logger,
153 ):
154 uploaded_count += 1
155 return uploaded_count
158def run_sync_loop(
159 api: HfApi,
160 repo_id: str,
161 checkpoint_dir: str = CHECKPOINT_DIR,
162 *,
163 poll_interval: int = DEFAULT_POLL_INTERVAL,
164 sleep_before_upload: int = DEFAULT_SLEEP_BEFORE_UPLOAD,
165 min_file_age_seconds: int = DEFAULT_MIN_FILE_AGE,
166 stabilize_window_seconds: int = DEFAULT_STABILIZE_WINDOW,
167 validate_zip: bool = DEFAULT_VALIDATE_ZIP,
168 sleep_fn=time.sleep,
169 logger=print,
170):
171 """Run continuous sync loop."""
172 uploaded_files: set[str] = set()
173 logger(f"Watching {checkpoint_dir}/ for new .pt files...")
174 while True:
175 try:
176 uploaded_count = sync_once(
177 api,
178 repo_id,
179 checkpoint_dir,
180 uploaded_files,
181 sleep_before_upload=sleep_before_upload,
182 min_file_age_seconds=min_file_age_seconds,
183 stabilize_window_seconds=stabilize_window_seconds,
184 validate_zip=validate_zip,
185 sleep_fn=sleep_fn,
186 logger=logger,
187 )
188 logger(
189 f"Sync cycle complete: uploaded={uploaded_count}, tracked_total={len(uploaded_files)}"
190 )
191 except Exception as exc:
192 logger(f"Sync cycle failed: {exc}")
193 sleep_fn(poll_interval)
196def _read_int_env(name: str, default: int, *, logger=print) -> int:
197 """Read integer env var safely, logging and falling back on invalid values."""
198 raw = os.environ.get(name)
199 if raw is None:
200 return default
201 try:
202 value = int(raw)
203 if value < 0: 203 ↛ 204line 203 didn't jump to line 204 because the condition on line 203 was never true
204 raise ValueError("negative")
205 return value
206 except Exception:
207 logger(f"Invalid {name}={raw!r}; using default {default}")
208 return default
211def _read_bool_env(name: str, default: bool, *, logger=print) -> bool:
212 """Read boolean env var safely, supporting 1/0, true/false, yes/no."""
213 raw = os.environ.get(name)
214 if raw is None:
215 return default
216 normalized = raw.strip().lower()
217 if normalized in {"1", "true", "yes", "y", "on"}:
218 return True
219 if normalized in {"0", "false", "no", "n", "off"}: 219 ↛ 221line 219 didn't jump to line 221 because the condition on line 219 was always true
220 return False
221 logger(f"Invalid {name}={raw!r}; using default {default}")
222 return default
225def run_self_check(api: HfApi, repo_id: str, checkpoint_dir: str = CHECKPOINT_DIR, *, logger=print) -> bool:
226 """Run a one-shot diagnostics check for auth/repo access and checkpoint discovery."""
227 logger("Running sync self-check...")
228 try:
229 who = api.whoami()
230 user_name = who.get("name") if isinstance(who, dict) else str(who)
231 logger(f"HF auth OK (user={user_name})")
232 except Exception as exc:
233 logger(f"HF auth failed: {exc}")
234 return False
236 try:
237 api.create_repo(repo_id, repo_type="model", private=True)
238 logger(f"Repo access OK: {repo_id}")
239 except Exception as exc:
240 # If create fails (e.g. already exists), explicitly verify we can access it.
241 logger(f"Repo create/check returned: {exc}")
242 try:
243 api.repo_info(repo_id=repo_id, repo_type="model")
244 logger(f"Repo access confirmed via repo_info: {repo_id}")
245 except Exception as repo_exc:
246 logger(f"Repo access failed: {repo_exc}")
247 return False
249 if not os.path.exists(checkpoint_dir): 249 ↛ 253line 249 didn't jump to line 253 because the condition on line 249 was always true
250 logger(f"Checkpoint directory missing: {checkpoint_dir}")
251 return True
253 pt_files = []
254 for root, _, files in os.walk(checkpoint_dir):
255 for file in files:
256 if file.endswith(".pt"):
257 pt_files.append(os.path.join(root, file))
259 logger(f"Checkpoint scan OK: found {len(pt_files)} .pt file(s) under {checkpoint_dir}")
260 if pt_files:
261 logger(f"Newest candidate: {max(pt_files, key=os.path.getmtime)}")
262 return True
265def main(argv=None):
266 parser = argparse.ArgumentParser(description="Background checkpoint uploader")
267 parser.add_argument(
268 "--self-check",
269 action="store_true",
270 help="Run one-time diagnostics (auth/repo/checkpoint scan) and exit",
271 )
272 args = parser.parse_args(argv if argv is not None else [])
274 hf_token = os.environ.get("HF_TOKEN")
275 repo_id = os.environ.get("REPO_ID")
277 if not repo_id:
278 raise ValueError("REPO_ID environment variable is required")
280 _log("Starting background checkpoint sync")
281 _log(f"Python PID={os.getpid()}")
282 _log(f"HF token present={bool(hf_token)} | repo_id={repo_id}")
284 api = HfApi(token=hf_token)
286 if args.self_check:
287 ok = run_self_check(api, repo_id, CHECKPOINT_DIR, logger=_log)
288 _log(f"Self-check {'passed' if ok else 'failed'}")
289 return 0 if ok else 2
291 try:
292 api.create_repo(repo_id, repo_type="model", private=True)
293 _log("Ensured Hugging Face repo exists")
294 except Exception as exc:
295 # Repository may already exist; keep syncing in either case.
296 _log(f"create_repo skipped/failed (continuing): {exc}")
297 poll_interval = _read_int_env("SYNC_POLL_INTERVAL", DEFAULT_POLL_INTERVAL, logger=_log)
298 sleep_before_upload = _read_int_env(
299 "SYNC_SLEEP_BEFORE_UPLOAD", DEFAULT_SLEEP_BEFORE_UPLOAD, logger=_log
300 )
301 min_file_age_seconds = _read_int_env(
302 "SYNC_MIN_FILE_AGE", DEFAULT_MIN_FILE_AGE, logger=_log
303 )
304 stabilize_window_seconds = _read_int_env(
305 "SYNC_STABILIZE_WINDOW", DEFAULT_STABILIZE_WINDOW, logger=_log
306 )
307 validate_zip = _read_bool_env("SYNC_VALIDATE_ZIP", DEFAULT_VALIDATE_ZIP, logger=_log)
308 _log(
309 f"Sync loop config: poll_interval={poll_interval}s, "
310 f"sleep_before_upload={sleep_before_upload}s, "
311 f"min_file_age={min_file_age_seconds}s, "
312 f"stabilize_window={stabilize_window_seconds}s, "
313 f"validate_zip={validate_zip}"
314 )
315 run_sync_loop(
316 api,
317 repo_id,
318 CHECKPOINT_DIR,
319 poll_interval=poll_interval,
320 sleep_before_upload=sleep_before_upload,
321 min_file_age_seconds=min_file_age_seconds,
322 stabilize_window_seconds=stabilize_window_seconds,
323 validate_zip=validate_zip,
324 logger=_log,
325 )
328if __name__ == "__main__":
329 import sys
331 raise SystemExit(main(sys.argv[1:]))