// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using Moq;
using NuGet.Common;
using NuGet.Configuration;
using NuGet.Packaging.Core;
using NuGet.Protocol;
using NuGet.Protocol.Core.Types;
using NuGet.Protocol.Model;
using NuGet.Versioning;
using NuGet.VisualStudio.Internal.Contracts;
using Xunit;
using Xunit.Abstractions;

#nullable enable

namespace NuGet.PackageManagement.VisualStudio.Test
{
    public class PackageVulnerabilityServiceTests
    {
        TestNuGetUILogger _testLogger;
        ITestOutputHelper _testOutputHelper;

        public PackageVulnerabilityServiceTests(ITestOutputHelper testOutputHelper)
        {
            _testOutputHelper = testOutputHelper;
            _testLogger = new TestNuGetUILogger(_testOutputHelper);
        }

        [Fact]
        public async Task GetVulnerabilityInfoAsync_ValidPackageId_ReturnsVulnerabilityInfo()
        {
            // Arrange
            PackageSource source = new PackageSource("https://contoso.test/vulnerability/v3/index.json");
            List<IReadOnlyDictionary<string, IReadOnlyList<PackageVulnerabilityInfo>>> knownVulnerabilities = new List<IReadOnlyDictionary<string, IReadOnlyList<PackageVulnerabilityInfo>>>()
            {
                new Dictionary<string, IReadOnlyList<PackageVulnerabilityInfo>>
                {
                    {
                        "Package1",
                        new PackageVulnerabilityInfo[] {
                            new PackageVulnerabilityInfo(new Uri("https://vulnerability1"), PackageVulnerabilitySeverity.Low, VersionRange.Parse("[1.0.0,2.0.0)"))
                        }
                    }
                }
            };

            Dictionary<string, GetVulnerabilityInfoResult> vulnerabilityResults = new()
            {
                { source.Name, new GetVulnerabilityInfoResult(knownVulnerabilities, exceptions: null) }
            };

            var providers = new List<INuGetResourceProvider> { new VulnerabilityInfoResourceProvider(vulnerabilityResults) };

            var sourceRepository = new SourceRepository(source, providers);
            PackageVulnerabilityService packageVulnerabilityService = new PackageVulnerabilityService(new List<SourceRepository> { sourceRepository }, _testLogger);

            // Act
            List<PackageVulnerabilityMetadataContextInfo> vulnerabilities = await packageVulnerabilityService.GetVulnerabilityInfoAsync(new PackageIdentity("Package1", new NuGetVersion("1.0.0")), CancellationToken.None);

            // Assert
            Assert.Equal(1, vulnerabilities.Count);
        }

        [Fact]
        public async Task GetVulnerabilityInfoAsync_MultipleTimes_LoadsVulnerabilityDataOnce()
        {
            // Arrange
            PackageSource source = new PackageSource("https://contoso.test/vulnerability/v3/index.json");
            List<IReadOnlyDictionary<string, IReadOnlyList<PackageVulnerabilityInfo>>> knownVulnerabilities = new List<IReadOnlyDictionary<string, IReadOnlyList<PackageVulnerabilityInfo>>>()
            {
                new Dictionary<string, IReadOnlyList<PackageVulnerabilityInfo>>
                {
                    {
                        "Package1",
                        new PackageVulnerabilityInfo[] {
                            new PackageVulnerabilityInfo(new Uri("https://vulnerability1"), PackageVulnerabilitySeverity.Low, VersionRange.Parse("[1.0.0,2.0.0)"))
                        }
                    }
                }
            };

            GetVulnerabilityInfoResult getVulnerabilityInfoResult = new GetVulnerabilityInfoResult(knownVulnerabilities, exceptions: null);

            int getVulnerabilityCallCount = 0;

            var vulnerabilityResource = new Mock<IVulnerabilityInfoResource>();
            vulnerabilityResource.Setup(resource => resource.GetVulnerabilityInfoAsync(
                It.IsAny<SourceCacheContext>(),
                It.IsAny<ILogger>(),
                It.IsAny<CancellationToken>()))
                .Callback(() => getVulnerabilityCallCount++)
                .ReturnsAsync(getVulnerabilityInfoResult);

            var sourceRepository = new Mock<SourceRepository>();

            sourceRepository.Setup(s => s.GetResourceAsync<IVulnerabilityInfoResource>(CancellationToken.None))
                .ReturnsAsync(vulnerabilityResource.Object);

            PackageVulnerabilityService packageVulnerabilityService = new PackageVulnerabilityService(new List<SourceRepository> { sourceRepository.Object }, _testLogger);

            // Act & Assert
            List<PackageVulnerabilityMetadataContextInfo> vulnerabilities = await packageVulnerabilityService.GetVulnerabilityInfoAsync(new PackageIdentity("Package1", new NuGetVersion("1.0.0")), CancellationToken.None);
            Assert.Equal(1, vulnerabilities.Count);
            Assert.Equal(1, getVulnerabilityCallCount);
            vulnerabilities = await packageVulnerabilityService.GetVulnerabilityInfoAsync(new PackageIdentity("Package1", new NuGetVersion("1.0.0")), CancellationToken.None);
            Assert.Equal(1, vulnerabilities.Count);
            Assert.Equal(1, getVulnerabilityCallCount);
            vulnerabilities = await packageVulnerabilityService.GetVulnerabilityInfoAsync(new PackageIdentity("Package1", new NuGetVersion("1.0.0")), CancellationToken.None);
            Assert.Equal(1, vulnerabilities.Count);
            Assert.Equal(1, getVulnerabilityCallCount);
        }

