رفتن به نوشته‌ها

چگونه از tf.data در تنسورفلو استفاده کنیم؟

یکی از کارهایی که تقریبا تو همه تسک‌های ماشین لرنینگ و دیپ‌ لرنینگ انجام میشه خوراندن (feed) داده به مدل برای آموزش است. از اونجایی که با گسترش مدل‌های دیپ‌لرنیگ حجم دادگان هم افزایش پیدا کرده است، در نتیجه خوراندن بهینه (efficient) داده‌ها به مدل یکی از کارهایی مهمی است که انجام می‌شود. خوراندن باید به نحوی باشد که GPU تا حد امکان مشغول نگه داریم و نذاریم معطل کارهای مرتبط با IO باشد.

در تنسورفلو به طور کلی ۴ روش برای خوراندن داده به مدل داریم:

  • استفاده از placeholder و feed کردن به مدل در هر epoch
  • استفاده از QueueRunner که شما یک صف درست میکنید و به ابتدای گراف‌ تنسورفلو اضافه میکنید (روشی نسبتا پیچیده و دشوار)
  • استفاده از tf.data
  • استفاده از Preloaded data: تمام داده‌هاتون رو در قالب ثابت (constant) و متغیر (variable) نگه میدارید (این روش تنها برای دادگان کوچک عملی است)

۱) روش اول استفاده از place_holderها

تا چند ماه پیش من اغلب از همین روش برای خواندن داده‌ها و خوراندن (feed) برای آموزش مدل استفاده می‌کردم. در اکثر آموزش‌ها و کدهایی که تو جاهای مختلف هم میدیدم از این روش استفاده میکردند که این روش طبق گفته وبسایت رسمی تنسورفلو کندترین روش برای خواندن داده‌هاست:

و کاری که انجام میدادم به صورت زیر بود:

ابتدا placeholderها رو تعریف میکردم (مثلا در مثال بالا برای داده‌های mnist) بعد یکی یکی batchها رو با استفاده از یک کلاس به نام DataHelper میخوندم و به مدل میدادم و با کمک اون مدل رو آموزش میدادم. این روش بسیار ساده است ولی مشکلی که داره یه تیکه کدتون (در مثال بالا dataHelper) در پایتون پیاده‌سازی شده است و مابقی کد در تنسورفلو. و همین موضوع باعث میشه که گلوگاه (bottleneck) دیتای ورودی برای مدل اتون ایجاد بشه و حداکثر استفاده رو نتونید از GPU ببرید.

۲) روش دوم استفاده از QueueRunner

تا قبل از تنسورفلو ۱.۲ این روش به عنوان سریع‌ترین و البته روش پیشنهادی تنسورفلو برای خواندن و خوراندن داده‌ها پیشنهاد شده بود. ایده پشت این روش به همانطور که از اسمش مشخص است به صف (Queue) و چندنخی (MultiThreading) برمیگردد.
استفاده از QueueRunner از چند مرحله تشکیل شده است:

  • به دست آوردن لیست فایلها با استفاده از tf.train.match_filenames_once
  • ایجاد صفی از اسامی فایل‌ها با استفاده از tf.train.string_input_producer
  • انتخاب یک Reader مناسب برای خواندن فایل‌ها مثل tf.WholeFileReader یا tf.LineReader و..
  • یک دیکدر برای خواندن رکوردهایی که توسط Reader خوانده شده است مثل tf.io.decode_images یا tf.io.decode_raw
  • انجام پیش‌پردازش (درصورت داشتن پیش پردازش)
  • ایجاد صف مناسب برای نمونه‌های آموزش یا اصطلاحا Example Queue

همانطور که می‌بینید فرآیند بالا خیلی پیچیده است و برای کسایی که تازه شروع کردند و بخوانند همان ابتدا با مفاهیم صف و چندنخی درگیر بشوند اندکی دشوار است.

۳) روش سوم استفاده از tf.data

