Mixed producer-consumer scenario in .NET
When I was programming a web crawler ([available on GitHub][https://github.com/meziantou/WebCrawler]), I have to set up a mixed producer-consumer. In a classic producer-consumer scenario, some threads produce the data, and others consume them. In this case, some threads consume the data (the URL of the page), crawl the page, and produce new URLs to explore. So, each thread can consume and produce data.
Let's see how to do this!
First, I choose to use a ConcurrentQueue<T>
to store the list of items to process. Indeed, the collection must be thread-safe as we can enqueue/dequeue data from many threads. But, you can use any collection from the System.Collections.Concurrent
namespace such as BlockingCollection<T>
.
var items = new ConcurrentQueue<string>();
items.Enqueue("https://www.meziantou.net"); // add the initial url to explore
Then, you need to create as many threads as needed to process the data. A common practice is to use the number of cores of the CPU (or x2 if hyper-threading is enabled). In my case, threads are often waiting for the network so I create more threads.
int maxDegreeOfParallelism = 8;
var tasks = new Task[maxDegreeOfParallelism];
int activeThreadsNumber = 0;
for (int i = 0; i < tasks.Length; i++)
{
tasks[i] = Task.Factory.StartNew(() => { /* TODO Consumer loop */ }));
}
Task.WaitAll(tasks);
Now, you have to create the consumer loop. The idea is to get an item from the queue, process it, and add the newly created items to the queue. While the queue has items, you repeat the process. When the queue is empty, this doesn't mean this is finished. One thread can still add new items in the queue. So, you have to wait until there are no more active threads.
while (true)
{
Interlocked.Increment(ref activeThreadsNumber);
while (collection.TryTake(out T item))
{
var nextItems = processItem(item);
foreach (var nextItem in nextItems)
{
collection.TryAdd(nextItem);
}
}
Interlocked.Decrement(ref activeThreadsNumber);
if (activeThreadsNumber == 0) //all tasks finished
return;
}
Now, you can wrap up the code in a generic method:
Task Process<T>(IProducerConsumerCollection<T> collection, Func<T, IEnumerable<T>> processItem, int maxDegreeOfParallelism, CancellationToken ct)
{
var tasks = new Task[maxDegreeOfParallelism];
int activeThreadsNumber = 0;
for (int i = 0; i < tasks.Length; i++)
{
tasks[i] = Task.Run(() =>
{
while (true)
{
Interlocked.Increment(ref activeThreadsNumber);
while (collection.TryTake(out T item))
{
var nextItems = processItem(item);
foreach (var nextItem in nextItems)
{
collection.TryAdd(nextItem);
}
}
Interlocked.Decrement(ref activeThreadsNumber);
if (activeThreadsNumber == 0) // all tasks finished
return;
}
}, ct);
}
return Task.WhenAll(tasks);
}
#Update: Using a Channel
using System.Diagnostics.CodeAnalysis;
using System.Threading.Channels;
public abstract class Processor<T>
{
private readonly Channel<T> _pendingItems = Channel.CreateUnbounded<T>();
protected abstract Task ProcessAsync(T item);
public async ValueTask EnqueueAsync(T item, CancellationToken cancellationToken = default)
{
await _pendingItems.Writer.WriteAsync(item, cancellationToken);
}
public async Task RunAsync(int degreeOfParallelism, CancellationToken cancellationToken = default)
{
var tasks = new List<Task>(degreeOfParallelism);
var remainingConcurrency = degreeOfParallelism;
while (await _pendingItems.Reader.WaitToReadAsync(cancellationToken))
{
while (TryGetItem(out var item))
{
// If we reach the maximum number of concurrent tasks, wait for one to finish
while (Volatile.Read(ref remainingConcurrency) < 0)
{
// The tasks collection can change while Task.WhenAny enumerates the collection
// so, we need to clone the collection to avoid issues
Task[]? clone = null;
lock (tasks)
{
if (tasks.Count > 0)
{
clone = tasks.ToArray();
}
}
if (clone != null)
{
await Task.WhenAny(clone);
}
}
var task = Task.Run(() => ProcessAsync(item), cancellationToken);
lock (tasks)
{
tasks.Add(task);
task.ContinueWith(task => OnTaskCompleted(task), cancellationToken);
}
}
bool TryGetItem([MaybeNullWhen(false)] out T result)
{
lock (tasks)
{
if (_pendingItems.Reader.TryRead(out var item))
{
remainingConcurrency--;
result = item;
return true;
}
result = default;
return false;
}
}
void OnTaskCompleted(Task completedTask)
{
lock (tasks)
{
if (!tasks.Remove(completedTask))
throw new Exception("An unexpected error occurred");
remainingConcurrency++;
// There is no active tasks, so we are sure we are at the end
if (degreeOfParallelism == remainingConcurrency && !_pendingItems.Reader.TryPeek(out _))
{
_pendingItems.Writer.Complete();
}
}
}
}
}
}
Do you have a question or a suggestion about this post? Contact me!