        [Fact]
        public async Task GetVulnerabilityInfoAsync_ResetVulnerabilityData_ResetsVulnerabilityData()
        {
            // Arrange
            PackageSource source = new PackageSource("https://contoso.test/vulnerability/v3/index.json");
            List<IReadOnlyDictionary<string, IReadOnlyList<PackageVulnerabilityInfo>>> knownVulnerabilities = new List<IReadOnlyDictionary<string, IReadOnlyList<PackageVulnerabilityInfo>>>()
            {
                new Dictionary<string, IReadOnlyList<PackageVulnerabilityInfo>>
                {
                    {
                        "Package1",
                        new PackageVulnerabilityInfo[] {
                            new PackageVulnerabilityInfo(new Uri("https://vulnerability1"), PackageVulnerabilitySeverity.Low, VersionRange.Parse("[1.0.0,2.0.0)"))
                        }
                    }
                }
            };

            GetVulnerabilityInfoResult getVulnerabilityInfoResult = new GetVulnerabilityInfoResult(knownVulnerabilities, exceptions: null);

            int getVulnerabilityCallCount = 0;

            var vulnerabilityResource = new Mock<IVulnerabilityInfoResource>();
            vulnerabilityResource.Setup(resource => resource.GetVulnerabilityInfoAsync(
                It.IsAny<SourceCacheContext>(),
                It.IsAny<ILogger>(),
                It.IsAny<CancellationToken>()))
                .Callback(() => getVulnerabilityCallCount++)
                .ReturnsAsync(getVulnerabilityInfoResult);

            var sourceRepository = new Mock<SourceRepository>();

            sourceRepository.Setup(s => s.GetResourceAsync<IVulnerabilityInfoResource>(CancellationToken.None))
                .ReturnsAsync(vulnerabilityResource.Object);

            PackageVulnerabilityService packageVulnerabilityService = new PackageVulnerabilityService(new List<SourceRepository> { sourceRepository.Object }, _testLogger);

            // Act & Assert
            List<PackageVulnerabilityMetadataContextInfo> vulnerabilities = await packageVulnerabilityService.GetVulnerabilityInfoAsync(new PackageIdentity("Package1", new NuGetVersion("1.0.0")), CancellationToken.None);
            Assert.Equal(1, vulnerabilities.Count);
            Assert.Equal(1, getVulnerabilityCallCount);
            // Reset the vulnerability data
            packageVulnerabilityService.ResetVulnerabilityData();
            vulnerabilities = await packageVulnerabilityService.GetVulnerabilityInfoAsync(new PackageIdentity("Package1", new NuGetVersion("1.0.0")), CancellationToken.None);
            Assert.Equal(1, vulnerabilities.Count);
            Assert.Equal(2, getVulnerabilityCallCount);
            // Reset the vulnerability data again
            packageVulnerabilityService.ResetVulnerabilityData();
            vulnerabilities = await packageVulnerabilityService.GetVulnerabilityInfoAsync(new PackageIdentity("Package1", new NuGetVersion("1.0.0")), CancellationToken.None);
            Assert.Equal(1, vulnerabilities.Count);
            Assert.Equal(3, getVulnerabilityCallCount);
        }

        internal class VulnerabilityInfoResourceProvider : ResourceProvider
        {
            private readonly Dictionary<string, GetVulnerabilityInfoResult> _vulnerabilityInfoResults;

            public VulnerabilityInfoResourceProvider(Dictionary<string, GetVulnerabilityInfoResult> vulnerabilityInfoResults)
                : base(typeof(IVulnerabilityInfoResource), nameof(VulnerabilityInfoResourceProvider))
            {
                _vulnerabilityInfoResults = vulnerabilityInfoResults ?? throw new ArgumentNullException(nameof(vulnerabilityInfoResults));
            }

            public override Task<Tuple<bool, INuGetResource?>> TryCreate(SourceRepository source, CancellationToken token)
            {
                if (_vulnerabilityInfoResults.TryGetValue(source.PackageSource.Source, out GetVulnerabilityInfoResult? value))
                {
                    var resource = new VulnerabilityInfoResourceImplementation(source, value);
                    var result = new Tuple<bool, INuGetResource?>(true, resource);
                    return Task.FromResult(result);
                }
                return Task.FromResult(new Tuple<bool, INuGetResource?>(false, null));
            }
        }

        internal class VulnerabilityInfoResourceImplementation : IVulnerabilityInfoResource
        {
            internal GetVulnerabilityInfoResult Result { get; }
            internal SourceRepository SourceRepository { get; }
            public VulnerabilityInfoResourceImplementation(SourceRepository sourceRepository, GetVulnerabilityInfoResult result)
            {
                SourceRepository = sourceRepository;
                Result = result;
            }
            public Task<GetVulnerabilityInfoResult> GetVulnerabilityInfoAsync(SourceCacheContext cacheContext, ILogger logger, CancellationToken cancellationToken)
            {
                return Task.FromResult(Result);
            }
        }
    }
}