API دیتاست یا tf.data در تنسورفلو ۱.۴ به کد اصلی تنسورفلو اضافه شده است و در حال حاضر هم روی پکیج tf.contrib.data ویژگی‌های جالبتری در حال اضافه شدن به آن است. علاوه بر بهینه بودن این روش، کار کردن با اون بسیار راحت است و دیگه درگیر placeholder و feed کردن نمیشید و از همه مهمتر خوندن فایل شما مثل یک گره (node) به گراف تنسورفلو شما اضافه خواهد شد و تنسورفلو خودش خواندن رو بهینه میکنه تا GPU شما معطل خواندن داده نشه. این API در واقع ترکیبی از روش QueueRunner و Placeholder است.
این API امکان را فراهم می‌سازد تا فرمت‌های مختلف ورودی و خروجی را بخوانیم و پیش‌پردازش کنیم و تبدیل‌های مختلف مثل دسته‌بندی کنید، باکت‌بندی و … را مستقیما در خود تنسورفلو انجام دهید.

خوب بیایید با یک مثال ساده شروع کنیم. فرض کنید یک فایل متنی با نام file.txt و محتوای زیر داریم و میخواییم اون رو بخونیم:

 

این یک متن تصادفی است.

یادگیری عمیق جذاب است.

یادگیری عمیق زیر مجموعه یادگیری ماشین محسوب میشود.

خوب بیاید فایل متنی بالا رو با استفاده از tf.data بخوانیم و یکی یکی آنها را چاپ کنیم:

نتیجه که میگیرید به صورت زیر خواهد بود:

چی شد؟ اکسپشن خوردیم! قرار بود کارمون رو راحت کنه که !!! واضح است که اررور میگه TextLineDataset قابل پیمایش نیست (فعلا با مود eager کار نداریم). بیایید یکم بررسی اش کنیم:

اگر نوع TextLineDataset رو چاپ کنیم مشاهده میکنیم که:

که اگر دقت کنید TextLineDataset جز تنسورفلو است. پس باید به گراف تنسورفلو اضافه بشه و عملیات خوندن از فایل رو در گراف تنسورفلو انجام بده. یعنی همونجوری که میومدیم عملیات (ضرب، جمع و تفریق) تعریف میکردیم و به گراف تنسورفلو اضافه‌اش میکردیم اینجا هم باید همچین کاری انجام بدهیم! یه گره به ابتدای گراف محاسباتی تنسورفلو اضافه کنیم که کارهای خوندن از فایل رو برای ما انجام بدهد. خوب بیاید فایل بالا رو پیمایش کنیم.

برای اینکه بتونیم فایل رو پیمایش کنیم باید یک Iterator تعریف کنیم. در حال حاضر در تنسورفلو دو نوع Iterator در tf.data وجود دارد:

  • Dataset.make_one_shot_iterator
  • Dataset.make_initializable_iterator

اولی برای مواقعی است که تنها میخواهیم یک بار داده‌های ورودی را بخوانیم و دومی برای زمانی است که میخواهیم چندین بار روی آنها پیمایش کنیم که ما از مورد دوم اغلب استفاده میکنیم ولی مورد اول ساده‌تر است. اگر فایل بالا رو با استفاده از تابع اولی پیمایش کنیم خواهیم داشت:

خیلی راحته نه؟ دیگه نه میخواهیم placeholder تعریف کنیم نه feeddata! همچنین داریم از sess.run برای گرفتن دیتا استفاده میکنیم!!! یعنی اینکه دیتای ورودی یه جورایی بخشی از گراف شما شما شده است!

اگر print آخر رو کامنت کنید به اکسپشن خواهید خورد چون make_one_shot_iterator تنها یکبار مجاز به پیمایش داده‌ها هستیم. پس برای رفع این موضوع از make_initializable_iterator استفاده خواهیم کرد:

در تابع make_initializable_iterator هر موقع که پیمایش تمام شد باید از دوباره آن را initialize کنیم. خب همانطور که میبینید مشکل مون حل شد :).

خوب حالا بیایید در یک مثال واقعی ببینیم که چطور میشه ازش استفاده کرد مثلا فرض کنید میخواهیم یک دسته‌بندی کننده روی داده‌های Mnist درست کنیم و داده‌های ورودی رو با استفاده از tf.data به آن بدیم:

