Coverage for background_sync.py: 70%

168 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-21 23:06 +0000

1# Copyright 2026 venim1103 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15import os 

16import time 

17import argparse 

18import zipfile 

19from huggingface_hub import HfApi 

20 

21CHECKPOINT_DIR = "checkpoints" 

22DEFAULT_POLL_INTERVAL = 15 

23DEFAULT_SLEEP_BEFORE_UPLOAD = 30 

24DEFAULT_MIN_FILE_AGE = 60 

25DEFAULT_STABILIZE_WINDOW = 10 

26DEFAULT_VALIDATE_ZIP = True 

27 

28 

29def _log(message: str): 

30 """Emit timestamped log lines with immediate flush for notebook log files.""" 

31 ts = time.strftime("%Y-%m-%d %H:%M:%S") 

32 print(f"[{ts}] {message}", flush=True) 

33 

34 

35def iter_new_checkpoint_files(checkpoint_dir: str, uploaded_files: set[str]): 

36 """Yield unseen .pt checkpoint file paths under checkpoint_dir.""" 

37 if not os.path.exists(checkpoint_dir): 37 ↛ 38line 37 didn't jump to line 38 because the condition on line 37 was never true

38 return 

39 for root, _, files in os.walk(checkpoint_dir): 

40 for file in files: 

41 if not file.endswith(".pt"): 

42 continue 

43 filepath = os.path.join(root, file) 

44 if filepath not in uploaded_files: 

45 yield filepath 

46 

47 

48def _is_old_enough(filepath: str, min_file_age_seconds: int) -> bool: 

49 """Return True once file mtime is at least min_file_age_seconds in the past.""" 

50 if min_file_age_seconds <= 0: 

51 return True 

52 return (time.time() - os.path.getmtime(filepath)) >= min_file_age_seconds 

53 

54 

55def _is_stable(filepath: str, stabilize_window: int, *, sleep_fn=time.sleep) -> bool: 

56 """Return True when file size and mtime remain unchanged across stabilize_window.""" 

57 if stabilize_window <= 0: 

58 st = os.stat(filepath) 

59 return st.st_size > 0 

60 first = os.stat(filepath) 

61 sleep_fn(stabilize_window) 

62 second = os.stat(filepath) 

63 return ( 

64 first.st_size > 0 

65 and first.st_size == second.st_size 

66 and first.st_mtime == second.st_mtime 

67 ) 

68 

69 

70def _looks_like_valid_torch_zip(filepath: str) -> bool: 

71 """Basic structural check for torch checkpoint zip container integrity.""" 

72 try: 

73 with zipfile.ZipFile(filepath, "r") as zf: 

74 return len(zf.infolist()) > 0 

75 except Exception: 

76 return False 

77 

78 

79def upload_checkpoint( 

80 api: HfApi, 

81 filepath: str, 

82 repo_id: str, 

83 uploaded_files: set[str], 

84 *, 

85 sleep_before_upload: int = DEFAULT_SLEEP_BEFORE_UPLOAD, 

86 min_file_age_seconds: int = DEFAULT_MIN_FILE_AGE, 

87 stabilize_window_seconds: int = DEFAULT_STABILIZE_WINDOW, 

88 validate_zip: bool = DEFAULT_VALIDATE_ZIP, 

89 sleep_fn=time.sleep, 

90 logger=print, 

91) -> bool: 

92 """Upload a checkpoint file and record it in uploaded_files on success.""" 

93 sleep_fn(sleep_before_upload) 

94 if not os.path.exists(filepath): 94 ↛ 95line 94 didn't jump to line 95 because the condition on line 94 was never true

95 logger(f"Deferred {os.path.basename(filepath)}: file missing") 

96 return False 

97 try: 

98 if not _is_old_enough(filepath, min_file_age_seconds): 

99 logger(f"Deferred {os.path.basename(filepath)}: file is too new") 

100 return False 

101 if not _is_stable(filepath, stabilize_window_seconds, sleep_fn=sleep_fn): 

102 logger(f"Deferred {os.path.basename(filepath)}: file still changing") 

103 return False 

104 if validate_zip and not _looks_like_valid_torch_zip(filepath): 

105 logger(f"Deferred {os.path.basename(filepath)}: checkpoint zip is not valid yet") 

106 return False 

107 except OSError as exc: 

108 logger(f"Deferred {os.path.basename(filepath)}: OS error during preflight ({exc})") 

109 return False 

110 try: 

111 path_in_repo = os.path.relpath(filepath, start=".") 

112 api.upload_file( 

113 path_or_fileobj=filepath, 

114 path_in_repo=path_in_repo, 

115 repo_id=repo_id, 

116 repo_type="model", 

117 ) 

118 uploaded_files.add(filepath) 

