diff --git a/ioutil.c b/ioutil.c index 4e6aa55..cc430df 100644 --- a/ioutil.c +++ b/ioutil.c @@ -1,7 +1,10 @@ #include #include +#include #include "types.h" #include "ioutil.h" +#include "vec.h" +#include #ifndef _WIN32 @@ -31,7 +34,7 @@ FH createfile(const char *path,int secret) int fd; do { fd = open(path,O_WRONLY | O_CREAT | O_TRUNC,secret ? 0600 : 0666); - if (fd == -1) { + if (fd < 0) { if (errno == EINTR) continue; return -1; @@ -45,7 +48,7 @@ int closefile(FH fd) int cret; do { cret = close(fd); - if (cret == -1) { + if (cret < 0) { if (errno == EINTR) continue; return -1; @@ -59,6 +62,122 @@ int createdir(const char *path,int secret) return mkdir(path,secret ? 0700 : 0777); } +static int syncwritefile(const char *filename,const char *tmpname,int secret,const u8 *data,size_t datalen) +{ + FH f = createfile(tmpname,secret); + if (f == FH_invalid) + return -1; + + if (writeall(f,data,datalen) < 0) { + goto failclose; + } + + int sret; + do { + sret = fsync(f); + if (sret < 0) { + if (errno == EINTR) + continue; + + goto failclose; + } + } while (0); + + if (closefile(f) < 0) { + goto failrm; + } + + if (rename(tmpname,filename) < 0) { + goto failrm; + } + + return 0; + +failclose: + (void) closefile(f); +failrm: + remove(tmpname); + + return -1; +} + +int syncwrite(const char *filename,int secret,const u8 *data,size_t datalen) +{ + //fprintf(stderr,"filename = %s\n",filename); + + size_t fnlen = strlen(filename); + + VEC_STRUCT(,char) tmpnamebuf; + VEC_INIT(tmpnamebuf); + VEC_ADDN(tmpnamebuf,fnlen + 4 /* ".tmp" */ + 1 /* "\0" */); + memcpy(&VEC_BUF(tmpnamebuf,0),filename,fnlen); + strcpy(&VEC_BUF(tmpnamebuf,fnlen),".tmp"); + const char *tmpname = &VEC_BUF(tmpnamebuf,0); + + //fprintf(stderr,"tmpname = %s\n",tmpname); + + int r = syncwritefile(filename,tmpname,secret,data,datalen); + + VEC_FREE(tmpnamebuf); + + if (r < 0) + return r; + + VEC_STRUCT(,char) dirnamebuf; + VEC_INIT(dirnamebuf); + const char *dirname; + + for (ssize_t x = ((ssize_t)fnlen) - 1;x >= 0;--x) { + if (filename[x] == '/') { + if (x) + --x; + ++x; + VEC_ADDN(dirnamebuf,x + 1); + memcpy(&VEC_BUF(dirnamebuf,0),filename,x); + VEC_BUF(dirnamebuf,x) = '\0'; + dirname = &VEC_BUF(dirnamebuf,0); + goto foundslash; + } + } + /* not found slash, fall back to "." */ + dirname = "."; + +foundslash: + //fprintf(stderr,"dirname = %s\n",dirname); + ; + + int dirf; + do { + dirf = open(dirname,O_RDONLY); + if (dirf < 0) { + if (errno == EINTR) + continue; + + // failed for non-eintr reasons + goto skipdsync; // don't really care enough + } + } while (0); + + int sret; + do { + sret = fsync(dirf); + if (sret < 0) { + if (errno == EINTR) + continue; + + // failed for non-eintr reasons + break; // don't care + } + } while (0); + + (void) closefile(dirf); // don't care + +skipdsync: + VEC_FREE(dirnamebuf); + + return 0; +} + #else int writeall(FH fd,const u8 *data,size_t len) @@ -99,6 +218,61 @@ int createdir(const char *path,int secret) return CreateDirectoryA(path,0) ? 0 : -1; } +static int syncwritefile(const char *filename,const char *tmpname,int secret,const char *data,size_t datalen) +{ + FH f = createfile(tmpnamestr,secret) + if (f == FH_invalid) + return -1; + + if (writeall(f,data,datalen) < 0) { + goto failclose; + } + + if (FlushFileBuffers(f) == 0) { + goto failclose; + } + + if (closefile(f) < 0) { + goto failrm; + } + + if (MoveFileA(tmpnamestr,filename) == 0) { + goto failrm; + } + + return 0; + +failclose: + (void) closefile(f); +failrm: + remove(tmpnamestr); + + return -1; +} + +int syncwrite(const char *filename,int secret,const char *data,size_t datalen) +{ + size_t fnlen = strlen(filename); + + VEC_STRUCT(,char) tmpnamebuf; + VEC_INIT(tmpnamebuf); + VEC_ADDN(tmpnamebuf,fnlen + 4 /* ".tmp" */ + 1 /* "\0" */); + memcpy(&VEC_BUF(tmpnamebuf,0),filename,fnlen); + strcpy(&VEC_BUF(tmpnamebuf,fnlen),".tmp"); + const char *tmpname = &VEC_BUF(tmpnamebuf,0); + + int r = syncwritefile(filename,tmpname,secret,data,datalen); + + VEC_FREE(tmpnamebuf); + + if (r < 0) + return r; + + // can't fsync parent dir on windows so just end here + + return 0; +} + #endif int writetofile(const char *path,const u8 *data,size_t len,int secret) diff --git a/ioutil.h b/ioutil.h index c7a1dab..5244508 100644 --- a/ioutil.h +++ b/ioutil.h @@ -18,3 +18,4 @@ int closefile(FH fd); int writeall(FH,const u8 *data,size_t len); int writetofile(const char *path,const u8 *data,size_t len,int secret); int createdir(const char *path,int secret); +int syncwrite(const char *filename,int secret,const u8 *data,size_t datalen); diff --git a/main.c b/main.c index cec31bb..12bd1f3 100644 --- a/main.c +++ b/main.c @@ -29,6 +29,8 @@ #include "worker.h" +#include "likely.h" + #ifndef _WIN32 #define FSZ "%zu" #else @@ -58,6 +60,11 @@ size_t printlen; // precalculated, related to printstartpos pthread_mutex_t fout_mutex; FILE *fout; +#ifdef PASSPHRASE +u8 orig_determseed[SEED_LEN]; +const char *checkpointfile = 0; +#endif + static void termhandler(int sig) { switch (sig) { @@ -112,6 +119,7 @@ static void printhelp(FILE *out,const char *progname) #ifdef PASSPHRASE "\t-p passphrase - use passphrase to initialize the random seed with\n" "\t-P - same as -p, but takes passphrase from PASSPHRASE environment variable\n" + "\t--checkpoint filename - load/save checkpoint of progress to specified file (requires passphrase)\n" #endif ,progname,progname); fflush(out); @@ -173,9 +181,58 @@ static void setpassphrase(const char *pass) } fprintf(stderr," done.\n"); } + +static void savecheckpoint(void) +{ + u8 checkpoint[SEED_LEN]; + bool carry = 0; + pthread_mutex_lock(&determseed_mutex); + for (int i = 0; i < SEED_LEN; i++) { + checkpoint[i] = determseed[i] - orig_determseed[i] - carry; + carry = checkpoint[i] > determseed[i]; + } + pthread_mutex_unlock(&determseed_mutex); + + if (syncwrite(checkpointfile,1,checkpoint,SEED_LEN) < 0) { + pthread_mutex_lock(&fout_mutex); + fprintf(stderr,"ERROR: could not save checkpoint\n"); + pthread_mutex_unlock(&fout_mutex); + } +} + +static volatile int checkpointer_endwork = 0; + +static void *checkpointworker(void *arg) +{ + (void) arg; + + struct timespec ts; + memset(&ts,0,sizeof(ts)); + ts.tv_nsec = 100000000; + + struct timespec nowtime; + u64 ilasttime,inowtime; + clock_gettime(CLOCK_MONOTONIC,&nowtime); + ilasttime = (1000000 * (u64)nowtime.tv_sec) + ((u64)nowtime.tv_nsec / 1000); + + while (!unlikely(checkpointer_endwork)) { + + clock_gettime(CLOCK_MONOTONIC,&nowtime); + inowtime = (1000000 * (u64)nowtime.tv_sec) + ((u64)nowtime.tv_nsec / 1000); + + if ((i64)(inowtime - ilasttime) >= 300 * 1000000 /* 5 minutes */) { + savecheckpoint(); + ilasttime = inowtime; + } + } + + savecheckpoint(); + + return 0; +} #endif -VEC_STRUCT(threadvec, pthread_t); +VEC_STRUCT(threadvec,pthread_t); #include "filters_inc.inc.h" #include "filters_main.inc.h" @@ -248,6 +305,14 @@ int main(int argc,char **argv) } else if (!strcmp(arg,"rawyaml")) yamlraw = 1; +#ifdef PASSPHRASE + else if (!strcmp(arg,"checkpoint")) { + if (argc--) + checkpointfile = *argv++; + else + e_additional(); + } +#endif // PASSPHRASE else { fprintf(stderr,"unrecognised argument: --%s\n",arg); exit(1); @@ -415,6 +480,11 @@ int main(int argc,char **argv) exit(1); } + if (checkpointfile && !deterministic) { + fprintf(stderr,"--checkpoint requires passphrase\n"); + exit(1); + } + if (outfile) { fout = fopen(outfile,!outfileoverwrite ? "a" : "w"); if (!fout) { @@ -500,8 +570,27 @@ int main(int argc,char **argv) numthreads,numthreads == 1 ? "thread" : "threads"); #ifdef PASSPHRASE - if (!quietflag && deterministic && numneedgenerate != 1) - fprintf(stderr,"CAUTION: avoid using keys generated with same password for unrelated services, as single leaked key may help attacker to regenerate related keys.\n"); + if (deterministic) { + if (!quietflag && numneedgenerate != 1) + fprintf(stderr,"CAUTION: avoid using keys generated with same password for unrelated services, as single leaked key may help attacker to regenerate related keys.\n"); + if (checkpointfile) { + memcpy(orig_determseed,determseed,sizeof(determseed)); + // Read current checkpoint position if file exists + FILE *checkout = fopen(checkpointfile,"r"); + if (checkout) { + u8 checkpoint[SEED_LEN]; + if(fread(checkpoint,1,SEED_LEN,checkout) != SEED_LEN) { + fprintf(stderr,"failed to read checkpoint file\n"); + exit(1); + } + fclose(checkout); + + // Apply checkpoint to determseed + for (int i = 0; i < SEED_LEN; i++) + determseed[i] += checkpoint[i]; + } + } + } #endif signal(SIGTERM,termhandler); @@ -572,6 +661,18 @@ int main(int argc,char **argv) perror("pthread_attr_destroy"); } +#if PASSPHRASE + pthread_t checkpoint_thread; + + if (checkpointfile) { + tret = pthread_create(&checkpoint_thread,NULL,checkpointworker,NULL); + if (tret) { + fprintf(stderr,"error while making checkpoint thread: %s\n",strerror(tret)); + exit(1); + } + } +#endif + #ifdef STATISTICS struct timespec nowtime; u64 istarttime,inowtime,ireporttime = 0,elapsedoffset = 0; @@ -581,6 +682,7 @@ int main(int argc,char **argv) } istarttime = (1000000 * (u64)nowtime.tv_sec) + ((u64)nowtime.tv_nsec / 1000); #endif + struct timespec ts; memset(&ts,0,sizeof(ts)); ts.tv_nsec = 100000000; @@ -616,7 +718,7 @@ int main(int argc,char **argv) VEC_BUF(tstats,i).numrestart += (u64)tdiff; sumrestart += VEC_BUF(tstats,i).numrestart; } - if (reportdelay && (!ireporttime || inowtime - ireporttime >= reportdelay)) { + if (reportdelay && (!ireporttime || (i64)(inowtime - ireporttime) >= (i64)reportdelay)) { if (ireporttime) ireporttime += reportdelay; else @@ -656,8 +758,16 @@ int main(int argc,char **argv) if (!quietflag) fprintf(stderr,"waiting for threads to finish..."); + for (size_t i = 0;i < VEC_LENGTH(threads);++i) pthread_join(VEC_BUF(threads,i),0); +#ifdef PASSPHRASE + if (checkpointfile) { + checkpointer_endwork = 1; + pthread_join(checkpoint_thread,0); + } +#endif + if (!quietflag) fprintf(stderr," done.\n");