در ابتدا باید داده‌های Mnist رو بگیریم. برای اینکار من از دیتاست‌های keras استفاده میکنم که خیلی راحت میشه ازش استفاده کرد. همچین نوع داده‌ای mnist در کراس int هست که آنها را به float32 تبدیل میکنم. مورد آخر اینکه داده‌های کراس sparse هست و من برای راحتی آنها را به categorical تبدیل میکنم تا به به این ترتیب برچسب هر عکس به صورت one-hot تبدیل شود که هر کدام ۱ بود نشان میدهد آن برچسب مورد نظر است.

در مرحله بعد باید این داده‌ها رو به تنسورفلو و tf.data بدهیم. چون میخوام ارزیابی هم انجام بدم دوتا دیتاست درست کردم یکی برای تست و یکی برای آموزش:

همچنین یک سری پارامتر هم به دیتاست اضافه کردم که مثلا دادگان رو Shuffle کنه و اندازه batchها چقدر باشه (برای مشاهده سایر پارامترها مثل prefetch، repeat و… به اینجا سر بزنید)

اگر در این مرحله مقادیر X و Y رو چاپ کنیم، خواهیم داشت:

 

print(X,Y)
Tensor(“IteratorGetNext:0”, shape=(?, 784), dtype=float32)
Tensor(“IteratorGetNext:1”, shape=(?, 10), dtype=float32)

که همانطور که ملاحضه میکنید هر دو از جنس تنسور هستند.

در مرحله بعدی باید مدل تنسورفلو رو تعریف کنیم. در ساده‌ترین حالت یک مدل logistic regression تعریف میکنیم.

همانطور که دقت میکنید دیگه خبری از placeholder نیست و مستقیم داده‌های ورودی‌مون رو از iterator داریم میگیریم.

برای آموزش و تست مدل هم خواهیم داشت:

که در اینجا هم دیگه خبری از feeddata نیست :). همانطور که مشاهده میکنید برای اینکه چندبار پیمایش کنم هر دفعه که اکسپشن خوردم از دوباره initialize کردم و به طور کلی کدشو میتونیم اینجوری داشته باشیم:

اگر گراف بالا رو با استفاده از تنسوربورد نمایش بدید:

 

tensorboard –logdir=../graphs/logreg

گراف مدل به صورت زیر خواهد بود:

که همانطور که مشاهده میکنید Iteratorها هرکدام یک گره در گراف محاسباتی تنسورفلو هستند. چون در ابتدا ما دوتا Dataset مختلف (یکی برای آموزش و یکی برای آزمون) تعریف کردیم‌، دو گره در گراف شکل گرفته است.

۴) روش چهارم استفاده مستقیم از داده‌ها

روش آخر این است که مستقیم داده‌هامون رو در tf.constant و tf.Variable ذخیره کنیم و از آنها استفاده کنیم. این روش عملا برای دیتاست‌های بزرگ کارایی ندارد و نمیتوانیم مدل‌های نسبت متوسط و بزرگ را با آن آموزش دهیم. یک مثال کوچک که شاید خودتون هم آن را انجام داده‌ باشید این است که داده‌هامون رو مستقیم به tf.constant بریزیم:

خب همانطور که دیدیم چهار روش برای خواندن داده‌ها در تنسورفل وجود دارد که روش روش سوم یعنی tf.data روشی است که تنسورفلو پیشنهاد داده است. اگر دوس داشتید میتونید به سایت تنسورفلو سر بزنید و بیشتر راجع به tf.data و کاربردهاش مطالعه کنید. در قسمت بعد به سراغ این میرم که چطور میشه مدل‌های مختلف رو به صورت شی‌گرا (OOP) پیاده سازی کرد.

منابع بیشتر برای مطالعه:

https://www.tensorflow.org/guide/datasets

https://www.tensorflow.org/api_docs/python/tf/data/Iterator

https://github.com/tensorflow/docs/blob/master/site/en/api_guides/python/io_ops.md

https://github.com/tensorflow/docs/blob/master/site/en/api_guides/python/reading_data.md#standard_tensorflow_format

منتشر شده در تنسورفلو

اولین باشید که نظر می دهید

دیدگاهتان را بنویسید

نشانی ایمیل شما منتشر نخواهد شد. بخش‌های موردنیاز علامت‌گذاری شده‌اند *