import type { CredentialsParams, RepoDesignation } from "../types/public"; import { omit } from "../utils/omit"; import { toRepoId } from "../utils/toRepoId"; import { typedEntries } from "../utils/typedEntries"; import { downloadFile } from "./download-file"; import { fileExists } from "./file-exists"; import { promisesQueue } from "../utils/promisesQueue"; import type { SetRequired } from "../vendor/type-fest/set-required"; export const SAFETENSORS_FILE = "model.safetensors"; export const SAFETENSORS_INDEX_FILE = "model.safetensors.index.json"; /// We advise model/library authors to use the filenames above for convention inside model repos, /// but in some situations safetensors weights have different filenames. export const RE_SAFETENSORS_FILE = /\.safetensors$/; export const RE_SAFETENSORS_INDEX_FILE = /\.safetensors\.index\.json$/; export const RE_SAFETENSORS_SHARD_FILE = /^(?(?.*?)[_-])(?\d{5})-of-(?\d{5})\.safetensors$/; export interface SafetensorsShardFileInfo { prefix: string; basePrefix: string; shard: string; total: string; } export function parseSafetensorsShardFilename(filename: string): SafetensorsShardFileInfo | null { const match = RE_SAFETENSORS_SHARD_FILE.exec(filename); if (match && match.groups) { return { prefix: match.groups["prefix"], basePrefix: match.groups["basePrefix"], shard: match.groups["shard"], total: match.groups["total"], }; } return null; } const PARALLEL_DOWNLOADS = 20; const MAX_HEADER_LENGTH = 25_000_000; class SafetensorParseError extends Error {} type FileName = string; export type TensorName = string; export type Dtype = "F64" | "F32" | "F16" | "BF16" | "I64" | "I32" | "I16" | "I8" | "U8" | "BOOL"; export interface TensorInfo { dtype: Dtype; shape: number[]; data_offsets: [number, number]; } export type SafetensorsFileHeader = Record & { __metadata__: Record; }; export interface SafetensorsIndexJson { dtype?: string; /// ^there's sometimes a dtype but it looks inconsistent. metadata?: Record; /// ^ why the naming inconsistency? weight_map: Record; } export type SafetensorsShardedHeaders = Record; export type SafetensorsParseFromRepo = | { sharded: false; header: SafetensorsFileHeader; parameterCount?: Partial>; } | { sharded: true; index: SafetensorsIndexJson; headers: SafetensorsShardedHeaders; parameterCount?: Partial>; }; async function parseSingleFile( path: string, params: { repo: RepoDesignation; revision?: string; hubUrl?: string; /** * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers. */ fetch?: typeof fetch; } & Partial ): Promise { const blob = await downloadFile({ ...params, path }); if (!blob) { throw new SafetensorParseError(`Failed to parse file ${path}: failed to fetch safetensors header length.`); } const bufLengthOfHeaderLE = await blob.slice(0, 8).arrayBuffer(); const lengthOfHeader = new DataView(bufLengthOfHeaderLE).getBigUint64(0, true); // ^little-endian if (lengthOfHeader <= 0) { throw new SafetensorParseError(`Failed to parse file ${path}: safetensors header is malformed.`); } if (lengthOfHeader > MAX_HEADER_LENGTH) { throw new SafetensorParseError( `Failed to parse file ${path}: safetensor header is too big. Maximum supported size is ${MAX_HEADER_LENGTH} bytes.` ); } try { // no validation for now, we assume it's a valid FileHeader. const header: SafetensorsFileHeader = JSON.parse(await blob.slice(8, 8 + Number(lengthOfHeader)).text()); return header; } catch (err) { throw new SafetensorParseError(`Failed to parse file ${path}: safetensors header is not valid JSON.`); } } async function parseShardedIndex( path: string, params: { repo: RepoDesignation; revision?: string; hubUrl?: string; /** * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers. */ fetch?: typeof fetch; } & Partial ): Promise<{ index: SafetensorsIndexJson; headers: SafetensorsShardedHeaders }> { const indexBlob = await downloadFile({ ...params, path, }); if (!indexBlob) { throw new SafetensorParseError(`Failed to parse file ${path}: failed to fetch safetensors index.`); } // no validation for now, we assume it's a valid IndexJson. let index: SafetensorsIndexJson; try { index = JSON.parse(await indexBlob.slice(0, 10_000_000).text()); } catch (error) { throw new SafetensorParseError(`Failed to parse file ${path}: not a valid JSON.`); } const pathPrefix = path.slice(0, path.lastIndexOf("/") + 1); const filenames = [...new Set(Object.values(index.weight_map))]; const shardedMap: SafetensorsShardedHeaders = Object.fromEntries( await promisesQueue( filenames.map( (filename) => async () => [filename, await parseSingleFile(pathPrefix + filename, params)] satisfies [string, SafetensorsFileHeader] ), PARALLEL_DOWNLOADS ) ); return { index, headers: shardedMap }; } /** * Analyze model.safetensors.index.json or model.safetensors from a model hosted * on Hugging Face using smart range requests to extract its metadata. */ export async function parseSafetensorsMetadata( params: { /** Only models are supported */ repo: RepoDesignation; /** * Relative file path to safetensors file inside `repo`. Defaults to `SAFETENSORS_FILE` or `SAFETENSORS_INDEX_FILE` (whichever one exists). */ path?: string; /** * Will include SafetensorsParseFromRepo["parameterCount"], an object containing the number of parameters for each DType * * @default false */ computeParametersCount: true; hubUrl?: string; revision?: string; /** * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers. */ fetch?: typeof fetch; } & Partial ): Promise>; export async function parseSafetensorsMetadata( params: { /** Only models are supported */ repo: RepoDesignation; /** * Will include SafetensorsParseFromRepo["parameterCount"], an object containing the number of parameters for each DType * * @default false */ path?: string; computeParametersCount?: boolean; hubUrl?: string; revision?: string; /** * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers. */ fetch?: typeof fetch; } & Partial ): Promise; export async function parseSafetensorsMetadata( params: { repo: RepoDesignation; path?: string; computeParametersCount?: boolean; hubUrl?: string; revision?: string; /** * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers. */ fetch?: typeof fetch; } & Partial ): Promise { const repoId = toRepoId(params.repo); if (repoId.type !== "model") { throw new TypeError("Only model repos should contain safetensors files."); } if (RE_SAFETENSORS_FILE.test(params.path ?? "") || (await fileExists({ ...params, path: SAFETENSORS_FILE }))) { const header = await parseSingleFile(params.path ?? SAFETENSORS_FILE, params); return { sharded: false, header, ...(params.computeParametersCount && { parameterCount: computeNumOfParamsByDtypeSingleFile(header), }), }; } else if ( RE_SAFETENSORS_INDEX_FILE.test(params.path ?? "") || (await fileExists({ ...params, path: SAFETENSORS_INDEX_FILE })) ) { const { index, headers } = await parseShardedIndex(params.path ?? SAFETENSORS_INDEX_FILE, params); return { sharded: true, index, headers, ...(params.computeParametersCount && { parameterCount: computeNumOfParamsByDtypeSharded(headers), }), }; } else { throw new Error("model id does not seem to contain safetensors weights"); } } function computeNumOfParamsByDtypeSingleFile(header: SafetensorsFileHeader): Partial> { const counter: Partial> = {}; const tensors = omit(header, "__metadata__"); for (const [, v] of typedEntries(tensors)) { if (v.shape.length === 0) { continue; } counter[v.dtype] = (counter[v.dtype] ?? 0) + v.shape.reduce((a, b) => a * b); } return counter; } function computeNumOfParamsByDtypeSharded(shardedMap: SafetensorsShardedHeaders): Partial> { const counter: Partial> = {}; for (const header of Object.values(shardedMap)) { for (const [k, v] of typedEntries(computeNumOfParamsByDtypeSingleFile(header))) { counter[k] = (counter[k] ?? 0) + (v ?? 0); } } return counter; }