119 logger(f"Backed up {os.path.basename(filepath)}") 

120 return True 

121 except Exception as exc: 

122 logger(f"Upload failed {os.path.basename(filepath)}: {exc}") 

123 return False 

124 

125 

126def sync_once( 

127 api: HfApi, 

128 repo_id: str, 

129 checkpoint_dir: str, 

130 uploaded_files: set[str], 

131 *, 

132 sleep_before_upload: int = DEFAULT_SLEEP_BEFORE_UPLOAD, 

133 min_file_age_seconds: int = DEFAULT_MIN_FILE_AGE, 

134 stabilize_window_seconds: int = DEFAULT_STABILIZE_WINDOW, 

135 validate_zip: bool = DEFAULT_VALIDATE_ZIP, 

136 sleep_fn=time.sleep, 

137 logger=print, 

138) -> int: 

139 """Perform one scan/upload cycle and return number of successful uploads.""" 

140 uploaded_count = 0 

141 for filepath in iter_new_checkpoint_files(checkpoint_dir, uploaded_files): 

142 if upload_checkpoint( 

143 api, 

144 filepath, 

145 repo_id, 

146 uploaded_files, 

147 sleep_before_upload=sleep_before_upload, 

148 min_file_age_seconds=min_file_age_seconds, 

149 stabilize_window_seconds=stabilize_window_seconds, 

150 validate_zip=validate_zip, 

151 sleep_fn=sleep_fn, 

152 logger=logger, 

153 ): 

154 uploaded_count += 1 

155 return uploaded_count 

156 

157 

158def run_sync_loop( 

159 api: HfApi, 

160 repo_id: str, 

161 checkpoint_dir: str = CHECKPOINT_DIR, 

162 *, 

163 poll_interval: int = DEFAULT_POLL_INTERVAL, 

164 sleep_before_upload: int = DEFAULT_SLEEP_BEFORE_UPLOAD, 

165 min_file_age_seconds: int = DEFAULT_MIN_FILE_AGE, 

166 stabilize_window_seconds: int = DEFAULT_STABILIZE_WINDOW, 

167 validate_zip: bool = DEFAULT_VALIDATE_ZIP, 

168 sleep_fn=time.sleep, 

169 logger=print, 

170): 

171 """Run continuous sync loop.""" 

172 uploaded_files: set[str] = set() 

173 logger(f"Watching {checkpoint_dir}/ for new .pt files...") 

174 while True: 

175 try: 

176 uploaded_count = sync_once( 

177 api, 

178 repo_id, 

179 checkpoint_dir, 

180 uploaded_files, 

181 sleep_before_upload=sleep_before_upload, 

182 min_file_age_seconds=min_file_age_seconds, 

183 stabilize_window_seconds=stabilize_window_seconds, 

184 validate_zip=validate_zip, 

185 sleep_fn=sleep_fn, 

186 logger=logger, 

187 ) 

188 logger( 

189 f"Sync cycle complete: uploaded={uploaded_count}, tracked_total={len(uploaded_files)}" 

190 ) 

191 except Exception as exc: 

192 logger(f"Sync cycle failed: {exc}") 

193 sleep_fn(poll_interval) 

194 

195 

196def _read_int_env(name: str, default: int, *, logger=print) -> int: 

197 """Read integer env var safely, logging and falling back on invalid values.""" 

198 raw = os.environ.get(name) 

199 if raw is None: 

200 return default 

201 try: 

202 value = int(raw) 

203 if value < 0: 203 ↛ 204line 203 didn't jump to line 204 because the condition on line 203 was never true

204 raise ValueError("negative") 

205 return value 

206 except Exception: 

207 logger(f"Invalid {name}={raw!r}; using default {default}") 

208 return default 

209 

210 

211def _read_bool_env(name: str, default: bool, *, logger=print) -> bool: 

212 """Read boolean env var safely, supporting 1/0, true/false, yes/no.""" 

213 raw = os.environ.get(name) 

214 if raw is None: 

215 return default 

216 normalized = raw.strip().lower() 

217 if normalized in {"1", "true", "yes", "y", "on"}: 

218 return True 

219 if normalized in {"0", "false", "no", "n", "off"}: 219 ↛ 221line 219 didn't jump to line 221 because the condition on line 219 was always true

220 return False 

221 logger(f"Invalid {name}={raw!r}; using default {default}") 

222 return default 

223 

224 

225def run_self_check(api: HfApi, repo_id: str, checkpoint_dir: str = CHECKPOINT_DIR, *, logger=print) -> bool: 

226 """Run a one-shot diagnostics check for auth/repo access and checkpoint discovery.""" 

227 logger("Running sync self-check...") 

228 try: 

229 who = api.whoami() 

