From 031ddfcb7b173684de66c622370ffeff5438ac52 Mon Sep 17 00:00:00 2001
From: Giteabot <teabot@gitea.io>
Date: Wed, 14 Jun 2023 22:14:00 -0400
Subject: [PATCH] Fix index generation parallelly failure (#25235) (#25269)

Backport #25235 by @lunny

Fix #22109

Co-authored-by: Lunny Xiao <xiaolunwen@gmail.com>
Co-authored-by: silverwind <me@silverwind.io>
---
 models/db/index.go          | 29 +++++++++++++++++++++++++++++
 models/git/commit_status.go | 34 ++++++++++++++++++++++++++++++++++
 modules/git/sha1.go         |  8 ++++++++
 3 files changed, 71 insertions(+)

diff --git a/models/db/index.go b/models/db/index.go
index 259ddd6ade..29254b1f07 100644
--- a/models/db/index.go
+++ b/models/db/index.go
@@ -89,6 +89,33 @@ func mysqlGetNextResourceIndex(ctx context.Context, tableName string, groupID in
 	return idx, nil
 }
 
+func mssqlGetNextResourceIndex(ctx context.Context, tableName string, groupID int64) (int64, error) {
+	if _, err := GetEngine(ctx).Exec(fmt.Sprintf(`
+MERGE INTO %s WITH (HOLDLOCK) AS target
+USING (SELECT %d AS group_id) AS source
+(group_id)
+ON target.group_id = source.group_id
+WHEN MATCHED
+	THEN UPDATE
+			SET max_index = max_index + 1
+WHEN NOT MATCHED
+	THEN INSERT (group_id, max_index)
+			VALUES (%d, 1);
+`, tableName, groupID, groupID)); err != nil {
+		return 0, err
+	}
+
+	var idx int64
+	_, err := GetEngine(ctx).SQL(fmt.Sprintf("SELECT max_index FROM %s WHERE group_id = ?", tableName), groupID).Get(&idx)
+	if err != nil {
+		return 0, err
+	}
+	if idx == 0 {
+		return 0, errors.New("cannot get the correct index")
+	}
+	return idx, nil
+}
+
 // GetNextResourceIndex generates a resource index, it must run in the same transaction where the resource is created
 func GetNextResourceIndex(ctx context.Context, tableName string, groupID int64) (int64, error) {
 	switch {
@@ -96,6 +123,8 @@ func GetNextResourceIndex(ctx context.Context, tableName string, groupID int64)
 		return postgresGetNextResourceIndex(ctx, tableName, groupID)
 	case setting.Database.Type.IsMySQL():
 		return mysqlGetNextResourceIndex(ctx, tableName, groupID)
+	case setting.Database.Type.IsMSSQL():
+		return mssqlGetNextResourceIndex(ctx, tableName, groupID)
 	}
 
 	e := GetEngine(ctx)
diff --git a/models/git/commit_status.go b/models/git/commit_status.go
index a018bb0553..49143a87e8 100644
--- a/models/git/commit_status.go
+++ b/models/git/commit_status.go
@@ -83,13 +83,47 @@ func mysqlGetCommitStatusIndex(ctx context.Context, repoID int64, sha string) (i
 	return idx, nil
 }
 
+func mssqlGetCommitStatusIndex(ctx context.Context, repoID int64, sha string) (int64, error) {
+	if _, err := db.GetEngine(ctx).Exec(`
+MERGE INTO commit_status_index WITH (HOLDLOCK) AS target
+USING (SELECT ? AS repo_id, ? AS sha) AS source
+(repo_id, sha)
+ON target.repo_id = source.repo_id AND target.sha = source.sha
+WHEN MATCHED
+	THEN UPDATE
+			SET max_index = max_index + 1
+WHEN NOT MATCHED
+	THEN INSERT (repo_id, sha, max_index)
+			VALUES (?, ?, 1);
+`, repoID, sha, repoID, sha); err != nil {
+		return 0, err
+	}
+
+	var idx int64
+	_, err := db.GetEngine(ctx).SQL("SELECT max_index FROM `commit_status_index` WHERE repo_id = ? AND sha = ?",
+		repoID, sha).Get(&idx)
+	if err != nil {
+		return 0, err
+	}
+	if idx == 0 {
+		return 0, errors.New("cannot get the correct index")
+	}
+	return idx, nil
+}
+
 // GetNextCommitStatusIndex retried 3 times to generate a resource index
 func GetNextCommitStatusIndex(ctx context.Context, repoID int64, sha string) (int64, error) {
+	if !git.IsValidSHAPattern(sha) {
+		return 0, git.ErrInvalidSHA{SHA: sha}
+	}
+
 	switch {
 	case setting.Database.Type.IsPostgreSQL():
 		return postgresGetCommitStatusIndex(ctx, repoID, sha)
 	case setting.Database.Type.IsMySQL():
 		return mysqlGetCommitStatusIndex(ctx, repoID, sha)
+	case setting.Database.Type.IsMSSQL():
+		return mssqlGetCommitStatusIndex(ctx, repoID, sha)
 	}
 
 	e := db.GetEngine(ctx)
diff --git a/modules/git/sha1.go b/modules/git/sha1.go
index 4d69653e09..7d9d9776da 100644
--- a/modules/git/sha1.go
+++ b/modules/git/sha1.go
@@ -28,6 +28,14 @@ func IsValidSHAPattern(sha string) bool {
 	return shaPattern.MatchString(sha)
 }
 
+type ErrInvalidSHA struct {
+	SHA string
+}
+
+func (err ErrInvalidSHA) Error() string {
+	return fmt.Sprintf("invalid sha: %s", err.SHA)
+}
+
 // MustID always creates a new SHA1 from a [20]byte array with no validation of input.
 func MustID(b []byte) SHA1 {
 	var id SHA1