230 user_name = who.get("name") if isinstance(who, dict) else str(who) 

231 logger(f"HF auth OK (user={user_name})") 

232 except Exception as exc: 

233 logger(f"HF auth failed: {exc}") 

234 return False 

235 

236 try: 

237 api.create_repo(repo_id, repo_type="model", private=True) 

238 logger(f"Repo access OK: {repo_id}") 

239 except Exception as exc: 

240 # If create fails (e.g. already exists), explicitly verify we can access it. 

241 logger(f"Repo create/check returned: {exc}") 

242 try: 

243 api.repo_info(repo_id=repo_id, repo_type="model") 

244 logger(f"Repo access confirmed via repo_info: {repo_id}") 

245 except Exception as repo_exc: 

246 logger(f"Repo access failed: {repo_exc}") 

247 return False 

248 

249 if not os.path.exists(checkpoint_dir): 249 ↛ 253line 249 didn't jump to line 253 because the condition on line 249 was always true

250 logger(f"Checkpoint directory missing: {checkpoint_dir}") 

251 return True 

252 

253 pt_files = [] 

254 for root, _, files in os.walk(checkpoint_dir): 

255 for file in files: 

256 if file.endswith(".pt"): 

257 pt_files.append(os.path.join(root, file)) 

258 

259 logger(f"Checkpoint scan OK: found {len(pt_files)} .pt file(s) under {checkpoint_dir}") 

260 if pt_files: 

261 logger(f"Newest candidate: {max(pt_files, key=os.path.getmtime)}") 

262 return True 

263 

264 

265def main(argv=None): 

266 parser = argparse.ArgumentParser(description="Background checkpoint uploader") 

267 parser.add_argument( 

268 "--self-check", 

269 action="store_true", 

270 help="Run one-time diagnostics (auth/repo/checkpoint scan) and exit", 

271 ) 

272 args = parser.parse_args(argv if argv is not None else []) 

273 

274 hf_token = os.environ.get("HF_TOKEN") 

275 repo_id = os.environ.get("REPO_ID") 

276 

277 if not repo_id: 

278 raise ValueError("REPO_ID environment variable is required") 

279 

280 _log("Starting background checkpoint sync") 

281 _log(f"Python PID={os.getpid()}") 

282 _log(f"HF token present={bool(hf_token)} | repo_id={repo_id}") 

283 

284 api = HfApi(token=hf_token) 

285 

286 if args.self_check: 

287 ok = run_self_check(api, repo_id, CHECKPOINT_DIR, logger=_log) 

288 _log(f"Self-check {'passed' if ok else 'failed'}") 

289 return 0 if ok else 2 

290 

291 try: 

292 api.create_repo(repo_id, repo_type="model", private=True) 

293 _log("Ensured Hugging Face repo exists") 

294 except Exception as exc: 

295 # Repository may already exist; keep syncing in either case. 

296 _log(f"create_repo skipped/failed (continuing): {exc}") 

297 poll_interval = _read_int_env("SYNC_POLL_INTERVAL", DEFAULT_POLL_INTERVAL, logger=_log) 

298 sleep_before_upload = _read_int_env( 

299 "SYNC_SLEEP_BEFORE_UPLOAD", DEFAULT_SLEEP_BEFORE_UPLOAD, logger=_log 

300 ) 

301 min_file_age_seconds = _read_int_env( 

302 "SYNC_MIN_FILE_AGE", DEFAULT_MIN_FILE_AGE, logger=_log 

303 ) 

304 stabilize_window_seconds = _read_int_env( 

305 "SYNC_STABILIZE_WINDOW", DEFAULT_STABILIZE_WINDOW, logger=_log 

306 ) 

307 validate_zip = _read_bool_env("SYNC_VALIDATE_ZIP", DEFAULT_VALIDATE_ZIP, logger=_log) 

308 _log( 

309 f"Sync loop config: poll_interval={poll_interval}s, " 

310 f"sleep_before_upload={sleep_before_upload}s, " 

311 f"min_file_age={min_file_age_seconds}s, " 

312 f"stabilize_window={stabilize_window_seconds}s, " 

313 f"validate_zip={validate_zip}" 

314 ) 

315 run_sync_loop( 

316 api, 

317 repo_id, 

318 CHECKPOINT_DIR, 

319 poll_interval=poll_interval, 

320 sleep_before_upload=sleep_before_upload, 

321 min_file_age_seconds=min_file_age_seconds, 

322 stabilize_window_seconds=stabilize_window_seconds, 

323 validate_zip=validate_zip, 

324 logger=_log, 

325 ) 

326 

327 

328if __name__ == "__main__": 

329 import sys 

330 

331 raise SystemExit(main(sys.argv[1:]))